summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSvilen Mihaylov <svilen.mihaylov@mongodb.com>2022-01-31 21:05:27 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-01-31 21:48:46 +0000
commit50db8e9573e191ba2c193b4ef3dba6b5c6488f82 (patch)
tree1d211e40920b5952af569bb6e9fa7dd830d5bbaa
parentb696e034fe97e7699dd45ac2595422e1d510ba2c (diff)
downloadmongo-50db8e9573e191ba2c193b4ef3dba6b5c6488f82.tar.gz
SERVER-62434 Implement query optimizer based on Path algebra and Cascades
-rw-r--r--buildscripts/gdb/mongo_printers.py45
-rw-r--r--buildscripts/resmokeconfig/fully_disabled_feature_flags.yml1
-rw-r--r--buildscripts/resmokeconfig/suites/cqf.yml30
-rw-r--r--buildscripts/resmokeconfig/suites/cqf_parallel.yml31
-rw-r--r--etc/evergreen.yml22
-rw-r--r--jstests/cqf/array_index.js33
-rw-r--r--jstests/cqf/basic_agg.js42
-rw-r--r--jstests/cqf/basic_find.js42
-rw-r--r--jstests/cqf/basic_unwind.js25
-rw-r--r--jstests/cqf/chess.js107
-rw-r--r--jstests/cqf/empty_results.js18
-rw-r--r--jstests/cqf/filter_order.js22
-rw-r--r--jstests/cqf/find_sort.js41
-rw-r--r--jstests/cqf/group.js27
-rw-r--r--jstests/cqf/index_intersect.js47
-rw-r--r--jstests/cqf/index_intersect1.js35
-rw-r--r--jstests/cqf/no_collection.js15
-rw-r--r--jstests/cqf/nonselective_index.js30
-rw-r--r--jstests/cqf/object_elemMatch.js33
-rw-r--r--jstests/cqf/partial_index.js34
-rw-r--r--jstests/cqf/residual_pred_costing.js35
-rw-r--r--jstests/cqf/sampling.js31
-rw-r--r--jstests/cqf/selective_index.js34
-rw-r--r--jstests/cqf/sort.js22
-rw-r--r--jstests/cqf/sort_match.js33
-rw-r--r--jstests/cqf/sort_project.js74
-rw-r--r--jstests/cqf/type_bracket.js62
-rw-r--r--jstests/cqf/type_predicate.js26
-rw-r--r--jstests/cqf/unionWith.js54
-rw-r--r--jstests/cqf/value_elemMatch.js52
-rw-r--r--jstests/cqf_parallel/basic_exchange.js22
-rw-r--r--jstests/cqf_parallel/groupby.js37
-rw-r--r--jstests/cqf_parallel/index.js25
-rw-r--r--jstests/libs/optimizer_utils.js8
-rw-r--r--src/mongo/db/commands/SConscript6
-rw-r--r--src/mongo/db/commands/cqf/cqf_aggregate.cpp431
-rw-r--r--src/mongo/db/commands/cqf/cqf_aggregate.h44
-rw-r--r--src/mongo/db/commands/find_cmd.cpp15
-rw-r--r--src/mongo/db/commands/run_aggregate.cpp144
-rw-r--r--src/mongo/db/exec/sbe/SConscript29
-rw-r--r--src/mongo/db/exec/sbe/abt/abt_lower.cpp1014
-rw-r--r--src/mongo/db/exec/sbe/abt/abt_lower.h204
-rw-r--r--src/mongo/db/exec/sbe/abt/sbe_abt_diff_test.cpp369
-rw-r--r--src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp311
-rw-r--r--src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp63
-rw-r--r--src/mongo/db/exec/sbe/abt/sbe_abt_test_util.h48
-rw-r--r--src/mongo/db/exec/sbe/stages/exchange.cpp16
-rw-r--r--src/mongo/db/exec/sbe/stages/scan.cpp50
-rw-r--r--src/mongo/db/exec/sbe/stages/scan.h13
-rw-r--r--src/mongo/db/exec/sbe/values/bson.h6
-rw-r--r--src/mongo/db/exec/sbe/vm/vm.cpp4
-rw-r--r--src/mongo/db/exec/sbe_cmd.cpp1
-rw-r--r--src/mongo/db/matcher/match_expression_walker.h12
-rw-r--r--src/mongo/db/pipeline/SConscript10
-rw-r--r--src/mongo/db/pipeline/abt/abt_document_source_visitor.cpp878
-rw-r--r--src/mongo/db/pipeline/abt/abt_document_source_visitor.h44
-rw-r--r--src/mongo/db/pipeline/abt/agg_expression_visitor.cpp840
-rw-r--r--src/mongo/db/pipeline/abt/agg_expression_visitor.h42
-rw-r--r--src/mongo/db/pipeline/abt/expr_algebrizer_context.cpp73
-rw-r--r--src/mongo/db/pipeline/abt/expr_algebrizer_context.h70
-rw-r--r--src/mongo/db/pipeline/abt/match_expression_visitor.cpp479
-rw-r--r--src/mongo/db/pipeline/abt/match_expression_visitor.h45
-rw-r--r--src/mongo/db/pipeline/abt/pipeline_test.cpp2508
-rw-r--r--src/mongo/db/pipeline/abt/utils.cpp104
-rw-r--r--src/mongo/db/pipeline/abt/utils.h54
-rw-r--r--src/mongo/db/pipeline/document_source_union_with.h4
-rw-r--r--src/mongo/db/pipeline/visitors/document_source_visitor.h135
-rw-r--r--src/mongo/db/pipeline/visitors/document_source_walker.cpp144
-rw-r--r--src/mongo/db/pipeline/visitors/document_source_walker.h58
-rw-r--r--src/mongo/db/pipeline/visitors/transformer_interface_visitor.h66
-rw-r--r--src/mongo/db/pipeline/visitors/transformer_interface_walker.cpp72
-rw-r--r--src/mongo/db/pipeline/visitors/transformer_interface_walker.h48
-rw-r--r--src/mongo/db/query/SConscript2
-rw-r--r--src/mongo/db/query/ce/SConscript16
-rw-r--r--src/mongo/db/query/ce/ce_sampling.cpp290
-rw-r--r--src/mongo/db/query/ce/ce_sampling.h52
-rw-r--r--src/mongo/db/query/get_executor.cpp1
-rw-r--r--src/mongo/db/query/optimizer/SConscript76
-rw-r--r--src/mongo/db/query/optimizer/algebra/SConscript15
-rw-r--r--src/mongo/db/query/optimizer/algebra/algebra_test.cpp570
-rw-r--r--src/mongo/db/query/optimizer/algebra/operator.h341
-rw-r--r--src/mongo/db/query/optimizer/algebra/polyvalue.h541
-rw-r--r--src/mongo/db/query/optimizer/bool_expression.h140
-rw-r--r--src/mongo/db/query/optimizer/cascades/ce_heuristic.cpp192
-rw-r--r--src/mongo/db/query/optimizer/cascades/ce_heuristic.h48
-rw-r--r--src/mongo/db/query/optimizer/cascades/ce_hinted.cpp96
-rw-r--r--src/mongo/db/query/optimizer/cascades/ce_hinted.h56
-rw-r--r--src/mongo/db/query/optimizer/cascades/cost_derivation.cpp429
-rw-r--r--src/mongo/db/query/optimizer/cascades/cost_derivation.h49
-rw-r--r--src/mongo/db/query/optimizer/cascades/enforcers.cpp269
-rw-r--r--src/mongo/db/query/optimizer/cascades/enforcers.h47
-rw-r--r--src/mongo/db/query/optimizer/cascades/implementers.cpp1441
-rw-r--r--src/mongo/db/query/optimizer/cascades/implementers.h50
-rw-r--r--src/mongo/db/query/optimizer/cascades/interfaces.h81
-rw-r--r--src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp494
-rw-r--r--src/mongo/db/query/optimizer/cascades/logical_props_derivation.h52
-rw-r--r--src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp1288
-rw-r--r--src/mongo/db/query/optimizer/cascades/logical_rewriter.h132
-rw-r--r--src/mongo/db/query/optimizer/cascades/logical_rewriter_rules.h89
-rw-r--r--src/mongo/db/query/optimizer/cascades/memo.cpp794
-rw-r--r--src/mongo/db/query/optimizer/cascades/memo.h248
-rw-r--r--src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp407
-rw-r--r--src/mongo/db/query/optimizer/cascades/physical_rewriter.h93
-rw-r--r--src/mongo/db/query/optimizer/cascades/rewrite_queues.cpp76
-rw-r--r--src/mongo/db/query/optimizer/cascades/rewrite_queues.h149
-rw-r--r--src/mongo/db/query/optimizer/containers.h81
-rw-r--r--src/mongo/db/query/optimizer/defs.cpp257
-rw-r--r--src/mongo/db/query/optimizer/defs.h254
-rw-r--r--src/mongo/db/query/optimizer/explain.cpp2411
-rw-r--r--src/mongo/db/query/optimizer/explain.h113
-rw-r--r--src/mongo/db/query/optimizer/explain_interface.h48
-rw-r--r--src/mongo/db/query/optimizer/index_bounds.cpp246
-rw-r--r--src/mongo/db/query/optimizer/index_bounds.h206
-rw-r--r--src/mongo/db/query/optimizer/interval_intersection_test.cpp667
-rw-r--r--src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp1495
-rw-r--r--src/mongo/db/query/optimizer/metadata.cpp161
-rw-r--r--src/mongo/db/query/optimizer/metadata.h160
-rw-r--r--src/mongo/db/query/optimizer/node.cpp775
-rw-r--r--src/mongo/db/query/optimizer/node.h862
-rw-r--r--src/mongo/db/query/optimizer/node_defs.h66
-rw-r--r--src/mongo/db/query/optimizer/opt_phase_manager.cpp332
-rw-r--r--src/mongo/db/query/optimizer/opt_phase_manager.h181
-rw-r--r--src/mongo/db/query/optimizer/optimizer_test.cpp639
-rw-r--r--src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp4659
-rw-r--r--src/mongo/db/query/optimizer/props.cpp373
-rw-r--r--src/mongo/db/query/optimizer/props.h481
-rw-r--r--src/mongo/db/query/optimizer/reference_tracker.cpp689
-rw-r--r--src/mongo/db/query/optimizer/reference_tracker.h113
-rw-r--r--src/mongo/db/query/optimizer/rewrites/const_eval.cpp565
-rw-r--r--src/mongo/db/query/optimizer/rewrites/const_eval.h121
-rw-r--r--src/mongo/db/query/optimizer/rewrites/path.cpp346
-rw-r--r--src/mongo/db/query/optimizer/rewrites/path.h94
-rw-r--r--src/mongo/db/query/optimizer/rewrites/path_lower.cpp400
-rw-r--r--src/mongo/db/query/optimizer/rewrites/path_lower.h158
-rw-r--r--src/mongo/db/query/optimizer/rewrites/path_optimizer_test.cpp881
-rw-r--r--src/mongo/db/query/optimizer/syntax/expr.cpp128
-rw-r--r--src/mongo/db/query/optimizer/syntax/expr.h350
-rw-r--r--src/mongo/db/query/optimizer/syntax/path.h349
-rw-r--r--src/mongo/db/query/optimizer/syntax/syntax.h263
-rw-r--r--src/mongo/db/query/optimizer/syntax/syntax_fwd_declare.h101
-rw-r--r--src/mongo/db/query/optimizer/utils/abt_hash.cpp473
-rw-r--r--src/mongo/db/query/optimizer/utils/abt_hash.h45
-rw-r--r--src/mongo/db/query/optimizer/utils/interval_utils.cpp330
-rw-r--r--src/mongo/db/query/optimizer/utils/interval_utils.h68
-rw-r--r--src/mongo/db/query/optimizer/utils/memo_utils.cpp176
-rw-r--r--src/mongo/db/query/optimizer/utils/memo_utils.h75
-rw-r--r--src/mongo/db/query/optimizer/utils/printable_enum.h48
-rw-r--r--src/mongo/db/query/optimizer/utils/unit_test_utils.cpp157
-rw-r--r--src/mongo/db/query/optimizer/utils/unit_test_utils.h99
-rw-r--r--src/mongo/db/query/optimizer/utils/utils.cpp1583
-rw-r--r--src/mongo/db/query/optimizer/utils/utils.h317
-rw-r--r--src/mongo/db/query/plan_executor_factory.cpp3
-rw-r--r--src/mongo/db/query/plan_executor_factory.h3
-rw-r--r--src/mongo/db/query/plan_executor_sbe.cpp5
-rw-r--r--src/mongo/db/query/plan_executor_sbe.h2
-rw-r--r--src/mongo/db/query/plan_explainer_factory.cpp22
-rw-r--r--src/mongo/db/query/plan_explainer_factory.h3
-rw-r--r--src/mongo/db/query/plan_explainer_sbe.cpp40
-rw-r--r--src/mongo/db/query/plan_explainer_sbe.h7
-rw-r--r--src/mongo/db/query/query_feature_flags.idl7
-rw-r--r--src/mongo/db/query/query_knobs.idl63
-rw-r--r--src/mongo/db/query/sbe_cached_solution_planner.cpp1
162 files changed, 40729 insertions, 96 deletions
diff --git a/buildscripts/gdb/mongo_printers.py b/buildscripts/gdb/mongo_printers.py
index 83289b33ffe..07ebc47bd64 100644
--- a/buildscripts/gdb/mongo_printers.py
+++ b/buildscripts/gdb/mongo_printers.py
@@ -815,6 +815,45 @@ class SbeCodeFragmentPrinter(object):
instr_count if not error else '? (successfully parsed {})'.format(instr_count)
+def eval_print_fn(val, print_fn_name):
+ """Evaluate a print function, and return the resulting string."""
+ print_fn_symbol = gdb.lookup_symbol(print_fn_name)[0]
+ print_fn = print_fn_symbol.value()
+ # The generated output from explain contains the string "\n" (two characters)
+ # replace them with a single EOL character so that GDB prints multi-line
+ # explains nicely.
+ pp_result = print_fn(val)
+ pp_str = str(pp_result).replace("\"", "").replace("\\n", "\n")
+ return pp_str
+
+
+class ABTPrinter(object):
+ """Pretty-printer for mongo::optimizer::ABT."""
+
+ def __init__(self, val):
+ """Initialize ABTPrinter."""
+ self.val = val
+
+ @staticmethod
+ def display_hint():
+ """Display hint."""
+ return 'ABT'
+
+ def to_string(self):
+ """Return ABT for printing."""
+ # Python will truncate/compress certain strings that contain many repeated characters.
+ # For an ABT, this is quite common when indenting nodes to represent children, so
+ # disable it for now.
+ prior_repeats = gdb.parameter("print repeats")
+ prior_elements = gdb.parameter("print elements")
+ gdb.execute("set print repeats 0") # for "<repeats N times>"
+ gdb.execute("set print elements 0") # for ... on long strings
+ res = eval_print_fn(self.val, "_printNode")
+ gdb.execute("set print repeats " + str(prior_repeats))
+ gdb.execute("set print elements " + str(prior_elements))
+ return res
+
+
def build_pretty_printer():
"""Build a pretty printer."""
pp = MongoPrettyPrinterCollection()
@@ -836,6 +875,12 @@ def build_pretty_printer():
pp.add('__wt_txn', '__wt_txn', False, WtTxnPrinter)
pp.add('__wt_update', '__wt_update', False, WtUpdateToBsonPrinter)
pp.add('CodeFragment', 'mongo::sbe::vm::CodeFragment', False, SbeCodeFragmentPrinter)
+
+ # TODO: enable with SERVER-62044.
+ # Optimizer/ABT related pretty printers that can be used only with a running process.
+ # abt_type = gdb.lookup_type("mongo::optimizer::ABT").strip_typedefs()
+ # pp.add("ABT", abt_type.name, True, ABTPrinter)
+
return pp
diff --git a/buildscripts/resmokeconfig/fully_disabled_feature_flags.yml b/buildscripts/resmokeconfig/fully_disabled_feature_flags.yml
index 59025aba509..c08125ad738 100644
--- a/buildscripts/resmokeconfig/fully_disabled_feature_flags.yml
+++ b/buildscripts/resmokeconfig/fully_disabled_feature_flags.yml
@@ -4,6 +4,7 @@
# by modifying their respective definitions in evergreen.yml.
- featureFlagFryer
+- featureFlagCommonQueryFramework
# Disable featureFlagRequireTenantID until all paths pass tenant id to TenantNamespace
# and TenantDatabase constructors.
- featureFlagRequireTenantID
diff --git a/buildscripts/resmokeconfig/suites/cqf.yml b/buildscripts/resmokeconfig/suites/cqf.yml
new file mode 100644
index 00000000000..5c4415228b7
--- /dev/null
+++ b/buildscripts/resmokeconfig/suites/cqf.yml
@@ -0,0 +1,30 @@
+test_kind: js_test
+
+selector:
+ roots:
+ - jstests/cqf/**/*.js
+
+executor:
+ archive:
+ hooks:
+ - ValidateCollections
+ config:
+ shell_options:
+ crashOnInvalidBSONError: ""
+ objcheck: ""
+ eval: load("jstests/libs/override_methods/detect_spawning_own_mongod.js");
+ hooks:
+ - class: ValidateCollections
+ shell_options:
+ global_vars:
+ TestData:
+ skipValidationOnNamespaceNotFound: false
+ - class: CleanEveryN
+ n: 20
+ fixture:
+ class: MongoDFixture
+ mongod_options:
+ set_parameters:
+ enableTestCommands: 1
+ featureFlagCommonQueryFramework: true
+ internalQueryEnableCascadesOptimizer: true
diff --git a/buildscripts/resmokeconfig/suites/cqf_parallel.yml b/buildscripts/resmokeconfig/suites/cqf_parallel.yml
new file mode 100644
index 00000000000..57d55f023a3
--- /dev/null
+++ b/buildscripts/resmokeconfig/suites/cqf_parallel.yml
@@ -0,0 +1,31 @@
+test_kind: js_test
+
+selector:
+ roots:
+ - jstests/cqf_parallel/**/*.js
+
+executor:
+ archive:
+ hooks:
+ - ValidateCollections
+ config:
+ shell_options:
+ crashOnInvalidBSONError: ""
+ objcheck: ""
+ eval: load("jstests/libs/override_methods/detect_spawning_own_mongod.js");
+ hooks:
+ - class: ValidateCollections
+ shell_options:
+ global_vars:
+ TestData:
+ skipValidationOnNamespaceNotFound: false
+ - class: CleanEveryN
+ n: 20
+ fixture:
+ class: MongoDFixture
+ mongod_options:
+ set_parameters:
+ enableTestCommands: 1
+ featureFlagCommonQueryFramework: true
+ internalQueryEnableCascadesOptimizer: true
+ internalQueryDefaultDOP: 5
diff --git a/etc/evergreen.yml b/etc/evergreen.yml
index 10557a651a7..48b467f05bb 100644
--- a/etc/evergreen.yml
+++ b/etc/evergreen.yml
@@ -7081,6 +7081,20 @@ tasks:
vars:
resmoke_jobs_max: 1
+- <<: *task_template
+ name: cqf
+ tags: []
+ commands:
+ - func: "do setup"
+ - func: "run tests"
+
+- <<: *task_template
+ name: cqf_parallel
+ tags: []
+ commands:
+ - func: "do setup"
+ - func: "run tests"
+
- name: shared_scons_cache_pruning
tags: []
exec_timeout_secs: 7200 # 2 hour timeout for the task overall
@@ -8918,6 +8932,8 @@ buildvariants:
--excludeWithAnyTags=incompatible_with_windows_tls
--excludeWithAnyTags=incompatible_with_shard_merge
tasks:
+ - name: cqf
+ - name: cqf_parallel
- name: compile_and_archive_dist_test_then_package_TG
distros:
- windows-vsCurrent-xlarge
@@ -9790,6 +9806,8 @@ buildvariants:
--runAllFeatureFlagTests
--excludeWithAnyTags=incompatible_with_shard_merge
tasks: &enterprise-rhel-80-64-bit-dynamic-all-feature-flags-tasks
+ - name: cqf
+ - name: cqf_parallel
- name: compile_test_and_package_parallel_core_stream_TG
distros:
- rhel80-large
@@ -11614,6 +11632,8 @@ buildvariants:
--excludeWithAnyTags=incompatible_with_shard_merge
separate_debug: off
tasks:
+ - name: cqf
+ - name: cqf_parallel
- name: compile_and_archive_dist_test_then_package_TG
- name: compile_benchmarks
- name: build_variant_gen
@@ -11856,6 +11876,8 @@ buildvariants:
scons_cache_scope: shared
separate_debug: off
tasks:
+ - name: cqf
+ - name: cqf_parallel
- name: compile_and_archive_dist_test_then_package_TG
- name: compile_benchmarks
- name: build_variant_gen
diff --git a/jstests/cqf/array_index.js b/jstests/cqf/array_index.js
new file mode 100644
index 00000000000..5a40a3040fb
--- /dev/null
+++ b/jstests/cqf/array_index.js
@@ -0,0 +1,33 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_array_index;
+t.drop();
+
+assert.commandWorked(t.insert({a: [1, 2, 3, 4]}));
+assert.commandWorked(t.insert({a: [2, 3, 4]}));
+assert.commandWorked(t.insert({a: [2]}));
+assert.commandWorked(t.insert({a: 2}));
+assert.commandWorked(t.insert({a: [1, 3]}));
+
+// Generate enough documents for index to be preferable.
+for (let i = 0; i < 100; i++) {
+ assert.commandWorked(t.insert({a: i + 10}));
+}
+
+assert.commandWorked(t.createIndex({a: 1}));
+
+let res = t.explain("executionStats").aggregate([{$match: {a: 2}}]);
+assert.eq(4, res.executionStats.nReturned);
+assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType);
+
+res = t.explain("executionStats").aggregate([{$match: {a: {$lt: 2}}}]);
+assert.eq(2, res.executionStats.nReturned);
+assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.child.nodeType);
+}());
diff --git a/jstests/cqf/basic_agg.js b/jstests/cqf/basic_agg.js
new file mode 100644
index 00000000000..3165b4403d0
--- /dev/null
+++ b/jstests/cqf/basic_agg.js
@@ -0,0 +1,42 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const coll = db.cqf_basic_index;
+coll.drop();
+
+assert.commandWorked(
+ coll.insert([{a: {b: 1}}, {a: {b: 2}}, {a: {b: 3}}, {a: {b: 4}}, {a: {b: 5}}]));
+
+const extraDocCount = 50;
+// Add extra docs to make sure indexes can be picked.
+for (let i = 0; i < extraDocCount; i++) {
+ assert.commandWorked(coll.insert({a: {b: i + 10}}));
+}
+assert.commandWorked(coll.createIndex({'a.b': 1}));
+
+let res = coll.explain("executionStats").aggregate([{$match: {'a.b': 2}}]);
+assert.eq(1, res.executionStats.nReturned);
+assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType);
+
+res = coll.explain("executionStats").aggregate([{$match: {'a.b': {$gt: 2}}}]);
+assert.eq(3 + extraDocCount, res.executionStats.nReturned);
+assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+
+res = coll.explain("executionStats").aggregate([{$match: {'a.b': {$gte: 2}}}]);
+assert.eq(4 + extraDocCount, res.executionStats.nReturned);
+assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+
+res = coll.explain("executionStats").aggregate([{$match: {'a.b': {$lt: 2}}}]);
+assert.eq(1, res.executionStats.nReturned);
+assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType);
+
+res = coll.explain("executionStats").aggregate([{$match: {'a.b': {$lte: 2}}}]);
+assert.eq(2, res.executionStats.nReturned);
+assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType);
+}());
diff --git a/jstests/cqf/basic_find.js b/jstests/cqf/basic_find.js
new file mode 100644
index 00000000000..37abd4b5eaa
--- /dev/null
+++ b/jstests/cqf/basic_find.js
@@ -0,0 +1,42 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const coll = db.cqf_basic_find;
+coll.drop();
+
+assert.commandWorked(
+ coll.insert([{a: {b: 1}}, {a: {b: 2}}, {a: {b: 3}}, {a: {b: 4}}, {a: {b: 5}}]));
+
+const extraDocCount = 50;
+// Add extra docs to make sure indexes can be picked.
+for (let i = 0; i < extraDocCount; i++) {
+ assert.commandWorked(coll.insert({a: {b: i + 10}}));
+}
+assert.commandWorked(coll.createIndex({'a.b': 1}));
+
+let res = coll.explain("executionStats").find({'a.b': 2}).finish();
+assert.eq(1, res.executionStats.nReturned);
+assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType);
+
+res = coll.explain("executionStats").find({'a.b': {$gt: 2}}).finish();
+assert.eq(3 + extraDocCount, res.executionStats.nReturned);
+assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+
+res = coll.explain("executionStats").find({'a.b': {$gte: 2}}).finish();
+assert.eq(4 + extraDocCount, res.executionStats.nReturned);
+assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+
+res = coll.explain("executionStats").find({'a.b': {$lt: 2}}).finish();
+assert.eq(1, res.executionStats.nReturned);
+assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType);
+
+res = coll.explain("executionStats").find({'a.b': {$lte: 2}}).finish();
+assert.eq(2, res.executionStats.nReturned);
+assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType);
+}());
diff --git a/jstests/cqf/basic_unwind.js b/jstests/cqf/basic_unwind.js
new file mode 100644
index 00000000000..c0faa5c0e0d
--- /dev/null
+++ b/jstests/cqf/basic_unwind.js
@@ -0,0 +1,25 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const coll = db.cqf_basic_unwind;
+coll.drop();
+
+assert.commandWorked(coll.insert([
+ {_id: 1},
+ {_id: 2, x: null},
+ {_id: 3, x: []},
+ {_id: 4, x: [1, 2]},
+ {_id: 5, x: [3]},
+ {_id: 6, x: 4}
+]));
+
+let res = coll.explain("executionStats").aggregate([{$unwind: '$x'}]);
+assert.eq(4, res.executionStats.nReturned);
+assert.eq("Unwind", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+}());
diff --git a/jstests/cqf/chess.js b/jstests/cqf/chess.js
new file mode 100644
index 00000000000..30124ea99b8
--- /dev/null
+++ b/jstests/cqf/chess.js
@@ -0,0 +1,107 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const coll = db.cqf_chess;
+Random.srand(0);
+
+const players = [
+ "penguingim1", "aladdin65", "aleksey472", "azuaga", "benpig",
+ "blackboarder", "bockosrb555", "bogdan_low_player", "charlytb", "chchbbuur",
+ "chessexplained", "cmcookiemonster", "crptone", "cselhu3", "darkzam",
+ "dmitri31", "dorado99", "ericrosen", "fast-tsunami", "flaneur"
+];
+const sources = [1, 2, 3, 4, 5, 6, 7, 8];
+const variants = [1, 2, 3, 4, 5, 6, 7, 8];
+const results = [1, 2, 3, 4, 5, 6, 7, 8];
+const winColor = [true, false, null];
+
+const nbGames = 1000; // 1000 * 1000;
+
+function intRandom(max) {
+ return Random.randInt(max);
+}
+function anyOf(as) {
+ return as [intRandom(as.length)];
+}
+
+coll.drop();
+
+print(`Adding ${nbGames} games`);
+const bulk = coll.initializeUnorderedBulkOp();
+for (let i = 0; i < nbGames; i++) {
+ const users = [anyOf(players), anyOf(players)];
+ const winnerIndex = intRandom(2);
+ bulk.insert({
+ users: users,
+ winner: users[winnerIndex],
+ loser: users[1 - winnerIndex],
+ winColor: anyOf(winColor),
+ avgRating: NumberInt(600 + intRandom(2400)),
+ source: NumberInt(anyOf(sources)),
+ variants: NumberInt(anyOf(variants)),
+ mode: !!intRandom(2),
+ turns: NumberInt(1 + intRandom(300)),
+ minutes: NumberInt(30 + intRandom(3600 * 3)),
+ clock: {init: NumberInt(0 + intRandom(10800)), inc: NumberInt(0 + intRandom(180))},
+ result: anyOf(results),
+ date: new Date(Date.now() - intRandom(118719488)),
+ analysed: !!intRandom(2)
+ });
+ if (i % 1000 == 0) {
+ print(`${i} / ${nbGames}`);
+ }
+}
+assert.commandWorked(bulk.execute());
+
+const indexes = [
+ {users: 1},
+ {winner: 1},
+ {loser: 1},
+ {winColor: 1},
+ {avgRating: 1},
+ {source: 1},
+ {variants: 1},
+ {mode: 1},
+ {turns: 1},
+ {minutes: 1},
+ {'clock.init': 1},
+ {'clock.inc': 1},
+ {result: 1},
+ {date: 1},
+ {analysed: 1}
+];
+
+print("Adding indexes");
+indexes.forEach(index => {
+ printjson(index);
+ coll.createIndex(index);
+});
+
+print("Searching");
+
+const res = coll.explain("executionStats").aggregate([
+ {
+ $match: {
+ avgRating: {$gt: 1000},
+ turns: {$lt: 250},
+ 'clock.init': {$gt: 1},
+ minutes: {$gt: 2, $lt: 150}
+ }
+ },
+ {$sort: {date: -1}},
+ {$limit: 20}
+]);
+
+// TODO: verify expected results.
+
+// Verify we are using the index on "minutes".
+const indexNode = res.queryPlanner.winningPlan.optimizerPlan.child.child.leftChild;
+assert.eq("IndexScan", indexNode.nodeType);
+assert.eq("minutes_1", indexNode.indexDefName);
+}());
diff --git a/jstests/cqf/empty_results.js b/jstests/cqf/empty_results.js
new file mode 100644
index 00000000000..5eed556189c
--- /dev/null
+++ b/jstests/cqf/empty_results.js
@@ -0,0 +1,18 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_empty_results;
+t.drop();
+
+assert.commandWorked(t.insert([{a: 1}, {a: 2}]));
+
+const res = t.explain("executionStats").aggregate([{$match: {'a': 2}}, {$limit: 1}, {$skip: 10}]);
+assert.eq(0, res.executionStats.nReturned);
+assert.eq("CoScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.child.nodeType);
+}());
diff --git a/jstests/cqf/filter_order.js b/jstests/cqf/filter_order.js
new file mode 100644
index 00000000000..e33e45c661e
--- /dev/null
+++ b/jstests/cqf/filter_order.js
@@ -0,0 +1,22 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const coll = db.cqf_filter_order;
+coll.drop();
+
+const bulk = coll.initializeUnorderedBulkOp();
+for (let i = 0; i < 10000; i++) {
+ // "a" has the most ones, then "b", then "c".
+ bulk.insert({a: (i % 2), b: (i % 3), c: (i % 4)});
+}
+assert.commandWorked(bulk.execute());
+
+let res = coll.aggregate([{$match: {'a': {$eq: 1}, 'b': {$eq: 1}, 'c': {$eq: 1}}}]).toArray();
+// TODO: verify plan that predicate on "c" is applied first (most selective), then "b", then "a".
+}());
diff --git a/jstests/cqf/find_sort.js b/jstests/cqf/find_sort.js
new file mode 100644
index 00000000000..5ead920b48c
--- /dev/null
+++ b/jstests/cqf/find_sort.js
@@ -0,0 +1,41 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const coll = db.cqf_find_sort;
+coll.drop();
+
+const bulk = coll.initializeUnorderedBulkOp();
+const nDocs = 10000;
+let numResults = 0;
+
+Random.srand(0);
+for (let i = 0; i < nDocs; i++) {
+ const va = 100.0 * Random.rand();
+ const vb = 100.0 * Random.rand();
+ if (va < 5.0 && vb < 5.0) {
+ numResults++;
+ }
+ bulk.insert({a: va, b: vb});
+}
+assert.gt(numResults, 0);
+
+assert.commandWorked(bulk.execute());
+
+assert.commandWorked(coll.createIndex({a: 1, b: 1}));
+
+const res = coll.explain("executionStats")
+ .find({a: {$lt: 5}, b: {$lt: 5}}, {a: 1, b: 1})
+ .sort({b: 1})
+ .finish();
+assert.eq(numResults, res.executionStats.nReturned);
+
+const indexScanNode = res.queryPlanner.winningPlan.optimizerPlan.child.child.child.leftChild.child;
+assert.eq("IndexScan", indexScanNode.nodeType);
+assert.eq(5, indexScanNode.interval[0].highBound.bound.value);
+}());
diff --git a/jstests/cqf/group.js b/jstests/cqf/group.js
new file mode 100644
index 00000000000..4af1ad6b021
--- /dev/null
+++ b/jstests/cqf/group.js
@@ -0,0 +1,27 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const coll = db.cqf_group;
+coll.drop();
+
+assert.commandWorked(coll.insert([
+ {a: 1, b: 1, c: 1},
+ {a: 1, b: 2, c: 2},
+ {a: 1, b: 2, c: 3},
+ {a: 2, b: 1, c: 4},
+ {a: 2, b: 1, c: 5},
+ {a: 2, b: 2, c: 6},
+]));
+
+const res = coll.explain("executionStats").aggregate([
+ {$group: {_id: {a: '$a', b: '$b'}, sum: {$sum: '$c'}, avg: {$avg: '$c'}}}
+]);
+assert.eq("GroupBy", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+assert.eq(4, res.executionStats.nReturned);
+}());
diff --git a/jstests/cqf/index_intersect.js b/jstests/cqf/index_intersect.js
new file mode 100644
index 00000000000..66ad1935996
--- /dev/null
+++ b/jstests/cqf/index_intersect.js
@@ -0,0 +1,47 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_index_intersect;
+t.drop();
+
+const nMatches = 60;
+
+assert.commandWorked(t.insert({a: 1, b: 1, c: 1}));
+assert.commandWorked(t.insert({a: 3, b: 2, c: 1}));
+for (let i = 0; i < nMatches; i++) {
+ assert.commandWorked(t.insert({a: 3, b: 3, c: i}));
+}
+assert.commandWorked(t.insert({a: 4, b: 3, c: 2}));
+assert.commandWorked(t.insert({a: 5, b: 5, c: 2}));
+
+for (let i = 1; i < nMatches + 100; i++) {
+ assert.commandWorked(t.insert({a: i + nMatches, b: i + nMatches, c: i + nMatches}));
+}
+
+assert.commandWorked(t.createIndex({'a': 1}));
+assert.commandWorked(t.createIndex({'b': 1}));
+
+let res = t.explain("executionStats").aggregate([{$match: {'a': 3, 'b': 3}}]);
+assert.eq(nMatches, res.executionStats.nReturned);
+
+// Verify we can place a MergeJoin
+let joinNode = res.queryPlanner.winningPlan.optimizerPlan.child.leftChild;
+assert.eq("MergeJoin", joinNode.nodeType);
+assert.eq("IndexScan", joinNode.leftChild.nodeType);
+assert.eq("IndexScan", joinNode.rightChild.children[0].child.nodeType);
+
+// One side is not equality, and we use a HashJoin.
+res = t.explain("executionStats").aggregate([{$match: {'a': {$lte: 3}, 'b': 3}}]);
+assert.eq(nMatches, res.executionStats.nReturned);
+
+joinNode = res.queryPlanner.winningPlan.optimizerPlan.child.leftChild;
+assert.eq("HashJoin", joinNode.nodeType);
+assert.eq("IndexScan", joinNode.leftChild.nodeType);
+assert.eq("IndexScan", joinNode.rightChild.children[0].child.nodeType);
+}());
diff --git a/jstests/cqf/index_intersect1.js b/jstests/cqf/index_intersect1.js
new file mode 100644
index 00000000000..fcf0036c974
--- /dev/null
+++ b/jstests/cqf/index_intersect1.js
@@ -0,0 +1,35 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_index_intersect1;
+t.drop();
+
+assert.commandWorked(t.insert({a: 50}));
+assert.commandWorked(t.insert({a: 70}));
+assert.commandWorked(t.insert({a: 90}));
+assert.commandWorked(t.insert({a: 110}));
+assert.commandWorked(t.insert({a: 130}));
+
+// Generate enough documents for index to be preferable.
+for (let i = 0; i < 100; i++) {
+ assert.commandWorked(t.insert({a: 200 + i}));
+}
+
+assert.commandWorked(t.createIndex({'a': 1}));
+
+let res = t.explain("executionStats").aggregate([{$match: {'a': {$gt: 60, $lt: 100}}}]);
+assert.eq(2, res.executionStats.nReturned);
+assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType);
+
+// Should get a covered plan.
+res = t.explain("executionStats")
+ .aggregate([{$project: {'_id': 0, 'a': 1}}, {$match: {'a': {$gt: 60, $lt: 100}}}]);
+assert.eq(2, res.executionStats.nReturned);
+assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+}()); \ No newline at end of file
diff --git a/jstests/cqf/no_collection.js b/jstests/cqf/no_collection.js
new file mode 100644
index 00000000000..3c7ecae4c32
--- /dev/null
+++ b/jstests/cqf/no_collection.js
@@ -0,0 +1,15 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+let t = db.cqf_no_collection;
+t.drop();
+
+const res = t.explain("executionStats").aggregate([{$match: {'a': 2}}]);
+assert.eq(0, res.executionStats.nReturned);
+}()); \ No newline at end of file
diff --git a/jstests/cqf/nonselective_index.js b/jstests/cqf/nonselective_index.js
new file mode 100644
index 00000000000..f951ae7dc40
--- /dev/null
+++ b/jstests/cqf/nonselective_index.js
@@ -0,0 +1,30 @@
+/**
+ * Tests scenario related to SERVER-13065.
+ */
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_nonselective_index;
+t.drop();
+
+const bulk = t.initializeUnorderedBulkOp();
+const nDocs = 1000;
+for (let i = 0; i < nDocs; i++) {
+ bulk.insert({a: i});
+}
+assert.commandWorked(bulk.execute());
+
+assert.commandWorked(t.createIndex({a: 1}));
+
+// We pick collection scan since the query is not selective.
+const res = t.explain("executionStats").aggregate([{$match: {a: {$gte: 0}}}]);
+assert.eq(nDocs, res.executionStats.nReturned);
+
+assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+}()); \ No newline at end of file
diff --git a/jstests/cqf/object_elemMatch.js b/jstests/cqf/object_elemMatch.js
new file mode 100644
index 00000000000..f402c590658
--- /dev/null
+++ b/jstests/cqf/object_elemMatch.js
@@ -0,0 +1,33 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_object_elemMatch;
+
+t.drop();
+assert.commandWorked(t.insert({a: [{a: 1, b: 1}, {a: 1, b: 2}]}));
+assert.commandWorked(t.insert({a: [{a: 2, b: 1}, {a: 2, b: 2}]}));
+assert.commandWorked(t.insert({a: {a: 2, b: 1}}));
+assert.commandWorked(t.insert({a: [{b: [1, 2], c: [3, 4]}]}));
+
+{
+ // Object elemMatch. Currently we do not support index here.
+ const res = t.explain("executionStats").aggregate([{$match: {a: {$elemMatch: {a: 2, b: 1}}}}]);
+ assert.eq(1, res.executionStats.nReturned);
+ assert.eq("PhysicalScan",
+ res.queryPlanner.winningPlan.optimizerPlan.child.child.child.nodeType);
+}
+
+{
+ // Should not be getting any results.
+ const res = t.explain("executionStats").aggregate([
+ {$match: {a: {$elemMatch: {b: {$elemMatch: {}}, c: {$elemMatch: {}}}}}}
+ ]);
+ assert.eq(0, res.executionStats.nReturned);
+}
+}());
diff --git a/jstests/cqf/partial_index.js b/jstests/cqf/partial_index.js
new file mode 100644
index 00000000000..d8196c8cea8
--- /dev/null
+++ b/jstests/cqf/partial_index.js
@@ -0,0 +1,34 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_partial_index;
+t.drop();
+
+assert.commandWorked(t.insert({a: 1, b: 1, c: 1}));
+assert.commandWorked(t.insert({a: 3, b: 2, c: 1}));
+assert.commandWorked(t.insert({a: 3, b: 3, c: 1}));
+assert.commandWorked(t.insert({a: 3, b: 3, c: 2}));
+assert.commandWorked(t.insert({a: 4, b: 3, c: 2}));
+assert.commandWorked(t.insert({a: 5, b: 5, c: 2}));
+
+for (let i = 0; i < 40; i++) {
+ assert.commandWorked(t.insert({a: i + 10, b: i + 10, c: i + 10}));
+}
+
+assert.commandWorked(t.createIndex({'a': 1}, {partialFilterExpression: {'b': 2}}));
+// assert.commandWorked(t.createIndex({'a': 1}));
+
+// TODO: verify with explain the plan should use the index.
+let res = t.aggregate([{$match: {'a': 3, 'b': 2}}]).toArray();
+assert.eq(1, res.length);
+
+// TODO: verify with explain the plan should not use the index.
+res = t.aggregate([{$match: {'a': 3, 'b': 3}}]).toArray();
+assert.eq(2, res.length);
+}()); \ No newline at end of file
diff --git a/jstests/cqf/residual_pred_costing.js b/jstests/cqf/residual_pred_costing.js
new file mode 100644
index 00000000000..07bc7211836
--- /dev/null
+++ b/jstests/cqf/residual_pred_costing.js
@@ -0,0 +1,35 @@
+/**
+ * Tests scenario related to SERVER-21697.
+ */
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_residual_pred_costing;
+t.drop();
+
+const bulk = t.initializeUnorderedBulkOp();
+const nDocs = 2000;
+for (let i = 0; i < nDocs; i++) {
+ bulk.insert({a: i % 10, b: i % 10, c: i % 10, d: i % 10});
+}
+assert.commandWorked(bulk.execute());
+
+assert.commandWorked(t.createIndex({a: 1, b: 1, c: 1, d: 1}));
+assert.commandWorked(t.createIndex({a: 1, b: 1, d: 1}));
+assert.commandWorked(t.createIndex({a: 1, d: 1}));
+
+let res = t.explain("executionStats")
+ .aggregate([{$match: {a: {$eq: 0}, b: {$eq: 0}, c: {$eq: 0}}}, {$sort: {d: 1}}]);
+assert.eq(nDocs * 0.1, res.executionStats.nReturned);
+
+// Demonstrate we can pick the indexing covering most fields.
+const indexNode = res.queryPlanner.winningPlan.optimizerPlan.child.leftChild;
+assert.eq("IndexScan", indexNode.nodeType);
+assert.eq("a_1_b_1_c_1_d_1", indexNode.indexDefName);
+}());
diff --git a/jstests/cqf/sampling.js b/jstests/cqf/sampling.js
new file mode 100644
index 00000000000..37dd0ae0e44
--- /dev/null
+++ b/jstests/cqf/sampling.js
@@ -0,0 +1,31 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const coll = db.cqf_sampling;
+coll.drop();
+
+const bulk = coll.initializeUnorderedBulkOp();
+const nDocs = 10000;
+
+Random.srand(0);
+for (let i = 0; i < nDocs; i++) {
+ const valA = 10.0 * Random.rand();
+ const valB = 10.0 * Random.rand();
+ bulk.insert({a: valA, b: valB});
+}
+assert.commandWorked(bulk.execute());
+
+const res = coll.explain().aggregate([{$match: {'a': {$lt: 2}}}]);
+assert(res.queryPlanner.winningPlan.optimizerPlan.hasOwnProperty("properties"));
+const props = res.queryPlanner.winningPlan.optimizerPlan.properties;
+
+// Verify the winning plan cardinality is within roughly 25% of the expected documents.
+assert.lt(nDocs * 0.2 * 0.75, props.adjustedCE);
+assert.gt(nDocs * 0.2 * 1.25, props.adjustedCE);
+}());
diff --git a/jstests/cqf/selective_index.js b/jstests/cqf/selective_index.js
new file mode 100644
index 00000000000..722f04e75c7
--- /dev/null
+++ b/jstests/cqf/selective_index.js
@@ -0,0 +1,34 @@
+/**
+ * Tests scenario related to SERVER-20616.
+ */
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_selective_index;
+t.drop();
+
+const bulk = t.initializeUnorderedBulkOp();
+const nDocs = 1000;
+for (let i = 0; i < nDocs; i++) {
+ bulk.insert({a: i % 10, b: i});
+}
+assert.commandWorked(bulk.execute());
+
+assert.commandWorked(t.createIndex({a: 1}));
+assert.commandWorked(t.createIndex({b: 1}));
+
+// Predicate on "b" is more selective than the one on "a": 0.1% vs 10%.
+const res = t.explain("executionStats").aggregate([{$match: {a: {$eq: 0}, b: {$eq: 0}}}]);
+assert.eq(1, res.executionStats.nReturned);
+
+// Demonstrate we can pick index on "b".
+const indexNode = res.queryPlanner.winningPlan.optimizerPlan.child.leftChild;
+assert.eq("IndexScan", indexNode.nodeType);
+assert.eq("b_1", indexNode.indexDefName);
+}()); \ No newline at end of file
diff --git a/jstests/cqf/sort.js b/jstests/cqf/sort.js
new file mode 100644
index 00000000000..1a9f4582262
--- /dev/null
+++ b/jstests/cqf/sort.js
@@ -0,0 +1,22 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_basic_unwind1;
+t.drop();
+
+assert.commandWorked(t.insert({_id: 1}));
+assert.commandWorked(t.insert({_id: 2, x: null}));
+assert.commandWorked(t.insert({_id: 3, x: []}));
+assert.commandWorked(t.insert({_id: 4, x: [1, 2]}));
+assert.commandWorked(t.insert({_id: 5, x: [10]}));
+assert.commandWorked(t.insert({_id: 6, x: 4}));
+
+const res = t.aggregate([{$unwind: '$x'}, {$sort: {'x': 1}}]).toArray();
+assert.eq(4, res.length);
+}()); \ No newline at end of file
diff --git a/jstests/cqf/sort_match.js b/jstests/cqf/sort_match.js
new file mode 100644
index 00000000000..54a22a64071
--- /dev/null
+++ b/jstests/cqf/sort_match.js
@@ -0,0 +1,33 @@
+/**
+ * Tests scenario related to SERVER-12923.
+ */
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_sort_match;
+t.drop();
+
+const bulk = t.initializeUnorderedBulkOp();
+const nDocs = 1000;
+for (let i = 0; i < nDocs; i++) {
+ bulk.insert({a: i, b: i % 10});
+}
+assert.commandWorked(bulk.execute());
+
+assert.commandWorked(t.createIndex({a: 1}));
+assert.commandWorked(t.createIndex({b: 1}));
+
+let res = t.explain("executionStats").aggregate([{$sort: {b: 1}}, {$match: {a: {$eq: 0}}}]);
+assert.eq(1, res.executionStats.nReturned);
+
+// Index on "a" is preferred.
+const indexNode = res.queryPlanner.winningPlan.optimizerPlan.child.child.leftChild;
+assert.eq("IndexScan", indexNode.nodeType);
+assert.eq("a_1", indexNode.indexDefName);
+}()); \ No newline at end of file
diff --git a/jstests/cqf/sort_project.js b/jstests/cqf/sort_project.js
new file mode 100644
index 00000000000..49beb912191
--- /dev/null
+++ b/jstests/cqf/sort_project.js
@@ -0,0 +1,74 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+var coll = db.cqf_testCovIndxScan;
+
+coll.drop();
+
+coll.createIndex({f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1});
+coll.getIndexes();
+
+coll.insertMany([
+ {f_0: 2, f_1: 8, f_2: 2, f_3: 0, f_4: 2}, {f_0: 7, f_1: 9, f_2: 8, f_3: 3, f_4: 3},
+ {f_0: 6, f_1: 6, f_2: 2, f_3: 8, f_4: 3}, {f_0: 9, f_1: 2, f_2: 3, f_3: 5, f_4: 7},
+ {f_0: 7, f_1: 8, f_2: 8, f_3: 2, f_4: 9}, {f_0: 7, f_1: 1, f_2: 7, f_3: 3, f_4: 1},
+ {f_0: 7, f_1: 3, f_2: 4, f_3: 0, f_4: 7}, {f_0: 8, f_1: 4, f_2: 5, f_3: 6, f_4: 0},
+ {f_0: 5, f_1: 2, f_2: 0, f_3: 7, f_4: 0}, {f_0: 0, f_1: 2, f_2: 1, f_3: 9, f_4: 2},
+ {f_0: 6, f_1: 0, f_2: 5, f_3: 9, f_4: 1}, {f_0: 0, f_1: 1, f_2: 6, f_3: 8, f_4: 6},
+ {f_0: 6, f_1: 5, f_2: 3, f_3: 8, f_4: 5}, {f_0: 2, f_1: 9, f_2: 7, f_3: 2, f_4: 3},
+ {f_0: 0, f_1: 6, f_2: 9, f_3: 6, f_4: 8}, {f_0: 5, f_1: 7, f_2: 8, f_3: 1, f_4: 4},
+ {f_0: 8, f_1: 5, f_2: 1, f_3: 4, f_4: 6}, {f_0: 6, f_1: 2, f_2: 8, f_3: 4, f_4: 3},
+ {f_0: 1, f_1: 6, f_2: 2, f_3: 0, f_4: 3}, {f_0: 1, f_1: 8, f_2: 2, f_3: 5, f_4: 2}
+]);
+
+const nDocs = 20;
+
+var pln0 = [{'$project': {_id: 0, f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1}}];
+
+var pln1 = [{'$sort': {f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1}}];
+
+var pln2 = [
+ {'$project': {_id: 0, f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1}},
+ {'$sort': {f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1}}
+];
+
+var pln3 = [
+ {'$sort': {f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1}},
+ {'$project': {_id: 0, f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1}}
+];
+
+{
+ // Covered plan. Still chooses collection scan because there is no field size/count statistics.
+ // Also an index scan on all fields is not cheaper than a collection scan.
+ let res = coll.explain("executionStats").aggregate(pln0);
+ assert.eq(nDocs, res.executionStats.nReturned);
+ assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+}
+
+{
+ // Covered plan.
+ let res = coll.explain("executionStats").aggregate(pln1);
+ assert.eq(nDocs, res.executionStats.nReturned);
+ assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType);
+}
+
+{
+ // Covered plan.
+ let res = coll.explain("executionStats").aggregate(pln2);
+ assert.eq(nDocs, res.executionStats.nReturned);
+ assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+}
+
+{
+ // Covered plan.
+ let res = coll.explain("executionStats").aggregate(pln3);
+ assert.eq(nDocs, res.executionStats.nReturned);
+ assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+}
+}());
diff --git a/jstests/cqf/type_bracket.js b/jstests/cqf/type_bracket.js
new file mode 100644
index 00000000000..1cacba0df2e
--- /dev/null
+++ b/jstests/cqf/type_bracket.js
@@ -0,0 +1,62 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_type_bracket;
+t.drop();
+
+// Generate enough documents for index to be preferable if it exists.
+for (let i = 0; i < 100; i++) {
+ assert.commandWorked(t.insert({a: i}));
+ assert.commandWorked(t.insert({a: i.toString()}));
+}
+
+{
+ const res = t.explain("executionStats").aggregate([{$match: {a: {$lt: "2"}}}]);
+ assert.eq(12, res.executionStats.nReturned);
+ assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+}
+{
+ const res = t.explain("executionStats").aggregate([{$match: {a: {$gt: "95"}}}]);
+ assert.eq(4, res.executionStats.nReturned);
+ assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+}
+{
+ const res = t.explain("executionStats").aggregate([{$match: {a: {$lt: 2}}}]);
+ assert.eq(2, res.executionStats.nReturned);
+ assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+}
+{
+ const res = t.explain("executionStats").aggregate([{$match: {a: {$gt: 95}}}]);
+ assert.eq(4, res.executionStats.nReturned);
+ assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType);
+}
+
+assert.commandWorked(t.createIndex({a: 1}));
+
+{
+ const res = t.explain("executionStats").aggregate([{$match: {a: {$lt: "2"}}}]);
+ assert.eq(12, res.executionStats.nReturned);
+ assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType);
+}
+{
+ const res = t.explain("executionStats").aggregate([{$match: {a: {$gt: "95"}}}]);
+ assert.eq(4, res.executionStats.nReturned);
+ assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType);
+}
+{
+ const res = t.explain("executionStats").aggregate([{$match: {a: {$lt: 2}}}]);
+ assert.eq(2, res.executionStats.nReturned);
+ assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType);
+}
+{
+ const res = t.explain("executionStats").aggregate([{$match: {a: {$gt: 95}}}]);
+ assert.eq(4, res.executionStats.nReturned);
+ assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType);
+}
+}()); \ No newline at end of file
diff --git a/jstests/cqf/type_predicate.js b/jstests/cqf/type_predicate.js
new file mode 100644
index 00000000000..eb8de44b3f6
--- /dev/null
+++ b/jstests/cqf/type_predicate.js
@@ -0,0 +1,26 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_type_predicate;
+t.drop();
+
+for (let i = 0; i < 10; i++) {
+ assert.commandWorked(t.insert({a: i}));
+ assert.commandWorked(t.insert({a: i.toString()}));
+}
+
+{
+ const res = t.explain("executionStats").aggregate([{$match: {a: {$type: "string"}}}]);
+ assert.eq(10, res.executionStats.nReturned);
+}
+{
+ const res = t.explain("executionStats").aggregate([{$match: {a: {$type: "double"}}}]);
+ assert.eq(10, res.executionStats.nReturned);
+}
+}()); \ No newline at end of file
diff --git a/jstests/cqf/unionWith.js b/jstests/cqf/unionWith.js
new file mode 100644
index 00000000000..63dedc9d750
--- /dev/null
+++ b/jstests/cqf/unionWith.js
@@ -0,0 +1,54 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+load("jstests/aggregation/extras/utils.js");
+
+const collA = db.collA;
+collA.drop();
+
+const collB = db.collB;
+collB.drop();
+
+assert.commandWorked(collA.insert({_id: 0, a: 1}));
+assert.commandWorked(collB.insert({_id: 0, a: 2}));
+
+let res = collA.aggregate([{$unionWith: "collB"}]).toArray();
+assert.eq(2, res.length);
+assert.eq([{_id: 0, a: 1}, {_id: 0, a: 2}], res);
+
+// Test a filter after the union which can be pushed down to each branch.
+res = collA.aggregate([{$unionWith: "collB"}, {$match: {a: {$lt: 2}}}]).toArray();
+assert.eq(1, res.length);
+assert.eq([{_id: 0, a: 1}], res);
+
+// Test a non-simple inner pipeline.
+res = collA.aggregate([{$unionWith: {coll: "collB", pipeline: [{$match: {a: 2}}]}}]).toArray();
+assert.eq(2, res.length);
+assert.eq([{_id: 0, a: 1}, {_id: 0, a: 2}], res);
+
+// Test a union with non-existent collection.
+res = collA.aggregate([{$unionWith: "non_existent"}]).toArray();
+assert.eq(1, res.length);
+assert.eq([{_id: 0, a: 1}], res);
+
+// Test union alongside projections. This is meant to test the pipeline translation logic that adds
+// a projection to the inner pipeline when necessary.
+res = collA.aggregate([{$project: {_id: 0, a: 1}}, {$unionWith: "collB"}]).toArray();
+assert.eq(2, res.length);
+assert.eq([{a: 1}, {_id: 0, a: 2}], res);
+
+res = collA.aggregate([{$unionWith: {coll: "collB", pipeline: [{$project: {_id: 0, a: 1}}]}}])
+ .toArray();
+assert.eq(2, res.length);
+assert.eq([{_id: 0, a: 1}, {a: 2}], res);
+
+res = collA.aggregate([{$unionWith: "collB"}, {$project: {_id: 0, a: 1}}]).toArray();
+assert.eq(2, res.length);
+assert.eq([{a: 1}, {a: 2}], res);
+}());
diff --git a/jstests/cqf/value_elemMatch.js b/jstests/cqf/value_elemMatch.js
new file mode 100644
index 00000000000..4bb46e6f1a7
--- /dev/null
+++ b/jstests/cqf/value_elemMatch.js
@@ -0,0 +1,52 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_value_elemMatch;
+t.drop();
+
+assert.commandWorked(t.insert({a: [1, 2, 3, 4, 5, 6]}));
+assert.commandWorked(t.insert({a: [5, 6, 7, 8, 9]}));
+assert.commandWorked(t.insert({a: [1, 2, 3]}));
+assert.commandWorked(t.insert({a: []}));
+assert.commandWorked(t.insert({a: [1]}));
+assert.commandWorked(t.insert({a: [10]}));
+assert.commandWorked(t.insert({a: 5}));
+assert.commandWorked(t.insert({a: 6}));
+
+// Generate enough documents for index to be preferable.
+const nDocs = 400;
+for (let i = 0; i < nDocs; i++) {
+ assert.commandWorked(t.insert({a: i + 10}));
+}
+
+assert.commandWorked(t.createIndex({a: 1}));
+
+{
+ // Value elemMatch. Demonstrate we can use an index.
+ const res =
+ t.explain("executionStats").aggregate([{$match: {a: {$elemMatch: {$gte: 5, $lte: 6}}}}]);
+ assert.eq(2, res.executionStats.nReturned);
+ assert.eq("IndexScan",
+ res.queryPlanner.winningPlan.optimizerPlan.child.child.leftChild.child.nodeType);
+}
+{
+ const res =
+ t.explain("executionStats").aggregate([{$match: {a: {$elemMatch: {$lt: 11, $gt: 9}}}}]);
+ assert.eq(1, res.executionStats.nReturned);
+ assert.eq("IndexScan",
+ res.queryPlanner.winningPlan.optimizerPlan.child.child.leftChild.child.nodeType);
+}
+{
+ // Contradiction.
+ const res =
+ t.explain("executionStats").aggregate([{$match: {a: {$elemMatch: {$lt: 5, $gt: 6}}}}]);
+ assert.eq(0, res.executionStats.nReturned);
+ assert.eq("CoScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.child.nodeType);
+}
+}());
diff --git a/jstests/cqf_parallel/basic_exchange.js b/jstests/cqf_parallel/basic_exchange.js
new file mode 100644
index 00000000000..3be8768b0de
--- /dev/null
+++ b/jstests/cqf_parallel/basic_exchange.js
@@ -0,0 +1,22 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_exchange;
+t.drop();
+
+assert.commandWorked(t.insert({a: {b: 1}}));
+assert.commandWorked(t.insert({a: {b: 2}}));
+assert.commandWorked(t.insert({a: {b: 3}}));
+assert.commandWorked(t.insert({a: {b: 4}}));
+assert.commandWorked(t.insert({a: {b: 5}}));
+
+const res = t.explain("executionStats").aggregate([{$match: {'a.b': 2}}]);
+assert.eq(1, res.executionStats.nReturned);
+assert.eq("Exchange", res.queryPlanner.winningPlan.optimizerPlan.child.nodeType);
+}());
diff --git a/jstests/cqf_parallel/groupby.js b/jstests/cqf_parallel/groupby.js
new file mode 100644
index 00000000000..91fa3fc80fa
--- /dev/null
+++ b/jstests/cqf_parallel/groupby.js
@@ -0,0 +1,37 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const t = db.cqf_exchange;
+t.drop();
+
+assert.commandWorked(t.insert({a: 1}));
+assert.commandWorked(t.insert({a: 2}));
+assert.commandWorked(t.insert({a: 3}));
+assert.commandWorked(t.insert({a: 4}));
+assert.commandWorked(t.insert({a: 5}));
+
+// Demonstrate local-global optimization.
+const res = t.explain("executionStats").aggregate([{$group: {_id: "$a", cnt: {$sum: 1}}}]);
+assert.eq(5, res.executionStats.nReturned);
+
+assert.eq("Exchange", res.queryPlanner.winningPlan.optimizerPlan.child.nodeType);
+assert.eq("Centralized", res.queryPlanner.winningPlan.optimizerPlan.child.distribution.type);
+
+assert.eq("GroupBy", res.queryPlanner.winningPlan.optimizerPlan.child.child.child.nodeType);
+
+assert.eq("Exchange", res.queryPlanner.winningPlan.optimizerPlan.child.child.child.child.nodeType);
+assert.eq("HashPartitioning",
+ res.queryPlanner.winningPlan.optimizerPlan.child.child.child.child.distribution.type);
+
+assert.eq("GroupBy",
+ res.queryPlanner.winningPlan.optimizerPlan.child.child.child.child.child.nodeType);
+assert.eq("UnknownPartitioning",
+ res.queryPlanner.winningPlan.optimizerPlan.child.child.child.child.child.properties
+ .physicalProperties.distribution.type);
+}()); \ No newline at end of file
diff --git a/jstests/cqf_parallel/index.js b/jstests/cqf_parallel/index.js
new file mode 100644
index 00000000000..c155e9b9f61
--- /dev/null
+++ b/jstests/cqf_parallel/index.js
@@ -0,0 +1,25 @@
+(function() {
+"use strict";
+
+load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled.
+if (!checkCascadesOptimizerEnabled(db)) {
+ jsTestLog("Skipping test because the optimizer is not enabled");
+ return;
+}
+
+const coll = db.cqf_parallel_index;
+coll.drop();
+
+const bulk = coll.initializeUnorderedBulkOp();
+for (let i = 0; i < 1000; i++) {
+ bulk.insert({a: i});
+}
+assert.commandWorked(bulk.execute());
+
+assert.commandWorked(coll.createIndex({a: 1}));
+
+let res = coll.explain("executionStats").aggregate([{$match: {a: {$lt: 10}}}]);
+assert.eq(10, res.executionStats.nReturned);
+assert.eq("IndexScan",
+ res.queryPlanner.winningPlan.optimizerPlan.child.child.leftChild.child.nodeType);
+}());
diff --git a/jstests/libs/optimizer_utils.js b/jstests/libs/optimizer_utils.js
new file mode 100644
index 00000000000..ff2a179388a
--- /dev/null
+++ b/jstests/libs/optimizer_utils.js
@@ -0,0 +1,8 @@
+/*
+ * Utility for checking if the query optimizer is enabled.
+ */
+function checkCascadesOptimizerEnabled(theDB) {
+ const param = theDB.adminCommand({getParameter: 1, featureFlagCommonQueryFramework: 1});
+ return param.hasOwnProperty("featureFlagCommonQueryFramework") &&
+ param.featureFlagCommonQueryFramework.value;
+}
diff --git a/src/mongo/db/commands/SConscript b/src/mongo/db/commands/SConscript
index a5c8c6ce453..7551a05e8a6 100644
--- a/src/mongo/db/commands/SConscript
+++ b/src/mongo/db/commands/SConscript
@@ -317,6 +317,7 @@ env.Library(
target="standalone",
source=[
"count_cmd.cpp",
+ "cqf/cqf_aggregate.cpp",
"create_command.cpp",
"create_indexes.cpp",
"current_op.cpp",
@@ -364,13 +365,16 @@ env.Library(
'$BUILD_DIR/mongo/db/concurrency/lock_manager',
'$BUILD_DIR/mongo/db/concurrency/write_conflict_exception',
'$BUILD_DIR/mongo/db/curop_failpoint_helpers',
+ '$BUILD_DIR/mongo/db/exec/sbe/query_sbe_abt',
'$BUILD_DIR/mongo/db/index_builds_coordinator_interface',
'$BUILD_DIR/mongo/db/index_commands_idl',
'$BUILD_DIR/mongo/db/ops/write_ops_exec',
'$BUILD_DIR/mongo/db/pipeline/aggregation_request_helper',
'$BUILD_DIR/mongo/db/pipeline/process_interface/mongo_process_interface',
+ '$BUILD_DIR/mongo/db/query/ce/query_ce',
'$BUILD_DIR/mongo/db/query/command_request_response',
'$BUILD_DIR/mongo/db/query/cursor_response_idl',
+ '$BUILD_DIR/mongo/db/query/optimizer/optimizer',
'$BUILD_DIR/mongo/db/query_exec',
'$BUILD_DIR/mongo/db/repl/replica_set_messages',
'$BUILD_DIR/mongo/db/repl/tenant_migration_access_blocker',
@@ -766,4 +770,4 @@ env.CppUnitTest(
"servers",
"standalone",
],
-)
+) \ No newline at end of file
diff --git a/src/mongo/db/commands/cqf/cqf_aggregate.cpp b/src/mongo/db/commands/cqf/cqf_aggregate.cpp
new file mode 100644
index 00000000000..f092233de11
--- /dev/null
+++ b/src/mongo/db/commands/cqf/cqf_aggregate.cpp
@@ -0,0 +1,431 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/commands/cqf/cqf_aggregate.h"
+
+#include "mongo/db/exec/sbe/abt/abt_lower.h"
+#include "mongo/db/pipeline/abt/abt_document_source_visitor.h"
+#include "mongo/db/pipeline/abt/match_expression_visitor.h"
+#include "mongo/db/query/ce/ce_sampling.h"
+#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
+#include "mongo/db/query/optimizer/cascades/cost_derivation.h"
+#include "mongo/db/query/optimizer/explain.h"
+#include "mongo/db/query/optimizer/node.h"
+#include "mongo/db/query/optimizer/opt_phase_manager.h"
+#include "mongo/db/query/plan_executor_factory.h"
+#include "mongo/db/query/query_knobs_gen.h"
+#include "mongo/db/query/query_planner_params.h"
+#include "mongo/db/query/sbe_stage_builder.h"
+#include "mongo/db/query/yield_policy_callbacks_impl.h"
+
+namespace mongo {
+
+using namespace optimizer;
+
+static opt::unordered_map<std::string, optimizer::IndexDefinition> buildIndexSpecsOptimizer(
+ boost::intrusive_ptr<ExpressionContext> expCtx,
+ OperationContext* opCtx,
+ const CollectionPtr& collection,
+ const optimizer::ProjectionName& scanProjName,
+ const DisableIndexOptions disableIndexOptions) {
+ using namespace optimizer;
+
+ if (disableIndexOptions == DisableIndexOptions::DisableAll) {
+ return {};
+ }
+
+ const IndexCatalog& indexCatalog = *collection->getIndexCatalog();
+ opt::unordered_map<std::string, IndexDefinition> result;
+ auto indexIterator = indexCatalog.getIndexIterator(opCtx, false /*includeUnfinished*/);
+
+ while (indexIterator->more()) {
+ const IndexCatalogEntry& catalogEntry = *indexIterator->next();
+
+ const bool isMultiKey = catalogEntry.isMultikey(opCtx, collection);
+ const MultikeyPaths& multiKeyPaths = catalogEntry.getMultikeyPaths(opCtx, collection);
+ uassert(6624251, "Multikey paths cannot be empty.", !multiKeyPaths.empty());
+
+ const IndexDescriptor& descriptor = *catalogEntry.descriptor();
+ if (descriptor.hidden() || descriptor.isSparse() ||
+ descriptor.getIndexType() != IndexType::INDEX_BTREE) {
+ // Not supported for now.
+ continue;
+ }
+
+ // SBE version is base 0.
+ const int64_t version = static_cast<int>(descriptor.version()) - 1;
+
+ uint32_t orderingBits = 0;
+ {
+ const Ordering ordering = catalogEntry.ordering();
+ for (int i = 0; i < descriptor.getNumFields(); i++) {
+ if ((ordering.get(i) == 1)) {
+ orderingBits |= (1ull << i);
+ }
+ }
+ }
+
+ IndexCollationSpec indexCollationSpec;
+ bool useIndex = true;
+ size_t elementIdx = 0;
+ for (const auto& element : descriptor.keyPattern()) {
+ FieldPathType fieldPath;
+ FieldPath path(element.fieldName());
+
+ for (size_t i = 0; i < path.getPathLength(); i++) {
+ const std::string& fieldName = path.getFieldName(i).toString();
+ if (fieldName == "$**") {
+ // TODO: For now disallow wildcard indexes.
+ useIndex = false;
+ break;
+ }
+ fieldPath.emplace_back(fieldName);
+ }
+ if (!useIndex) {
+ break;
+ }
+
+ const int direction = element.numberInt();
+ if (direction != -1 && direction != 1) {
+ // Invalid value?
+ useIndex = false;
+ break;
+ }
+
+ const CollationOp collationOp =
+ (direction == 1) ? CollationOp::Ascending : CollationOp::Descending;
+
+ // Construct an ABT path for each index component (field path).
+ const MultikeyComponents& elementMultiKeyInfo = multiKeyPaths[elementIdx];
+ ABT abtPath = make<PathIdentity>();
+ for (size_t i = fieldPath.size(); i-- > 0;) {
+ if (isMultiKey && elementMultiKeyInfo.find(i) != elementMultiKeyInfo.cend()) {
+ // This is a multikey element of the path.
+ abtPath = make<PathTraverse>(std::move(abtPath));
+ }
+ abtPath = make<PathGet>(fieldPath.at(i), std::move(abtPath));
+ }
+ indexCollationSpec.emplace_back(std::move(abtPath), collationOp);
+ ++elementIdx;
+ }
+ if (!useIndex) {
+ continue;
+ }
+
+ PartialSchemaRequirements partialIndexReqMap;
+ if (descriptor.isPartial() &&
+ disableIndexOptions != DisableIndexOptions::DisablePartialOnly) {
+ auto expr = MatchExpressionParser::parseAndNormalize(
+ descriptor.partialFilterExpression(),
+ expCtx,
+ ExtensionsCallbackNoop(),
+ MatchExpressionParser::kBanAllSpecialFeatures);
+
+ ABT exprABT = generateMatchExpression(expr.get(), false /*allowAggExpression*/, "", "");
+ exprABT = make<EvalFilter>(std::move(exprABT), make<Variable>(scanProjName));
+
+ // TODO: simplify expression.
+
+ PartialSchemaReqConversion conversion = convertExprToPartialSchemaReq(exprABT);
+ if (!conversion._success || conversion._hasEmptyInterval) {
+ // Unsatisfiable partial index filter?
+ continue;
+ }
+ partialIndexReqMap = std::move(conversion._reqMap);
+ }
+
+ // For now we assume distribution is Centralized.
+ result.emplace(descriptor.indexName(),
+ IndexDefinition(std::move(indexCollationSpec),
+ version,
+ orderingBits,
+ isMultiKey,
+ DistributionType::Centralized,
+ std::move(partialIndexReqMap)));
+ }
+
+ return result;
+}
+
+static QueryHints getHintsFromQueryKnobs() {
+ QueryHints hints;
+
+ hints._disableScan = internalCascadesOptimizerDisableScan.load();
+ hints._disableIndexes = internalCascadesOptimizerDisableIndexes.load()
+ ? DisableIndexOptions::DisableAll
+ : DisableIndexOptions::Enabled;
+ hints._disableHashJoinRIDIntersect =
+ internalCascadesOptimizerDisableHashJoinRIDIntersect.load();
+ hints._disableMergeJoinRIDIntersect =
+ internalCascadesOptimizerDisableMergeJoinRIDIntersect.load();
+ hints._disableGroupByAndUnionRIDIntersect =
+ internalCascadesOptimizerDisableGroupByAndUnionRIDIntersect.load();
+ hints._keepRejectedPlans = internalCascadesOptimizerKeepRejectedPlans.load();
+ hints._disableBranchAndBound = internalCascadesOptimizerDisableBranchAndBound.load();
+
+ return hints;
+}
+
+static std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> optimizeAndCreateExecutor(
+ OptPhaseManager& phaseManager,
+ ABT abtTree,
+ OperationContext* opCtx,
+ boost::intrusive_ptr<ExpressionContext> expCtx,
+ const NamespaceString& nss,
+ const CollectionPtr& collection) {
+
+ const bool optimizationResult = phaseManager.optimize(abtTree);
+ uassert(6624252, "Optimization failed", optimizationResult);
+
+ // TODO: SERVER-62648. std::cerr is used for debugging. Consider structured logging.
+ std::cerr << "********* Optimizer Stats *********\n";
+ {
+ const auto& memo = phaseManager.getMemo();
+ std::cerr << "Memo groups: " << memo.getGroupCount() << "\n";
+ std::cerr << "Memo logical nodes: " << memo.getLogicalNodeCount() << "\n";
+ std::cerr << "Memo phys. nodes: " << memo.getPhysicalNodeCount() << "\n";
+
+ const auto& memoStats = memo.getStats();
+ std::cerr << "Memo integrations: " << memoStats._numIntegrations << "\n";
+ std::cerr << "Phys. plans explored: " << memoStats._physPlanExplorationCount << "\n";
+ std::cerr << "Phys. memo checks: " << memoStats._physMemoCheckCount << "\n";
+ }
+ std::cerr << "********* Optimizer Stats *********\n";
+
+ std::cerr << "********* Optimized ABT *********\n";
+ std::cerr << ExplainGenerator::explainV2(
+ make<MemoPhysicalDelegatorNode>(phaseManager.getPhysicalNodeId()),
+ true /*displayPhysicalProperties*/,
+ &phaseManager.getMemo());
+ std::cerr << "********* Optimized ABT *********\n";
+
+ auto env = VariableEnvironment::build(abtTree);
+ SlotVarMap slotMap;
+ sbe::value::SlotIdGenerator ids;
+ SBENodeLowering g{env,
+ slotMap,
+ ids,
+ phaseManager.getMetadata(),
+ phaseManager.getNodeToGroupPropsMap(),
+ phaseManager.getRIDProjections()};
+ auto sbePlan = g.optimize(abtTree);
+
+ uassert(6624253, "Lowering failed: did not produce a plan.", sbePlan != nullptr);
+ uassert(6624254, "Lowering failed: did not produce any output slots.", !slotMap.empty());
+
+ {
+ std::cerr << "********* SBE *********\n";
+ sbe::DebugPrinter p;
+ std::cerr << p.print(*sbePlan.get()) << "\n";
+ std::cerr << "********* SBE *********\n";
+ }
+
+ stage_builder::PlanStageData data{std::make_unique<sbe::RuntimeEnvironment>()};
+ data.outputs.set(stage_builder::PlanStageSlots::kResult, slotMap.begin()->second);
+
+ sbePlan->attachToOperationContext(opCtx);
+ if (expCtx->explain || expCtx->mayDbProfile) {
+ sbePlan->markShouldCollectTimingInfo();
+ }
+
+ auto yieldPolicy =
+ std::make_unique<PlanYieldPolicySBE>(PlanYieldPolicy::YieldPolicy::YIELD_AUTO,
+ opCtx->getServiceContext()->getFastClockSource(),
+ internalQueryExecYieldIterations.load(),
+ Milliseconds{internalQueryExecYieldPeriodMS.load()},
+ nullptr,
+ std::make_unique<YieldPolicyCallbacksImpl>(nss),
+ false /*useExperimentalCommitTxnBehavior*/);
+
+ auto planExec = uassertStatusOK(plan_executor_factory::make(
+ opCtx,
+ nullptr /*cq*/,
+ nullptr /*solution*/,
+ {std::move(sbePlan), std::move(data)},
+ std::make_unique<ABTPrinter>(std::move(abtTree), phaseManager.getNodeToGroupPropsMap()),
+ &collection,
+ QueryPlannerParams::Options::DEFAULT,
+ nss,
+ std::move(yieldPolicy)));
+ return planExec;
+}
+
+static void populateAdditionalScanDefs(OperationContext* opCtx,
+ boost::intrusive_ptr<ExpressionContext> expCtx,
+ const Pipeline& pipeline,
+ const size_t numberOfPartitions,
+ PrefixId& prefixId,
+ opt::unordered_map<std::string, ScanDefinition>& scanDefs,
+ const DisableIndexOptions disableIndexOptions) {
+ for (const auto& involvedNss : pipeline.getInvolvedCollections()) {
+ // TODO handle views?
+ AutoGetCollectionForReadCommandMaybeLockFree ctx(
+ opCtx, involvedNss, AutoGetCollectionViewMode::kViewsForbidden);
+ const CollectionPtr& collection = ctx ? ctx.getCollection() : CollectionPtr::null;
+ const bool collectionExists = collection != nullptr;
+ const std::string uuidStr =
+ collectionExists ? collection->uuid().toString() : "<missing_uuid>";
+
+ const std::string collNameStr = involvedNss.coll().toString();
+ // TODO: We cannot add the uuidStr suffix because the pipeline translation does not have
+ // access to the metadata so it generates a scan over just the collection name.
+ const std::string scanDefName = collNameStr;
+
+ opt::unordered_map<std::string, optimizer::IndexDefinition> indexDefs;
+ const ProjectionName& scanProjName = prefixId.getNextId("scan");
+ if (collectionExists) {
+ // TODO: add locks on used indexes?
+ indexDefs = buildIndexSpecsOptimizer(
+ expCtx, opCtx, collection, scanProjName, disableIndexOptions);
+ }
+
+ // For now handle only local parallelism (no over-the-network exchanges).
+ DistributionAndPaths distribution{(numberOfPartitions == 1)
+ ? DistributionType::Centralized
+ : DistributionType::UnknownPartitioning};
+
+ const CEType collectionCE = collectionExists ? collection->numRecords(opCtx) : -1.0;
+ scanDefs[scanDefName] =
+ ScanDefinition({{"type", "mongod"},
+ {"database", involvedNss.db().toString()},
+ {"uuid", uuidStr},
+ {ScanNode::kDefaultCollectionNameSpec, collNameStr}},
+ std::move(indexDefs),
+ std::move(distribution),
+ collectionExists,
+ collectionCE);
+ }
+}
+
+std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> getSBEExecutorViaCascadesOptimizer(
+ OperationContext* opCtx,
+ boost::intrusive_ptr<ExpressionContext> expCtx,
+ const NamespaceString& nss,
+ const CollectionPtr& collection,
+ const Pipeline& pipeline) {
+ const bool collectionExists = collection != nullptr;
+ const std::string uuidStr = collectionExists ? collection->uuid().toString() : "<missing_uuid>";
+ const std::string collNameStr = nss.coll().toString();
+ const std::string scanDefName = collNameStr + "_" + uuidStr;
+
+ QueryHints queryHints = getHintsFromQueryKnobs();
+
+ PrefixId prefixId;
+ const ProjectionName& scanProjName = prefixId.getNextId("scan");
+
+ // Add the base collection metadata.
+ opt::unordered_map<std::string, optimizer::IndexDefinition> indexDefs;
+ if (collectionExists) {
+ // TODO: add locks on used indexes?
+ indexDefs = buildIndexSpecsOptimizer(
+ expCtx, opCtx, collection, scanProjName, queryHints._disableIndexes);
+ }
+
+ const size_t numberOfPartitions = internalQueryDefaultDOP.load();
+ // For now handle only local parallelism (no over-the-network exchanges).
+ DistributionAndPaths distribution{(numberOfPartitions == 1)
+ ? DistributionType::Centralized
+ : DistributionType::UnknownPartitioning};
+
+ opt::unordered_map<std::string, ScanDefinition> scanDefs;
+ const int64_t numRecords = collectionExists ? collection->numRecords(opCtx) : -1;
+ scanDefs.emplace(scanDefName,
+ ScanDefinition({{"type", "mongod"},
+ {"database", nss.db().toString()},
+ {"uuid", uuidStr},
+ {ScanNode::kDefaultCollectionNameSpec, collNameStr}},
+ std::move(indexDefs),
+ std::move(distribution),
+ collectionExists,
+ static_cast<CEType>(numRecords)));
+
+ // Add a scan definition for all involved collections. Note that the base namespace has already
+ // been accounted for above and isn't included here.
+ populateAdditionalScanDefs(opCtx,
+ expCtx,
+ pipeline,
+ numberOfPartitions,
+ prefixId,
+ scanDefs,
+ queryHints._disableIndexes);
+
+ Metadata metadata(std::move(scanDefs), numberOfPartitions);
+
+ ABT abtTree = collectionExists ? make<ScanNode>(scanProjName, scanDefName)
+ : make<ValueScanNode>(ProjectionNameVector{scanProjName});
+ abtTree =
+ translatePipelineToABT(metadata, pipeline, scanProjName, std::move(abtTree), prefixId);
+
+ std::cerr << "******* Translated ABT **********\n";
+ std::cerr << ExplainGenerator::explainV2(abtTree) << std::endl;
+ std::cerr << "******* Translated ABT **********\n";
+
+ if (collectionExists && numRecords > 0 &&
+ internalQueryEnableSamplingCardinalityEstimator.load()) {
+ Metadata metadataForSampling = metadata;
+ // Do not use indexes for sampling.
+ for (auto& entry : metadataForSampling._scanDefs) {
+ entry.second.getIndexDefs().clear();
+ }
+
+ // TODO: consider a limited rewrite set.
+ OptPhaseManager phaseManagerForSampling(OptPhaseManager::getAllRewritesSet(),
+ prefixId,
+ false /*requireRID*/,
+ std::move(metadataForSampling),
+ std::make_unique<HeuristicCE>(),
+ std::make_unique<DefaultCosting>(),
+ DebugInfo::kDefaultForProd);
+
+ OptPhaseManager phaseManager{
+ OptPhaseManager::getAllRewritesSet(),
+ prefixId,
+ false /*requireRID*/,
+ std::move(metadata),
+ std::make_unique<CESamplingTransport>(opCtx, phaseManagerForSampling, numRecords),
+ std::make_unique<DefaultCosting>(),
+ DebugInfo::kDefaultForProd};
+ phaseManager.getHints() = queryHints;
+
+ return optimizeAndCreateExecutor(
+ phaseManager, std::move(abtTree), opCtx, expCtx, nss, collection);
+ }
+
+ // Use heuristics.
+ OptPhaseManager phaseManager{OptPhaseManager::getAllRewritesSet(),
+ prefixId,
+ std::move(metadata),
+ DebugInfo::kDefaultForProd};
+ phaseManager.getHints() = queryHints;
+
+ return optimizeAndCreateExecutor(
+ phaseManager, std::move(abtTree), opCtx, expCtx, nss, collection);
+}
+
+} // namespace mongo
diff --git a/src/mongo/db/commands/cqf/cqf_aggregate.h b/src/mongo/db/commands/cqf/cqf_aggregate.h
new file mode 100644
index 00000000000..b98716400f2
--- /dev/null
+++ b/src/mongo/db/commands/cqf/cqf_aggregate.h
@@ -0,0 +1,44 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/catalog/collection.h"
+#include "mongo/db/query/plan_executor.h"
+
+namespace mongo {
+
+std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> getSBEExecutorViaCascadesOptimizer(
+ OperationContext* opCtx,
+ boost::intrusive_ptr<ExpressionContext> expCtx,
+ const NamespaceString& nss,
+ const CollectionPtr& collection,
+ const Pipeline& pipeline);
+
+} // namespace mongo
diff --git a/src/mongo/db/commands/find_cmd.cpp b/src/mongo/db/commands/find_cmd.cpp
index aaf86a480ef..d1fbe6fa8c7 100644
--- a/src/mongo/db/commands/find_cmd.cpp
+++ b/src/mongo/db/commands/find_cmd.cpp
@@ -51,6 +51,7 @@
#include "mongo/db/query/find.h"
#include "mongo/db/query/find_common.h"
#include "mongo/db/query/get_executor.h"
+#include "mongo/db/query/query_knobs_gen.h"
#include "mongo/db/repl/replication_coordinator.h"
#include "mongo/db/service_context.h"
#include "mongo/db/stats/counters.h"
@@ -317,7 +318,12 @@ public:
extensionsCallback,
MatchExpressionParser::kAllowAllSpecialFeatures));
- if (ctx->getView()) {
+ // If we are running a query against a view, or if we are trying to test the new
+ // optimizer, redirect this query through the aggregation system.
+ if (ctx->getView() ||
+ (feature_flags::gfeatureFlagCommonQueryFramework.isEnabled(
+ serverGlobalParams.featureCompatibility) &&
+ internalQueryEnableCascadesOptimizer.load())) {
// Relinquish locks. The aggregation command will re-acquire them.
ctx.reset();
@@ -521,7 +527,12 @@ public:
extensionsCallback,
MatchExpressionParser::kAllowAllSpecialFeatures));
- if (ctx->getView()) {
+ // If we are running a query against a view, or if we are trying to test the new
+ // optimizer, redirect this query through the aggregation system.
+ if (ctx->getView() ||
+ (feature_flags::gfeatureFlagCommonQueryFramework.isEnabled(
+ serverGlobalParams.featureCompatibility) &&
+ internalQueryEnableCascadesOptimizer.load())) {
// Relinquish locks. The aggregation command will re-acquire them.
ctx.reset();
diff --git a/src/mongo/db/commands/run_aggregate.cpp b/src/mongo/db/commands/run_aggregate.cpp
index 2bb3d6f9d3f..b1b80cc7827 100644
--- a/src/mongo/db/commands/run_aggregate.cpp
+++ b/src/mongo/db/commands/run_aggregate.cpp
@@ -41,6 +41,7 @@
#include "mongo/db/auth/authorization_session.h"
#include "mongo/db/catalog/database.h"
#include "mongo/db/catalog/database_holder.h"
+#include "mongo/db/commands/cqf/cqf_aggregate.h"
#include "mongo/db/curop.h"
#include "mongo/db/cursor_manager.h"
#include "mongo/db/db_raii.h"
@@ -66,6 +67,7 @@
#include "mongo/db/query/get_executor.h"
#include "mongo/db/query/plan_executor_factory.h"
#include "mongo/db/query/plan_summary_stats.h"
+#include "mongo/db/query/query_feature_flags_gen.h"
#include "mongo/db/query/query_planner_common.h"
#include "mongo/db/read_concern.h"
#include "mongo/db/repl/oplog.h"
@@ -129,7 +131,6 @@ bool handleCursorCommand(OperationContext* opCtx,
request.getCursor().getBatchSize().value_or(aggregation_request_helper::kDefaultBatchSize);
if (cursors.size() > 1) {
-
uassert(
ErrorCodes::BadValue, "the exchange initial batch size must be zero", batchSize == 0);
@@ -535,6 +536,74 @@ void performValidationChecks(const OperationContext* opCtx,
aggregation_request_helper::validateRequestForAPIVersion(opCtx, request);
}
+std::vector<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> createLegacyExecutor(
+ std::unique_ptr<Pipeline, PipelineDeleter> pipeline,
+ const LiteParsedPipeline& liteParsedPipeline,
+ const NamespaceString& nss,
+ const MultiCollection& collections,
+ const AggregateCommandRequest& request,
+ CurOp* curOp,
+ const std::function<void(void)>& resetContextFn) {
+ const auto expCtx = pipeline->getContext();
+ // Check if the pipeline has a $geoNear stage, as it will be ripped away during the build query
+ // executor phase below (to be replaced with a $geoNearCursorStage later during the executor
+ // attach phase).
+ auto hasGeoNearStage = !pipeline->getSources().empty() &&
+ dynamic_cast<DocumentSourceGeoNear*>(pipeline->peekFront());
+
+ // Prepare a PlanExecutor to provide input into the pipeline, if needed.
+ auto attachExecutorCallback =
+ PipelineD::buildInnerQueryExecutor(collections, nss, &request, pipeline.get());
+
+ std::vector<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> execs;
+ if (canOptimizeAwayPipeline(pipeline.get(),
+ attachExecutorCallback.second.get(),
+ request,
+ hasGeoNearStage,
+ liteParsedPipeline.hasChangeStream())) {
+ // Mark that this query does not use DocumentSource.
+ curOp->debug().documentSourceUsed = false;
+
+ // This pipeline is currently empty, but once completed it will have only one source,
+ // which is a DocumentSourceCursor. Instead of creating a whole pipeline to do nothing
+ // more than forward the results of its cursor document source, we can use the
+ // PlanExecutor by itself. The resulting cursor will look like what the client would
+ // have gotten from find command.
+ execs.emplace_back(std::move(attachExecutorCallback.second));
+ } else {
+ // Mark that this query uses DocumentSource.
+ curOp->debug().documentSourceUsed = true;
+
+ // Complete creation of the initial $cursor stage, if needed.
+ PipelineD::attachInnerQueryExecutorToPipeline(collections,
+ attachExecutorCallback.first,
+ std::move(attachExecutorCallback.second),
+ pipeline.get());
+
+ auto pipelines = createExchangePipelinesIfNeeded(
+ expCtx->opCtx, expCtx, request, std::move(pipeline), expCtx->uuid);
+ for (auto&& pipelineIt : pipelines) {
+ // There are separate ExpressionContexts for each exchange pipeline, so make sure to
+ // pass the pipeline's ExpressionContext to the plan executor factory.
+ auto pipelineExpCtx = pipelineIt->getContext();
+
+ execs.emplace_back(
+ plan_executor_factory::make(std::move(pipelineExpCtx),
+ std::move(pipelineIt),
+ aggregation_request_helper::getResumableScanType(
+ request, liteParsedPipeline.hasChangeStream())));
+ }
+
+ // With the pipelines created, we can relinquish locks as they will manage the locks
+ // internally further on. We still need to keep the lock for an optimized away pipeline
+ // though, as we will be changing its lock policy to 'kLockExternally' (see details
+ // below), and in order to execute the initial getNext() call in 'handleCursorCommand',
+ // we need to hold the collection lock.
+ resetContextFn();
+ }
+ return execs;
+}
+
} // namespace
Status runAggregate(OperationContext* opCtx,
@@ -604,6 +673,7 @@ Status runAggregate(OperationContext* opCtx,
std::vector<unique_ptr<PlanExecutor, PlanExecutor::Deleter>> execs;
boost::intrusive_ptr<ExpressionContext> expCtx;
auto curOp = CurOp::get(opCtx);
+
{
// If we are in a transaction, check whether the parsed pipeline supports being in
// a transaction and if the transaction's read concern is supported.
@@ -812,59 +882,29 @@ Status runAggregate(OperationContext* opCtx,
constexpr bool alreadyOptimized = true;
pipeline->validateCommon(alreadyOptimized);
- // Check if the pipeline has a $geoNear stage, as it will be ripped away during the build
- // query executor phase below (to be replaced with a $geoNearCursorStage later during the
- // executor attach phase).
- auto hasGeoNearStage = !pipeline->getSources().empty() &&
- dynamic_cast<DocumentSourceGeoNear*>(pipeline->peekFront());
-
- // Prepare a PlanExecutor to provide input into the pipeline, if needed.
- auto attachExecutorCallback =
- PipelineD::buildInnerQueryExecutor(collections, nss, &request, pipeline.get());
-
- if (canOptimizeAwayPipeline(pipeline.get(),
- attachExecutorCallback.second.get(),
- request,
- hasGeoNearStage,
- liteParsedPipeline.hasChangeStream())) {
- // This pipeline is currently empty, but once completed it will have only one source,
- // which is a DocumentSourceCursor. Instead of creating a whole pipeline to do nothing
- // more than forward the results of its cursor document source, we can use the
- // PlanExecutor by itself. The resulting cursor will look like what the client would
- // have gotten from find command.
- execs.emplace_back(std::move(attachExecutorCallback.second));
- // Mark that this query does not use DocumentSource.
- curOp->debug().documentSourceUsed = false;
+ if (feature_flags::gfeatureFlagCommonQueryFramework.isEnabled(
+ serverGlobalParams.featureCompatibility) &&
+ internalQueryEnableCascadesOptimizer.load()) {
+ uassert(6624344,
+ "Exchanging is not supported in the Cascades optimizer",
+ !request.getExchange().has_value());
+
+ auto timeBegin = Date_t::now();
+ execs.emplace_back(getSBEExecutorViaCascadesOptimizer(
+ opCtx, expCtx, nss, collections.getMainCollection(), *pipeline));
+ auto elapsed =
+ (Date_t::now().toMillisSinceEpoch() - timeBegin.toMillisSinceEpoch()) / 1000.0;
+ std::cerr << "Optimization took: " << elapsed << " s.\n";
} else {
- // Mark that this query uses DocumentSource.
- curOp->debug().documentSourceUsed = true;
- // Complete creation of the initial $cursor stage, if needed.
- PipelineD::attachInnerQueryExecutorToPipeline(collections,
- attachExecutorCallback.first,
- std::move(attachExecutorCallback.second),
- pipeline.get());
-
- auto pipelines =
- createExchangePipelinesIfNeeded(opCtx, expCtx, request, std::move(pipeline), uuid);
- for (auto&& pipelineIt : pipelines) {
- // There are separate ExpressionContexts for each exchange pipeline, so make sure to
- // pass the pipeline's ExpressionContext to the plan executor factory.
- auto pipelineExpCtx = pipelineIt->getContext();
-
- execs.emplace_back(plan_executor_factory::make(
- std::move(pipelineExpCtx),
- std::move(pipelineIt),
- aggregation_request_helper::getResumableScanType(
- request, liteParsedPipeline.hasChangeStream())));
- }
-
- // With the pipelines created, we can relinquish locks as they will manage the locks
- // internally further on. We still need to keep the lock for an optimized away pipeline
- // though, as we will be changing its lock policy to 'kLockExternally' (see details
- // below), and in order to execute the initial getNext() call in 'handleCursorCommand',
- // we need to hold the collection lock.
- resetContext();
+ execs = createLegacyExecutor(std::move(pipeline),
+ liteParsedPipeline,
+ nss,
+ collections,
+ request,
+ curOp,
+ resetContext);
}
+ tassert(6624353, "No executors", !execs.empty());
{
auto planSummary = execs[0]->getPlanExplainer().getPlanSummary();
diff --git a/src/mongo/db/exec/sbe/SConscript b/src/mongo/db/exec/sbe/SConscript
index 0e220aafed2..c9395ca88e6 100644
--- a/src/mongo/db/exec/sbe/SConscript
+++ b/src/mongo/db/exec/sbe/SConscript
@@ -114,6 +114,18 @@ env.Library(
)
env.Library(
+ target='query_sbe_abt',
+ source=[
+ 'abt/abt_lower.cpp',
+ ],
+ LIBDEPS=[
+ '$BUILD_DIR/mongo/db/query/optimizer/optimizer',
+ 'query_sbe',
+ 'query_sbe_storage'
+ ]
+ )
+
+env.Library(
target='sbe_plan_stage_test',
source=[
'sbe_plan_stage_test.cpp',
@@ -191,3 +203,20 @@ env.CppUnitTest(
'sbe_plan_stage_test',
],
)
+
+env.CppUnitTest(
+ target='sbe_abt_test',
+ source=[
+ 'abt/sbe_abt_diff_test.cpp',
+ 'abt/sbe_abt_test.cpp',
+ 'abt/sbe_abt_test_util.cpp',
+ ],
+ LIBDEPS=[
+ "$BUILD_DIR/mongo/db/auth/authmocks",
+ '$BUILD_DIR/mongo/db/query/query_test_service_context',
+ '$BUILD_DIR/mongo/db/query_exec',
+ '$BUILD_DIR/mongo/db/service_context_test_fixture',
+ '$BUILD_DIR/mongo/unittest/unittest',
+ 'query_sbe_abt',
+ ],
+)
diff --git a/src/mongo/db/exec/sbe/abt/abt_lower.cpp b/src/mongo/db/exec/sbe/abt/abt_lower.cpp
new file mode 100644
index 00000000000..737f94064da
--- /dev/null
+++ b/src/mongo/db/exec/sbe/abt/abt_lower.cpp
@@ -0,0 +1,1014 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/exec/sbe/abt/abt_lower.h"
+#include "mongo/db/exec/sbe/stages/bson_scan.h"
+#include "mongo/db/exec/sbe/stages/co_scan.h"
+#include "mongo/db/exec/sbe/stages/exchange.h"
+#include "mongo/db/exec/sbe/stages/filter.h"
+#include "mongo/db/exec/sbe/stages/hash_agg.h"
+#include "mongo/db/exec/sbe/stages/hash_join.h"
+#include "mongo/db/exec/sbe/stages/ix_scan.h"
+#include "mongo/db/exec/sbe/stages/limit_skip.h"
+#include "mongo/db/exec/sbe/stages/loop_join.h"
+#include "mongo/db/exec/sbe/stages/merge_join.h"
+#include "mongo/db/exec/sbe/stages/project.h"
+#include "mongo/db/exec/sbe/stages/scan.h"
+#include "mongo/db/exec/sbe/stages/sort.h"
+#include "mongo/db/exec/sbe/stages/union.h"
+#include "mongo/db/exec/sbe/stages/unique.h"
+#include "mongo/db/exec/sbe/stages/unwind.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer {
+
+static sbe::EExpression::Vector toInlinedVector(
+ std::vector<std::unique_ptr<sbe::EExpression>> args) {
+ sbe::EExpression::Vector inlined;
+ for (auto&& arg : args) {
+ inlined.emplace_back(std::move(arg));
+ }
+ return inlined;
+}
+
+std::unique_ptr<sbe::EExpression> SBEExpressionLowering::optimize(const ABT& n) {
+ return algebra::transport<false>(n, *this);
+}
+
+std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(const Constant& c) {
+ auto [tag, val] = c.get();
+ auto [copyTag, copyVal] = sbe::value::copyValue(tag, val);
+ sbe::value::ValueGuard guard(copyTag, copyVal);
+
+ auto result = sbe::makeE<sbe::EConstant>(copyTag, copyVal);
+
+ guard.reset();
+ return result;
+}
+
+std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(const Source&) {
+ uasserted(6624202, "not yet implemented");
+ return nullptr;
+}
+
+std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(const Variable& var) {
+ auto def = _env.getDefinition(&var);
+
+ if (!def.definedBy.empty()) {
+ if (auto let = def.definedBy.cast<Let>(); let) {
+ auto it = _letMap.find(let);
+ uassert(6624203, "incorrect let map", it != _letMap.end());
+
+ return sbe::makeE<sbe::EVariable>(it->second, 0, _env.isLastRef(&var));
+ } else if (auto lam = def.definedBy.cast<LambdaAbstraction>(); lam) {
+ // This is a lambda parameter.
+ auto it = _lambdaMap.find(lam);
+ uassert(6624204, "incorrect lambda map", it != _lambdaMap.end());
+
+ return sbe::makeE<sbe::EVariable>(it->second, 0, _env.isLastRef(&var));
+ }
+ }
+ if (auto it = _slotMap.find(var.name()); it != _slotMap.end()) {
+ // Found the slot.
+ return sbe::makeE<sbe::EVariable>(it->second);
+ }
+ uasserted(6624205, str::stream() << "undefined variable: " << var.name());
+ return nullptr;
+}
+
+std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(
+ const BinaryOp& op,
+ std::unique_ptr<sbe::EExpression> lhs,
+ std::unique_ptr<sbe::EExpression> rhs) {
+
+ sbe::EPrimBinary::Op sbeOp = [](const auto abtOp) {
+ switch (abtOp) {
+ case Operations::Eq:
+ return sbe::EPrimBinary::eq;
+ case Operations::Neq:
+ return sbe::EPrimBinary::neq;
+ case Operations::Gt:
+ return sbe::EPrimBinary::greater;
+ case Operations::Gte:
+ return sbe::EPrimBinary::greaterEq;
+ case Operations::Lt:
+ return sbe::EPrimBinary::less;
+ case Operations::Lte:
+ return sbe::EPrimBinary::lessEq;
+ case Operations::Add:
+ return sbe::EPrimBinary::add;
+ case Operations::Sub:
+ return sbe::EPrimBinary::sub;
+ case Operations::And:
+ return sbe::EPrimBinary::logicAnd;
+ case Operations::Or:
+ return sbe::EPrimBinary::logicOr;
+ case Operations::Cmp3w:
+ return sbe::EPrimBinary::cmp3w;
+ case Operations::Div:
+ return sbe::EPrimBinary::div;
+ case Operations::Mult:
+ return sbe::EPrimBinary::mul;
+ default:
+ MONGO_UNREACHABLE;
+ }
+ }(op.op());
+
+ return sbe::makeE<sbe::EPrimBinary>(sbeOp, std::move(lhs), std::move(rhs));
+}
+
+std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(
+ const UnaryOp& op, std::unique_ptr<sbe::EExpression> arg) {
+
+ sbe::EPrimUnary::Op sbeOp = [](const auto abtOp) {
+ switch (abtOp) {
+ case Operations::Neg:
+ return sbe::EPrimUnary::negate;
+ case Operations::Not:
+ return sbe::EPrimUnary::logicNot;
+ default:
+ MONGO_UNREACHABLE;
+ }
+ }(op.op());
+
+ return sbe::makeE<sbe::EPrimUnary>(sbeOp, std::move(arg));
+}
+
+std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(
+ const If&,
+ std::unique_ptr<sbe::EExpression> cond,
+ std::unique_ptr<sbe::EExpression> thenBranch,
+ std::unique_ptr<sbe::EExpression> elseBranch) {
+ return sbe::makeE<sbe::EIf>(std::move(cond), std::move(thenBranch), std::move(elseBranch));
+}
+
+void SBEExpressionLowering::prepare(const Let& let) {
+ _letMap[&let] = ++_frameCounter;
+}
+
+std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(
+ const Let& let, std::unique_ptr<sbe::EExpression> bind, std::unique_ptr<sbe::EExpression> in) {
+ auto it = _letMap.find(&let);
+ uassert(6624206, "incorrect let map", it != _letMap.end());
+ auto frameId = it->second;
+ _letMap.erase(it);
+
+ // ABT let binds only a single variable. When we extend it to support multiple binds then we
+ // have to revisit how we map variable names to sbe slot ids.
+ return sbe::makeE<sbe::ELocalBind>(frameId, sbe::makeEs(std::move(bind)), std::move(in));
+}
+
+void SBEExpressionLowering::prepare(const LambdaAbstraction& lam) {
+ _lambdaMap[&lam] = ++_frameCounter;
+}
+
+std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(
+ const LambdaAbstraction& lam, std::unique_ptr<sbe::EExpression> body) {
+ auto it = _lambdaMap.find(&lam);
+ uassert(6624207, "incorrect lambda map", it != _lambdaMap.end());
+ auto frameId = it->second;
+ _lambdaMap.erase(it);
+
+ return sbe::makeE<sbe::ELocalLambda>(frameId, std::move(body));
+}
+
+std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(
+ const LambdaApplication&,
+ std::unique_ptr<sbe::EExpression> lam,
+ std::unique_ptr<sbe::EExpression> arg) {
+ // lambda applications are not directly supported by SBE (yet) and must not be present.
+ uasserted(6624208, "lambda application is not implemented");
+ return nullptr;
+}
+
+std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(
+ const FunctionCall& fn, std::vector<std::unique_ptr<sbe::EExpression>> args) {
+ auto name = fn.name();
+
+ if (name == "typeMatch") {
+ uassert(6624209, "Invalid number of typeMatch arguments", fn.nodes().size() == 2);
+ if (const auto* constPtr = fn.nodes().at(1).cast<Constant>();
+ constPtr != nullptr && constPtr->isValueInt64()) {
+ return sbe::makeE<sbe::ETypeMatch>(std::move(args.at(0)), constPtr->getValueInt64());
+ } else {
+ uasserted(6624210, "Second argument of typeMatch must be an integer constant");
+ }
+ }
+
+ // TODO - this is an open question how to do the name mappings.
+ if (name == "$sum") {
+ name = "sum";
+ } else if (name == "$first") {
+ name = "first";
+ } else if (name == "$last") {
+ name = "last";
+ } else if (name == "$min") {
+ name = "min";
+ } else if (name == "$max") {
+ name = "max";
+ } else if (name == "$addToSet") {
+ name = "addToSet";
+ }
+
+ return sbe::makeE<sbe::EFunction>(name, toInlinedVector(std::move(args)));
+}
+
+sbe::value::SlotVector SBENodeLowering::convertProjectionsToSlots(
+ const ProjectionNameVector& projectionNames) {
+ sbe::value::SlotVector result;
+ for (const ProjectionName& projectionName : projectionNames) {
+ auto it = _slotMap.find(projectionName);
+ uassert(6624211,
+ str::stream() << "undefined variable: " << projectionName,
+ it != _slotMap.end());
+ result.push_back(it->second);
+ }
+ return result;
+}
+
+sbe::value::SlotVector SBENodeLowering::convertRequiredProjectionsToSlots(
+ const NodeProps& props, const bool addRIDProjection, const ProjectionNameVector& toExclude) {
+ using namespace properties;
+
+ const PhysProps& physProps = props._physicalProps;
+ auto projections = getPropertyConst<ProjectionRequirement>(physProps).getProjections();
+
+ if (addRIDProjection && hasProperty<IndexingRequirement>(physProps) &&
+ getPropertyConst<IndexingRequirement>(physProps).getNeedsRID()) {
+ const auto& scanDefName =
+ getPropertyConst<IndexingAvailability>(props._logicalProps).getScanDefName();
+ projections.emplace_back(_ridProjections.at(scanDefName));
+ }
+
+ for (const ProjectionName& projName : toExclude) {
+ projections.erase(projName);
+ }
+
+ return convertProjectionsToSlots(projections.getVector());
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::optimize(const ABT& n) {
+ return generateInternal(n);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::generateInternal(const ABT& n) {
+ return algebra::walk<false>(n, *this);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const RootNode& n,
+ const ABT& child,
+ const ABT& refs) {
+
+ auto input = generateInternal(child);
+
+ auto output = refs.cast<References>();
+ uassert(6624212, "refs expected", output);
+
+ SlotVarMap finalMap;
+ for (auto& o : output->nodes()) {
+ auto var = o.cast<Variable>();
+ uassert(6624213, "var expected", var);
+ if (auto it = _slotMap.find(var->name()); it != _slotMap.end()) {
+ finalMap.emplace(var->name(), it->second);
+ }
+ }
+
+ _slotMap = finalMap;
+
+ return input;
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const EvaluationNode& n,
+ const ABT& child,
+ const ABT& binds) {
+ auto input = generateInternal(child);
+
+ if (auto varPtr = n.getProjection().cast<Variable>(); varPtr != nullptr) {
+ // Evaluation node is only renaming a variable. Do not place a project stage.
+ _slotMap.emplace(n.getProjectionName(), _slotMap.at(varPtr->name()));
+ return input;
+ }
+
+ auto binder = binds.cast<ExpressionBinder>();
+ uassert(6624214, "binder expected", binder);
+
+ auto& names = binder->names();
+ auto& exprs = binder->exprs();
+
+ sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> projects;
+
+ for (size_t idx = 0; idx < exprs.size(); ++idx) {
+ auto expr = SBEExpressionLowering{_env, _slotMap}.optimize(exprs[idx]);
+ auto slot = _slotIdGenerator.generate();
+
+ _slotMap.emplace(names[idx], slot);
+ projects.emplace(slot, std::move(expr));
+ }
+
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+ return sbe::makeS<sbe::ProjectStage>(std::move(input), std::move(projects), planNodeId);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const FilterNode& n,
+ const ABT& child,
+ const ABT& filter) {
+ auto input = generateInternal(child);
+ auto expr = SBEExpressionLowering{_env, _slotMap}.optimize(filter);
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+
+ // Check if the filter expression is 'constant' (i.e. does not depend on any variables)
+ // and create FilterStage<true>.
+ if (_env.getVariables(filter)._variables.empty()) {
+ return sbe::makeS<sbe::FilterStage<true>>(std::move(input), std::move(expr), planNodeId);
+ } else {
+ return sbe::makeS<sbe::FilterStage<false>>(std::move(input), std::move(expr), planNodeId);
+ }
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const LimitSkipNode& n, const ABT& child) {
+ auto input = generateInternal(child);
+
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+ return sbe::makeS<sbe::LimitSkipStage>(
+ std::move(input), n.getProperty().getLimit(), n.getProperty().getSkip(), planNodeId);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const ExchangeNode& n,
+ const ABT& child,
+ const ABT& refs) {
+ using namespace std::literals;
+ // TODO: Implement all types of distributions.
+ using namespace properties;
+
+ // The DOP is obtained from the child (number of producers).
+ const auto& childProps = _nodeToGroupPropsMap.at(n.getChild().cast<Node>())._physicalProps;
+ const auto& childDistribution = getPropertyConst<DistributionRequirement>(childProps);
+ uassert(6624330,
+ "Parent and child distributions are the same",
+ !(childDistribution == n.getProperty()));
+
+ const size_t localDOP =
+ (childDistribution.getDistributionAndProjections()._type == DistributionType::Centralized)
+ ? 1
+ : _metadata._numberOfPartitions;
+ uassert(6624215, "invalid DOP", localDOP >= 1);
+
+ auto input = generateInternal(child);
+
+ // Initialized to arbitrary placeholder
+ sbe::ExchangePolicy localPolicy{};
+ std::unique_ptr<sbe::EExpression> partitionExpr;
+
+ const auto& distribAndProjections = n.getProperty().getDistributionAndProjections();
+ switch (distribAndProjections._type) {
+ case DistributionType::Centralized:
+ case DistributionType::Replicated:
+ localPolicy = sbe::ExchangePolicy::broadcast;
+ break;
+
+ case DistributionType::RoundRobin:
+ localPolicy = sbe::ExchangePolicy::roundrobin;
+ break;
+
+ case DistributionType::RangePartitioning:
+ localPolicy = sbe::ExchangePolicy::rangepartition;
+ break;
+
+ case DistributionType::HashPartitioning: {
+ localPolicy = sbe::ExchangePolicy::hashpartition;
+ std::vector<std::unique_ptr<sbe::EExpression>> args;
+ for (const ProjectionName& proj : distribAndProjections._projectionNames) {
+ auto it = _slotMap.find(proj);
+ uassert(6624216, str::stream() << "undefined var: " << proj, it != _slotMap.end());
+
+ args.emplace_back(sbe::makeE<sbe::EVariable>(it->second));
+ }
+ partitionExpr = sbe::makeE<sbe::EFunction>("hash"_sd, toInlinedVector(std::move(args)));
+ break;
+ }
+
+ case DistributionType::UnknownPartitioning:
+ uasserted(6624217, "Cannot partition into unknown distribution");
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+
+ const auto& nodeProps = _nodeToGroupPropsMap.at(&n);
+ auto fields = convertRequiredProjectionsToSlots(nodeProps, true /*addRIDProjection*/);
+
+ return sbe::makeS<sbe::ExchangeConsumer>(std::move(input),
+ localDOP,
+ std::move(fields),
+ localPolicy,
+ std::move(partitionExpr),
+ nullptr,
+ nodeProps._planNodeId);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const CollationNode& n,
+ const ABT& child,
+ const ABT& refs) {
+ auto input = generateInternal(child);
+
+ sbe::value::SlotVector orderBySlots;
+ std::vector<sbe::value::SortDirection> directions;
+ ProjectionNameVector collationProjections;
+ for (const auto& entry : n.getProperty().getCollationSpec()) {
+ collationProjections.push_back(entry.first);
+ auto it = _slotMap.find(entry.first);
+
+ uassert(6624219,
+ str::stream() << "undefined orderBy var: " << entry.first,
+ it != _slotMap.end());
+ orderBySlots.push_back(it->second);
+
+ switch (entry.second) {
+ case CollationOp::Ascending:
+ case CollationOp::Clustered:
+ // TODO: is there a more efficient way to compute clustered collation op than sort?
+ directions.push_back(sbe::value::SortDirection::Ascending);
+ break;
+
+ case CollationOp::Descending:
+ directions.push_back(sbe::value::SortDirection::Descending);
+ break;
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+ }
+
+ const auto& nodeProps = _nodeToGroupPropsMap.at(&n);
+ const auto& physProps = nodeProps._physicalProps;
+
+ size_t limit = std::numeric_limits<std::size_t>::max();
+ if (properties::hasProperty<properties::LimitSkipRequirement>(physProps)) {
+ const auto& limitSkipReq =
+ properties::getPropertyConst<properties::LimitSkipRequirement>(physProps);
+ uassert(6624221, "We should not have skip set here", limitSkipReq.getSkip() == 0);
+ limit = limitSkipReq.getLimit();
+ }
+
+ // TODO: obtain defaults for these.
+ const size_t memoryLimit = 100 * (1ul << 20); // 100MB
+ const bool allowDiskUse = false;
+
+ auto vals = convertRequiredProjectionsToSlots(
+ nodeProps, true /*addRIDProjection*/, collationProjections);
+ return sbe::makeS<sbe::SortStage>(std::move(input),
+ std::move(orderBySlots),
+ std::move(directions),
+ std::move(vals),
+ limit,
+ memoryLimit,
+ allowDiskUse,
+ nodeProps._planNodeId);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const UniqueNode& n,
+ const ABT& child,
+ const ABT& refs) {
+ auto input = generateInternal(child);
+
+ sbe::value::SlotVector keySlots;
+ for (const ProjectionName& projectionName : n.getProjections()) {
+ auto it = _slotMap.find(projectionName);
+ uassert(6624222,
+ str::stream() << "undefined variable: " << projectionName,
+ it != _slotMap.end());
+ keySlots.push_back(it->second);
+ }
+
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+ return sbe::makeS<sbe::UniqueStage>(std::move(input), std::move(keySlots), planNodeId);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const GroupByNode& n,
+ const ABT& child,
+ const ABT& aggBinds,
+ const ABT& aggRefs,
+ const ABT& gbBind,
+ const ABT& gbRefs) {
+
+ auto input = generateInternal(child);
+
+ // Ideally, we should make a distinction between gbBind and gbRefs; i.e. internal references
+ // used by the hash agg to determinte the group by values from its input and group by values as
+ // outputted by the hash agg after the grouping. However, SBE hash agg uses the same slot it to
+ // represent both so that distinction is kind of moot.
+ sbe::value::SlotVector gbs;
+ auto gbCols = gbRefs.cast<References>();
+ uassert(6624223, "refs expected", gbCols);
+ for (auto& o : gbCols->nodes()) {
+ auto var = o.cast<Variable>();
+ uassert(6624224, "var expected", var);
+ auto it = _slotMap.find(var->name());
+ uassert(6624225, str::stream() << "undefined var: " << var->name(), it != _slotMap.end());
+ gbs.push_back(it->second);
+ }
+
+ // Similar considerations apply to the agg expressions as to the group by columns.
+ auto binderAgg = aggBinds.cast<ExpressionBinder>();
+ uassert(6624226, "binder expected", binderAgg);
+ auto refsAgg = aggRefs.cast<References>();
+ uassert(6624227, "refs expected", refsAgg);
+
+ auto& names = binderAgg->names();
+ auto& exprs = refsAgg->nodes();
+
+ sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> aggs;
+
+ for (size_t idx = 0; idx < exprs.size(); ++idx) {
+ auto expr = SBEExpressionLowering{_env, _slotMap}.optimize(exprs[idx]);
+ auto slot = _slotIdGenerator.generate();
+
+ _slotMap.emplace(names[idx], slot);
+ aggs.emplace(slot, std::move(expr));
+ }
+
+ // TODO: use collator slot.
+ boost::optional<sbe::value::SlotId> collatorSlot;
+ // Unused
+ sbe::value::SlotVector seekKeysSlots;
+
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+ return sbe::makeS<sbe::HashAggStage>(std::move(input),
+ std::move(gbs),
+ std::move(aggs),
+ std::move(seekKeysSlots),
+ true /*optimizedClose*/,
+ collatorSlot,
+ false /*allowDiskUse*/,
+ planNodeId);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const BinaryJoinNode& n,
+ const ABT& leftChild,
+ const ABT& rightChild,
+ const ABT& filter) {
+ auto outerStage = generateInternal(leftChild);
+ auto innerStage = generateInternal(rightChild);
+
+ // List of correlated projections (bound in outer side and referred to in the inner side).
+ sbe::value::SlotVector correlatedSlots;
+ for (const ProjectionName& projectionName : n.getCorrelatedProjectionNames()) {
+ correlatedSlots.push_back(_slotMap.at(projectionName));
+ }
+
+ const auto& leftChildProps = _nodeToGroupPropsMap.at(n.getLeftChild().cast<Node>());
+ auto expr = SBEExpressionLowering{_env, _slotMap}.optimize(filter);
+
+ auto outerProjects =
+ convertRequiredProjectionsToSlots(leftChildProps, true /*addRIDProjection*/);
+
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+ return sbe::makeS<sbe::LoopJoinStage>(std::move(outerStage),
+ std::move(innerStage),
+ std::move(outerProjects),
+ std::move(correlatedSlots),
+ std::move(expr),
+ planNodeId);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const HashJoinNode& n,
+ const ABT& leftChild,
+ const ABT& rightChild,
+ const ABT& refs) {
+ auto outerStage = generateInternal(leftChild);
+ auto innerStage = generateInternal(rightChild);
+
+ uassert(6624228, "Only inner joins supported for now", n.getJoinType() == JoinType::Inner);
+
+ const auto& leftProps = _nodeToGroupPropsMap.at(n.getLeftChild().cast<Node>());
+ const auto& rightProps = _nodeToGroupPropsMap.at(n.getRightChild().cast<Node>());
+
+ // Add RID projection only from outer side.
+ auto outerKeys = convertProjectionsToSlots(n.getLeftKeys());
+ auto outerProjects =
+ convertRequiredProjectionsToSlots(leftProps, true /*addRIDProjection*/, n.getLeftKeys());
+ auto innerKeys = convertProjectionsToSlots(n.getRightKeys());
+ auto innerProjects =
+ convertRequiredProjectionsToSlots(rightProps, false /*addRIDProjection*/, n.getRightKeys());
+
+ // TODO: use collator slot.
+ boost::optional<sbe::value::SlotId> collatorSlot;
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+ return sbe::makeS<sbe::HashJoinStage>(std::move(outerStage),
+ std::move(innerStage),
+ std::move(outerKeys),
+ std::move(outerProjects),
+ std::move(innerKeys),
+ std::move(innerProjects),
+ collatorSlot,
+ planNodeId);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const MergeJoinNode& n,
+ const ABT& leftChild,
+ const ABT& rightChild,
+ const ABT& refs) {
+ auto outerStage = generateInternal(leftChild);
+ auto innerStage = generateInternal(rightChild);
+
+ const auto& leftProps = _nodeToGroupPropsMap.at(n.getLeftChild().cast<Node>());
+ const auto& rightProps = _nodeToGroupPropsMap.at(n.getRightChild().cast<Node>());
+
+ std::vector<sbe::value::SortDirection> sortDirs;
+ for (const CollationOp op : n.getCollation()) {
+ switch (op) {
+ case CollationOp::Ascending:
+ case CollationOp::Clustered:
+ sortDirs.push_back(sbe::value::SortDirection::Ascending);
+ break;
+
+ case CollationOp::Descending:
+ sortDirs.push_back(sbe::value::SortDirection::Descending);
+ break;
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+ }
+
+ // Add RID projection only from outer side.
+ auto outerKeys = convertProjectionsToSlots(n.getLeftKeys());
+ auto outerProjects =
+ convertRequiredProjectionsToSlots(leftProps, true /*addRIDProjection*/, n.getLeftKeys());
+ auto innerKeys = convertProjectionsToSlots(n.getRightKeys());
+ auto innerProjects =
+ convertRequiredProjectionsToSlots(rightProps, false /*addRIDProjection*/, n.getRightKeys());
+
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+ return sbe::makeS<sbe::MergeJoinStage>(std::move(outerStage),
+ std::move(innerStage),
+ std::move(outerKeys),
+ std::move(outerProjects),
+ std::move(innerKeys),
+ std::move(innerProjects),
+ std::move(sortDirs),
+ planNodeId);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const UnionNode& n,
+ const ABTVector& children,
+ const ABT& binder,
+ const ABT& refs) {
+ auto unionBinder = binder.cast<ExpressionBinder>();
+ uassert(6624229, "binder expected", unionBinder);
+ const auto& names = unionBinder->names();
+
+ sbe::PlanStage::Vector loweredChildren;
+ std::vector<sbe::value::SlotVector> inputVals;
+
+ for (const ABT& child : children) {
+ // Use a fresh map to prevent same projections for every child being overwritten.
+ SlotVarMap localMap;
+ SBENodeLowering localLowering(_env,
+ localMap,
+ _slotIdGenerator,
+ _metadata,
+ _nodeToGroupPropsMap,
+ _ridProjections,
+ _randomScan);
+ auto loweredChild = localLowering.optimize(child);
+
+ if (children.size() == 1) {
+ // Union with one child is used to restrict projections. Do not place a union stage.
+ for (const auto& name : names) {
+ _slotMap.emplace(name, localMap.at(name));
+ }
+ return loweredChild;
+ }
+ loweredChildren.push_back(std::move(loweredChild));
+
+ sbe::value::SlotVector childSlots;
+ for (const auto& name : names) {
+ childSlots.push_back(localMap.at(name));
+ }
+ inputVals.emplace_back(std::move(childSlots));
+ }
+
+ sbe::value::SlotVector outputVals;
+ for (const auto& name : names) {
+ const auto outputSlot = _slotIdGenerator.generate();
+ _slotMap.emplace(name, outputSlot);
+ outputVals.push_back(outputSlot);
+ }
+
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+ return sbe::makeS<sbe::UnionStage>(
+ std::move(loweredChildren), std::move(inputVals), std::move(outputVals), planNodeId);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const UnwindNode& n,
+ const ABT& child,
+ const ABT& pidBind,
+ const ABT& refs) {
+ auto input = generateInternal(child);
+
+ auto it = _slotMap.find(n.getProjectionName());
+ uassert(6624230,
+ str::stream() << "undefined unwind variable: " << n.getProjectionName(),
+ it != _slotMap.end());
+
+ auto inputSlot = it->second;
+ auto outputSlot = _slotIdGenerator.generate();
+ auto outputPidSlot = _slotIdGenerator.generate();
+
+ _slotMap[n.getProjectionName()] = outputSlot;
+ _slotMap[n.getPIDProjectionName()] = outputPidSlot;
+
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+ return sbe::makeS<sbe::UnwindStage>(
+ std::move(input), inputSlot, outputSlot, outputPidSlot, n.getRetainNonArrays(), planNodeId);
+}
+
+void SBENodeLowering::generateSlots(const FieldProjectionMap& fieldProjectionMap,
+ boost::optional<sbe::value::SlotId>& ridSlot,
+ boost::optional<sbe::value::SlotId>& rootSlot,
+ std::vector<std::string>& fields,
+ sbe::value::SlotVector& vars) {
+ if (!fieldProjectionMap._ridProjection.empty()) {
+ ridSlot = _slotIdGenerator.generate();
+ _slotMap.emplace(fieldProjectionMap._ridProjection, ridSlot.value());
+ }
+ if (!fieldProjectionMap._rootProjection.empty()) {
+ rootSlot = _slotIdGenerator.generate();
+ _slotMap.emplace(fieldProjectionMap._rootProjection, rootSlot.value());
+ }
+ for (const auto& [fieldName, projectionName] : fieldProjectionMap._fieldProjections) {
+ vars.push_back(_slotIdGenerator.generate());
+ _slotMap.emplace(projectionName, vars.back());
+ fields.push_back(fieldName);
+ }
+}
+
+static NamespaceStringOrUUID parseFromScanDef(const ScanDefinition& def) {
+ const auto& dbName = def.getOptionsMap().at("database");
+ const auto& uuidStr = def.getOptionsMap().at("uuid");
+ return {dbName, UUID::parse(uuidStr).getValue()};
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::lowerScanNode(
+ const Node& n,
+ const std::string& scanDefName,
+ const FieldProjectionMap& fieldProjectionMap,
+ const bool useParallelScan) {
+ const ScanDefinition& def = _metadata._scanDefs.at(scanDefName);
+ uassert(6624231, "Collection must exist to lower Scan", def.exists());
+ auto& typeSpec = def.getOptionsMap().at("type");
+
+ boost::optional<sbe::value::SlotId> ridSlot;
+ boost::optional<sbe::value::SlotId> rootSlot;
+ std::vector<std::string> fields;
+ sbe::value::SlotVector vars;
+ generateSlots(fieldProjectionMap, ridSlot, rootSlot, fields, vars);
+
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+ if (typeSpec == "mongod") {
+ NamespaceStringOrUUID nss = parseFromScanDef(def);
+
+ // Unused.
+ boost::optional<sbe::value::SlotId> seekKeySlot;
+
+ sbe::ScanCallbacks callbacks({}, {}, {});
+ if (useParallelScan) {
+ return sbe::makeS<sbe::ParallelScanStage>(nss.uuid().get(),
+ rootSlot,
+ ridSlot,
+ boost::none,
+ boost::none,
+ boost::none,
+ boost::none,
+ fields,
+ vars,
+ nullptr /*yieldPolicy*/,
+ planNodeId,
+ callbacks);
+ } else {
+ return sbe::makeS<sbe::ScanStage>(nss.uuid().get(),
+ rootSlot,
+ ridSlot,
+ boost::none,
+ boost::none,
+ boost::none,
+ boost::none,
+ boost::none,
+ fields,
+ vars,
+ seekKeySlot,
+ true /*forward*/,
+ nullptr /*yieldPolicy*/,
+ planNodeId,
+ callbacks,
+ _randomScan);
+ }
+ } else {
+ uasserted(6624355, "Unknown scan type.");
+ }
+ return nullptr;
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const ScanNode& n, const ABT& /*binds*/) {
+ FieldProjectionMap fieldProjectionMap;
+ fieldProjectionMap._rootProjection = n.getProjectionName();
+ return lowerScanNode(n, n.getScanDefName(), fieldProjectionMap, false /*useParallelScan*/);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const PhysicalScanNode& n,
+ const ABT& /*binds*/) {
+ return lowerScanNode(n, n.getScanDefName(), n.getFieldProjectionMap(), n.useParallelScan());
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const CoScanNode& n) {
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+ return sbe::makeS<sbe::CoScanStage>(planNodeId);
+}
+
+std::unique_ptr<sbe::EExpression> SBENodeLowering::convertBoundsToExpr(
+ const bool isLower,
+ const IndexDefinition& indexDef,
+ const MultiKeyIntervalRequirement& interval) {
+ std::vector<std::unique_ptr<sbe::EExpression>> ksFnArgs;
+ ksFnArgs.emplace_back(
+ sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::NumberInt64,
+ sbe::value::bitcastFrom<int64_t>(indexDef.getVersion())));
+
+ // TODO: ordering is unsigned int32??
+ ksFnArgs.emplace_back(
+ sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::NumberInt32,
+ sbe::value::bitcastFrom<uint32_t>(indexDef.getOrdering())));
+
+ auto exprLower = SBEExpressionLowering{_env, _slotMap};
+ bool inclusive = true;
+ bool fullyInfinite = true;
+ for (const auto& entry : interval) {
+ const BoundRequirement& entryBound = isLower ? entry.getLowBound() : entry.getHighBound();
+ const bool isInfinite = entryBound.isInfinite();
+ if (!isInfinite) {
+ fullyInfinite = false;
+ if (!entryBound.isInclusive()) {
+ inclusive = false;
+ }
+ }
+
+ ABT bound = isInfinite ? (isLower ? Constant::minKey() : Constant::maxKey())
+ : entryBound.getBound();
+ auto boundExpr = exprLower.optimize(std::move(bound));
+ ksFnArgs.emplace_back(std::move(boundExpr));
+ }
+ if (fullyInfinite && !isLower) {
+ // We can skip if fully infinite only for upper bound. For lower bound we need to generate
+ // minkeys.
+ return nullptr;
+ };
+
+ ksFnArgs.emplace_back(sbe::makeE<sbe::EConstant>(
+ sbe::value::TypeTags::NumberInt64,
+ sbe::value::bitcastFrom<int64_t>((isLower == inclusive) ? 1 : 2)));
+ return sbe::makeE<sbe::EFunction>("ks", toInlinedVector(std::move(ksFnArgs)));
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const IndexScanNode& n, const ABT&) {
+ const auto& fieldProjectionMap = n.getFieldProjectionMap();
+ const auto& indexSpec = n.getIndexSpecification();
+ const auto& interval = indexSpec.getInterval();
+
+ const std::string& indexDefName = n.getIndexSpecification().getIndexDefName();
+ const ScanDefinition& scanDef = _metadata._scanDefs.at(indexSpec.getScanDefName());
+ uassert(6624232, "Collection must exist to lower IndexScan", scanDef.exists());
+ const IndexDefinition& indexDef = scanDef.getIndexDefs().at(indexDefName);
+
+ NamespaceStringOrUUID nss = parseFromScanDef(scanDef);
+
+ boost::optional<sbe::value::SlotId> ridSlot;
+ boost::optional<sbe::value::SlotId> rootSlot;
+ std::vector<std::string> fields;
+ sbe::value::SlotVector vars;
+ generateSlots(fieldProjectionMap, ridSlot, rootSlot, fields, vars);
+ uassert(6624233, "Cannot deliver root projection in this context", !rootSlot.has_value());
+
+ sbe::IndexKeysInclusionSet indexKeysToInclude;
+ for (const std::string& fieldName : fields) {
+ indexKeysToInclude.set(decodeIndexKeyName(fieldName), true);
+ }
+
+ auto lowerBoundExpr = convertBoundsToExpr(true /*isLower*/, indexDef, interval);
+ auto upperBoundExpr = convertBoundsToExpr(false /*isLower*/, indexDef, interval);
+ const bool hasLowerBound = lowerBoundExpr != nullptr;
+ const bool hasUpperBound = upperBoundExpr != nullptr;
+ uassert(6624234, "Invalid bounds combination", hasLowerBound || !hasUpperBound);
+
+ boost::optional<sbe::value::SlotId> seekKeySlotLower;
+ boost::optional<sbe::value::SlotId> seekKeySlotUpper;
+ sbe::value::SlotVector correlatedSlotsForJoin;
+
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+ auto projectForKeyStringBounds = sbe::makeS<sbe::LimitSkipStage>(
+ sbe::makeS<sbe::CoScanStage>(planNodeId), 1, boost::none, planNodeId);
+ if (hasLowerBound) {
+ seekKeySlotLower = _slotIdGenerator.generate();
+ correlatedSlotsForJoin.push_back(seekKeySlotLower.value());
+ projectForKeyStringBounds = sbe::makeProjectStage(std::move(projectForKeyStringBounds),
+ planNodeId,
+ seekKeySlotLower.value(),
+ std::move(lowerBoundExpr));
+ }
+ if (hasUpperBound) {
+ seekKeySlotUpper = _slotIdGenerator.generate();
+ correlatedSlotsForJoin.push_back(seekKeySlotUpper.value());
+ projectForKeyStringBounds = sbe::makeProjectStage(std::move(projectForKeyStringBounds),
+ planNodeId,
+ seekKeySlotUpper.value(),
+ std::move(upperBoundExpr));
+ }
+
+ // Unused.
+ boost::optional<sbe::value::SlotId> resultSlot;
+
+ auto result = sbe::makeS<sbe::IndexScanStage>(nss.uuid().get(),
+ indexDefName,
+ !indexSpec.isReverseOrder(),
+ resultSlot,
+ ridSlot,
+ boost::none,
+ indexKeysToInclude,
+ vars,
+ seekKeySlotLower,
+ seekKeySlotUpper,
+ nullptr /*yieldPolicy*/,
+ planNodeId);
+
+ return sbe::makeS<sbe::LoopJoinStage>(std::move(projectForKeyStringBounds),
+ std::move(result),
+ sbe::makeSV(),
+ std::move(correlatedSlotsForJoin),
+ nullptr,
+ planNodeId);
+}
+
+std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const SeekNode& n,
+ const ABT& /*binds*/,
+ const ABT& /*refs*/) {
+ const ScanDefinition& def = _metadata._scanDefs.at(n.getScanDefName());
+ uassert(6624235, "Collection must exist to lower Seek", def.exists());
+
+ auto& typeSpec = def.getOptionsMap().at("type");
+ uassert(6624236, "SeekNode only supports mongod collections", typeSpec == "mongod");
+ NamespaceStringOrUUID nss = parseFromScanDef(def);
+
+ boost::optional<sbe::value::SlotId> ridSlot;
+ boost::optional<sbe::value::SlotId> rootSlot;
+ std::vector<std::string> fields;
+ sbe::value::SlotVector vars;
+ generateSlots(n.getFieldProjectionMap(), ridSlot, rootSlot, fields, vars);
+
+ boost::optional<sbe::value::SlotId> seekKeySlot = _slotMap.at(n.getRIDProjectionName());
+
+ sbe::ScanCallbacks callbacks({}, {}, {});
+ const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId;
+ return sbe::makeS<sbe::ScanStage>(nss.uuid().get(),
+ rootSlot,
+ ridSlot,
+ boost::none,
+ boost::none,
+ boost::none,
+ boost::none,
+ boost::none,
+ fields,
+ vars,
+ seekKeySlot,
+ true /*forward*/,
+ nullptr /*yieldPolicy*/,
+ planNodeId,
+ callbacks);
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/exec/sbe/abt/abt_lower.h b/src/mongo/db/exec/sbe/abt/abt_lower.h
new file mode 100644
index 00000000000..17365f7925f
--- /dev/null
+++ b/src/mongo/db/exec/sbe/abt/abt_lower.h
@@ -0,0 +1,204 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/exec/sbe/expressions/expression.h"
+#include "mongo/db/exec/sbe/stages/stages.h"
+#include "mongo/db/query/optimizer/node_defs.h"
+#include "mongo/db/query/optimizer/reference_tracker.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer {
+using SlotVarMap = stdx::unordered_map<std::string, sbe::value::SlotId>;
+
+class SBEExpressionLowering {
+public:
+ SBEExpressionLowering(const VariableEnvironment& env, SlotVarMap& slotMap)
+ : _env(env), _slotMap(slotMap) {}
+
+ // The default noop transport.
+ template <typename T, typename... Ts>
+ std::unique_ptr<sbe::EExpression> transport(const T&, Ts&&...) {
+ uasserted(6624237, "abt tree is not lowered correctly");
+ return nullptr;
+ }
+
+ std::unique_ptr<sbe::EExpression> transport(const Constant&);
+ std::unique_ptr<sbe::EExpression> transport(const Variable& var);
+ std::unique_ptr<sbe::EExpression> transport(const Source&);
+ std::unique_ptr<sbe::EExpression> transport(const BinaryOp& op,
+ std::unique_ptr<sbe::EExpression> lhs,
+ std::unique_ptr<sbe::EExpression> rhs);
+ std::unique_ptr<sbe::EExpression> transport(const UnaryOp& op,
+ std::unique_ptr<sbe::EExpression> arg);
+ std::unique_ptr<sbe::EExpression> transport(const If&,
+ std::unique_ptr<sbe::EExpression> cond,
+ std::unique_ptr<sbe::EExpression> thenBranch,
+ std::unique_ptr<sbe::EExpression> elseBranch);
+
+ void prepare(const Let& let);
+ std::unique_ptr<sbe::EExpression> transport(const Let& let,
+ std::unique_ptr<sbe::EExpression> bind,
+ std::unique_ptr<sbe::EExpression> in);
+ void prepare(const LambdaAbstraction& lam);
+ std::unique_ptr<sbe::EExpression> transport(const LambdaAbstraction& lam,
+ std::unique_ptr<sbe::EExpression> body);
+ std::unique_ptr<sbe::EExpression> transport(const LambdaApplication&,
+ std::unique_ptr<sbe::EExpression> lam,
+ std::unique_ptr<sbe::EExpression> arg);
+ std::unique_ptr<sbe::EExpression> transport(
+ const FunctionCall& fn, std::vector<std::unique_ptr<sbe::EExpression>> args);
+
+ std::unique_ptr<sbe::EExpression> optimize(const ABT& n);
+
+private:
+ const VariableEnvironment& _env;
+ SlotVarMap& _slotMap;
+
+ sbe::FrameId _frameCounter{100};
+ stdx::unordered_map<const Let*, sbe::FrameId> _letMap;
+ stdx::unordered_map<const LambdaAbstraction*, sbe::FrameId> _lambdaMap;
+};
+
+class SBENodeLowering {
+public:
+ SBENodeLowering(const VariableEnvironment& env,
+ SlotVarMap& slotMap,
+ sbe::value::SlotIdGenerator& ids,
+ const Metadata& metadata,
+ const NodeToGroupPropsMap& nodeToGroupPropsMap,
+ const opt::unordered_map<std::string, ProjectionName>& ridProjections,
+ const bool randomScan = false)
+ : _env(env),
+ _slotMap(slotMap),
+ _slotIdGenerator(ids),
+ _metadata(metadata),
+ _nodeToGroupPropsMap(nodeToGroupPropsMap),
+ _ridProjections(ridProjections),
+ _randomScan(randomScan) {}
+
+ // The default noop transport.
+ template <typename T, typename... Ts>
+ std::unique_ptr<sbe::PlanStage> walk(const T&, Ts&&...) {
+ if constexpr (std::is_base_of_v<LogicalNode, T>) {
+ uasserted(6624238, "A physical plan should not contain exclusively logical nodes.");
+ }
+ return nullptr;
+ }
+
+ std::unique_ptr<sbe::PlanStage> walk(const RootNode& n, const ABT& child, const ABT& refs);
+ std::unique_ptr<sbe::PlanStage> walk(const EvaluationNode& n,
+ const ABT& child,
+ const ABT& binds);
+
+ std::unique_ptr<sbe::PlanStage> walk(const FilterNode& n, const ABT& child, const ABT& filter);
+
+ std::unique_ptr<sbe::PlanStage> walk(const LimitSkipNode& n, const ABT& child);
+ std::unique_ptr<sbe::PlanStage> walk(const ExchangeNode& n, const ABT& child, const ABT& refs);
+ std::unique_ptr<sbe::PlanStage> walk(const CollationNode& n, const ABT& child, const ABT& refs);
+
+ std::unique_ptr<sbe::PlanStage> walk(const UniqueNode& n, const ABT& child, const ABT& refs);
+
+ std::unique_ptr<sbe::PlanStage> walk(const GroupByNode& n,
+ const ABT& child,
+ const ABT& aggBinds,
+ const ABT& aggRefs,
+ const ABT& gbBind,
+ const ABT& gbRefs);
+
+ std::unique_ptr<sbe::PlanStage> walk(const BinaryJoinNode& n,
+ const ABT& leftChild,
+ const ABT& rightChild,
+ const ABT& filter);
+ std::unique_ptr<sbe::PlanStage> walk(const HashJoinNode& n,
+ const ABT& leftChild,
+ const ABT& rightChild,
+ const ABT& refs);
+ std::unique_ptr<sbe::PlanStage> walk(const MergeJoinNode& n,
+ const ABT& leftChild,
+ const ABT& rightChild,
+ const ABT& refs);
+
+ std::unique_ptr<sbe::PlanStage> walk(const UnionNode& n,
+ const ABTVector& children,
+ const ABT& binder,
+ const ABT& refs);
+
+ std::unique_ptr<sbe::PlanStage> walk(const UnwindNode& n,
+ const ABT& child,
+ const ABT& pidBind,
+ const ABT& refs);
+
+ std::unique_ptr<sbe::PlanStage> walk(const ScanNode& n, const ABT& /*binds*/);
+ std::unique_ptr<sbe::PlanStage> walk(const PhysicalScanNode& n, const ABT& /*binds*/);
+ std::unique_ptr<sbe::PlanStage> walk(const CoScanNode& n);
+
+ std::unique_ptr<sbe::PlanStage> walk(const IndexScanNode& n, const ABT& /*binds*/);
+ std::unique_ptr<sbe::PlanStage> walk(const SeekNode& n,
+ const ABT& /*binds*/,
+ const ABT& /*refs*/);
+
+ std::unique_ptr<sbe::PlanStage> optimize(const ABT& n);
+
+private:
+ std::unique_ptr<sbe::PlanStage> lowerScanNode(const Node& n,
+ const std::string& scanDefName,
+ const FieldProjectionMap& fieldProjectionMap,
+ bool useParallelScan);
+ void generateSlots(const FieldProjectionMap& fieldProjectionMap,
+ boost::optional<sbe::value::SlotId>& ridSlot,
+ boost::optional<sbe::value::SlotId>& rootSlot,
+ std::vector<std::string>& fields,
+ sbe::value::SlotVector& vars);
+
+ sbe::value::SlotVector convertProjectionsToSlots(const ProjectionNameVector& projectionNames);
+ sbe::value::SlotVector convertRequiredProjectionsToSlots(
+ const NodeProps& props, bool addRIDProjection, const ProjectionNameVector& toExclude = {});
+
+ std::unique_ptr<sbe::EExpression> convertBoundsToExpr(
+ bool isLower, const IndexDefinition& indexDef, const MultiKeyIntervalRequirement& interval);
+
+ std::unique_ptr<sbe::PlanStage> generateInternal(const ABT& n);
+
+ const VariableEnvironment& _env;
+ SlotVarMap& _slotMap;
+ sbe::value::SlotIdGenerator& _slotIdGenerator;
+
+ const Metadata& _metadata;
+ const NodeToGroupPropsMap& _nodeToGroupPropsMap;
+ const opt::unordered_map<std::string, ProjectionName>& _ridProjections;
+
+ // If true, will create scan nodes using a random cursor to support sampling.
+ // Currently only supported for single-threaded (non parallel-scanned) mongod collections.
+ // TODO: handle cases where we have more than one collection scan.
+ const bool _randomScan;
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_diff_test.cpp b/src/mongo/db/exec/sbe/abt/sbe_abt_diff_test.cpp
new file mode 100644
index 00000000000..3758bf6ac18
--- /dev/null
+++ b/src/mongo/db/exec/sbe/abt/sbe_abt_diff_test.cpp
@@ -0,0 +1,369 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/concurrency/lock_state.h"
+#include "mongo/db/exec/sbe/abt/abt_lower.h"
+#include "mongo/db/exec/sbe/abt/sbe_abt_test_util.h"
+#include "mongo/db/pipeline/abt/abt_document_source_visitor.h"
+#include "mongo/db/pipeline/document_source_queue.h"
+#include "mongo/db/query/optimizer/explain.h"
+#include "mongo/db/query/optimizer/opt_phase_manager.h"
+#include "mongo/db/query/plan_executor.h"
+#include "mongo/db/query/plan_executor_factory.h"
+#include "mongo/unittest/temp_dir.h"
+
+namespace mongo::optimizer {
+namespace {
+
+ABT createValueArray(const std::vector<std::string>& jsonVector) {
+ const auto [tag, val] = sbe::value::makeNewArray();
+ auto outerArrayPtr = sbe::value::getArrayView(val);
+
+ for (const std::string& s : jsonVector) {
+ const auto [tag1, val1] = sbe::value::makeNewArray();
+ auto innerArrayPtr = sbe::value::getArrayView(val1);
+
+ const BSONObj& bsonObj = fromjson(s);
+ const auto [tag2, val2] =
+ sbe::value::copyValue(sbe::value::TypeTags::bsonObject,
+ sbe::value::bitcastFrom<const char*>(bsonObj.objdata()));
+ innerArrayPtr->push_back(tag2, val2);
+
+ outerArrayPtr->push_back(tag1, val1);
+ }
+
+ return make<Constant>(tag, val);
+}
+
+using ContextFn = std::function<ServiceContext::UniqueOperationContext()>;
+using ResultSet = std::vector<BSONObj>;
+
+ResultSet runSBEAST(const ContextFn& fn,
+ const std::string& pipelineStr,
+ const std::vector<std::string>& jsonVector) {
+ PrefixId prefixId;
+ Metadata metadata{{}};
+
+ auto pipeline = parsePipeline(pipelineStr, NamespaceString("test"));
+
+ ABT tree = createValueArray(jsonVector);
+
+ const ProjectionName scanProjName = prefixId.getNextId("scan");
+ tree = translatePipelineToABT(
+ metadata,
+ *pipeline.get(),
+ scanProjName,
+ make<ValueScanNode>(ProjectionNameVector{scanProjName}, std::move(tree)),
+ prefixId);
+
+ std::cerr << "********* Translated ABT *********\n";
+ std::cerr << ExplainGenerator::explainV2(tree);
+ std::cerr << "********* Translated ABT *********\n";
+
+ OptPhaseManager phaseManager(
+ OptPhaseManager::getAllRewritesSet(), prefixId, {{}}, DebugInfo::kDefaultForTests);
+ ASSERT_TRUE(phaseManager.optimize(tree));
+
+ std::cerr << "********* Optimized ABT *********\n";
+ std::cerr << ExplainGenerator::explainV2(tree);
+ std::cerr << "********* Optimized ABT *********\n";
+
+ SlotVarMap map;
+ sbe::value::SlotIdGenerator ids;
+
+ auto env = VariableEnvironment::build(tree);
+ SBENodeLowering g{env,
+ map,
+ ids,
+ phaseManager.getMetadata(),
+ phaseManager.getNodeToGroupPropsMap(),
+ phaseManager.getRIDProjections()};
+ auto sbePlan = g.optimize(tree);
+ uassert(6624249, "Cannot optimize SBE plan", sbePlan != nullptr);
+
+ sbe::CompileCtx ctx(std::make_unique<sbe::RuntimeEnvironment>());
+ sbePlan->prepare(ctx);
+
+ std::vector<sbe::value::SlotAccessor*> accessors;
+ for (auto& [name, slot] : map) {
+ accessors.emplace_back(sbePlan->getAccessor(ctx, slot));
+ }
+ // For now assert we only have one final projection.
+ ASSERT_EQ(1, accessors.size());
+
+ sbePlan->attachToOperationContext(fn().get());
+ sbePlan->open(false);
+
+ ResultSet results;
+ while (sbePlan->getNext() != sbe::PlanState::IS_EOF) {
+ if (results.size() > 1000) {
+ uasserted(6624250, "Too many results!");
+ }
+
+ std::ostringstream os;
+ os << accessors.at(0)->getViewOfValue();
+ results.push_back(fromjson(os.str()));
+ };
+ sbePlan->close();
+
+ return results;
+}
+
+ResultSet runPipeline(const ContextFn& fn,
+ const std::string& pipelineStr,
+ const std::vector<std::string>& jsonVector) {
+ NamespaceString nss("test");
+ std::unique_ptr<mongo::Pipeline, mongo::PipelineDeleter> pipeline =
+ parsePipeline(pipelineStr, nss);
+
+ const auto queueStage = DocumentSourceQueue::create(pipeline->getContext());
+ for (const std::string& s : jsonVector) {
+ BSONObj bsonObj = fromjson(s);
+ queueStage->emplace_back(Document{bsonObj});
+ }
+
+ pipeline->addInitialSource(queueStage);
+
+ boost::intrusive_ptr<ExpressionContext> expCtx;
+ expCtx.reset(new ExpressionContext(fn().get(), nullptr, nss));
+
+ std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> planExec =
+ plan_executor_factory::make(expCtx, std::move(pipeline));
+
+ ResultSet results;
+ bool done = false;
+ while (!done) {
+ BSONObj outObj;
+ auto result = planExec->getNext(&outObj, nullptr);
+ switch (result) {
+ case PlanExecutor::ADVANCED:
+ results.push_back(outObj);
+ break;
+ case PlanExecutor::IS_EOF:
+ done = true;
+ break;
+ }
+ }
+
+ return results;
+}
+
+bool compareBSONObj(const BSONObj& actual, const BSONObj& expected, const bool preserveFieldOrder) {
+ BSONObj::ComparisonRulesSet rules = BSONObj::ComparisonRules::kConsiderFieldName;
+ if (!preserveFieldOrder) {
+ rules |= BSONObj::ComparisonRules::kIgnoreFieldOrder;
+ }
+ return actual.woCompare(expected, BSONObj(), rules) == 0;
+}
+
+bool compareResults(const ResultSet& expected,
+ const ResultSet& actual,
+ const bool preserveFieldOrder) {
+ if (expected.size() != actual.size()) {
+ std::cout << "Different result size: expected: " << expected.size()
+ << " vs actual: " << actual.size() << "\n";
+ if (!expected.empty()) {
+ std::cout << "First expected result: " << expected.front() << "\n";
+ }
+ if (!actual.empty()) {
+ std::cout << "First actual result: " << actual.front() << "\n";
+ }
+ return false;
+ }
+
+ for (size_t i = 0; i < expected.size(); i++) {
+ if (!compareBSONObj(actual.at(i), expected.at(i), preserveFieldOrder)) {
+ std::cout << "Result at position " << i << "/" << expected.size()
+ << " mismatch: expected: " << expected.at(i) << " vs actual: " << actual.at(i)
+ << "\n";
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool compareSBEABTAgainstExpected(const ContextFn& fn,
+ const std::string& pipelineStr,
+ const std::vector<std::string>& jsonVector,
+ const ResultSet& expected) {
+ const ResultSet& actual = runSBEAST(fn, pipelineStr, jsonVector);
+ return compareResults(expected, actual, true /*preserveFieldOrder*/);
+}
+
+bool comparePipelineAgainstExpected(const ContextFn& fn,
+ const std::string& pipelineStr,
+ const std::vector<std::string>& jsonVector,
+ const ResultSet& expected) {
+ const ResultSet& actual = runPipeline(fn, pipelineStr, jsonVector);
+ return compareResults(expected, actual, true /*preserveFieldOrder*/);
+}
+
+bool compareSBEABTAgainstPipeline(const ContextFn& fn,
+ const std::string& pipelineStr,
+ const std::vector<std::string>& jsonVector,
+ const bool preserveFieldOrder = true) {
+ const ResultSet& pipelineResults = runPipeline(fn, pipelineStr, jsonVector);
+ const ResultSet& sbeResults = runSBEAST(fn, pipelineStr, jsonVector);
+
+ std::cout << "Pipeline: " << pipelineStr << ", input size: " << jsonVector.size() << "\n";
+ const bool result = compareResults(pipelineResults, sbeResults, preserveFieldOrder);
+ if (result) {
+ std::cout << "Success. Result count: " << pipelineResults.size() << "\n";
+ constexpr size_t maxResults = 1;
+ for (size_t i = 0; i < std::min(pipelineResults.size(), maxResults); i++) {
+ std::cout << "Result " << (i + 1) << "/" << pipelineResults.size()
+ << ": expected (pipeline): " << pipelineResults.at(i)
+ << " vs actual (SBE): " << sbeResults.at(i) << "\n";
+ }
+ }
+
+ return result;
+}
+
+ResultSet toResultSet(const std::vector<std::string>& jsonVector) {
+ ResultSet results;
+ for (const std::string& jsonStr : jsonVector) {
+ results.emplace_back(fromjson(jsonStr));
+ }
+ return results;
+}
+
+class TestObserver : public ServiceContext::ClientObserver {
+public:
+ TestObserver() = default;
+ ~TestObserver() = default;
+
+ void onCreateClient(Client* client) final {}
+
+ void onDestroyClient(Client* client) final {}
+
+ void onCreateOperationContext(OperationContext* opCtx) override {
+ opCtx->setLockState(std::make_unique<LockerImpl>());
+ }
+
+ void onDestroyOperationContext(OperationContext* opCtx) final {}
+};
+
+const ServiceContext::ConstructorActionRegisterer clientObserverRegisterer{
+ "TestObserver",
+ [](ServiceContext* service) {
+ service->registerClientObserver(std::make_unique<TestObserver>());
+ },
+ [](ServiceContext* serviceContext) {}};
+
+TEST_F(NodeSBE, DiffTestBasic) {
+ const auto contextFn = [this]() { return makeOperationContext(); };
+ const auto compare = [&contextFn](const std::string& pipelineStr,
+ const std::vector<std::string>& jsonVector) {
+ return compareSBEABTAgainstPipeline(
+ contextFn, pipelineStr, jsonVector, true /*preserveFieldOrder*/);
+ };
+
+ ASSERT_TRUE(compareSBEABTAgainstExpected(
+ contextFn, "[]", {"{a:1, b:2, c:3}"}, toResultSet({"{ a: 1, b: 2, c: 3 }"})));
+ ASSERT_TRUE(compareSBEABTAgainstExpected(contextFn,
+ "[{$addFields: {c: {$literal: 3}}}]",
+ {"{a:1, b:2}"},
+ toResultSet({"{ a: 1, b: 2, c: 3 }"})));
+
+ ASSERT_TRUE(comparePipelineAgainstExpected(
+ contextFn, "[]", {"{a:1, b:2, c:3}"}, toResultSet({"{ a: 1, b: 2, c: 3 }"})));
+ ASSERT_TRUE(comparePipelineAgainstExpected(contextFn,
+ "[{$addFields: {c: {$literal: 3}}}]",
+ {"{a:1, b:2}"},
+ toResultSet({"{ a: 1, b: 2, c: 3 }"})));
+
+ ASSERT_TRUE(compare("[]", {"{a:1, b:2, c:3}"}));
+ ASSERT_TRUE(compare("[{$addFields: {c: {$literal: 3}}}]", {"{a:1, b:2}"}));
+}
+
+TEST_F(NodeSBE, DiffTest) {
+ const auto contextFn = [this]() { return makeOperationContext(); };
+ const auto compare = [&contextFn](const std::string& pipelineStr,
+ const std::vector<std::string>& jsonVector) {
+ return compareSBEABTAgainstPipeline(
+ contextFn, pipelineStr, jsonVector, true /*preserveFieldOrder*/);
+ };
+
+ // Consider checking if compare() works first.
+ const auto compareUnordered = [&contextFn](const std::string& pipelineStr,
+ const std::vector<std::string>& jsonVector) {
+ return compareSBEABTAgainstPipeline(
+ contextFn, pipelineStr, jsonVector, false /*preserveFieldOrder*/);
+ };
+
+ ASSERT_TRUE(compare("[]", {}));
+
+ ASSERT_TRUE(compare("[{$project: {a: 1, b: 1}}]", {"{a: 10, b: 20, c: 30}"}));
+ ASSERT_TRUE(compare("[{$match: {a: 2}}]", {"{a: [1, 2, 3, 4]}"}));
+ ASSERT_TRUE(compare("[{$match: {a: 5}}]", {"{a: [1, 2, 3, 4]}"}));
+ ASSERT_TRUE(compare("[{$match: {a: {$gte: 3}}}]", {"{a: [1, 2, 3, 4]}"}));
+ ASSERT_TRUE(compare("[{$match: {a: {$gte: 30}}}]", {"{a: [1, 2, 3, 4]}"}));
+ ASSERT_TRUE(
+ compare("[{$match: {a: {$elemMatch: {$gte: 2, $lte: 3}}}}]", {"{a: [1, 2, 3, 4]}"}));
+ ASSERT_TRUE(
+ compare("[{$match: {a: {$elemMatch: {$gte: 20, $lte: 30}}}}]", {"{a: [1, 2, 3, 4]}"}));
+
+ ASSERT_TRUE(compare("[{$project: {'a.b': '$c'}}]", {"{a: {d: 1}, c: 2}"}));
+ ASSERT_TRUE(compare("[{$project: {'a.b': '$c'}}]", {"{a: [{d: 1}, {d: 2}, {b: 10}], c: 2}"}));
+
+ ASSERT_TRUE(compareUnordered("[{$project: {'a.b': '$c', c: 1}}]", {"{a: {d: 1}, c: 2}"}));
+ ASSERT_TRUE(compareUnordered("[{$project: {'a.b': '$c', 'a.d': 1, c: 1}}]",
+ {"{a: [{d: 1}, {d: 2}, {b: 10}], c: 2}"}));
+
+ ASSERT_TRUE(
+ compare("[{$project: {a: {$filter: {input: '$b', as: 'num', cond: {$and: [{$gte: ['$$num', "
+ "2]}, {$lte: ['$$num', 3]}]}}}}}]",
+ {"{b: [1, 2, 3, 4]}"}));
+ ASSERT_TRUE(
+ compare("[{$project: {a: {$filter: {input: '$b', as: 'num', cond: {$and: [{$gte: ['$$num', "
+ "3]}, {$lte: ['$$num', 2]}]}}}}}]",
+ {"{b: [1, 2, 3, 4]}"}));
+
+ ASSERT_TRUE(compare("[{$unwind: {path: '$a'}}]", {"{a: [1, 2, 3, 4]}"}));
+ ASSERT_TRUE(compare("[{$unwind: {path: '$a.b'}}]", {"{a: {b: [1, 2, 3, 4]}}"}));
+
+ ASSERT_TRUE(compare("[{$match:{'a.b.c':'aaa'}}]", {"{a: {b: {c: 'aaa'}}}"}));
+ ASSERT_TRUE(
+ compare("[{$match:{'a.b.c':'aaa'}}]", {"{a: {b: {c: 'aaa'}}}", "{a: {b: {c: 'aaa'}}}"}));
+
+ ASSERT_TRUE(compare("[{$match: {a: {$lt: 5, $gt: 5}}}]", {"{_id: 1, a: [4, 6]}"}));
+ ASSERT_TRUE(compare("[{$match: {a: {$gt: null}}}]", {"{_id: 1, a: 1}"}));
+
+ ASSERT_TRUE(compare("[{$match: {a: {$elemMatch: {$lt: 6, $gt: 4}}}}]", {"{a: [5]}"}));
+ ASSERT_TRUE(compare("[{$match: {'a.b': {$elemMatch: {$lt: 6, $gt: 4}}}}]",
+ {"{a: {b: [5]}}", "{a: [{b: 5}]}"}));
+
+ ASSERT_TRUE(compare("[{$match: {a: {$elemMatch: {$elemMatch: {$lt: 6, $gt: 4}}}}}]",
+ {"{a: [[4, 5, 6], [5]]}", "{a: [4, 5, 6]}"}));
+}
+
+} // namespace
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp
new file mode 100644
index 00000000000..a8dbb516982
--- /dev/null
+++ b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp
@@ -0,0 +1,311 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/exec/sbe/abt/abt_lower.h"
+#include "mongo/db/exec/sbe/abt/sbe_abt_test_util.h"
+#include "mongo/db/pipeline/abt/abt_document_source_visitor.h"
+#include "mongo/db/query/optimizer/explain.h"
+#include "mongo/db/query/optimizer/opt_phase_manager.h"
+#include "mongo/db/query/optimizer/rewrites/const_eval.h"
+#include "mongo/db/query/optimizer/rewrites/path_lower.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo::optimizer {
+namespace {
+
+TEST_F(ABTSBE, Lower1) {
+ auto tree = Constant::int64(100);
+ auto env = VariableEnvironment::build(tree);
+ SlotVarMap map;
+
+ auto expr = SBEExpressionLowering{env, map}.optimize(tree);
+
+ ASSERT(expr);
+
+ auto compiledExpr = compileExpression(*expr);
+ auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get());
+
+ ASSERT_EQ(sbe::value::TypeTags::NumberInt64, resultTag);
+ ASSERT_EQ(sbe::value::bitcastTo<int64_t>(resultVal), 100);
+}
+
+TEST_F(ABTSBE, Lower2) {
+ auto tree =
+ make<Let>("x",
+ Constant::int64(100),
+ make<BinaryOp>(Operations::Add, make<Variable>("x"), Constant::int64(100)));
+
+ auto env = VariableEnvironment::build(tree);
+ SlotVarMap map;
+
+ auto expr = SBEExpressionLowering{env, map}.optimize(tree);
+
+ ASSERT(expr);
+
+ auto compiledExpr = compileExpression(*expr);
+ auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get());
+
+ ASSERT_EQ(sbe::value::TypeTags::NumberInt64, resultTag);
+ ASSERT_EQ(sbe::value::bitcastTo<int64_t>(resultVal), 200);
+}
+
+TEST_F(ABTSBE, Lower3) {
+ auto tree = make<FunctionCall>("isNumber", makeSeq(Constant::int64(10)));
+ auto env = VariableEnvironment::build(tree);
+ SlotVarMap map;
+
+ auto expr = SBEExpressionLowering{env, map}.optimize(tree);
+
+ ASSERT(expr);
+
+ auto compiledExpr = compileExpression(*expr);
+ auto result = runCompiledExpressionPredicate(compiledExpr.get());
+
+ ASSERT(result);
+}
+
+TEST_F(ABTSBE, Lower4) {
+ auto [tagArr, valArr] = sbe::value::makeNewArray();
+ auto arr = sbe::value::getArrayView(valArr);
+ arr->push_back(sbe::value::TypeTags::NumberInt64, 1);
+ arr->push_back(sbe::value::TypeTags::NumberInt64, 2);
+ auto [tagArrNest, valArrNest] = sbe::value::makeNewArray();
+ auto arrNest = sbe::value::getArrayView(valArrNest);
+ arrNest->push_back(sbe::value::TypeTags::NumberInt64, 21);
+ arrNest->push_back(sbe::value::TypeTags::NumberInt64, 22);
+ arr->push_back(tagArrNest, valArrNest);
+ arr->push_back(sbe::value::TypeTags::NumberInt64, 3);
+
+ auto tree = make<FunctionCall>(
+ "traverseP",
+ makeSeq(
+ make<Constant>(tagArr, valArr),
+ make<LambdaAbstraction>(
+ "x", make<BinaryOp>(Operations::Add, make<Variable>("x"), Constant::int64(10)))));
+ auto env = VariableEnvironment::build(tree);
+ SlotVarMap map;
+
+ auto expr = SBEExpressionLowering{env, map}.optimize(tree);
+
+ ASSERT(expr);
+
+ auto compiledExpr = compileExpression(*expr);
+ auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get());
+ sbe::value::ValueGuard guard(resultTag, resultVal);
+
+ ASSERT_EQ(sbe::value::TypeTags::Array, resultTag);
+}
+
+TEST_F(ABTSBE, Lower5) {
+ auto tree = make<FunctionCall>(
+ "setField", makeSeq(Constant::nothing(), Constant::str("fieldA"), Constant::int64(10)));
+
+ auto env = VariableEnvironment::build(tree);
+ SlotVarMap map;
+
+ auto expr = SBEExpressionLowering{env, map}.optimize(tree);
+
+ ASSERT(expr);
+
+ auto compiledExpr = compileExpression(*expr);
+ auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get());
+ sbe::value::ValueGuard guard(resultTag, resultVal);
+}
+
+TEST_F(ABTSBE, Lower6) {
+ PrefixId prefixId;
+
+ auto [tagObj, valObj] = sbe::value::makeNewObject();
+ auto obj = sbe::value::getObjectView(valObj);
+
+ auto [tagObjIn, valObjIn] = sbe::value::makeNewObject();
+ auto objIn = sbe::value::getObjectView(valObjIn);
+ objIn->push_back("fieldB", sbe::value::TypeTags::NumberInt64, 100);
+ obj->push_back("fieldA", tagObjIn, valObjIn);
+
+ sbe::value::OwnedValueAccessor accessor;
+ auto slotId = bindAccessor(&accessor);
+ SlotVarMap map;
+ map["root"] = slotId;
+
+ accessor.reset(tagObj, valObj);
+
+ auto tree = make<EvalPath>(
+ make<PathField>("fieldA",
+ make<PathTraverse>(make<PathComposeM>(
+ make<PathField>("fieldB", make<PathDefault>(Constant::int64(0))),
+ make<PathField>("fieldC", make<PathConstant>(Constant::int64(50)))))),
+ make<Variable>("root"));
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ // std::cout << ExplainGenerator::explain(tree);
+
+ auto expr = SBEExpressionLowering{env, map}.optimize(tree);
+
+ ASSERT(expr);
+
+ auto compiledExpr = compileExpression(*expr);
+ auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get());
+ sbe::value::ValueGuard guard(resultTag, resultVal);
+
+ // std::cout << std::pair{resultTag, resultVal} << "\n";
+
+ ASSERT_EQ(sbe::value::TypeTags::Object, resultTag);
+}
+
+TEST_F(ABTSBE, Lower7) {
+ PrefixId prefixId;
+
+ auto [tagArr, valArr] = sbe::value::makeNewArray();
+ auto arr = sbe::value::getArrayView(valArr);
+ arr->push_back(sbe::value::TypeTags::NumberInt64, 1);
+ arr->push_back(sbe::value::TypeTags::NumberInt64, 2);
+ arr->push_back(sbe::value::TypeTags::NumberInt64, 3);
+
+ auto [tagObj, valObj] = sbe::value::makeNewObject();
+ auto obj = sbe::value::getObjectView(valObj);
+ obj->push_back("fieldA", tagArr, valArr);
+
+ sbe::value::OwnedValueAccessor accessor;
+ auto slotId = bindAccessor(&accessor);
+ SlotVarMap map;
+ map["root"] = slotId;
+
+ accessor.reset(tagObj, valObj);
+ auto tree = make<EvalFilter>(
+ make<PathGet>("fieldA",
+ make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2)))),
+ make<Variable>("root"));
+
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ auto expr = SBEExpressionLowering{env, map}.optimize(tree);
+
+ ASSERT(expr);
+ auto compiledExpr = compileExpression(*expr);
+ auto result = runCompiledExpressionPredicate(compiledExpr.get());
+
+ ASSERT(result);
+}
+
+TEST_F(NodeSBE, Lower1) {
+ PrefixId prefixId;
+ Metadata metadata{{}};
+
+ auto pipeline =
+ parsePipeline("[{$project:{'a.b.c.d':{$literal:'abc'}}}]", NamespaceString("test"));
+
+ const auto [tag, val] = sbe::value::makeNewArray();
+ {
+ // Create an array of array with one empty document.
+ auto outerArrayPtr = sbe::value::getArrayView(val);
+
+ const auto [tag1, val1] = sbe::value::makeNewArray();
+ auto innerArrayPtr = sbe::value::getArrayView(val1);
+
+ const auto [tag2, val2] = sbe::value::makeNewObject();
+ innerArrayPtr->push_back(tag2, val2);
+
+ outerArrayPtr->push_back(tag1, val1);
+ }
+ ABT tree = make<Constant>(tag, val);
+
+ const ProjectionName scanProjName = prefixId.getNextId("scan");
+ tree = translatePipelineToABT(
+ metadata,
+ *pipeline.get(),
+ scanProjName,
+ make<ValueScanNode>(ProjectionNameVector{scanProjName}, std::move(tree)),
+ prefixId);
+
+ OptPhaseManager phaseManager(
+ OptPhaseManager::getAllRewritesSet(), prefixId, {{}}, DebugInfo::kDefaultForTests);
+
+ ASSERT_TRUE(phaseManager.optimize(tree));
+ auto env = VariableEnvironment::build(tree);
+ SlotVarMap map;
+ sbe::value::SlotIdGenerator ids;
+
+ SBENodeLowering g{env,
+ map,
+ ids,
+ phaseManager.getMetadata(),
+ phaseManager.getNodeToGroupPropsMap(),
+ phaseManager.getRIDProjections()};
+
+ auto sbePlan = g.optimize(tree);
+
+ auto opCtx = makeOperationContext();
+
+ sbe::CompileCtx ctx(std::make_unique<sbe::RuntimeEnvironment>());
+ sbePlan->prepare(ctx);
+
+ std::vector<sbe::value::SlotAccessor*> accessors;
+ for (auto& [name, slot] : map) {
+ std::cout << name << " ";
+ accessors.emplace_back(sbePlan->getAccessor(ctx, slot));
+ }
+ std::cout << "\n";
+ sbePlan->attachToOperationContext(opCtx.get());
+ sbePlan->open(false);
+ while (sbePlan->getNext() != sbe::PlanState::IS_EOF) {
+ for (auto acc : accessors) {
+ std::cout << acc->getViewOfValue() << " ";
+ }
+ std::cout << "\n";
+ };
+ sbePlan->close();
+}
+
+} // namespace
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp b/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp
new file mode 100644
index 00000000000..95b9796c324
--- /dev/null
+++ b/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp
@@ -0,0 +1,63 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/exec/sbe/abt/sbe_abt_test_util.h"
+#include "mongo/db/pipeline/aggregate_command_gen.h"
+#include "mongo/db/pipeline/expression_context_for_test.h"
+#include "mongo/unittest/temp_dir.h"
+
+namespace mongo::optimizer {
+
+static std::unique_ptr<mongo::Pipeline, mongo::PipelineDeleter> parsePipelineInternal(
+ NamespaceString nss, const std::string& inputPipeline, OperationContextNoop& opCtx) {
+ const BSONObj inputBson = fromjson("{pipeline: " + inputPipeline + "}");
+
+ std::vector<BSONObj> rawPipeline;
+ for (auto&& stageElem : inputBson["pipeline"].Array()) {
+ ASSERT_EQUALS(stageElem.type(), BSONType::Object);
+ rawPipeline.push_back(stageElem.embeddedObject());
+ }
+
+ AggregateCommandRequest request(std::move(nss), rawPipeline);
+ boost::intrusive_ptr<ExpressionContextForTest> ctx(
+ new ExpressionContextForTest(&opCtx, request));
+
+ unittest::TempDir tempDir("ABTPipelineTest");
+ ctx->tempDir = tempDir.path();
+
+ return Pipeline::parse(request.getPipeline(), ctx);
+}
+
+std::unique_ptr<mongo::Pipeline, mongo::PipelineDeleter> parsePipeline(
+ const std::string& pipelineStr, NamespaceString nss) {
+ OperationContextNoop opCtx;
+ return parsePipelineInternal(std::move(nss), pipelineStr, opCtx);
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.h b/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.h
new file mode 100644
index 00000000000..c56d035b6ff
--- /dev/null
+++ b/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.h
@@ -0,0 +1,48 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/exec/sbe/expression_test_base.h"
+#include "mongo/db/operation_context_noop.h"
+#include "mongo/db/pipeline/pipeline.h"
+#include "mongo/db/query/optimizer/node.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+#include "mongo/db/service_context_test_fixture.h"
+
+#pragma once
+
+namespace mongo::optimizer {
+
+class NodeSBE : public ServiceContextTest {};
+
+std::unique_ptr<mongo::Pipeline, mongo::PipelineDeleter> parsePipeline(
+ const std::string& pipelineStr, NamespaceString nss);
+
+using ABTSBE = sbe::EExpressionTestFixture;
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/exec/sbe/stages/exchange.cpp b/src/mongo/db/exec/sbe/stages/exchange.cpp
index 95b143ebec5..8cd7b065559 100644
--- a/src/mongo/db/exec/sbe/stages/exchange.cpp
+++ b/src/mongo/db/exec/sbe/stages/exchange.cpp
@@ -33,6 +33,7 @@
#include "mongo/base/init.h"
#include "mongo/db/client.h"
+#include "mongo/db/concurrency/d_concurrency.h"
#include "mongo/db/exec/sbe/size_estimator.h"
namespace mongo::sbe {
@@ -389,7 +390,10 @@ void ExchangeConsumer::close() {
std::unique_ptr<PlanStageStats> ExchangeConsumer::getStats(bool includeDebugInfo) const {
auto ret = std::make_unique<PlanStageStats>(_commonStats);
- ret->children.emplace_back(_children[0]->getStats(includeDebugInfo));
+ if (!_children.empty()) {
+ // TODO: handle empty _children.
+ ret->children.emplace_back(_children[0]->getStats(includeDebugInfo));
+ }
return ret;
}
@@ -428,8 +432,11 @@ std::vector<DebugPrinter::Block> ExchangeConsumer::debugPrint() const {
uasserted(4822835, "policy not yet implemented");
}
- DebugPrinter::addNewLine(ret);
- DebugPrinter::addBlocks(ret, _children[0]->debugPrint());
+ if (!_children.empty()) {
+ // TODO: handle empty _children.
+ DebugPrinter::addNewLine(ret);
+ DebugPrinter::addBlocks(ret, _children[0]->debugPrint());
+ }
return ret;
}
@@ -497,6 +504,9 @@ void ExchangeProducer::start(OperationContext* opCtx,
std::unique_ptr<PlanStage> producer) {
ExchangeProducer* p = static_cast<ExchangeProducer*>(producer.get());
+ // TODO: SERVER-62925. Rationalize this lock.
+ Lock::GlobalLock lock(opCtx, MODE_IS);
+
p->attachToOperationContext(opCtx);
try {
diff --git a/src/mongo/db/exec/sbe/stages/scan.cpp b/src/mongo/db/exec/sbe/stages/scan.cpp
index ad3d1b7e06b..cf1e40bf6cd 100644
--- a/src/mongo/db/exec/sbe/stages/scan.cpp
+++ b/src/mongo/db/exec/sbe/stages/scan.cpp
@@ -55,7 +55,8 @@ ScanStage::ScanStage(UUID collectionUuid,
bool forward,
PlanYieldPolicy* yieldPolicy,
PlanNodeId nodeId,
- ScanCallbacks scanCallbacks)
+ ScanCallbacks scanCallbacks,
+ bool useRandomCursor)
: PlanStage(seekKeySlot ? "seek"_sd : "scan"_sd, yieldPolicy, nodeId),
_collUuid(collectionUuid),
_recordSlot(recordSlot),
@@ -69,7 +70,8 @@ ScanStage::ScanStage(UUID collectionUuid,
_vars(std::move(vars)),
_seekKeySlot(seekKeySlot),
_forward(forward),
- _scanCallbacks(std::move(scanCallbacks)) {
+ _scanCallbacks(std::move(scanCallbacks)),
+ _useRandomCursor(useRandomCursor) {
invariant(_fields.size() == _vars.size());
invariant(!_seekKeySlot || _forward);
tassert(5567202,
@@ -77,6 +79,8 @@ ScanStage::ScanStage(UUID collectionUuid,
!_oplogTsSlot ||
(std::find(_fields.begin(), _fields.end(), repl::OpTime::kTimestampFieldName) !=
_fields.end()));
+ // We cannot use a random cursor if we are seeking or requesting a reverse scan.
+ invariant(!_useRandomCursor || (!_seekKeySlot && _forward));
}
std::unique_ptr<PlanStage> ScanStage::clone() const {
@@ -197,12 +201,12 @@ void ScanStage::doSaveState(bool relinquishCursor) {
#endif
- if (_cursor && relinquishCursor) {
- _cursor->save();
+ if (auto cursor = getActiveCursor(); cursor != nullptr && relinquishCursor) {
+ cursor->save();
}
- if (_cursor) {
- _cursor->setSaveStorageCursorOnDetachFromOperationContext(!relinquishCursor);
+ if (auto cursor = getActiveCursor()) {
+ cursor->setSaveStorageCursorOnDetachFromOperationContext(!relinquishCursor);
}
_coll.reset();
@@ -220,10 +224,10 @@ void ScanStage::doRestoreState(bool relinquishCursor) {
tassert(5777408, "Catalog epoch should be initialized", _catalogEpoch);
_coll = restoreCollection(_opCtx, *_collName, _collUuid, *_catalogEpoch);
- if (_cursor) {
+ if (auto cursor = getActiveCursor(); cursor != nullptr) {
if (relinquishCursor) {
const auto tolerateCappedCursorRepositioning = false;
- const bool couldRestore = _cursor->restore(tolerateCappedCursorRepositioning);
+ const bool couldRestore = cursor->restore(tolerateCappedCursorRepositioning);
uassert(
ErrorCodes::CappedPositionLost,
str::stream()
@@ -262,14 +266,14 @@ void ScanStage::doRestoreState(bool relinquishCursor) {
}
void ScanStage::doDetachFromOperationContext() {
- if (_cursor) {
- _cursor->detachFromOperationContext();
+ if (auto cursor = getActiveCursor()) {
+ cursor->detachFromOperationContext();
}
}
void ScanStage::doAttachToOperationContext(OperationContext* opCtx) {
- if (_cursor) {
- _cursor->reattachToOperationContext(opCtx);
+ if (auto cursor = getActiveCursor()) {
+ cursor->reattachToOperationContext(opCtx);
}
}
@@ -283,6 +287,10 @@ PlanStage::TrialRunTrackerAttachResultMask ScanStage::doAttachToTrialRunTracker(
return childrenAttachResult | TrialRunTrackerAttachResultFlags::AttachedToStreamingStage;
}
+RecordCursor* ScanStage::getActiveCursor() const {
+ return _useRandomCursor ? _randomCursor.get() : _cursor.get();
+}
+
void ScanStage::open(bool reOpen) {
auto optTimer(getOptTimer(_opCtx));
@@ -292,13 +300,13 @@ void ScanStage::open(bool reOpen) {
if (_open) {
tassert(5071001, "reopened ScanStage but reOpen=false", reOpen);
tassert(5071002, "ScanStage is open but _coll is not null", _coll);
- tassert(5071003, "ScanStage is open but don't have _cursor", _cursor);
+ tassert(5071003, "ScanStage is open but doesn't have a cursor", getActiveCursor());
} else {
tassert(5071004, "first open to ScanStage but reOpen=true", !reOpen);
if (!_coll) {
// We're being opened after 'close()'. We need to re-acquire '_coll' in this case and
// make some validity checks (the collection has not been dropped, renamed, etc.).
- tassert(5071005, "ScanStage is not open but have _cursor", !_cursor);
+ tassert(5071005, "ScanStage is not open but has a cursor", !getActiveCursor());
tassert(5777401, "Collection name should be initialized", _collName);
tassert(5777402, "Catalog epoch should be initialized", _catalogEpoch);
_coll = restoreCollection(_opCtx, *_collName, _collUuid, *_catalogEpoch);
@@ -321,7 +329,11 @@ void ScanStage::open(bool reOpen) {
}
if (!_cursor || !_seekKeyAccessor) {
- _cursor = _coll->getCursor(_opCtx, _forward);
+ if (_useRandomCursor) {
+ _randomCursor = _coll->getRecordStore()->getRandomCursor(_opCtx);
+ } else {
+ _cursor = _coll->getCursor(_opCtx, _forward);
+ }
}
} else {
MONGO_UNREACHABLE_TASSERT(5959701);
@@ -355,7 +367,8 @@ PlanState ScanStage::getNext() {
}
auto res = _firstGetNext && _seekKeyAccessor;
- auto nextRecord = res ? _cursor->seekExact(_key) : _cursor->next();
+ auto nextRecord = _useRandomCursor ? _randomCursor->next()
+ : (res ? _cursor->seekExact(_key) : _cursor->next());
_firstGetNext = false;
if (!nextRecord) {
@@ -445,6 +458,7 @@ void ScanStage::close() {
trackClose();
_cursor.reset();
+ _randomCursor.reset();
_coll.reset();
_open = false;
}
@@ -532,6 +546,10 @@ std::vector<DebugPrinter::Block> ScanStage::debugPrint() const {
DebugPrinter::addIdentifier(ret, DebugPrinter::kNoneKeyword);
}
+ if (_useRandomCursor) {
+ DebugPrinter::addKeyword(ret, "random");
+ }
+
ret.emplace_back(DebugPrinter::Block("[`"));
for (size_t idx = 0; idx < _fields.size(); ++idx) {
if (idx) {
diff --git a/src/mongo/db/exec/sbe/stages/scan.h b/src/mongo/db/exec/sbe/stages/scan.h
index 95982e6eb0c..37462ac5e14 100644
--- a/src/mongo/db/exec/sbe/stages/scan.h
+++ b/src/mongo/db/exec/sbe/stages/scan.h
@@ -107,7 +107,8 @@ public:
bool forward,
PlanYieldPolicy* yieldPolicy,
PlanNodeId nodeId,
- ScanCallbacks scanCallbacks);
+ ScanCallbacks scanCallbacks,
+ bool useRandomCursor = false);
std::unique_ptr<PlanStage> clone() const final;
@@ -132,6 +133,9 @@ protected:
TrialRunTracker* tracker, TrialRunTrackerAttachResultMask childrenAttachResult) override;
private:
+ // Returns the primary cursor or the random cursor depending on whether _useRandomCursor is set.
+ RecordCursor* getActiveCursor() const;
+
const UUID _collUuid;
const boost::optional<value::SlotId> _recordSlot;
const boost::optional<value::SlotId> _recordIdSlot;
@@ -168,6 +172,9 @@ private:
value::SlotAccessor* _indexKeyPatternAccessor{nullptr};
RuntimeEnvironment::Accessor* _oplogTsAccessor{nullptr};
+ // Used to return a random sample of the collection.
+ const bool _useRandomCursor;
+
value::FieldAccessorMap _fieldAccessors;
value::SlotAccessorMap _varAccessors;
value::SlotAccessor* _seekKeyAccessor{nullptr};
@@ -177,6 +184,10 @@ private:
bool _open{false};
std::unique_ptr<SeekableRecordCursor> _cursor;
+
+ // TODO: SERVER-62647. Consider removing random cursor when no longer needed.
+ std::unique_ptr<RecordCursor> _randomCursor;
+
RecordId _key;
bool _firstGetNext{false};
diff --git a/src/mongo/db/exec/sbe/values/bson.h b/src/mongo/db/exec/sbe/values/bson.h
index 73e0cd272e8..f5c429d6890 100644
--- a/src/mongo/db/exec/sbe/values/bson.h
+++ b/src/mongo/db/exec/sbe/values/bson.h
@@ -40,6 +40,12 @@ std::pair<value::TypeTags, value::Value> convertFrom(const char* be,
const char* end,
size_t fieldNameSize);
+template <bool View>
+std::pair<value::TypeTags, value::Value> convertFrom(const BSONElement& elem) {
+ return convertFrom<View>(
+ elem.rawdata(), elem.rawdata() + elem.size(), elem.fieldNameSize() - 1);
+}
+
const char* advance(const char* be, size_t fieldNameSize);
inline auto fieldNameView(const char* be) noexcept {
diff --git a/src/mongo/db/exec/sbe/vm/vm.cpp b/src/mongo/db/exec/sbe/vm/vm.cpp
index 14dbb993def..b4646cefc17 100644
--- a/src/mongo/db/exec/sbe/vm/vm.cpp
+++ b/src/mongo/db/exec/sbe/vm/vm.cpp
@@ -1387,7 +1387,7 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinDropFields(Arit
}
// Build the set of fields to drop.
- StringDataSet restrictFieldsSet;
+ StringSet restrictFieldsSet;
for (ArityType idx = 1; idx < arity; ++idx) {
auto [owned, tag, val] = getFromStack(idx);
@@ -1463,7 +1463,7 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinKeepFields(Arit
}
// Build the set of fields to keep.
- StringDataSet keepFieldsSet;
+ StringSet keepFieldsSet;
for (uint8_t idx = 1; idx < arity; ++idx) {
auto [owned, tag, val] = getFromStack(idx);
diff --git a/src/mongo/db/exec/sbe_cmd.cpp b/src/mongo/db/exec/sbe_cmd.cpp
index 5468c9a54e7..c76885bec0d 100644
--- a/src/mongo/db/exec/sbe_cmd.cpp
+++ b/src/mongo/db/exec/sbe_cmd.cpp
@@ -124,6 +124,7 @@ public:
std::move(cq),
nullptr,
{std::move(root), std::move(data)},
+ {},
&CollectionPtr::null,
false, /* returnOwnedBson */
nss,
diff --git a/src/mongo/db/matcher/match_expression_walker.h b/src/mongo/db/matcher/match_expression_walker.h
index ccef9cb5106..3313a2948ea 100644
--- a/src/mongo/db/matcher/match_expression_walker.h
+++ b/src/mongo/db/matcher/match_expression_walker.h
@@ -44,15 +44,21 @@ public:
: _preVisitor{preVisitor}, _inVisitor{inVisitor}, _postVisitor{postVisitor} {}
void preVisit(const MatchExpression* expr) {
- expr->acceptVisitor(_preVisitor);
+ if (_preVisitor != nullptr) {
+ expr->acceptVisitor(_preVisitor);
+ }
}
void postVisit(const MatchExpression* expr) {
- expr->acceptVisitor(_postVisitor);
+ if (_postVisitor != nullptr) {
+ expr->acceptVisitor(_postVisitor);
+ }
}
void inVisit(long count, const MatchExpression* expr) {
- expr->acceptVisitor(_inVisitor);
+ if (_inVisitor != nullptr) {
+ expr->acceptVisitor(_inVisitor);
+ }
}
private:
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript
index fd3b9ecbb49..ffbda680ee3 100644
--- a/src/mongo/db/pipeline/SConscript
+++ b/src/mongo/db/pipeline/SConscript
@@ -262,6 +262,11 @@ pipelineEnv.InjectThirdParty(libraries=['snappy'])
pipelineEnv.Library(
target='pipeline',
source=[
+ 'abt/abt_document_source_visitor.cpp',
+ 'abt/agg_expression_visitor.cpp',
+ 'abt/expr_algebrizer_context.cpp',
+ 'abt/match_expression_visitor.cpp',
+ 'abt/utils.cpp',
'document_source.cpp',
'document_source_add_fields.cpp',
'document_source_bucket.cpp',
@@ -315,6 +320,8 @@ pipelineEnv.Library(
'sequential_document_cache.cpp',
'skip_and_limit.cpp',
'tee_buffer.cpp',
+ 'visitors/document_source_walker.cpp',
+ 'visitors/transformer_interface_walker.cpp',
'window_function/partition_iterator.cpp',
'window_function/spillable_cache.cpp',
'window_function/window_function_exec.cpp',
@@ -344,6 +351,7 @@ pipelineEnv.Library(
'$BUILD_DIR/mongo/db/query/collation/collator_factory_interface',
'$BUILD_DIR/mongo/db/query/collation/collator_interface',
'$BUILD_DIR/mongo/db/query/datetime/date_time_support',
+ '$BUILD_DIR/mongo/db/query/optimizer/optimizer',
'$BUILD_DIR/mongo/db/query/query_knobs',
'$BUILD_DIR/mongo/db/query/sort_pattern',
'$BUILD_DIR/mongo/db/repl/oplog_entry',
@@ -487,6 +495,7 @@ env.Library(
env.CppUnitTest(
target='db_pipeline_test',
source=[
+ 'abt/pipeline_test.cpp',
'accumulator_js_test.cpp' if get_option('js-engine') != 'none' else [],
'accumulator_test.cpp',
'aggregation_request_test.cpp',
@@ -598,6 +607,7 @@ env.CppUnitTest(
'$BUILD_DIR/mongo/db/exec/document_value/document_value_test_util',
'$BUILD_DIR/mongo/db/mongohasher',
'$BUILD_DIR/mongo/db/query/collation/collator_interface_mock',
+ '$BUILD_DIR/mongo/db/query/optimizer/unit_test_utils',
'$BUILD_DIR/mongo/db/query/query_test_service_context',
'$BUILD_DIR/mongo/db/repl/image_collection_entry',
'$BUILD_DIR/mongo/db/repl/oplog_entry',
diff --git a/src/mongo/db/pipeline/abt/abt_document_source_visitor.cpp b/src/mongo/db/pipeline/abt/abt_document_source_visitor.cpp
new file mode 100644
index 00000000000..913866c7364
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/abt_document_source_visitor.cpp
@@ -0,0 +1,878 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/pipeline/abt/abt_document_source_visitor.h"
+#include "mongo/db/exec/add_fields_projection_executor.h"
+#include "mongo/db/exec/exclusion_projection_executor.h"
+#include "mongo/db/exec/inclusion_projection_executor.h"
+#include "mongo/db/pipeline/abt/agg_expression_visitor.h"
+#include "mongo/db/pipeline/abt/match_expression_visitor.h"
+#include "mongo/db/pipeline/abt/utils.h"
+#include "mongo/db/pipeline/document_source_bucket_auto.h"
+#include "mongo/db/pipeline/document_source_coll_stats.h"
+#include "mongo/db/pipeline/document_source_current_op.h"
+#include "mongo/db/pipeline/document_source_cursor.h"
+#include "mongo/db/pipeline/document_source_exchange.h"
+#include "mongo/db/pipeline/document_source_facet.h"
+#include "mongo/db/pipeline/document_source_geo_near.h"
+#include "mongo/db/pipeline/document_source_geo_near_cursor.h"
+#include "mongo/db/pipeline/document_source_graph_lookup.h"
+#include "mongo/db/pipeline/document_source_group.h"
+#include "mongo/db/pipeline/document_source_index_stats.h"
+#include "mongo/db/pipeline/document_source_internal_inhibit_optimization.h"
+#include "mongo/db/pipeline/document_source_internal_shard_filter.h"
+#include "mongo/db/pipeline/document_source_internal_split_pipeline.h"
+#include "mongo/db/pipeline/document_source_limit.h"
+#include "mongo/db/pipeline/document_source_list_cached_and_active_users.h"
+#include "mongo/db/pipeline/document_source_list_local_sessions.h"
+#include "mongo/db/pipeline/document_source_list_sessions.h"
+#include "mongo/db/pipeline/document_source_lookup.h"
+#include "mongo/db/pipeline/document_source_match.h"
+#include "mongo/db/pipeline/document_source_merge.h"
+#include "mongo/db/pipeline/document_source_operation_metrics.h"
+#include "mongo/db/pipeline/document_source_out.h"
+#include "mongo/db/pipeline/document_source_plan_cache_stats.h"
+#include "mongo/db/pipeline/document_source_queue.h"
+#include "mongo/db/pipeline/document_source_redact.h"
+#include "mongo/db/pipeline/document_source_replace_root.h"
+#include "mongo/db/pipeline/document_source_sample.h"
+#include "mongo/db/pipeline/document_source_sample_from_random_cursor.h"
+#include "mongo/db/pipeline/document_source_sequential_document_cache.h"
+#include "mongo/db/pipeline/document_source_single_document_transformation.h"
+#include "mongo/db/pipeline/document_source_skip.h"
+#include "mongo/db/pipeline/document_source_sort.h"
+#include "mongo/db/pipeline/document_source_tee_consumer.h"
+#include "mongo/db/pipeline/document_source_union_with.h"
+#include "mongo/db/pipeline/document_source_unwind.h"
+#include "mongo/db/pipeline/visitors/document_source_walker.h"
+#include "mongo/db/pipeline/visitors/transformer_interface_walker.h"
+#include "mongo/s/query/document_source_merge_cursors.h"
+#include "mongo/util/string_map.h"
+
+namespace mongo::optimizer {
+
+class DSAlgebrizerContext {
+public:
+ struct NodeWithRootProjection {
+ NodeWithRootProjection(ProjectionName rootProjection, ABT node)
+ : _rootProjection(std::move(rootProjection)), _node(std::move(node)) {}
+
+ ProjectionName _rootProjection;
+ ABT _node;
+ };
+
+ DSAlgebrizerContext(PrefixId& prefixId, NodeWithRootProjection node)
+ : _node(std::move(node)), _scanProjName(_node._rootProjection), _prefixId(prefixId) {
+ assertNodeSort(_node._node);
+ }
+
+ template <typename T, typename... Args>
+ inline auto setNode(ProjectionName rootProjection, Args&&... args) {
+ setNode(std::move(rootProjection), std::move(ABT::make<T>(std::forward<Args>(args)...)));
+ }
+
+ void setNode(ProjectionName rootProjection, ABT node) {
+ assertNodeSort(node);
+ _node._node = std::move(node);
+ _node._rootProjection = std::move(rootProjection);
+ }
+
+ NodeWithRootProjection& getNode() {
+ return _node;
+ }
+
+ std::string getNextId(const std::string& key) {
+ return _prefixId.getNextId(key);
+ }
+
+ PrefixId& getPrefixId() {
+ return _prefixId;
+ }
+
+ const ProjectionName& getScanProjName() const {
+ return _scanProjName;
+ }
+
+private:
+ NodeWithRootProjection _node;
+ ProjectionName _scanProjName;
+
+ // We don't own this.
+ PrefixId& _prefixId;
+};
+
+class ABTTransformerVisitor : public TransformerInterfaceConstVisitor {
+ static constexpr const char* kRootElement = "$root";
+
+public:
+ ABTTransformerVisitor(DSAlgebrizerContext& ctx) : _ctx(ctx) {}
+
+ void visit(const projection_executor::AddFieldsProjectionExecutor* transformer) override {
+ visitInclusionNode(transformer->getRoot(), true /*isAddingFields*/);
+ }
+
+ void visit(const projection_executor::ExclusionProjectionExecutor* transformer) override {
+ visitExclusionNode(*transformer->getRoot());
+ }
+
+ void visit(const projection_executor::InclusionProjectionExecutor* transformer) override {
+ visitInclusionNode(*transformer->getRoot(), false /*isAddingFields*/);
+ }
+
+ void visit(const GroupFromFirstDocumentTransformation* transformer) override {
+ // TODO: Is this internal-only?
+ unsupportedTransformer(transformer);
+ }
+
+ void visit(const ReplaceRootTransformation* transformer) override {
+ auto entry = _ctx.getNode();
+ const std::string& projName = _ctx.getNextId("newRoot");
+ ABT expr = generateAggExpression(
+ transformer->getExpression().get(), entry._rootProjection, projName);
+
+ _ctx.setNode<EvaluationNode>(projName, projName, std::move(expr), std::move(entry._node));
+ }
+
+ void generateCombinedProjection() const {
+ auto it = _fieldMap.find(kRootElement);
+ if (it == _fieldMap.cend()) {
+ return;
+ }
+
+ ABT result = generateABTForField(it->second);
+ auto entry = _ctx.getNode();
+ const ProjectionName projName = _ctx.getNextId("combinedProjection");
+ _ctx.setNode<EvaluationNode>(projName, projName, std::move(result), std::move(entry._node));
+ }
+
+private:
+ struct FieldMapEntry {
+ FieldMapEntry(std::string fieldName) : _fieldName(std::move(fieldName)) {
+ uassert(6624200, "Empty field name", !_fieldName.empty());
+ }
+
+ std::string _fieldName;
+ bool _hasKeep = false;
+ bool _hasLeadingObj = false;
+ bool _hasTrailingDefault = false;
+ bool _hasDrop = false;
+ std::string _constVarName;
+
+ std::set<std::string> _childPaths;
+ };
+
+ ABT generateABTForField(const FieldMapEntry& entry) const {
+ const bool isRootEntry = entry._fieldName == kRootElement;
+
+ bool hasLeadingObj = false;
+ bool hasTrailingDefault = false;
+ std::set<std::string> keepSet;
+ std::set<std::string> dropSet;
+ std::map<std::string, std::string> varMap;
+
+ for (const std::string& childField : entry._childPaths) {
+ const FieldMapEntry& childEntry = _fieldMap.at(childField);
+ const std::string& childFieldName = childEntry._fieldName;
+
+ if (childEntry._hasKeep) {
+ keepSet.insert(childFieldName);
+ }
+ if (childEntry._hasDrop) {
+ dropSet.insert(childFieldName);
+ }
+ if (childEntry._hasLeadingObj) {
+ hasLeadingObj = true;
+ }
+ if (childEntry._hasTrailingDefault) {
+ hasTrailingDefault = true;
+ }
+ if (!childEntry._constVarName.empty()) {
+ varMap.emplace(childFieldName, childEntry._constVarName);
+ }
+ }
+
+ const auto& ctxEntry = _ctx.getNode();
+ const ProjectionName& rootProjName = ctxEntry._rootProjection;
+
+ ABT result = make<PathIdentity>();
+ if (hasLeadingObj && (!isRootEntry || rootProjName != _ctx.getScanProjName())) {
+ // We do not need a leading Obj if we are using the scan projection directly (scan
+ // delivers Objects).
+ maybeComposePath(result, make<PathObj>());
+ }
+ if (!keepSet.empty()) {
+ maybeComposePath(result, make<PathKeep>(toUnorderedFieldNameSet(std::move(keepSet))));
+ }
+ if (!dropSet.empty()) {
+ maybeComposePath(result, make<PathDrop>(toUnorderedFieldNameSet(std::move(dropSet))));
+ }
+
+ for (const auto& varMapEntry : varMap) {
+ maybeComposePath(
+ result,
+ make<PathField>(varMapEntry.first,
+ make<PathConstant>(make<Variable>(varMapEntry.second))));
+ }
+
+ for (const std::string& childPath : entry._childPaths) {
+ const FieldMapEntry& childEntry = _fieldMap.at(childPath);
+
+ ABT childResult = generateABTForField(childEntry);
+ if (!childResult.is<PathIdentity>()) {
+ maybeComposePath(result,
+ make<PathField>(childEntry._fieldName,
+ make<PathTraverse>(std::move(childResult))));
+ }
+ }
+
+ if (hasTrailingDefault) {
+ maybeComposePath(result, make<PathDefault>(Constant::emptyObject()));
+ }
+ if (!isRootEntry) {
+ return result;
+ }
+ return make<EvalPath>(std::move(result), make<Variable>(rootProjName));
+ }
+
+ void integrateFieldPath(
+ const FieldPath& fieldPath,
+ const std::function<void(const bool isLastElement, FieldMapEntry& entry)>& fn) {
+ std::string path = kRootElement;
+ auto it = _fieldMap.emplace(path, kRootElement);
+ const size_t fieldPathLength = fieldPath.getPathLength();
+
+ for (size_t i = 0; i < fieldPathLength; i++) {
+ const std::string& fieldName = fieldPath.getFieldName(i).toString();
+ path += '.' + fieldName;
+
+ it.first->second._childPaths.insert(path);
+ it = _fieldMap.emplace(path, fieldName);
+ fn(i == fieldPathLength - 1, it.first->second);
+ }
+ }
+
+ void unsupportedTransformer(const TransformerInterface* transformer) const {
+ uasserted(ErrorCodes::InternalErrorNotSupported,
+ str::stream() << "Transformer is not supported (code: "
+ << static_cast<int>(transformer->getType()) << ")");
+ }
+
+ void processProjectedPaths(const projection_executor::InclusionNode& node) {
+ std::set<std::string> preservedPaths;
+ node.reportProjectedPaths(&preservedPaths);
+
+ for (const std::string& preservedPathStr : preservedPaths) {
+ integrateFieldPath(FieldPath(preservedPathStr),
+ [](const bool isLastElement, FieldMapEntry& entry) {
+ entry._hasLeadingObj = true;
+ entry._hasKeep = true;
+ });
+ }
+ }
+
+ void processComputedPaths(const projection_executor::InclusionNode& node,
+ const std::string& rootProjection,
+ const bool isAddingFields) {
+ std::set<std::string> computedPaths;
+ StringMap<std::string> renamedPaths;
+ node.reportComputedPaths(&computedPaths, &renamedPaths);
+
+ // Handle path renames: essentially single element FieldPath expression.
+ for (const auto& renamedPathEntry : renamedPaths) {
+ ABT path = translateFieldPath(
+ FieldPath(renamedPathEntry.second),
+ make<PathIdentity>(),
+ [](const std::string& fieldName, const bool isLastElement, ABT input) {
+ return make<PathGet>(fieldName,
+ isLastElement ? std::move(input)
+ : make<PathTraverse>(std::move(input)));
+ });
+
+ auto entry = _ctx.getNode();
+ const std::string& renamedProjName = _ctx.getNextId("projRenamedPath");
+ _ctx.setNode<EvaluationNode>(
+ entry._rootProjection,
+ renamedProjName,
+ make<EvalPath>(std::move(path), make<Variable>(entry._rootProjection)),
+ std::move(entry._node));
+
+ integrateFieldPath(FieldPath(renamedPathEntry.first),
+ [&renamedProjName, &isAddingFields](const bool isLastElement,
+ FieldMapEntry& entry) {
+ if (!isAddingFields) {
+ entry._hasKeep = true;
+ }
+ if (isLastElement) {
+ entry._constVarName = renamedProjName;
+ entry._hasTrailingDefault = true;
+ }
+ });
+ }
+
+ // Handle general expression projection.
+ for (const std::string& computedPathStr : computedPaths) {
+ const FieldPath computedPath(computedPathStr);
+
+ auto entry = _ctx.getNode();
+ const std::string& getProjName = _ctx.getNextId("projGetPath");
+ ABT getExpr = generateAggExpression(
+ node.getExpressionForPath(computedPath).get(), rootProjection, getProjName);
+
+ _ctx.setNode<EvaluationNode>(std::move(entry._rootProjection),
+ getProjName,
+ std::move(getExpr),
+ std::move(entry._node));
+
+ integrateFieldPath(
+ computedPath,
+ [&getProjName, &isAddingFields](const bool isLastElement, FieldMapEntry& entry) {
+ if (!isAddingFields) {
+ entry._hasKeep = true;
+ }
+ if (isLastElement) {
+ entry._constVarName = getProjName;
+ entry._hasTrailingDefault = true;
+ }
+ });
+ }
+ }
+
+ void visitInclusionNode(const projection_executor::InclusionNode& node,
+ const bool isAddingFields) {
+ auto entry = _ctx.getNode();
+ const std::string rootProjection = entry._rootProjection;
+
+ processProjectedPaths(node);
+ processComputedPaths(node, rootProjection, isAddingFields);
+ }
+
+ void visitExclusionNode(const projection_executor::ExclusionNode& node) {
+ std::set<std::string> preservedPaths;
+ node.reportProjectedPaths(&preservedPaths);
+
+ for (const std::string& preservedPathStr : preservedPaths) {
+ integrateFieldPath(FieldPath(preservedPathStr),
+ [](const bool isLastElement, FieldMapEntry& entry) {
+ if (isLastElement) {
+ entry._hasDrop = true;
+ }
+ });
+ }
+ }
+
+ DSAlgebrizerContext& _ctx;
+
+ opt::unordered_map<std::string, FieldMapEntry> _fieldMap;
+};
+
+class ABTDocumentSourceVisitor : public DocumentSourceConstVisitor {
+public:
+ ABTDocumentSourceVisitor(DSAlgebrizerContext& ctx, const Metadata& metadata)
+ : _ctx(ctx), _metadata(metadata) {}
+
+ void visit(const DocumentSourceBucketAuto* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceCollStats* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceCurrentOp* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceCursor* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceExchange* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceFacet* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceGeoNear* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceGeoNearCursor* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceGraphLookUp* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceGroup* source) override {
+ const StringMap<boost::intrusive_ptr<Expression>>& idFields = source->getIdFields();
+ uassert(6624201, "Empty idFields map", !idFields.empty());
+
+ std::vector<FieldNameType> groupByFieldNames;
+ for (const auto& [fieldName, expr] : idFields) {
+ groupByFieldNames.push_back(fieldName);
+ }
+ // Sort in order to generate consistent plans.
+ std::sort(groupByFieldNames.begin(), groupByFieldNames.end());
+
+ ProjectionNameVector groupByProjNames;
+ auto entry = _ctx.getNode();
+ for (const FieldNameType& fieldName : groupByFieldNames) {
+ const ProjectionName groupByProjName = _ctx.getNextId("groupByProj");
+ groupByProjNames.push_back(groupByProjName);
+
+ ABT groupByExpr = generateAggExpression(
+ idFields.at(fieldName).get(), entry._rootProjection, groupByProjName);
+
+ _ctx.setNode<EvaluationNode>(entry._rootProjection,
+ groupByProjName,
+ std::move(groupByExpr),
+ std::move(entry._node));
+ entry = _ctx.getNode();
+ }
+
+ // Fields corresponding to each accumulator
+ ProjectionNameVector aggProjFieldNames;
+ // Projection names corresponding to each high-level accumulator ($avg can be broken down
+ // into sum and count.).
+ ProjectionNameVector aggOutputProjNames;
+ // Projection names corresponding to each low-level accumulator (no $avg).
+ ProjectionNameVector aggLowLevelOutputProjNames;
+
+ ABTVector aggregationProjections;
+ const std::vector<AccumulationStatement>& accumulatedFields =
+ source->getAccumulatedFields();
+
+ // Used to keep track which $sum and $count projections to use to compute $avg.
+ struct AvgProjNames {
+ ProjectionName _output;
+ ProjectionName _sum;
+ ProjectionName _count;
+ };
+ std::vector<AvgProjNames> avgProjNames;
+
+ for (const AccumulationStatement& stmt : accumulatedFields) {
+ const FieldNameType& fieldName = stmt.fieldName;
+ aggProjFieldNames.push_back(fieldName);
+
+ ProjectionName aggOutputProjName = _ctx.getNextId(fieldName + "_agg");
+
+ ABT aggInputExpr = generateAggExpression(
+ stmt.expr.argument.get(), entry._rootProjection, aggOutputProjName);
+ if (!aggInputExpr.is<Constant>() && !aggInputExpr.is<Variable>()) {
+ // Generate nodes for complex projections, otherwise inline constants and variables
+ // into the group.
+ const ProjectionName aggInputProjName = _ctx.getNextId("groupByInputProj");
+ _ctx.setNode<EvaluationNode>(entry._rootProjection,
+ aggInputProjName,
+ std::move(aggInputExpr),
+ std::move(entry._node));
+ entry = _ctx.getNode();
+ aggInputExpr = make<Variable>(aggInputProjName);
+ }
+
+ aggOutputProjNames.push_back(aggOutputProjName);
+ if (stmt.makeAccumulator()->getOpName() == "$avg"_sd) {
+ // Express $avg as sum / count.
+ ProjectionName sumProjName = _ctx.getNextId(fieldName + "_sum_agg");
+ aggLowLevelOutputProjNames.push_back(sumProjName);
+ ProjectionName countProjName = _ctx.getNextId(fieldName + "_count_agg");
+ aggLowLevelOutputProjNames.push_back(countProjName);
+ avgProjNames.emplace_back(AvgProjNames{std::move(aggOutputProjName),
+ std::move(sumProjName),
+ std::move(countProjName)});
+
+ aggregationProjections.emplace_back(
+ make<FunctionCall>("$sum", makeSeq(aggInputExpr)));
+ aggregationProjections.emplace_back(
+ make<FunctionCall>("$sum", makeSeq(Constant::int64(1))));
+ } else {
+ aggLowLevelOutputProjNames.push_back(std::move(aggOutputProjName));
+ aggregationProjections.emplace_back(make<FunctionCall>(
+ stmt.makeAccumulator()->getOpName(), makeSeq(std::move(aggInputExpr))));
+ }
+ }
+
+ ABT result = make<GroupByNode>(ProjectionNameVector{groupByProjNames},
+ aggLowLevelOutputProjNames,
+ aggregationProjections,
+ std::move(entry._node));
+
+ for (auto&& [outputProjName, sumProjName, countProjName] : avgProjNames) {
+ result = make<EvaluationNode>(
+ std::move(outputProjName),
+ make<If>(make<BinaryOp>(
+ Operations::Gt, make<Variable>(countProjName), Constant::int64(0)),
+ make<BinaryOp>(Operations::Div,
+ make<Variable>(std::move(sumProjName)),
+ make<Variable>(countProjName)),
+ Constant::nothing()),
+ std::move(result));
+ }
+
+ ABT integrationPath = make<PathIdentity>();
+ for (size_t i = 0; i < groupByFieldNames.size(); i++) {
+ maybeComposePath(integrationPath,
+ make<PathField>(std::move(groupByFieldNames.at(i)),
+ make<PathConstant>(make<Variable>(
+ std::move(groupByProjNames.at(i))))));
+ }
+ for (size_t i = 0; i < aggProjFieldNames.size(); i++) {
+ maybeComposePath(
+ integrationPath,
+ make<PathField>(aggProjFieldNames.at(i),
+ make<PathConstant>(make<Variable>(aggOutputProjNames.at(i)))));
+ }
+
+ entry = _ctx.getNode();
+ const std::string& mergeProject = _ctx.getNextId("agg_project");
+ _ctx.setNode<EvaluationNode>(
+ mergeProject,
+ mergeProject,
+ make<EvalPath>(std::move(integrationPath), Constant::emptyObject()),
+ std::move(result));
+ }
+
+ void visit(const DocumentSourceIndexStats* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceInternalInhibitOptimization* source) override {
+ // Can be ignored.
+ }
+
+ void visit(const DocumentSourceInternalShardFilter* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceInternalSplitPipeline* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceLimit* source) override {
+ pushLimitSkip(source->getLimit(), 0);
+ }
+
+ void visit(const DocumentSourceListCachedAndActiveUsers* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceListLocalSessions* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceListSessions* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceLookUp* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceMatch* source) override {
+ auto entry = _ctx.getNode();
+ ABT matchExpr = generateMatchExpression(source->getMatchExpression(),
+ true /*allowAggExpressions*/,
+ entry._rootProjection,
+ _ctx.getNextId("matchExpression"));
+
+ // If we have a top-level composition, create separate filters.
+ const auto& composition = collectComposed(matchExpr);
+ for (const auto& path : composition) {
+ _ctx.setNode<FilterNode>(entry._rootProjection,
+ make<EvalFilter>(path, make<Variable>(entry._rootProjection)),
+ std::move(entry._node));
+ entry = _ctx.getNode();
+ }
+ }
+
+ void visit(const DocumentSourceMerge* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceMergeCursors* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceOperationMetrics* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceOut* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourcePlanCacheStats* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceQueue* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceRedact* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceSample* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceSampleFromRandomCursor* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceSequentialDocumentCache* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceSingleDocumentTransformation* source) override {
+ ABTTransformerVisitor visitor(_ctx);
+ TransformerInterfaceWalker walker(&visitor);
+ walker.walk(&source->getTransformer());
+ visitor.generateCombinedProjection();
+ }
+
+ void visit(const DocumentSourceSkip* source) override {
+ pushLimitSkip(-1, source->getSkip());
+ }
+
+ void visit(const DocumentSourceSort* source) override {
+ ProjectionCollationSpec collationSpec;
+ const SortPattern& pattern = source->getSortKeyPattern();
+ for (size_t i = 0; i < pattern.size(); i++) {
+ const SortPattern::SortPatternPart& part = pattern[i];
+ if (!part.fieldPath.has_value()) {
+ // TODO: consider metadata expression.
+ continue;
+ }
+
+ const std::string& sortProjName = _ctx.getNextId("sort");
+ collationSpec.emplace_back(
+ sortProjName, part.isAscending ? CollationOp::Ascending : CollationOp::Descending);
+
+ const FieldPath& fieldPath = part.fieldPath.value();
+ ABT sortPath = make<PathIdentity>();
+ for (size_t j = 0; j < fieldPath.getPathLength(); j++) {
+ sortPath = make<PathGet>(fieldPath.getFieldName(j).toString(), std::move(sortPath));
+ }
+
+ auto entry = _ctx.getNode();
+ _ctx.setNode<EvaluationNode>(
+ entry._rootProjection,
+ sortProjName,
+ make<EvalPath>(std::move(sortPath), make<Variable>(entry._rootProjection)),
+ std::move(entry._node));
+ }
+
+ if (!collationSpec.empty()) {
+ auto entry = _ctx.getNode();
+ _ctx.setNode<CollationNode>(std::move(entry._rootProjection),
+ properties::CollationRequirement(std::move(collationSpec)),
+ std::move(entry._node));
+ }
+
+ if (source->getLimit().has_value()) {
+ // We need to limit the result of the collation.
+ pushLimitSkip(source->getLimit().value(), 0);
+ }
+ }
+
+ void visit(const DocumentSourceTeeConsumer* source) override {
+ unsupportedStage(source);
+ }
+
+ void visit(const DocumentSourceUnionWith* source) override {
+ auto entry = _ctx.getNode();
+ ProjectionName unionProjName = entry._rootProjection;
+
+ const Pipeline& pipeline = source->getPipeline();
+
+ NamespaceString involvedNss = pipeline.getContext()->ns;
+ std::string scanDefName = involvedNss.coll().toString();
+ const ProjectionName& scanProjName = _ctx.getNextId("scan");
+
+ PrefixId prefixId;
+ ABT initialNode = _metadata._scanDefs.at(scanDefName).exists()
+ ? make<ScanNode>(scanProjName, scanDefName)
+ : make<ValueScanNode>(ProjectionNameVector{scanProjName});
+ ABT pipelineABT = translatePipelineToABT(
+ _metadata, pipeline, scanProjName, std::move(initialNode), prefixId);
+
+ uassert(6624425, "Expected root node for union pipeline", pipelineABT.is<RootNode>());
+ ABT pipelineABTWithoutRoot = pipelineABT.cast<RootNode>()->getChild();
+ // Pull out the root projection(s) from the inner pipeline.
+ const ProjectionNameVector& rootProjections =
+ pipelineABT.cast<RootNode>()->getProperty().getProjections().getVector();
+ uassert(6624426,
+ "Expected a single projection for inner union branch",
+ rootProjections.size() == 1);
+
+ // Add an evaluation node such that it shares a projection with the outer pipeline. If the
+ // same projection name is already defined in the inner pipeline then there's no need for
+ // the extra eval node.
+ const ProjectionName& innerProjection = rootProjections[0];
+ ProjectionName newRootProj = unionProjName;
+ if (innerProjection != unionProjName) {
+ ABT evalNodeInner = make<EvaluationNode>(
+ unionProjName,
+ make<EvalPath>(make<PathIdentity>(), make<Variable>(innerProjection)),
+ std::move(pipelineABTWithoutRoot));
+ _ctx.setNode(
+ std::move(newRootProj),
+ make<UnionNode>(ProjectionNameVector{std::move(unionProjName)},
+ makeSeq(std::move(entry._node), std::move(evalNodeInner))));
+ } else {
+ _ctx.setNode(std::move(newRootProj),
+ make<UnionNode>(
+ ProjectionNameVector{std::move(unionProjName)},
+ makeSeq(std::move(entry._node), std::move(pipelineABTWithoutRoot))));
+ }
+ }
+
+ void visit(const DocumentSourceUnwind* source) override {
+ const FieldPath& unwindFieldPath = source->getUnwindPath();
+ const bool preserveNullAndEmpty = source->preserveNullAndEmptyArrays();
+
+ const std::string pidProjName = _ctx.getNextId("unwoundPid");
+ const std::string unwoundProjName = _ctx.getNextId("unwoundProj");
+
+ const auto generatePidGteZeroTest = [&pidProjName](ABT thenCond, ABT elseCond) {
+ return make<If>(
+ make<BinaryOp>(Operations::Gte, make<Variable>(pidProjName), Constant::int64(0)),
+ std::move(thenCond),
+ std::move(elseCond));
+ };
+
+ ABT embedPath = make<Variable>(unwoundProjName);
+ if (preserveNullAndEmpty) {
+ const std::string unwindLambdaVarName = _ctx.getNextId("unwoundLambdaVarName");
+ embedPath = make<PathLambda>(make<LambdaAbstraction>(
+ unwindLambdaVarName,
+ generatePidGteZeroTest(std::move(embedPath), make<Variable>(unwindLambdaVarName))));
+ } else {
+ embedPath = make<PathConstant>(std::move(embedPath));
+ }
+ embedPath = translateFieldPath(
+ unwindFieldPath,
+ std::move(embedPath),
+ [](const std::string& fieldName, const bool isLastElement, ABT input) {
+ return make<PathField>(fieldName,
+ isLastElement ? std::move(input)
+ : make<PathTraverse>(std::move(input)));
+ });
+
+ ABT unwoundPath = translateFieldPath(
+ unwindFieldPath,
+ make<PathIdentity>(),
+ [](const std::string& fieldName, const bool isLastElement, ABT input) {
+ return make<PathGet>(fieldName, std::move(input));
+ });
+
+ auto entry = _ctx.getNode();
+ _ctx.setNode<EvaluationNode>(
+ entry._rootProjection,
+ unwoundProjName,
+ make<EvalPath>(std::move(unwoundPath), make<Variable>(entry._rootProjection)),
+ std::move(entry._node));
+
+ entry = _ctx.getNode();
+ _ctx.setNode<UnwindNode>(std::move(entry._rootProjection),
+ unwoundProjName,
+ pidProjName,
+ preserveNullAndEmpty,
+ std::move(entry._node));
+
+ entry = _ctx.getNode();
+ const std::string embedProjName = _ctx.getNextId("embedProj");
+ _ctx.setNode<EvaluationNode>(
+ embedProjName,
+ embedProjName,
+ make<EvalPath>(std::move(embedPath), make<Variable>(entry._rootProjection)),
+ std::move(entry._node));
+
+ if (source->indexPath().has_value()) {
+ const FieldPath indexFieldPath = source->indexPath().get();
+ if (indexFieldPath.getPathLength() > 0) {
+ ABT indexPath = translateFieldPath(
+ indexFieldPath,
+ make<PathConstant>(
+ generatePidGteZeroTest(make<Variable>(pidProjName), Constant::null())),
+ [](const std::string& fieldName, const bool isLastElement, ABT input) {
+ return make<PathField>(fieldName, std::move(input));
+ });
+
+ entry = _ctx.getNode();
+ const std::string embedPidProjName = _ctx.getNextId("embedPidProj");
+ _ctx.setNode<EvaluationNode>(
+ embedPidProjName,
+ embedPidProjName,
+ make<EvalPath>(std::move(indexPath), make<Variable>(entry._rootProjection)),
+ std::move(entry._node));
+ }
+ }
+ }
+
+private:
+ void unsupportedStage(const DocumentSource* source) const {
+ uasserted(ErrorCodes::InternalErrorNotSupported,
+ str::stream() << "Stage is not supported: " << source->getSourceName());
+ }
+
+ void pushLimitSkip(const int64_t limit, const int64_t skip) {
+ auto entry = _ctx.getNode();
+ _ctx.setNode<LimitSkipNode>(std::move(entry._rootProjection),
+ properties::LimitSkipRequirement(limit, skip),
+ std::move(entry._node));
+ }
+
+ DSAlgebrizerContext& _ctx;
+ const Metadata& _metadata;
+};
+
+ABT translatePipelineToABT(const Metadata& metadata,
+ const Pipeline& pipeline,
+ ProjectionName scanProjName,
+ ABT initialNode,
+ PrefixId& prefixId) {
+ DSAlgebrizerContext ctx(prefixId, {scanProjName, std::move(initialNode)});
+ ABTDocumentSourceVisitor visitor(ctx, metadata);
+
+ DocumentSourceWalker walker(nullptr /*preVisitor*/, &visitor);
+ walker.walk(pipeline);
+
+ auto entry = ctx.getNode();
+ return make<RootNode>(
+ properties::ProjectionRequirement{ProjectionNameVector{std::move(entry._rootProjection)}},
+ std::move(entry._node));
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/pipeline/abt/abt_document_source_visitor.h b/src/mongo/db/pipeline/abt/abt_document_source_visitor.h
new file mode 100644
index 00000000000..d52d1fe8b23
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/abt_document_source_visitor.h
@@ -0,0 +1,44 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/pipeline/pipeline.h"
+#include "mongo/db/query/optimizer/node.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer {
+
+ABT translatePipelineToABT(const Metadata& metadata,
+ const Pipeline& pipeline,
+ ProjectionName scanProjName,
+ ABT initialNode,
+ PrefixId& prefixId);
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp b/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp
new file mode 100644
index 00000000000..ff6dc5e3b45
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp
@@ -0,0 +1,840 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <stack>
+
+#include "mongo/base/error_codes.h"
+#include "mongo/db/pipeline/abt/agg_expression_visitor.h"
+#include "mongo/db/pipeline/abt/expr_algebrizer_context.h"
+#include "mongo/db/pipeline/abt/utils.h"
+#include "mongo/db/pipeline/accumulator.h"
+#include "mongo/db/pipeline/accumulator_multi.h"
+#include "mongo/db/pipeline/expression_walker.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer {
+
+class ABTAggExpressionVisitor final : public ExpressionConstVisitor {
+public:
+ ABTAggExpressionVisitor(ExpressionAlgebrizerContext& ctx) : _prefixId(), _ctx(ctx){};
+
+ void visit(const ExpressionConstant* expr) override final {
+ auto [tag, val] = convertFrom(expr->getValue());
+ _ctx.push<Constant>(tag, val);
+ }
+
+ void visit(const ExpressionAbs* expr) override final {
+ pushSingleArgFunctionFromTop("abs");
+ }
+
+ void visit(const ExpressionAdd* expr) override final {
+ pushArithmeticBinaryExpr(expr, Operations::Add);
+ }
+
+ void visit(const ExpressionAllElementsTrue* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionAnd* expr) override final {
+ visitMultiBranchLogicExpression(expr, Operations::And);
+ }
+
+ void visit(const ExpressionAnyElementTrue* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionArray* expr) override final {
+ const size_t childCount = expr->getChildren().size();
+ _ctx.ensureArity(childCount);
+
+ // Need to process in reverse order because of the stack.
+ ABTVector args;
+ for (size_t i = 0; i < childCount; i++) {
+ args.emplace_back(_ctx.pop());
+ }
+
+ std::reverse(args.begin(), args.end());
+ _ctx.push<FunctionCall>("newArray", std::move(args));
+ }
+
+ void visit(const ExpressionArrayElemAt* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionFirst* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionLast* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionObjectToArray* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionArrayToObject* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionBsonSize* expr) override final {
+ pushSingleArgFunctionFromTop("bsonSize");
+ }
+
+ void visit(const ExpressionCeil* expr) override final {
+ pushSingleArgFunctionFromTop("ceil");
+ }
+
+ void visit(const ExpressionCoerceToBool* expr) override final {
+ // Since $coerceToBool is internal-only and there are not yet any input expressions that
+ // generate an ExpressionCoerceToBool expression, we will leave it as unreachable for now.
+ MONGO_UNREACHABLE;
+ }
+
+ void visit(const ExpressionCompare* expr) override final {
+ _ctx.ensureArity(2);
+ ABT right = _ctx.pop();
+ ABT left = _ctx.pop();
+
+ switch (expr->getOp()) {
+ case ExpressionCompare::CmpOp::EQ:
+ _ctx.push<BinaryOp>(Operations::Eq, std::move(left), std::move(right));
+ break;
+
+ case ExpressionCompare::CmpOp::NE:
+ _ctx.push<BinaryOp>(Operations::Neq, std::move(left), std::move(right));
+ break;
+
+ case ExpressionCompare::CmpOp::GT:
+ _ctx.push<BinaryOp>(Operations::Gt, std::move(left), std::move(right));
+ break;
+
+ case ExpressionCompare::CmpOp::GTE:
+ _ctx.push<BinaryOp>(Operations::Gte, std::move(left), std::move(right));
+ break;
+
+ case ExpressionCompare::CmpOp::LT:
+ _ctx.push<BinaryOp>(Operations::Lt, std::move(left), std::move(right));
+ break;
+
+ case ExpressionCompare::CmpOp::LTE:
+ _ctx.push<BinaryOp>(Operations::Lte, std::move(left), std::move(right));
+ break;
+
+ case ExpressionCompare::CmpOp::CMP:
+ _ctx.push<FunctionCall>("cmp3w", makeSeq(std::move(left), std::move(right)));
+ break;
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+ }
+
+ void visit(const ExpressionConcat* expr) override final {
+ pushMultiArgFunctionFromTop("concat", expr->getChildren().size());
+ }
+
+ void visit(const ExpressionConcatArrays* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionCond* expr) override final {
+ _ctx.ensureArity(3);
+ ABT cond = generateCoerceToBoolPopInput();
+ ABT thenCase = _ctx.pop();
+ ABT elseCase = _ctx.pop();
+ _ctx.push<If>(std::move(cond), std::move(thenCase), std::move(elseCase));
+ }
+
+ void visit(const ExpressionDateFromString* expr) override final {
+ unsupportedExpression("$dateFromString");
+ }
+
+ void visit(const ExpressionDateFromParts* expr) override final {
+ unsupportedExpression("$dateFromParts");
+ }
+
+ void visit(const ExpressionDateDiff* expr) override final {
+ unsupportedExpression("$dateDiff");
+ }
+
+ void visit(const ExpressionDateToParts* expr) override final {
+ unsupportedExpression("$dateToParts");
+ }
+
+ void visit(const ExpressionDateToString* expr) override final {
+ unsupportedExpression("$dateToString");
+ }
+
+ void visit(const ExpressionDateTrunc* expr) override final {
+ unsupportedExpression("$dateTrunc");
+ }
+
+ void visit(const ExpressionDivide* expr) override final {
+ pushArithmeticBinaryExpr(expr, Operations::Div);
+ }
+
+ void visit(const ExpressionExp* expr) override final {
+ pushSingleArgFunctionFromTop("exp");
+ }
+
+ void visit(const ExpressionFieldPath* expr) override final {
+ const auto& varId = expr->getVariableId();
+ if (Variables::isUserDefinedVariable(varId)) {
+ _ctx.push<Variable>(generateVariableName(varId));
+ return;
+ }
+
+ const FieldPath& fieldPath = expr->getFieldPath();
+ const size_t pathLength = fieldPath.getPathLength();
+ if (pathLength < 1) {
+ return;
+ }
+
+ const auto& firstFieldName = fieldPath.getFieldName(0);
+ if (pathLength == 1 && firstFieldName == "ROOT") {
+ _ctx.push<Variable>(_ctx.getRootProjection());
+ return;
+ }
+ uassert(6624239, "Unexpected leading path element.", firstFieldName == "CURRENT");
+
+ // Here we skip over "CURRENT" first path element. This is represented by rootProjection
+ // variable.
+ ABT path = translateFieldPath(
+ fieldPath,
+ make<PathIdentity>(),
+ [](const std::string& fieldName, const bool isLastElement, ABT input) {
+ return make<PathGet>(fieldName,
+ isLastElement ? std::move(input)
+ : make<PathTraverse>(std::move(input)));
+ },
+ 1ul);
+
+ _ctx.push<EvalPath>(std::move(path), make<Variable>(_ctx.getRootProjection()));
+ }
+
+ void visit(const ExpressionFilter* expr) override final {
+ const auto& varId = expr->getVariableId();
+ uassert(6624427,
+ "Filter variable must be user-defined.",
+ Variables::isUserDefinedVariable(varId));
+ const std::string& varName = generateVariableName(varId);
+
+ _ctx.ensureArity(2);
+ ABT filter = _ctx.pop();
+ ABT input = _ctx.pop();
+
+ _ctx.push<EvalPath>(make<PathTraverse>(make<PathLambda>(make<LambdaAbstraction>(
+ varName,
+ make<If>(generateCoerceToBoolInternal(std::move(filter)),
+ make<Variable>(varName),
+ Constant::nothing())))),
+ std::move(input));
+ }
+
+ void visit(const ExpressionFloor* expr) override final {
+ pushSingleArgFunctionFromTop("floor");
+ }
+
+ void visit(const ExpressionIfNull* expr) override final {
+ pushMultiArgFunctionFromTop("ifNull", 2);
+ }
+
+ void visit(const ExpressionIn* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionIndexOfArray* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionIndexOfBytes* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionIndexOfCP* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionIsNumber* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionLet* expr) override final {
+ unsupportedExpression("$let");
+ }
+
+ void visit(const ExpressionLn* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionLog* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionLog10* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionMap* expr) override final {
+ unsupportedExpression("$map");
+ }
+
+ void visit(const ExpressionMeta* expr) override final {
+ unsupportedExpression("$meta");
+ }
+
+ void visit(const ExpressionMod* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionMultiply* expr) override final {
+ pushArithmeticBinaryExpr(expr, Operations::Mult);
+ }
+
+ void visit(const ExpressionNot* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionObject* expr) override final {
+ const auto& expressions = expr->getChildExpressions();
+ const size_t childCount = expressions.size();
+ _ctx.ensureArity(childCount);
+
+ // Need to process in reverse order because of the stack.
+ ABTVector children;
+ for (size_t i = 0; i < childCount; i++) {
+ children.emplace_back(_ctx.pop());
+ }
+
+ sbe::value::Object object;
+ for (size_t i = 0; i < childCount; i++) {
+ ABT& child = children.at(childCount - i - 1);
+ uassert(
+ 6624345, "Only constants are supported as object fields.", child.is<Constant>());
+
+ auto [tag, val] = child.cast<Constant>()->get();
+ // Copy the value before inserting into the object
+ auto [tagCopy, valCopy] = sbe::value::copyValue(tag, val);
+ object.push_back(expressions.at(i).first, tagCopy, valCopy);
+ }
+
+ auto [tag, val] = sbe::value::makeCopyObject(object);
+ _ctx.push<Constant>(tag, val);
+ }
+
+ void visit(const ExpressionOr* expr) override final {
+ visitMultiBranchLogicExpression(expr, Operations::Or);
+ }
+
+ void visit(const ExpressionPow* expr) override final {
+ unsupportedExpression("$pow");
+ }
+
+ void visit(const ExpressionRange* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionReduce* expr) override final {
+ unsupportedExpression("$reduce");
+ }
+
+ void visit(const ExpressionReplaceOne* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionReplaceAll* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionSetDifference* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionSetEquals* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionSetIntersection* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionSetIsSubset* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionSetUnion* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionSize* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionReverseArray* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionSortArray* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionSlice* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionIsArray* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionRound* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionSplit* expr) override final {
+ pushMultiArgFunctionFromTop("split", expr->getChildren().size());
+ }
+
+ void visit(const ExpressionSqrt* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionStrcasecmp* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionSubstrBytes* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionSubstrCP* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionStrLenBytes* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionBinarySize* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionStrLenCP* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionSubtract* expr) override final {
+ pushArithmeticBinaryExpr(expr, Operations::Sub);
+ }
+
+ void visit(const ExpressionSwitch* expr) override final {
+ const size_t arity = expr->getChildren().size();
+ _ctx.ensureArity(arity);
+ const size_t numCases = (arity - 1) / 2;
+
+ ABTVector children;
+ for (size_t i = 0; i < numCases; i++) {
+ children.emplace_back(generateCoerceToBoolPopInput());
+ children.emplace_back(_ctx.pop());
+ }
+
+ if (expr->getChildren().back() != nullptr) {
+ children.push_back(_ctx.pop());
+ }
+
+ _ctx.push<FunctionCall>("switch", std::move(children));
+ }
+
+ void visit(const ExpressionTestApiVersion* expr) override final {
+ unsupportedExpression("$_testApiVersion");
+ }
+
+ void visit(const ExpressionToLower* expr) override final {
+ pushSingleArgFunctionFromTop("toLower");
+ }
+
+ void visit(const ExpressionToUpper* expr) override final {
+ pushSingleArgFunctionFromTop("toUpper");
+ }
+
+ void visit(const ExpressionTrim* expr) override final {
+ unsupportedExpression("$trim");
+ }
+
+ void visit(const ExpressionTrunc* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionType* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionZip* expr) override final {
+ unsupportedExpression("$zip");
+ }
+
+ void visit(const ExpressionConvert* expr) override final {
+ unsupportedExpression("$convert");
+ }
+
+ void visit(const ExpressionRegexFind* expr) override final {
+ unsupportedExpression("$regexFind");
+ }
+
+ void visit(const ExpressionRegexFindAll* expr) override final {
+ unsupportedExpression("$regexFindAll");
+ }
+
+ void visit(const ExpressionRegexMatch* expr) override final {
+ unsupportedExpression("$regexMatch");
+ }
+
+ void visit(const ExpressionCosine* expr) override final {
+ pushSingleArgFunctionFromTop("cosine");
+ }
+
+ void visit(const ExpressionSine* expr) override final {
+ pushSingleArgFunctionFromTop("sine");
+ }
+
+ void visit(const ExpressionTangent* expr) override final {
+ pushSingleArgFunctionFromTop("tangent");
+ }
+
+ void visit(const ExpressionArcCosine* expr) override final {
+ pushSingleArgFunctionFromTop("arcCosine");
+ }
+
+ void visit(const ExpressionArcSine* expr) override final {
+ pushSingleArgFunctionFromTop("arcSine");
+ }
+
+ void visit(const ExpressionArcTangent* expr) override final {
+ pushSingleArgFunctionFromTop("arcTangent");
+ }
+
+ void visit(const ExpressionArcTangent2* expr) override final {
+ pushSingleArgFunctionFromTop("arcTangent2");
+ }
+
+ void visit(const ExpressionHyperbolicArcTangent* expr) override final {
+ pushSingleArgFunctionFromTop("arcTangentH");
+ }
+
+ void visit(const ExpressionHyperbolicArcCosine* expr) override final {
+ pushSingleArgFunctionFromTop("arcCosineH");
+ }
+
+ void visit(const ExpressionHyperbolicArcSine* expr) override final {
+ pushSingleArgFunctionFromTop("arcSineH");
+ }
+
+ void visit(const ExpressionHyperbolicTangent* expr) override final {
+ pushSingleArgFunctionFromTop("tangentH");
+ }
+
+ void visit(const ExpressionHyperbolicCosine* expr) override final {
+ pushSingleArgFunctionFromTop("cosineH");
+ }
+
+ void visit(const ExpressionHyperbolicSine* expr) override final {
+ pushSingleArgFunctionFromTop("sineH");
+ }
+
+ void visit(const ExpressionDegreesToRadians* expr) override final {
+ pushSingleArgFunctionFromTop("degreesToRadians");
+ }
+
+ void visit(const ExpressionRadiansToDegrees* expr) override final {
+ pushSingleArgFunctionFromTop("radiansToDegrees");
+ }
+
+ void visit(const ExpressionDayOfMonth* expr) override final {
+ unsupportedExpression("$dayOfMonth");
+ }
+
+ void visit(const ExpressionDayOfWeek* expr) override final {
+ unsupportedExpression("$dayOfWeek");
+ }
+
+ void visit(const ExpressionDayOfYear* expr) override final {
+ unsupportedExpression("$dayOfYear");
+ }
+
+ void visit(const ExpressionHour* expr) override final {
+ unsupportedExpression("$hour");
+ }
+
+ void visit(const ExpressionMillisecond* expr) override final {
+ unsupportedExpression("$millisecond");
+ }
+
+ void visit(const ExpressionMinute* expr) override final {
+ unsupportedExpression("$minute");
+ }
+
+ void visit(const ExpressionMonth* expr) override final {
+ unsupportedExpression("$month");
+ }
+
+ void visit(const ExpressionSecond* expr) override final {
+ unsupportedExpression("$second");
+ }
+
+ void visit(const ExpressionWeek* expr) override final {
+ unsupportedExpression("$week");
+ }
+
+ void visit(const ExpressionIsoWeekYear* expr) override final {
+ unsupportedExpression("$isoWeekYear");
+ }
+
+ void visit(const ExpressionIsoDayOfWeek* expr) override final {
+ unsupportedExpression("$isoDayOfWeek");
+ }
+
+ void visit(const ExpressionIsoWeek* expr) override final {
+ unsupportedExpression("$isoWeek");
+ }
+
+ void visit(const ExpressionYear* expr) override final {
+ unsupportedExpression("$year");
+ }
+
+ void visit(const ExpressionFromAccumulator<AccumulatorAvg>* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionFromAccumulatorN<AccumulatorFirstN>* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionFromAccumulatorN<AccumulatorLastN>* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionFromAccumulator<AccumulatorMax>* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionFromAccumulator<AccumulatorMin>* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionFromAccumulatorN<AccumulatorMaxN>* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionFromAccumulatorN<AccumulatorMinN>* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionFromAccumulator<AccumulatorStdDevPop>* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionFromAccumulator<AccumulatorStdDevSamp>* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionFromAccumulator<AccumulatorSum>* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionFromAccumulator<AccumulatorMergeObjects>* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionTests::Testable* expr) override final {
+ unsupportedExpression("$test");
+ }
+
+ void visit(const ExpressionInternalJsEmit* expr) override final {
+ unsupportedExpression("$internalJsEmit");
+ }
+
+ void visit(const ExpressionInternalFindSlice* expr) override final {
+ unsupportedExpression("$internalFindSlice");
+ }
+
+ void visit(const ExpressionInternalFindPositional* expr) override final {
+ unsupportedExpression("$internalFindPositional");
+ }
+
+ void visit(const ExpressionInternalFindElemMatch* expr) override final {
+ unsupportedExpression("$internalFindElemMatch");
+ }
+
+ void visit(const ExpressionFunction* expr) override final {
+ unsupportedExpression("$function");
+ }
+
+ void visit(const ExpressionRandom* expr) override final {
+ unsupportedExpression(expr->getOpName());
+ }
+
+ void visit(const ExpressionToHashedIndexKey* expr) override final {
+ unsupportedExpression("$toHashedIndexKey");
+ }
+
+ void visit(const ExpressionDateAdd* expr) override final {
+ unsupportedExpression("dateAdd");
+ }
+
+ void visit(const ExpressionDateSubtract* expr) override final {
+ unsupportedExpression("dateSubtract");
+ }
+
+ void visit(const ExpressionSetField* expr) override final {
+ unsupportedExpression("$setField");
+ }
+
+ void visit(const ExpressionGetField* expr) override final {
+ unsupportedExpression("$getField");
+ }
+
+ void visit(const ExpressionTsSecond* expr) override final {
+ unsupportedExpression("tsSecond");
+ }
+
+ void visit(const ExpressionTsIncrement* expr) override final {
+ unsupportedExpression("tsIncrement");
+ }
+
+private:
+ /**
+ * Shared logic for $and, $or. Converts each child into an EExpression that evaluates to Boolean
+ * true or false, based on MQL rules for $and and $or branches, and then chains the branches
+ * together using binary and/or EExpressions so that the result has MQL's short-circuit
+ * semantics.
+ */
+ void visitMultiBranchLogicExpression(const Expression* expr, Operations logicOp) {
+ invariant(logicOp == Operations::And || logicOp == Operations::Or);
+ const size_t arity = expr->getChildren().size();
+ _ctx.ensureArity(arity);
+
+ if (arity == 0) {
+ // Empty $and and $or always evaluate to their logical operator's identity value: true
+ // and false, respectively.
+ const bool logicIdentityVal = (logicOp == Operations::And);
+ _ctx.push<Constant>(sbe::value::TypeTags::Boolean,
+ sbe::value::bitcastFrom<bool>(logicIdentityVal));
+ return;
+ }
+
+ ABT current = generateCoerceToBoolPopInput();
+ for (size_t i = 0; i < arity - 1; i++) {
+ current = make<BinaryOp>(logicOp, std::move(current), generateCoerceToBoolPopInput());
+ }
+ _ctx.push(std::move(current));
+ }
+
+ void pushMultiArgFunctionFromTop(const std::string& functionName, const size_t argCount) {
+ _ctx.ensureArity(argCount);
+
+ ABTVector children;
+ for (size_t i = 0; i < argCount; i++) {
+ children.emplace_back(_ctx.pop());
+ }
+ _ctx.push<FunctionCall>(functionName, children);
+ }
+
+ void pushSingleArgFunctionFromTop(const std::string& functionName) {
+ pushMultiArgFunctionFromTop(functionName, 1);
+ }
+
+ void pushArithmeticBinaryExpr(const Expression* expr, const Operations op) {
+ const size_t arity = expr->getChildren().size();
+ _ctx.ensureArity(arity);
+ if (arity < 2) {
+ // Nothing to do for arity 0 and 1.
+ return;
+ }
+
+ ABT current = _ctx.pop();
+ for (size_t i = 0; i < arity - 1; i++) {
+ current = make<BinaryOp>(op, std::move(current), _ctx.pop());
+ }
+ _ctx.push(std::move(current));
+ }
+
+ ABT generateCoerceToBoolInternal(ABT input) {
+ return generateCoerceToBool(std::move(input), getNextId("coerceToBool"));
+ }
+
+ ABT generateCoerceToBoolPopInput() {
+ return generateCoerceToBoolInternal(_ctx.pop());
+ }
+
+ std::string generateVariableName(const Variables::Id varId) {
+ std::ostringstream os;
+ os << _ctx.getUniqueIdPrefix() << "_var_" << varId;
+ return os.str();
+ }
+
+ std::string getNextId(const std::string& key) {
+ return _ctx.getUniqueIdPrefix() + "_" + _prefixId.getNextId(key);
+ }
+
+ void unsupportedExpression(const char* op) const {
+ uasserted(ErrorCodes::InternalErrorNotSupported,
+ str::stream() << "Expression is not supported: " << op);
+ }
+
+ PrefixId _prefixId;
+
+ // We don't own this.
+ ExpressionAlgebrizerContext& _ctx;
+};
+
+class AggExpressionWalker final {
+public:
+ AggExpressionWalker(ABTAggExpressionVisitor* visitor) : _visitor{visitor} {}
+
+ void postVisit(const Expression* expr) {
+ expr->acceptVisitor(_visitor);
+ }
+
+private:
+ ABTAggExpressionVisitor* _visitor;
+};
+
+ABT generateAggExpression(const Expression* expr,
+ const std::string& rootProjection,
+ const std::string& uniqueIdPrefix) {
+ ExpressionAlgebrizerContext ctx(
+ true /*assertExprSort*/, false /*assertPathSort*/, rootProjection, uniqueIdPrefix);
+ ABTAggExpressionVisitor visitor(ctx);
+
+ AggExpressionWalker walker(&visitor);
+ expression_walker::walk<const Expression>(expr, &walker);
+ return ctx.pop();
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/pipeline/abt/agg_expression_visitor.h b/src/mongo/db/pipeline/abt/agg_expression_visitor.h
new file mode 100644
index 00000000000..b0300dceb63
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/agg_expression_visitor.h
@@ -0,0 +1,42 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/operation_context.h"
+#include "mongo/db/pipeline/expression.h"
+#include "mongo/db/query/optimizer/node.h"
+
+namespace mongo::optimizer {
+
+ABT generateAggExpression(const Expression* expr,
+ const std::string& rootProjection,
+ const std::string& uniqueIdPrefix);
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/pipeline/abt/expr_algebrizer_context.cpp b/src/mongo/db/pipeline/abt/expr_algebrizer_context.cpp
new file mode 100644
index 00000000000..af838a74b06
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/expr_algebrizer_context.cpp
@@ -0,0 +1,73 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/pipeline/abt/expr_algebrizer_context.h"
+
+namespace mongo::optimizer {
+
+ExpressionAlgebrizerContext::ExpressionAlgebrizerContext(const bool assertExprSort,
+ const bool assertPathSort,
+ const std::string& rootProjection,
+ const std::string& uniqueIdPrefix)
+ : _assertExprSort(assertExprSort),
+ _assertPathSort(assertPathSort),
+ _rootProjection(rootProjection),
+ _uniqueIdPrefix(uniqueIdPrefix) {}
+
+void ExpressionAlgebrizerContext::push(ABT node) {
+ if (_assertExprSort) {
+ assertExprSort(node);
+ } else if (_assertPathSort) {
+ assertPathSort(node);
+ }
+
+ _stack.emplace(node);
+}
+
+ABT ExpressionAlgebrizerContext::pop() {
+ uassert(6624428, "Arity violation", !_stack.empty());
+
+ ABT node = _stack.top();
+ _stack.pop();
+ return node;
+}
+
+void ExpressionAlgebrizerContext::ensureArity(const size_t arity) {
+ uassert(6624429, "Arity violation", _stack.size() >= arity);
+}
+
+const std::string& ExpressionAlgebrizerContext::getRootProjection() const {
+ return _rootProjection;
+}
+
+const std::string& ExpressionAlgebrizerContext::getUniqueIdPrefix() const {
+ return _uniqueIdPrefix;
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/pipeline/abt/expr_algebrizer_context.h b/src/mongo/db/pipeline/abt/expr_algebrizer_context.h
new file mode 100644
index 00000000000..7a04da91da1
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/expr_algebrizer_context.h
@@ -0,0 +1,70 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <stack>
+
+#include "mongo/db/query/optimizer/node.h"
+
+namespace mongo::optimizer {
+
+class ExpressionAlgebrizerContext {
+public:
+ ExpressionAlgebrizerContext(bool assertExprSort,
+ bool assertPathSort,
+ const std::string& rootProjection,
+ const std::string& uniqueIdPrefix);
+
+ template <typename T, typename... Args>
+ inline auto push(Args&&... args) {
+ push(std::move(ABT::make<T>(std::forward<Args>(args)...)));
+ }
+
+ void push(ABT node);
+
+ ABT pop();
+
+ void ensureArity(size_t arity);
+
+ const std::string& getRootProjection() const;
+
+ const std::string& getUniqueIdPrefix() const;
+
+private:
+ const bool _assertExprSort;
+ const bool _assertPathSort;
+
+ const std::string _rootProjection;
+ const std::string _uniqueIdPrefix;
+
+ std::stack<ABT> _stack;
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/pipeline/abt/match_expression_visitor.cpp b/src/mongo/db/pipeline/abt/match_expression_visitor.cpp
new file mode 100644
index 00000000000..445c2e719ad
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/match_expression_visitor.cpp
@@ -0,0 +1,479 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/pipeline/abt/match_expression_visitor.h"
+#include "mongo/db/matcher/expression_always_boolean.h"
+#include "mongo/db/matcher/expression_array.h"
+#include "mongo/db/matcher/expression_expr.h"
+#include "mongo/db/matcher/expression_geo.h"
+#include "mongo/db/matcher/expression_internal_bucket_geo_within.h"
+#include "mongo/db/matcher/expression_internal_expr_comparison.h"
+#include "mongo/db/matcher/expression_leaf.h"
+#include "mongo/db/matcher/expression_text.h"
+#include "mongo/db/matcher/expression_text_noop.h"
+#include "mongo/db/matcher/expression_tree.h"
+#include "mongo/db/matcher/expression_type.h"
+#include "mongo/db/matcher/expression_visitor.h"
+#include "mongo/db/matcher/expression_where.h"
+#include "mongo/db/matcher/expression_where_noop.h"
+#include "mongo/db/matcher/match_expression_walker.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_all_elem_match_from_index.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_allowed_properties.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_cond.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_eq.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_fmod.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_match_array_index.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_max_items.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_max_length.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_max_properties.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_min_items.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_min_length.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_min_properties.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_object_match.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_root_doc_eq.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_unique_items.h"
+#include "mongo/db/matcher/schema/expression_internal_schema_xor.h"
+#include "mongo/db/pipeline/abt/agg_expression_visitor.h"
+#include "mongo/db/pipeline/abt/expr_algebrizer_context.h"
+#include "mongo/db/pipeline/abt/utils.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer {
+
+class ABTMatchExpressionVisitor : public MatchExpressionConstVisitor {
+public:
+ ABTMatchExpressionVisitor(ExpressionAlgebrizerContext& ctx, const bool allowAggExpressions)
+ : _prefixId(), _allowAggExpressions(allowAggExpressions), _ctx(ctx) {}
+
+ void visit(const AlwaysFalseMatchExpression* expr) override {
+ generateBoolConstant(false);
+ }
+
+ void visit(const AlwaysTrueMatchExpression* expr) override {
+ generateBoolConstant(true);
+ }
+
+ void visit(const AndMatchExpression* expr) override {
+ visitAndOrExpression<PathComposeM, true>(expr);
+ }
+
+ void visit(const BitsAllClearMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const BitsAllSetMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const BitsAnyClearMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const BitsAnySetMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const ElemMatchObjectMatchExpression* expr) override {
+ generateElemMatch<false /*isValueElemMatch*/>(expr);
+ }
+
+ void visit(const ElemMatchValueMatchExpression* expr) override {
+ generateElemMatch<true /*isValueElemMatch*/>(expr);
+ }
+
+ void visit(const EqualityMatchExpression* expr) override {
+ generateSimpleComparison(expr, Operations::Eq);
+ }
+
+ void visit(const ExistsMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const ExprMatchExpression* expr) override {
+ uassert(6624246, "Cannot generate an agg expression in this context", _allowAggExpressions);
+
+ ABT result = generateAggExpression(
+ expr->getExpression().get(), _ctx.getRootProjection(), _ctx.getUniqueIdPrefix());
+ _ctx.push<PathConstant>(generateCoerceToBool(std::move(result), getNextId("coerceToBool")));
+ }
+
+ void visit(const GTEMatchExpression* expr) override {
+ generateSimpleComparison(expr, Operations::Gte);
+ }
+
+ void visit(const GTMatchExpression* expr) override {
+ generateSimpleComparison(expr, Operations::Gt);
+ }
+
+ void visit(const GeoMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const GeoNearMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalBucketGeoWithinMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalExprEqMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalExprGTMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalExprGTEMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalExprLTMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalExprLTEMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaAllElemMatchFromIndexMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaAllowedPropertiesMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaBinDataEncryptedTypeExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaBinDataSubTypeExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaCondMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaEqMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaFmodMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaMatchArrayIndexMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaMaxItemsMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaMaxLengthMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaMaxPropertiesMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaMinItemsMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaMinLengthMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaMinPropertiesMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaObjectMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaRootDocEqMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaTypeExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaUniqueItemsMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const InternalSchemaXorMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const LTEMatchExpression* expr) override {
+ generateSimpleComparison(expr, Operations::Lte);
+ }
+
+ void visit(const LTMatchExpression* expr) override {
+ generateSimpleComparison(expr, Operations::Lt);
+ }
+
+ void visit(const ModMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const NorMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const NotMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const OrMatchExpression* expr) override {
+ visitAndOrExpression<PathComposeA, false>(expr);
+ }
+
+ void visit(const RegexMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const SizeMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const TextMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const TextNoOpMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const TwoDPtInAnnulusExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const TypeMatchExpression* expr) override {
+ const std::string lambdaProjName = _prefixId.getNextId("lambda_typeMatch");
+ ABT result = make<PathLambda>(make<LambdaAbstraction>(
+ lambdaProjName,
+ make<FunctionCall>("typeMatch",
+ makeSeq(make<Variable>(lambdaProjName),
+ Constant::int64(expr->typeSet().getBSONTypeMask())))));
+
+ if (!expr->path().empty()) {
+ result = generateFieldPath(FieldPath(expr->path().toString()), std::move(result));
+ }
+ _ctx.push(std::move(result));
+ }
+
+ void visit(const WhereMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+ void visit(const WhereNoOpMatchExpression* expr) override {
+ unsupportedExpression(expr);
+ }
+
+private:
+ void generateBoolConstant(const bool value) {
+ _ctx.push<PathConstant>(Constant::boolean(value));
+ }
+
+ template <bool isValueElemMatch>
+ void generateElemMatch(const ArrayMatchingMatchExpression* expr) {
+ // Returns true if at least one sub-objects matches the condition.
+
+ const size_t childCount = expr->numChildren();
+ if (childCount == 0) {
+ _ctx.push(Constant::boolean(true));
+ }
+
+ _ctx.ensureArity(childCount);
+ ABT result = _ctx.pop();
+ for (size_t i = 1; i < childCount; i++) {
+ maybeComposePath(result, _ctx.pop());
+ }
+ if constexpr (!isValueElemMatch) {
+ // Make sure we consider only objects as elements of the array.
+ maybeComposePath(result, make<PathObj>());
+ }
+ result = make<PathTraverse>(std::move(result));
+
+ // Make sure we consider only arrays fields on the path.
+ maybeComposePath(result, make<PathArr>());
+
+ if (!expr->path().empty()) {
+ result = translateFieldPath(
+ FieldPath{expr->path().toString()},
+ std::move(result),
+ [&](const std::string& fieldName, const bool isLastElement, ABT input) {
+ if (!isLastElement) {
+ input = make<PathTraverse>(std::move(input));
+ }
+ return make<PathGet>(fieldName, std::move(input));
+ });
+ }
+
+ _ctx.push(std::move(result));
+ }
+
+ /**
+ * Return the minimum or maximum value for the "class" of values represented by the input
+ * constant. Used to support type bracketing.
+ */
+ template <bool isMin>
+ std::pair<boost::optional<ABT>, bool> getMinMaxBoundForType(const sbe::value::TypeTags& tag) {
+ if (isNumber(tag)) {
+ if constexpr (isMin) {
+ return {Constant::fromDouble(std::numeric_limits<double>::quiet_NaN()), true};
+ } else {
+ return {Constant::str(""), false};
+ };
+ } else if (isStringOrSymbol(tag)) {
+ if constexpr (isMin) {
+ return {Constant::str(""), true};
+ } else {
+ // TODO: we need limit string from above.
+ return {boost::none, false};
+ }
+ } else if (tag == sbe::value::TypeTags::Null) {
+ // Same bound above and below.
+ return {Constant::null(), true};
+ } else {
+ // TODO: compute bounds for other types based on bsonobjbuilder.cpp.
+ return {boost::none, false};
+ }
+
+ MONGO_UNREACHABLE;
+ }
+
+ ABT generateFieldPath(const FieldPath& fieldPath, ABT initial) {
+ return translateFieldPath(
+ fieldPath,
+ std::move(initial),
+ [](const std::string& fieldName, const bool /*isLastElement*/, ABT input) {
+ return make<PathGet>(fieldName, make<PathTraverse>(std::move(input)));
+ });
+ }
+
+ void generateSimpleComparison(const ComparisonMatchExpressionBase* expr, const Operations op) {
+ auto [tag, val] = convertFrom(Value(expr->getData()));
+ ABT result = make<PathCompare>(op, make<Constant>(tag, val));
+
+ switch (op) {
+ case Operations::Lt:
+ case Operations::Lte: {
+ auto&& [constant, inclusive] = getMinMaxBoundForType<true /*isMin*/>(tag);
+ if (constant) {
+ maybeComposePath(result,
+ make<PathCompare>(inclusive ? Operations::Gte : Operations::Gt,
+ std::move(constant.get())));
+ }
+ break;
+ }
+
+ case Operations::Gt:
+ case Operations::Gte: {
+ auto&& [constant, inclusive] = getMinMaxBoundForType<false /*isMin*/>(tag);
+ if (constant) {
+ maybeComposePath(result,
+ make<PathCompare>(inclusive ? Operations::Lte : Operations::Lt,
+ std::move(constant.get())));
+ }
+ break;
+ }
+
+ default:
+ break;
+ }
+
+ if (!expr->path().empty()) {
+ result = generateFieldPath(FieldPath(expr->path().toString()), std::move(result));
+ }
+ _ctx.push(std::move(result));
+ }
+
+ template <class Composition, bool defaultResult>
+ void visitAndOrExpression(const ListOfMatchExpression* expr) {
+ const size_t childCount = expr->numChildren();
+ if (childCount == 0) {
+ generateBoolConstant(defaultResult);
+ return;
+ }
+ if (childCount == 1) {
+ return;
+ }
+
+ ABT node = _ctx.pop();
+ for (size_t i = 0; i < childCount - 1; i++) {
+ node = make<Composition>(_ctx.pop(), std::move(node));
+ }
+ _ctx.push(std::move(node));
+ }
+
+ std::string getNextId(const std::string& key) {
+ return _ctx.getUniqueIdPrefix() + "_" + _prefixId.getNextId(key);
+ }
+
+ void unsupportedExpression(const MatchExpression* expr) const {
+ uasserted(ErrorCodes::InternalErrorNotSupported,
+ str::stream() << "Match expression is not supported: " << expr->matchType());
+ }
+
+ PrefixId _prefixId;
+
+ // If we are parsing a partial index filter, we don't allow match expressions.
+ const bool _allowAggExpressions;
+
+ // We don't own this
+ ExpressionAlgebrizerContext& _ctx;
+};
+
+ABT generateMatchExpression(const MatchExpression* expr,
+ const bool allowAggExpressions,
+ const std::string& rootProjection,
+ const std::string& uniqueIdPrefix) {
+ ExpressionAlgebrizerContext ctx(
+ false /*assertExprSort*/, true /*assertPathSort*/, rootProjection, uniqueIdPrefix);
+ ABTMatchExpressionVisitor visitor(ctx, allowAggExpressions);
+ MatchExpressionWalker walker(nullptr /*preVisitor*/, nullptr /*inVisitor*/, &visitor);
+ tree_walker::walk<true, MatchExpression>(expr, &walker);
+ return ctx.pop();
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/pipeline/abt/match_expression_visitor.h b/src/mongo/db/pipeline/abt/match_expression_visitor.h
new file mode 100644
index 00000000000..d971ffd5ec6
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/match_expression_visitor.h
@@ -0,0 +1,45 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/matcher/expression.h"
+#include "mongo/db/query/optimizer/node.h"
+
+namespace mongo::optimizer {
+
+/**
+ * Returns a path encoding the match expression.
+ */
+ABT generateMatchExpression(const MatchExpression* expr,
+ bool allowAggExpressions,
+ const std::string& rootProjection,
+ const std::string& uniqueIdPrefix);
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/pipeline/abt/pipeline_test.cpp b/src/mongo/db/pipeline/abt/pipeline_test.cpp
new file mode 100644
index 00000000000..36dc2235db8
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/pipeline_test.cpp
@@ -0,0 +1,2508 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <boost/intrusive_ptr.hpp>
+
+#include "mongo/db/pipeline/aggregate_command_gen.h"
+#include "mongo/db/pipeline/pipeline.h"
+#include "mongo/db/query/optimizer/explain.h"
+#include "mongo/db/query/optimizer/opt_phase_manager.h"
+#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo {
+namespace {
+
+using namespace optimizer;
+
+TEST(ABTTranslate, SortLimitSkip) {
+ ABT translated = translatePipeline(
+ "[{$limit: 5}, "
+ "{$skip: 3}, "
+ "{$sort: {a: 1, b: -1}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | sort_0: Ascending\n"
+ "| | sort_1: Descending\n"
+ "| RefBlock: \n"
+ "| Variable [sort_0]\n"
+ "| Variable [sort_1]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [sort_1]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [b]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [sort_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathIdentity []\n"
+ "LimitSkip []\n"
+ "| limitSkip:\n"
+ "| limit: (none)\n"
+ "| skip: 3\n"
+ "LimitSkip []\n"
+ "| limitSkip:\n"
+ "| limit: 5\n"
+ "| skip: 0\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, ProjectRetain) {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+ Metadata metadata = {{{scanDefName, ScanDefinition{{}, {}}}}};
+ ABT translated = translatePipeline(
+ metadata, "[{$project: {a: 1, b: 1}}, {$match: {a: 2}}]", scanDefName, prefixId);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre,
+ OptPhaseManager::OptPhase::PathFuse,
+ OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = translated;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ // Observe the Filter can be reordered against the Eval node.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Const [{}]\n"
+ "| PathComposeM []\n"
+ "| | PathField [b]\n"
+ "| | PathConstant []\n"
+ "| | Variable [fieldProj_2]\n"
+ "| PathComposeM []\n"
+ "| | PathField [a]\n"
+ "| | PathConstant []\n"
+ "| | Variable [fieldProj_1]\n"
+ "| PathField [_id]\n"
+ "| PathConstant []\n"
+ "| Variable [fieldProj_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [fieldProj_1]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "PhysicalScan [{'_id': fieldProj_0, 'a': fieldProj_1, 'b': fieldProj_2}, collection]\n"
+ " BindBlock:\n"
+ " [fieldProj_0]\n"
+ " Source []\n"
+ " [fieldProj_1]\n"
+ " Source []\n"
+ " [fieldProj_2]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, ProjectRetain1) {
+ ABT translated = translatePipeline("[{$project: {a1: 1, a2: 1, a3: 1, a4: 1, a5: 1, a6: 1}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathKeep [_id, a1, a2, a3, a4, a5, a6]\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, AddFields) {
+ // Since '$z' is a single element, it will be considered a renamed path.
+ ABT translated = translatePipeline("[{$addFields: {a: '$z'}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathComposeM []\n"
+ "| | PathDefault []\n"
+ "| | Const [{}]\n"
+ "| PathField [a]\n"
+ "| PathConstant []\n"
+ "| Variable [projRenamedPath_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [projRenamedPath_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [z]\n"
+ "| PathIdentity []\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, ProjectRenames) {
+ // Since '$c' is a single element, it will be considered a renamed path.
+ ABT translated = translatePipeline("[{$project: {'a.b': '$c'}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathComposeM []\n"
+ "| | PathField [a]\n"
+ "| | PathTraverse []\n"
+ "| | PathComposeM []\n"
+ "| | | PathDefault []\n"
+ "| | | Const [{}]\n"
+ "| | PathComposeM []\n"
+ "| | | PathField [b]\n"
+ "| | | PathConstant []\n"
+ "| | | Variable [projRenamedPath_0]\n"
+ "| | PathKeep [b]\n"
+ "| PathKeep [_id, a]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [projRenamedPath_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [c]\n"
+ "| PathIdentity []\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, ProjectPaths) {
+ ABT translated = translatePipeline("[{$project: {'a.b.c': '$x.y.z'}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathComposeM []\n"
+ "| | PathField [a]\n"
+ "| | PathTraverse []\n"
+ "| | PathComposeM []\n"
+ "| | | PathField [b]\n"
+ "| | | PathTraverse []\n"
+ "| | | PathComposeM []\n"
+ "| | | | PathDefault []\n"
+ "| | | | Const [{}]\n"
+ "| | | PathComposeM []\n"
+ "| | | | PathField [c]\n"
+ "| | | | PathConstant []\n"
+ "| | | | Variable [projGetPath_0]\n"
+ "| | | PathKeep [c]\n"
+ "| | PathKeep [b]\n"
+ "| PathKeep [_id, a]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [projGetPath_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [x]\n"
+ "| PathTraverse []\n"
+ "| PathGet [y]\n"
+ "| PathTraverse []\n"
+ "| PathGet [z]\n"
+ "| PathIdentity []\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, ProjectPaths1) {
+ ABT translated = translatePipeline("[{$project: {'a.b':1, 'a.c':1, 'b':1}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathComposeM []\n"
+ "| | PathField [a]\n"
+ "| | PathTraverse []\n"
+ "| | PathComposeM []\n"
+ "| | | PathKeep [b, c]\n"
+ "| | PathObj []\n"
+ "| PathKeep [_id, a, b]\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, ProjectInclusion) {
+ ABT translated = translatePipeline("[{$project: {a: {$add: ['$c.d', 2]}, b: 1}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathComposeM []\n"
+ "| | PathDefault []\n"
+ "| | Const [{}]\n"
+ "| PathComposeM []\n"
+ "| | PathField [a]\n"
+ "| | PathConstant []\n"
+ "| | Variable [projGetPath_0]\n"
+ "| PathKeep [_id, a, b]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [projGetPath_0]\n"
+ "| BinaryOp [Add]\n"
+ "| | EvalPath []\n"
+ "| | | Variable [scan_0]\n"
+ "| | PathGet [c]\n"
+ "| | PathTraverse []\n"
+ "| | PathGet [d]\n"
+ "| | PathIdentity []\n"
+ "| Const [2]\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, ProjectExclusion) {
+ ABT translated = translatePipeline("[{$project: {a: 0, b: 0}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathDrop [a, b]\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, ProjectReplaceRoot) {
+ ABT translated = translatePipeline("[{$replaceRoot: {newRoot: '$a'}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | newRoot_0\n"
+ "| RefBlock: \n"
+ "| Variable [newRoot_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [newRoot_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathIdentity []\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, MatchBasic) {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+
+ ABT translated = translatePipeline("[{$match: {a: 1, b: 2}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [b]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{scanDefName, ScanDefinition{{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = translated;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_1]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_0]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "PhysicalScan [{'<root>': scan_0, 'a': evalTemp_0, 'b': evalTemp_1}, collection]\n"
+ " BindBlock:\n"
+ " [evalTemp_0]\n"
+ " Source []\n"
+ " [evalTemp_1]\n"
+ " Source []\n"
+ " [scan_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, MatchPath) {
+ ABT translated = translatePipeline("[{$match: {$expr: {$eq: ['$a.b', 1]}}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [scan_0]\n"
+ "| PathConstant []\n"
+ "| Let [matchExpression_0_coerceToBool_0]\n"
+ "| | BinaryOp [And]\n"
+ "| | | BinaryOp [And]\n"
+ "| | | | BinaryOp [And]\n"
+ "| | | | | FunctionCall [exists]\n"
+ "| | | | | BinaryOp [Eq]\n"
+ "| | | | | | Const [1]\n"
+ "| | | | | EvalPath []\n"
+ "| | | | | | Variable [scan_0]\n"
+ "| | | | | PathGet [a]\n"
+ "| | | | | PathTraverse []\n"
+ "| | | | | PathGet [b]\n"
+ "| | | | | PathIdentity []\n"
+ "| | | | UnaryOp [Not]\n"
+ "| | | | FunctionCall [isNull]\n"
+ "| | | | BinaryOp [Eq]\n"
+ "| | | | | Const [1]\n"
+ "| | | | EvalPath []\n"
+ "| | | | | Variable [scan_0]\n"
+ "| | | | PathGet [a]\n"
+ "| | | | PathTraverse []\n"
+ "| | | | PathGet [b]\n"
+ "| | | | PathIdentity []\n"
+ "| | | BinaryOp [Neq]\n"
+ "| | | | Const [0]\n"
+ "| | | BinaryOp [Cmp3w]\n"
+ "| | | | Const [false]\n"
+ "| | | Variable [matchExpression_0_coerceToBool_0]\n"
+ "| | BinaryOp [Neq]\n"
+ "| | | Const [0]\n"
+ "| | BinaryOp [Cmp3w]\n"
+ "| | | Const [0]\n"
+ "| | Variable [matchExpression_0_coerceToBool_0]\n"
+ "| BinaryOp [Eq]\n"
+ "| | Const [1]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathGet [b]\n"
+ "| PathIdentity []\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, ElemMatchPath) {
+ ABT translated = translatePipeline(
+ "[{$project: {a: {$literal: [1, 2, 3, 4]}}}, {$match: {a: {$elemMatch: {$gte: 2, $lte: "
+ "3}}}}]");
+
+ // Observe type bracketing in the filter.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [combinedProjection_0]\n"
+ "| PathGet [a]\n"
+ "| PathComposeM []\n"
+ "| | PathArr []\n"
+ "| PathTraverse []\n"
+ "| PathComposeM []\n"
+ "| | PathComposeM []\n"
+ "| | | PathCompare [Lt]\n"
+ "| | | Const [\"\"]\n"
+ "| | PathCompare [Gte]\n"
+ "| | Const [2]\n"
+ "| PathComposeM []\n"
+ "| | PathCompare [Gte]\n"
+ "| | Const [nan]\n"
+ "| PathCompare [Lte]\n"
+ "| Const [3]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathComposeM []\n"
+ "| | PathDefault []\n"
+ "| | Const [{}]\n"
+ "| PathComposeM []\n"
+ "| | PathField [a]\n"
+ "| | PathConstant []\n"
+ "| | Variable [projGetPath_0]\n"
+ "| PathKeep [_id, a]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [projGetPath_0]\n"
+ "| Const [[1, 2, 3, 4]]\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, MatchProject) {
+ ABT translated = translatePipeline(
+ "[{$project: {s: {$add: ['$a', '$b']}, c: 1}}, "
+ "{$match: {$or: [{c: 2}, {s: {$gte: 10}}]}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [combinedProjection_0]\n"
+ "| PathComposeA []\n"
+ "| | PathGet [s]\n"
+ "| | PathTraverse []\n"
+ "| | PathComposeM []\n"
+ "| | | PathCompare [Lt]\n"
+ "| | | Const [\"\"]\n"
+ "| | PathCompare [Gte]\n"
+ "| | Const [10]\n"
+ "| PathGet [c]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathComposeM []\n"
+ "| | PathDefault []\n"
+ "| | Const [{}]\n"
+ "| PathComposeM []\n"
+ "| | PathField [s]\n"
+ "| | PathConstant []\n"
+ "| | Variable [projGetPath_0]\n"
+ "| PathKeep [_id, c, s]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [projGetPath_0]\n"
+ "| BinaryOp [Add]\n"
+ "| | EvalPath []\n"
+ "| | | Variable [scan_0]\n"
+ "| | PathGet [a]\n"
+ "| | PathIdentity []\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [b]\n"
+ "| PathIdentity []\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, ProjectComplex) {
+ ABT translated = translatePipeline("[{$project: {'a1.b.c':1, 'a.b.c.d.e':'str'}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathComposeM []\n"
+ "| | PathField [a1]\n"
+ "| | PathTraverse []\n"
+ "| | PathComposeM []\n"
+ "| | | PathField [b]\n"
+ "| | | PathTraverse []\n"
+ "| | | PathComposeM []\n"
+ "| | | | PathKeep [c]\n"
+ "| | | PathObj []\n"
+ "| | PathComposeM []\n"
+ "| | | PathKeep [b]\n"
+ "| | PathObj []\n"
+ "| PathComposeM []\n"
+ "| | PathField [a]\n"
+ "| | PathTraverse []\n"
+ "| | PathComposeM []\n"
+ "| | | PathField [b]\n"
+ "| | | PathTraverse []\n"
+ "| | | PathComposeM []\n"
+ "| | | | PathField [c]\n"
+ "| | | | PathTraverse []\n"
+ "| | | | PathComposeM []\n"
+ "| | | | | PathField [d]\n"
+ "| | | | | PathTraverse []\n"
+ "| | | | | PathComposeM []\n"
+ "| | | | | | PathDefault []\n"
+ "| | | | | | Const [{}]\n"
+ "| | | | | PathComposeM []\n"
+ "| | | | | | PathField [e]\n"
+ "| | | | | | PathConstant []\n"
+ "| | | | | | Variable [projGetPath_0]\n"
+ "| | | | | PathKeep [e]\n"
+ "| | | | PathKeep [d]\n"
+ "| | | PathKeep [c]\n"
+ "| | PathKeep [b]\n"
+ "| PathKeep [_id, a, a1]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [projGetPath_0]\n"
+ "| Const [\"str\"]\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, ExprFilter) {
+ ABT translated = translatePipeline(
+ "[{$project: {a: {$filter: {input: [1, 2, 'str', {a: 2.0, b:'s'}, 3, 4], as: 'num', cond: "
+ "{$and: [{$gte: ['$$num', 2]}, {$lte: ['$$num', 3]}]}}}}}]");
+
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre},
+ prefixId,
+ {{{scanDefName, ScanDefinition{{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = translated;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ // Make sure we have a single array constant for (1, 2, 'str', ...).
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathComposeM []\n"
+ "| | PathDefault []\n"
+ "| | Const [{}]\n"
+ "| PathComposeM []\n"
+ "| | PathField [a]\n"
+ "| | PathConstant []\n"
+ "| | EvalPath []\n"
+ "| | | Const [[1, 2, \"str\", {\"a\" : 2, \"b\" : \"s\"}, 3, 4]]\n"
+ "| | PathTraverse []\n"
+ "| | PathLambda []\n"
+ "| | LambdaAbstraction [projGetPath_0_var_1]\n"
+ "| | If []\n"
+ "| | | | Const [Nothing]\n"
+ "| | | Variable [projGetPath_0_var_1]\n"
+ "| | Let [projGetPath_0_coerceToBool_2]\n"
+ "| | | BinaryOp [And]\n"
+ "| | | | BinaryOp [And]\n"
+ "| | | | | BinaryOp [And]\n"
+ "| | | | | | FunctionCall [exists]\n"
+ "| | | | | | BinaryOp [And]\n"
+ "| | | | | | | Let [projGetPath_0_coerceToBool_1]\n"
+ "| | | | | | | | BinaryOp [And]\n"
+ "| | | | | | | | | BinaryOp [And]\n"
+ "| | | | | | | | | | BinaryOp [And]\n"
+ "| | | | | | | | | | | FunctionCall [exists]\n"
+ "| | | | | | | | | | | BinaryOp [Gte]\n"
+ "| | | | | | | | | | | | Const [2]\n"
+ "| | | | | | | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | | | | | | UnaryOp [Not]\n"
+ "| | | | | | | | | | FunctionCall [isNull]\n"
+ "| | | | | | | | | | BinaryOp [Gte]\n"
+ "| | | | | | | | | | | Const [2]\n"
+ "| | | | | | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | | | | | BinaryOp [Neq]\n"
+ "| | | | | | | | | | Const [0]\n"
+ "| | | | | | | | | BinaryOp [Cmp3w]\n"
+ "| | | | | | | | | | Const [false]\n"
+ "| | | | | | | | | Variable [projGetPath_0_coerceToBool_1]\n"
+ "| | | | | | | | BinaryOp [Neq]\n"
+ "| | | | | | | | | Const [0]\n"
+ "| | | | | | | | BinaryOp [Cmp3w]\n"
+ "| | | | | | | | | Const [0]\n"
+ "| | | | | | | | Variable [projGetPath_0_coerceToBool_1]\n"
+ "| | | | | | | BinaryOp [Gte]\n"
+ "| | | | | | | | Const [2]\n"
+ "| | | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | | Let [projGetPath_0_coerceToBool_0]\n"
+ "| | | | | | | BinaryOp [And]\n"
+ "| | | | | | | | BinaryOp [And]\n"
+ "| | | | | | | | | BinaryOp [And]\n"
+ "| | | | | | | | | | FunctionCall [exists]\n"
+ "| | | | | | | | | | BinaryOp [Lte]\n"
+ "| | | | | | | | | | | Const [3]\n"
+ "| | | | | | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | | | | | UnaryOp [Not]\n"
+ "| | | | | | | | | FunctionCall [isNull]\n"
+ "| | | | | | | | | BinaryOp [Lte]\n"
+ "| | | | | | | | | | Const [3]\n"
+ "| | | | | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | | | | BinaryOp [Neq]\n"
+ "| | | | | | | | | Const [0]\n"
+ "| | | | | | | | BinaryOp [Cmp3w]\n"
+ "| | | | | | | | | Const [false]\n"
+ "| | | | | | | | Variable [projGetPath_0_coerceToBool_0]\n"
+ "| | | | | | | BinaryOp [Neq]\n"
+ "| | | | | | | | Const [0]\n"
+ "| | | | | | | BinaryOp [Cmp3w]\n"
+ "| | | | | | | | Const [0]\n"
+ "| | | | | | | Variable [projGetPath_0_coerceToBool_0]\n"
+ "| | | | | | BinaryOp [Lte]\n"
+ "| | | | | | | Const [3]\n"
+ "| | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | UnaryOp [Not]\n"
+ "| | | | | FunctionCall [isNull]\n"
+ "| | | | | BinaryOp [And]\n"
+ "| | | | | | Let [projGetPath_0_coerceToBool_1]\n"
+ "| | | | | | | BinaryOp [And]\n"
+ "| | | | | | | | BinaryOp [And]\n"
+ "| | | | | | | | | BinaryOp [And]\n"
+ "| | | | | | | | | | FunctionCall [exists]\n"
+ "| | | | | | | | | | BinaryOp [Gte]\n"
+ "| | | | | | | | | | | Const [2]\n"
+ "| | | | | | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | | | | | UnaryOp [Not]\n"
+ "| | | | | | | | | FunctionCall [isNull]\n"
+ "| | | | | | | | | BinaryOp [Gte]\n"
+ "| | | | | | | | | | Const [2]\n"
+ "| | | | | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | | | | BinaryOp [Neq]\n"
+ "| | | | | | | | | Const [0]\n"
+ "| | | | | | | | BinaryOp [Cmp3w]\n"
+ "| | | | | | | | | Const [false]\n"
+ "| | | | | | | | Variable [projGetPath_0_coerceToBool_1]\n"
+ "| | | | | | | BinaryOp [Neq]\n"
+ "| | | | | | | | Const [0]\n"
+ "| | | | | | | BinaryOp [Cmp3w]\n"
+ "| | | | | | | | Const [0]\n"
+ "| | | | | | | Variable [projGetPath_0_coerceToBool_1]\n"
+ "| | | | | | BinaryOp [Gte]\n"
+ "| | | | | | | Const [2]\n"
+ "| | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | Let [projGetPath_0_coerceToBool_0]\n"
+ "| | | | | | BinaryOp [And]\n"
+ "| | | | | | | BinaryOp [And]\n"
+ "| | | | | | | | BinaryOp [And]\n"
+ "| | | | | | | | | FunctionCall [exists]\n"
+ "| | | | | | | | | BinaryOp [Lte]\n"
+ "| | | | | | | | | | Const [3]\n"
+ "| | | | | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | | | | UnaryOp [Not]\n"
+ "| | | | | | | | FunctionCall [isNull]\n"
+ "| | | | | | | | BinaryOp [Lte]\n"
+ "| | | | | | | | | Const [3]\n"
+ "| | | | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | | | BinaryOp [Neq]\n"
+ "| | | | | | | | Const [0]\n"
+ "| | | | | | | BinaryOp [Cmp3w]\n"
+ "| | | | | | | | Const [false]\n"
+ "| | | | | | | Variable [projGetPath_0_coerceToBool_0]\n"
+ "| | | | | | BinaryOp [Neq]\n"
+ "| | | | | | | Const [0]\n"
+ "| | | | | | BinaryOp [Cmp3w]\n"
+ "| | | | | | | Const [0]\n"
+ "| | | | | | Variable [projGetPath_0_coerceToBool_0]\n"
+ "| | | | | BinaryOp [Lte]\n"
+ "| | | | | | Const [3]\n"
+ "| | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | BinaryOp [Neq]\n"
+ "| | | | | Const [0]\n"
+ "| | | | BinaryOp [Cmp3w]\n"
+ "| | | | | Const [false]\n"
+ "| | | | Variable [projGetPath_0_coerceToBool_2]\n"
+ "| | | BinaryOp [Neq]\n"
+ "| | | | Const [0]\n"
+ "| | | BinaryOp [Cmp3w]\n"
+ "| | | | Const [0]\n"
+ "| | | Variable [projGetPath_0_coerceToBool_2]\n"
+ "| | BinaryOp [And]\n"
+ "| | | Let [projGetPath_0_coerceToBool_1]\n"
+ "| | | | BinaryOp [And]\n"
+ "| | | | | BinaryOp [And]\n"
+ "| | | | | | BinaryOp [And]\n"
+ "| | | | | | | FunctionCall [exists]\n"
+ "| | | | | | | BinaryOp [Gte]\n"
+ "| | | | | | | | Const [2]\n"
+ "| | | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | | UnaryOp [Not]\n"
+ "| | | | | | FunctionCall [isNull]\n"
+ "| | | | | | BinaryOp [Gte]\n"
+ "| | | | | | | Const [2]\n"
+ "| | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | BinaryOp [Neq]\n"
+ "| | | | | | Const [0]\n"
+ "| | | | | BinaryOp [Cmp3w]\n"
+ "| | | | | | Const [false]\n"
+ "| | | | | Variable [projGetPath_0_coerceToBool_1]\n"
+ "| | | | BinaryOp [Neq]\n"
+ "| | | | | Const [0]\n"
+ "| | | | BinaryOp [Cmp3w]\n"
+ "| | | | | Const [0]\n"
+ "| | | | Variable [projGetPath_0_coerceToBool_1]\n"
+ "| | | BinaryOp [Gte]\n"
+ "| | | | Const [2]\n"
+ "| | | Variable [projGetPath_0_var_1]\n"
+ "| | Let [projGetPath_0_coerceToBool_0]\n"
+ "| | | BinaryOp [And]\n"
+ "| | | | BinaryOp [And]\n"
+ "| | | | | BinaryOp [And]\n"
+ "| | | | | | FunctionCall [exists]\n"
+ "| | | | | | BinaryOp [Lte]\n"
+ "| | | | | | | Const [3]\n"
+ "| | | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | | UnaryOp [Not]\n"
+ "| | | | | FunctionCall [isNull]\n"
+ "| | | | | BinaryOp [Lte]\n"
+ "| | | | | | Const [3]\n"
+ "| | | | | Variable [projGetPath_0_var_1]\n"
+ "| | | | BinaryOp [Neq]\n"
+ "| | | | | Const [0]\n"
+ "| | | | BinaryOp [Cmp3w]\n"
+ "| | | | | Const [false]\n"
+ "| | | | Variable [projGetPath_0_coerceToBool_0]\n"
+ "| | | BinaryOp [Neq]\n"
+ "| | | | Const [0]\n"
+ "| | | BinaryOp [Cmp3w]\n"
+ "| | | | Const [0]\n"
+ "| | | Variable [projGetPath_0_coerceToBool_0]\n"
+ "| | BinaryOp [Lte]\n"
+ "| | | Const [3]\n"
+ "| | Variable [projGetPath_0_var_1]\n"
+ "| PathKeep [_id, a]\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, GroupBasic) {
+ ABT translated =
+ translatePipeline("[{$group: {_id: '$a.b', s: {$sum: {$multiply: ['$b', '$c']}}}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | agg_project_0\n"
+ "| RefBlock: \n"
+ "| Variable [agg_project_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [agg_project_0]\n"
+ "| EvalPath []\n"
+ "| | Const [{}]\n"
+ "| PathComposeM []\n"
+ "| | PathField [s]\n"
+ "| | PathConstant []\n"
+ "| | Variable [s_agg_0]\n"
+ "| PathField [_id]\n"
+ "| PathConstant []\n"
+ "| Variable [groupByProj_0]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [groupByProj_0]\n"
+ "| aggregations: \n"
+ "| [s_agg_0]\n"
+ "| FunctionCall [$sum]\n"
+ "| Variable [groupByInputProj_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [groupByInputProj_0]\n"
+ "| BinaryOp [Mult]\n"
+ "| | EvalPath []\n"
+ "| | | Variable [scan_0]\n"
+ "| | PathGet [b]\n"
+ "| | PathIdentity []\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [c]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [groupByProj_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathGet [b]\n"
+ "| PathIdentity []\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, GroupLocalGlobal) {
+ ABT translated = translatePipeline("[{$group: {_id: '$a', c: {$sum: '$b'}}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | agg_project_0\n"
+ "| RefBlock: \n"
+ "| Variable [agg_project_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [agg_project_0]\n"
+ "| EvalPath []\n"
+ "| | Const [{}]\n"
+ "| PathComposeM []\n"
+ "| | PathField [c]\n"
+ "| | PathConstant []\n"
+ "| | Variable [c_agg_0]\n"
+ "| PathField [_id]\n"
+ "| PathConstant []\n"
+ "| Variable [groupByProj_0]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [groupByProj_0]\n"
+ "| aggregations: \n"
+ "| [c_agg_0]\n"
+ "| FunctionCall [$sum]\n"
+ "| Variable [groupByInputProj_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [groupByInputProj_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [b]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [groupByProj_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathIdentity []\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{scanDefName, ScanDefinition{{}, {}, {DistributionType::UnknownPartitioning}}}},
+ 5 /*numberOfPartitions*/},
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = std::move(translated);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | agg_project_0\n"
+ "| RefBlock: \n"
+ "| Variable [agg_project_0]\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| RefBlock: \n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [agg_project_0]\n"
+ "| EvalPath []\n"
+ "| | Const [{}]\n"
+ "| PathComposeM []\n"
+ "| | PathField [c]\n"
+ "| | PathConstant []\n"
+ "| | Variable [c_agg_0]\n"
+ "| PathField [_id]\n"
+ "| PathConstant []\n"
+ "| Variable [groupByProj_0]\n"
+ "GroupBy [Global]\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [groupByProj_0]\n"
+ "| aggregations: \n"
+ "| [c_agg_0]\n"
+ "| FunctionCall [$sum]\n"
+ "| Variable [preagg_0]\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: HashPartitioning\n"
+ "| | projections: \n"
+ "| | groupByProj_0\n"
+ "| RefBlock: \n"
+ "| Variable [groupByProj_0]\n"
+ "GroupBy [Local]\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [groupByProj_0]\n"
+ "| aggregations: \n"
+ "| [preagg_0]\n"
+ "| FunctionCall [$sum]\n"
+ "| Variable [groupByInputProj_0]\n"
+ "PhysicalScan [{'a': groupByProj_0, 'b': groupByInputProj_0}, collection, parallel]\n"
+ " BindBlock:\n"
+ " [groupByInputProj_0]\n"
+ " Source []\n"
+ " [groupByProj_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, UnwindBasic) {
+ ABT translated = translatePipeline("[{$unwind: {path: '$a.b.c'}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | embedProj_0\n"
+ "| RefBlock: \n"
+ "| Variable [embedProj_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [embedProj_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathField [a]\n"
+ "| PathTraverse []\n"
+ "| PathField [b]\n"
+ "| PathTraverse []\n"
+ "| PathField [c]\n"
+ "| PathConstant []\n"
+ "| Variable [unwoundProj_0]\n"
+ "Unwind []\n"
+ "| BindBlock:\n"
+ "| [unwoundPid_0]\n"
+ "| Source []\n"
+ "| [unwoundProj_0]\n"
+ "| Source []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [unwoundProj_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathGet [b]\n"
+ "| PathGet [c]\n"
+ "| PathIdentity []\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, UnwindComplex) {
+ ABT translated = translatePipeline(
+ "[{$unwind: {path: '$a.b.c', includeArrayIndex: 'p1.pid', preserveNullAndEmptyArrays: "
+ "true}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | embedPidProj_0\n"
+ "| RefBlock: \n"
+ "| Variable [embedPidProj_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [embedPidProj_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [embedProj_0]\n"
+ "| PathField [p1]\n"
+ "| PathField [pid]\n"
+ "| PathConstant []\n"
+ "| If []\n"
+ "| | | Const [null]\n"
+ "| | Variable [unwoundPid_0]\n"
+ "| BinaryOp [Gte]\n"
+ "| | Const [0]\n"
+ "| Variable [unwoundPid_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [embedProj_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathField [a]\n"
+ "| PathTraverse []\n"
+ "| PathField [b]\n"
+ "| PathTraverse []\n"
+ "| PathField [c]\n"
+ "| PathLambda []\n"
+ "| LambdaAbstraction [unwoundLambdaVarName_0]\n"
+ "| If []\n"
+ "| | | Variable [unwoundLambdaVarName_0]\n"
+ "| | Variable [unwoundProj_0]\n"
+ "| BinaryOp [Gte]\n"
+ "| | Const [0]\n"
+ "| Variable [unwoundPid_0]\n"
+ "Unwind [retainNonArrays]\n"
+ "| BindBlock:\n"
+ "| [unwoundPid_0]\n"
+ "| Source []\n"
+ "| [unwoundProj_0]\n"
+ "| Source []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [unwoundProj_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathGet [b]\n"
+ "| PathGet [c]\n"
+ "| PathIdentity []\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, UnwindAndGroup) {
+ ABT translated = translatePipeline(
+ "[{$unwind:{path: '$a.b', preserveNullAndEmptyArrays: true}}, "
+ "{$group:{_id: '$a.b'}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | agg_project_0\n"
+ "| RefBlock: \n"
+ "| Variable [agg_project_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [agg_project_0]\n"
+ "| EvalPath []\n"
+ "| | Const [{}]\n"
+ "| PathField [_id]\n"
+ "| PathConstant []\n"
+ "| Variable [groupByProj_0]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [groupByProj_0]\n"
+ "| aggregations: \n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [groupByProj_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [embedProj_0]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathGet [b]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [embedProj_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathField [a]\n"
+ "| PathTraverse []\n"
+ "| PathField [b]\n"
+ "| PathLambda []\n"
+ "| LambdaAbstraction [unwoundLambdaVarName_0]\n"
+ "| If []\n"
+ "| | | Variable [unwoundLambdaVarName_0]\n"
+ "| | Variable [unwoundProj_0]\n"
+ "| BinaryOp [Gte]\n"
+ "| | Const [0]\n"
+ "| Variable [unwoundPid_0]\n"
+ "Unwind [retainNonArrays]\n"
+ "| BindBlock:\n"
+ "| [unwoundPid_0]\n"
+ "| Source []\n"
+ "| [unwoundProj_0]\n"
+ "| Source []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [unwoundProj_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathGet [b]\n"
+ "| PathIdentity []\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, UnwindSort) {
+ ABT translated = translatePipeline("[{$unwind: '$x'}, {$sort: {'x': 1}}]");
+
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+ OptPhaseManager phaseManager(OptPhaseManager::getAllRewritesSet(),
+ prefixId,
+ {{{scanDefName, ScanDefinition{{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = translated;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | embedProj_0\n"
+ "| RefBlock: \n"
+ "| Variable [embedProj_0]\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | sort_0: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [sort_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [sort_0]\n"
+ "| Variable [unwoundProj_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [embedProj_0]\n"
+ "| If []\n"
+ "| | | Variable [scan_0]\n"
+ "| | FunctionCall [setField]\n"
+ "| | | | Variable [unwoundProj_0]\n"
+ "| | | Const [\"x\"]\n"
+ "| | Variable [scan_0]\n"
+ "| BinaryOp [Or]\n"
+ "| | FunctionCall [isObject]\n"
+ "| | Variable [scan_0]\n"
+ "| FunctionCall [exists]\n"
+ "| Variable [unwoundProj_0]\n"
+ "Unwind []\n"
+ "| BindBlock:\n"
+ "| [unwoundPid_0]\n"
+ "| Source []\n"
+ "| [unwoundProj_0]\n"
+ "| Source []\n"
+ "PhysicalScan [{'<root>': scan_0, 'x': unwoundProj_0}, collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n"
+ " [unwoundProj_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, MatchIndex) {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+
+ Metadata metadata = {
+ {{scanDefName,
+ ScanDefinition{{}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}};
+ ABT translated = translatePipeline(metadata, "[{$match: {'a': 10}}]", scanDefName, prefixId);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [10]\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+
+ ABT optimized = std::move(translated);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': scan_0}, collection]\n"
+ "| | BindBlock:\n"
+ "| | [scan_0]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: "
+ "{[Const [10], Const [10]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, MatchIndexCovered) {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+
+ Metadata metadata = {
+ {{scanDefName,
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}}},
+ false /*multiKey*/}}}}}}};
+ ABT translated = translatePipeline(
+ metadata, "[{$project: {_id: 0, a: 1}}, {$match: {'a': 10}}]", scanDefName, prefixId);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre,
+ OptPhaseManager::OptPhase::PathFuse,
+ OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = std::move(translated);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Const [{}]\n"
+ "| PathField [a]\n"
+ "| PathConstant []\n"
+ "| Variable [fieldProj_0]\n"
+ "IndexScan [{'<indexKey> 0': fieldProj_0}, scanDefName: collection, indexDefName: index1, "
+ "interval: {[Const [10], Const [10]]}]\n"
+ " BindBlock:\n"
+ " [fieldProj_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, MatchIndexCovered1) {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+
+ Metadata metadata = {
+ {{scanDefName,
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}}},
+ false /*multiKey*/}}}}}}};
+ ABT translated = translatePipeline(
+ metadata, "[{$match: {'a': 10}}, {$project: {_id: 0, a: 1}}]", scanDefName, prefixId);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre,
+ OptPhaseManager::OptPhase::PathFuse,
+ OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = std::move(translated);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Const [{}]\n"
+ "| PathField [a]\n"
+ "| PathConstant []\n"
+ "| Variable [fieldProj_0]\n"
+ "IndexScan [{'<indexKey> 0': fieldProj_0}, scanDefName: collection, indexDefName: index1, "
+ "interval: {[Const [10], Const [10]]}]\n"
+ " BindBlock:\n"
+ " [fieldProj_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, MatchIndexCovered2) {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+
+ Metadata metadata = {
+ {{scanDefName,
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("b"), CollationOp::Ascending}}},
+ false /*multiKey*/}}}}}}};
+ ABT translated = translatePipeline(metadata,
+ "[{$match: {'a': 10, 'b': 20}}, {$project: {_id: 0, a: 1}}]",
+ scanDefName,
+ prefixId);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre,
+ OptPhaseManager::OptPhase::PathFuse,
+ OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = std::move(translated);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Const [{}]\n"
+ "| PathField [a]\n"
+ "| PathConstant []\n"
+ "| Variable [fieldProj_0]\n"
+ "IndexScan [{'<indexKey> 0': fieldProj_0}, scanDefName: collection, indexDefName: index1, "
+ "interval: {[Const [10], Const [10]], [Const [20], Const [20]]}]\n"
+ " BindBlock:\n"
+ " [fieldProj_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, MatchIndexCovered3) {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+
+ Metadata metadata = {
+ {{scanDefName,
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("b"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("c"), CollationOp::Ascending}}},
+ false /*multiKey*/}}}}}}};
+ ABT translated = translatePipeline(
+ metadata,
+ "[{$match: {'a': 10, 'b': 20, 'c': 30}}, {$project: {_id: 0, a: 1, b: 1, c: 1}}]",
+ scanDefName,
+ prefixId);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre,
+ OptPhaseManager::OptPhase::PathFuse,
+ OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = std::move(translated);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Const [{}]\n"
+ "| PathComposeM []\n"
+ "| | PathField [c]\n"
+ "| | PathConstant []\n"
+ "| | Variable [fieldProj_2]\n"
+ "| PathComposeM []\n"
+ "| | PathField [b]\n"
+ "| | PathConstant []\n"
+ "| | Variable [fieldProj_1]\n"
+ "| PathField [a]\n"
+ "| PathConstant []\n"
+ "| Variable [fieldProj_0]\n"
+ "IndexScan [{'<indexKey> 0': fieldProj_0, '<indexKey> 1': fieldProj_1, '<indexKey> 2': "
+ "fieldProj_2}, scanDefName: collection, indexDefName: index1, interval: {[Const [10], "
+ "Const [10]], [Const [20], Const [20]], [Const [30], Const [30]]}]\n"
+ " BindBlock:\n"
+ " [fieldProj_0]\n"
+ " Source []\n"
+ " [fieldProj_1]\n"
+ " Source []\n"
+ " [fieldProj_2]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, MatchIndexCovered4) {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+
+ Metadata metadata = {
+ {{scanDefName,
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("b"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("c"), CollationOp::Ascending}}},
+ false /*multiKey*/}}}}}}};
+ ABT translated = translatePipeline(
+ metadata,
+ "[{$project: {_id: 0, a: 1, b: 1, c: 1}}, {$match: {'a': 10, 'b': 20, 'c': 30}}]",
+ scanDefName,
+ prefixId);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre,
+ OptPhaseManager::OptPhase::PathFuse,
+ OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = std::move(translated);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Const [{}]\n"
+ "| PathComposeM []\n"
+ "| | PathField [c]\n"
+ "| | PathConstant []\n"
+ "| | Variable [fieldProj_2]\n"
+ "| PathComposeM []\n"
+ "| | PathField [b]\n"
+ "| | PathConstant []\n"
+ "| | Variable [fieldProj_1]\n"
+ "| PathField [a]\n"
+ "| PathConstant []\n"
+ "| Variable [fieldProj_0]\n"
+ "IndexScan [{'<indexKey> 0': fieldProj_0, '<indexKey> 1': fieldProj_1, '<indexKey> 2': "
+ "fieldProj_2}, scanDefName: collection, indexDefName: index1, interval: {[Const [10], "
+ "Const [10]], [Const [20], Const [20]], [Const [30], Const [30]]}]\n"
+ " BindBlock:\n"
+ " [fieldProj_0]\n"
+ " Source []\n"
+ " [fieldProj_1]\n"
+ " Source []\n"
+ " [fieldProj_2]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, MatchSortIndex) {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+
+ Metadata metadata = {
+ {{scanDefName,
+ ScanDefinition{{}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}};
+ ABT translated = translatePipeline(
+ metadata, "[{$match: {'a': 10}}, {$sort: {'a': 1}}]", scanDefName, prefixId);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | sort_0: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [sort_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [sort_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathIdentity []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [10]\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+
+ ABT optimized = std::move(translated);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | sort_0: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [sort_0]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': scan_0, 'a': sort_0}, collection]\n"
+ "| | BindBlock:\n"
+ "| | [scan_0]\n"
+ "| | Source []\n"
+ "| | [sort_0]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: "
+ "{[Const [10], Const [10]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, RangeIndex) {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+ Metadata metadata = {
+ {{scanDefName,
+ ScanDefinition{{}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}};
+ ABT translated =
+ translatePipeline(metadata, "[{$match: {'a': {$gt: 70, $lt: 90}}}]", scanDefName, prefixId);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathComposeM []\n"
+ "| | PathCompare [Gte]\n"
+ "| | Const [nan]\n"
+ "| PathCompare [Lt]\n"
+ "| Const [90]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathComposeM []\n"
+ "| | PathCompare [Lt]\n"
+ "| | Const [\"\"]\n"
+ "| PathCompare [Gt]\n"
+ "| Const [70]\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
+
+ // Demonstrate we can get an intersection plan, even though it might not be the best one under
+ // the heuristic CE.
+ phaseManager.getHints()._disableScan = true;
+
+ ABT optimized = std::move(translated);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': scan_0}, collection]\n"
+ "| | BindBlock:\n"
+ "| | [scan_0]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | FunctionCall [getArraySize]\n"
+ "| | Variable [sides_0]\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [rid_0]\n"
+ "| aggregations: \n"
+ "| [sides_0]\n"
+ "| FunctionCall [$addToSet]\n"
+ "| Variable [sideId_0]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [rid_0]\n"
+ "| | Source []\n"
+ "| | [sideId_0]\n"
+ "| | Source []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [sideId_0]\n"
+ "| | Const [1]\n"
+ "| IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: "
+ "{[Const [nan], Const [90])}]\n"
+ "| BindBlock:\n"
+ "| [rid_0]\n"
+ "| Source []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [sideId_0]\n"
+ "| Const [0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: "
+ "{(Const [70], Const [\"\"])}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, Index1) {
+ {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+ Metadata metadata = {
+ {{scanDefName,
+ ScanDefinition{{},
+ {{"index1",
+ IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending},
+ {makeIndexPath("b"), CollationOp::Ascending}},
+ true /*multiKey*/}}}}}}};
+
+ ABT translated =
+ translatePipeline(metadata, "[{$match: {'a': 2, 'b': 2}}]", scanDefName, prefixId);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [b]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = translated;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': scan_0}, collection]\n"
+ "| | BindBlock:\n"
+ "| | [scan_0]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: "
+ "{[Const [2], Const [2]], [Const [2], Const [2]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+ }
+
+ {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+ Metadata metadata = {
+ {{scanDefName,
+ ScanDefinition{{}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}};
+
+ ABT translated =
+ translatePipeline(metadata, "[{$match: {'a': 2, 'b': 2}}]", scanDefName, prefixId);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [b]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+
+ // Demonstrate we can use an index over only one field.
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase,
+ OptPhaseManager::OptPhase::ConstEvalPost},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = translated;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| Filter []\n"
+ "| | EvalFilter []\n"
+ "| | | Variable [evalTemp_2]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Eq]\n"
+ "| | Const [2]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': scan_0, 'b': evalTemp_2}, collection]\n"
+ "| | BindBlock:\n"
+ "| | [evalTemp_2]\n"
+ "| | Source []\n"
+ "| | [scan_0]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: "
+ "{[Const [2], Const [2]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+ }
+}
+
+TEST(ABTTranslate, GroupMultiKey) {
+ ABT translated = translatePipeline(
+ "[{$group: {_id: {'isin': '$isin', 'year': '$year'}, 'count': {$sum: 1}, 'open': {$first: "
+ "'$$ROOT'}}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | agg_project_0\n"
+ "| RefBlock: \n"
+ "| Variable [agg_project_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [agg_project_0]\n"
+ "| EvalPath []\n"
+ "| | Const [{}]\n"
+ "| PathComposeM []\n"
+ "| | PathField [open]\n"
+ "| | PathConstant []\n"
+ "| | Variable [open_agg_0]\n"
+ "| PathComposeM []\n"
+ "| | PathField [count]\n"
+ "| | PathConstant []\n"
+ "| | Variable [count_agg_0]\n"
+ "| PathComposeM []\n"
+ "| | PathField [_id.year]\n"
+ "| | PathConstant []\n"
+ "| | Variable [groupByProj_1]\n"
+ "| PathField [_id.isin]\n"
+ "| PathConstant []\n"
+ "| Variable [groupByProj_0]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [groupByProj_0]\n"
+ "| | Variable [groupByProj_1]\n"
+ "| aggregations: \n"
+ "| [count_agg_0]\n"
+ "| FunctionCall [$sum]\n"
+ "| Const [1]\n"
+ "| [open_agg_0]\n"
+ "| FunctionCall [$first]\n"
+ "| Variable [scan_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [groupByProj_1]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [year]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [groupByProj_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [isin]\n"
+ "| PathIdentity []\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, GroupEvalNoInline) {
+ ABT translated = translatePipeline("[{$group: {_id: null, a: {$first: '$b'}}}]");
+
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+ OptPhaseManager phaseManager(OptPhaseManager::getAllRewritesSet(),
+ prefixId,
+ {{{scanDefName, ScanDefinition{{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = translated;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ // Verify that "b" is not inlined in the group expression, but is coming from the physical scan.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | agg_project_0\n"
+ "| RefBlock: \n"
+ "| Variable [agg_project_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [agg_project_0]\n"
+ "| Let [inputField_1]\n"
+ "| | If []\n"
+ "| | | | Variable [inputField_1]\n"
+ "| | | FunctionCall [setField]\n"
+ "| | | | | Variable [a_agg_0]\n"
+ "| | | | Const [\"a\"]\n"
+ "| | | Variable [inputField_1]\n"
+ "| | BinaryOp [Or]\n"
+ "| | | FunctionCall [isObject]\n"
+ "| | | Variable [inputField_1]\n"
+ "| | FunctionCall [exists]\n"
+ "| | Variable [a_agg_0]\n"
+ "| If []\n"
+ "| | | Const [{}]\n"
+ "| | FunctionCall [setField]\n"
+ "| | | | Variable [groupByProj_0]\n"
+ "| | | Const [\"_id\"]\n"
+ "| | Const [{}]\n"
+ "| BinaryOp [Or]\n"
+ "| | FunctionCall [isObject]\n"
+ "| | Const [{}]\n"
+ "| FunctionCall [exists]\n"
+ "| Variable [groupByProj_0]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [groupByProj_0]\n"
+ "| aggregations: \n"
+ "| [a_agg_0]\n"
+ "| FunctionCall [$first]\n"
+ "| Variable [groupByInputProj_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [groupByProj_0]\n"
+ "| Const [null]\n"
+ "PhysicalScan [{'b': groupByInputProj_0}, collection]\n"
+ " BindBlock:\n"
+ " [groupByInputProj_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, ArrayExpr) {
+ ABT translated = translatePipeline("[{$project: {a: ['$b', '$c']}}]");
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathComposeM []\n"
+ "| | PathDefault []\n"
+ "| | Const [{}]\n"
+ "| PathComposeM []\n"
+ "| | PathField [a]\n"
+ "| | PathConstant []\n"
+ "| | Variable [projGetPath_0]\n"
+ "| PathKeep [_id, a]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [projGetPath_0]\n"
+ "| FunctionCall [newArray]\n"
+ "| | EvalPath []\n"
+ "| | | Variable [scan_0]\n"
+ "| | PathGet [c]\n"
+ "| | PathIdentity []\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [b]\n"
+ "| PathIdentity []\n"
+ "Scan [collection]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+}
+
+TEST(ABTTranslate, Union) {
+ PrefixId prefixId;
+ std::string scanDefA = "collA";
+ std::string scanDefB = "collB";
+
+ Metadata metadata{{{scanDefA, {}}, {scanDefB, {}}}};
+ ABT translated = translatePipeline(metadata,
+ "[{$unionWith: 'collB'}, {$match: {_id: 1}}]",
+ prefixId.getNextId("scan"),
+ scanDefA,
+ prefixId,
+ {{NamespaceString("a." + scanDefB), {}}});
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [scan_0]\n"
+ "| PathGet [_id]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [scan_0]\n"
+ "| | Source []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [scan_0]\n"
+ "| | EvalPath []\n"
+ "| | | Variable [scan_1]\n"
+ "| | PathIdentity []\n"
+ "| Scan [collB]\n"
+ "| BindBlock:\n"
+ "| [scan_1]\n"
+ "| Source []\n"
+ "Scan [collA]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ translated);
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{scanDefA, ScanDefinition{{}, {}}}, {scanDefB, ScanDefinition{{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = translated;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ // Note that the optimized ABT will show the filter push-down.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [scan_0]\n"
+ "| | Source []\n"
+ "| Filter []\n"
+ "| | EvalFilter []\n"
+ "| | | Variable [scan_0]\n"
+ "| | PathGet [_id]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Eq]\n"
+ "| | Const [1]\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [scan_0]\n"
+ "| | EvalPath []\n"
+ "| | | Variable [scan_1]\n"
+ "| | PathIdentity []\n"
+ "| PhysicalScan [{'<root>': scan_1}, collB]\n"
+ "| BindBlock:\n"
+ "| [scan_1]\n"
+ "| Source []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_0]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "PhysicalScan [{'<root>': scan_0, '_id': evalTemp_0}, collA]\n"
+ " BindBlock:\n"
+ " [evalTemp_0]\n"
+ " Source []\n"
+ " [scan_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, PartialIndex) {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+ ProjectionName scanProjName = prefixId.getNextId("scan");
+
+ // The expression matches the pipeline.
+ // By default the constant is translated as "int32".
+ auto conversionResult = convertExprToPartialSchemaReq(make<EvalFilter>(
+ make<PathGet>("b",
+ make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int32(2)))),
+ make<Variable>(scanProjName)));
+ ASSERT_TRUE(conversionResult._success);
+ ASSERT_FALSE(conversionResult._hasEmptyInterval);
+
+ Metadata metadata = {
+ {{scanDefName,
+ ScanDefinition{{},
+ {{"index1",
+ IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}},
+ true /*multiKey*/,
+ {DistributionType::Centralized},
+ std::move(conversionResult._reqMap)}}}}}}};
+
+ ABT translated = translatePipeline(
+ metadata, "[{$match: {'a': 3, 'b': 2}}]", scanProjName, scanDefName, prefixId);
+
+ OptPhaseManager phaseManager(
+ OptPhaseManager::getAllRewritesSet(), prefixId, metadata, DebugInfo::kDefaultForTests);
+
+ ABT optimized = translated;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| Filter []\n"
+ "| | FunctionCall [traverseF]\n"
+ "| | | | Const [false]\n"
+ "| | | LambdaAbstraction [valCmp_0]\n"
+ "| | | BinaryOp [Eq]\n"
+ "| | | | Const [2]\n"
+ "| | | Variable [valCmp_0]\n"
+ "| | Variable [evalTemp_2]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': scan_0, 'b': evalTemp_2}, collection]\n"
+ "| | BindBlock:\n"
+ "| | [evalTemp_2]\n"
+ "| | Source []\n"
+ "| | [scan_0]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: "
+ "{[Const [3], Const [3]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, PartialIndexNegative) {
+ PrefixId prefixId;
+ std::string scanDefName = "collection";
+ ProjectionName scanProjName = prefixId.getNextId("scan");
+
+ // The expression does not match the pipeline.
+ auto conversionResult = convertExprToPartialSchemaReq(make<EvalFilter>(
+ make<PathGet>("b",
+ make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int32(2)))),
+ make<Variable>(scanProjName)));
+ ASSERT_TRUE(conversionResult._success);
+ ASSERT_FALSE(conversionResult._hasEmptyInterval);
+
+ Metadata metadata = {
+ {{scanDefName,
+ ScanDefinition{{},
+ {{"index1",
+ IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}},
+ true /*multiKey*/,
+ {DistributionType::Centralized},
+ std::move(conversionResult._reqMap)}}}}}}};
+
+ ABT translated = translatePipeline(
+ metadata, "[{$match: {'a': 3, 'b': 3}}]", scanProjName, scanDefName, prefixId);
+
+ OptPhaseManager phaseManager(
+ OptPhaseManager::getAllRewritesSet(), prefixId, metadata, DebugInfo::kDefaultForTests);
+
+ ABT optimized = translated;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Filter []\n"
+ "| FunctionCall [traverseF]\n"
+ "| | | Const [false]\n"
+ "| | LambdaAbstraction [valCmp_1]\n"
+ "| | BinaryOp [Eq]\n"
+ "| | | Const [3]\n"
+ "| | Variable [valCmp_1]\n"
+ "| Variable [evalTemp_1]\n"
+ "Filter []\n"
+ "| FunctionCall [traverseF]\n"
+ "| | | Const [false]\n"
+ "| | LambdaAbstraction [valCmp_0]\n"
+ "| | BinaryOp [Eq]\n"
+ "| | | Const [3]\n"
+ "| | Variable [valCmp_0]\n"
+ "| Variable [evalTemp_0]\n"
+ "PhysicalScan [{'<root>': scan_0, 'a': evalTemp_0, 'b': evalTemp_1}, collection]\n"
+ " BindBlock:\n"
+ " [evalTemp_0]\n"
+ " Source []\n"
+ " [evalTemp_1]\n"
+ " Source []\n"
+ " [scan_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(ABTTranslate, CommonExpressionElimination) {
+ PrefixId prefixId;
+ Metadata metadata = {{{"test", {{}, {}}}}};
+
+ auto rootNode =
+ translatePipeline(metadata,
+ "[{$project: {foo: {$add: ['$b', 1]}, bar: {$add: ['$b', 1]}}}]",
+ "test",
+ prefixId);
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::ConstEvalPre}, prefixId, metadata, DebugInfo::kDefaultForTests);
+
+ ASSERT_TRUE(phaseManager.optimize(rootNode));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | combinedProjection_0\n"
+ "| RefBlock: \n"
+ "| Variable [combinedProjection_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [combinedProjection_0]\n"
+ "| EvalPath []\n"
+ "| | Variable [scan_0]\n"
+ "| PathComposeM []\n"
+ "| | PathDefault []\n"
+ "| | Const [{}]\n"
+ "| PathComposeM []\n"
+ "| | PathField [foo]\n"
+ "| | PathConstant []\n"
+ "| | Variable [projGetPath_0]\n"
+ "| PathComposeM []\n"
+ "| | PathField [bar]\n"
+ "| | PathConstant []\n"
+ "| | Variable [projGetPath_0]\n"
+ "| PathKeep [_id, bar, foo]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [projGetPath_0]\n"
+ "| BinaryOp [Add]\n"
+ "| | EvalPath []\n"
+ "| | | Variable [scan_0]\n"
+ "| | PathGet [b]\n"
+ "| | PathIdentity []\n"
+ "| Const [1]\n"
+ "Scan [test]\n"
+ " BindBlock:\n"
+ " [scan_0]\n"
+ " Source []\n",
+ rootNode);
+}
+
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/abt/utils.cpp b/src/mongo/db/pipeline/abt/utils.cpp
new file mode 100644
index 00000000000..18e22eb265d
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/utils.cpp
@@ -0,0 +1,104 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/pipeline/abt/utils.h"
+
+namespace mongo::optimizer {
+
+template <bool isConjunction, typename... Args>
+ABT generateConjunctionOrDisjunction(Args&... args) {
+ ABTVector elements;
+ (elements.emplace_back(args), ...);
+
+ if (elements.size() == 0) {
+ return Constant::boolean(isConjunction);
+ }
+
+ ABT result = std::move(elements.at(0));
+ for (size_t i = 1; i < elements.size(); i++) {
+ result = make<BinaryOp>(isConjunction ? Operations::And : Operations::Or,
+ std::move(elements.at(i)),
+ std::move(result));
+ }
+ return result;
+}
+
+ABT generateCoerceToBool(ABT input, const std::string& varName) {
+ // Adapted from sbe_stage_builder_expression.cpp::generateCoerceToBoolExpression.
+
+ const auto makeNeqCheckFn = [&varName](ABT valExpr) {
+ return make<BinaryOp>(
+ Operations::Neq,
+ make<BinaryOp>(Operations::Cmp3w, make<Variable>(varName), std::move(valExpr)),
+ Constant::int64(0));
+ };
+
+ // If any of these are false, the branch is considered false for the purposes of the
+ // any logical expression.
+ ABT checkExists = make<FunctionCall>("exists", makeSeq(input));
+ ABT checkNotNull = make<UnaryOp>(Operations::Not, make<FunctionCall>("isNull", makeSeq(input)));
+ ABT checkNotFalse = makeNeqCheckFn(Constant::boolean(false));
+ ABT checkNotZero = makeNeqCheckFn(Constant::int64(0));
+
+ return make<Let>(varName,
+ std::move(input),
+ generateConjunctionOrDisjunction<true /*isConjunction*/>(
+ checkExists, checkNotNull, checkNotFalse, checkNotZero));
+};
+
+std::pair<sbe::value::TypeTags, sbe::value::Value> convertFrom(const Value val) {
+ // TODO: Either make this conversion unnecessary by changing the value representation in
+ // ExpressionConstant, or provide a nicer way to convert directly from Document/Value to
+ // sbe::Value.
+ BSONObjBuilder bob;
+ val.addToBsonObj(&bob, ""_sd);
+ auto obj = bob.done();
+ auto be = obj.objdata();
+ auto end = be + ConstDataView(be).read<LittleEndian<uint32_t>>();
+ return sbe::bson::convertFrom<false>(be + 4, end, 0);
+}
+
+ABT translateFieldPath(const FieldPath& fieldPath,
+ ABT initial,
+ const ABTFieldNameFn& fieldNameFn,
+ const size_t skipFromStart) {
+ ABT result = std::move(initial);
+
+ const size_t fieldPathLength = fieldPath.getPathLength();
+ bool isLastElement = true;
+ for (size_t i = fieldPathLength; i-- > skipFromStart;) {
+ result =
+ fieldNameFn(fieldPath.getFieldName(i).toString(), isLastElement, std::move(result));
+ isLastElement = false;
+ }
+
+ return result;
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/pipeline/abt/utils.h b/src/mongo/db/pipeline/abt/utils.h
new file mode 100644
index 00000000000..85c1a643764
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/utils.h
@@ -0,0 +1,54 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/exec/document_value/value.h"
+#include "mongo/db/exec/sbe/values/bson.h"
+#include "mongo/db/pipeline/field_path.h"
+#include "mongo/db/query/optimizer/node.h"
+
+namespace mongo::optimizer {
+
+/**
+ * Generate an AST to compute "coerceToBool" on the input. Expects a variable name to use as a let
+ * variable.
+ */
+ABT generateCoerceToBool(ABT input, const std::string& varName);
+
+std::pair<sbe::value::TypeTags, sbe::value::Value> convertFrom(Value val);
+
+using ABTFieldNameFn =
+ std::function<ABT(const std::string& fieldName, const bool isLastElement, ABT input)>;
+ABT translateFieldPath(const FieldPath& fieldPath,
+ ABT initial,
+ const ABTFieldNameFn& fieldNameFn,
+ size_t skipFromStart = 0);
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/pipeline/document_source_union_with.h b/src/mongo/db/pipeline/document_source_union_with.h
index 1d9068a9f4c..239a41c9235 100644
--- a/src/mongo/db/pipeline/document_source_union_with.h
+++ b/src/mongo/db/pipeline/document_source_union_with.h
@@ -134,6 +134,10 @@ public:
return &_stats;
}
+ const Pipeline& getPipeline() const {
+ return *_pipeline;
+ }
+
boost::intrusive_ptr<DocumentSource> clone() const final;
protected:
diff --git a/src/mongo/db/pipeline/visitors/document_source_visitor.h b/src/mongo/db/pipeline/visitors/document_source_visitor.h
new file mode 100644
index 00000000000..a0158147e38
--- /dev/null
+++ b/src/mongo/db/pipeline/visitors/document_source_visitor.h
@@ -0,0 +1,135 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/tree_walker.h"
+
+namespace mongo {
+
+class DocumentSourceBucketAuto;
+class DocumentSourceCollStats;
+class DocumentSourceCurrentOp;
+class DocumentSourceCursor;
+class DocumentSourceExchange;
+class DocumentSourceFacet;
+class DocumentSourceGeoNear;
+class DocumentSourceGeoNearCursor;
+class DocumentSourceGraphLookUp;
+class DocumentSourceGroup;
+class DocumentSourceIndexStats;
+class DocumentSourceInternalInhibitOptimization;
+class DocumentSourceInternalShardFilter;
+class DocumentSourceInternalSplitPipeline;
+class DocumentSourceLimit;
+class DocumentSourceListCachedAndActiveUsers;
+class DocumentSourceListLocalSessions;
+class DocumentSourceListSessions;
+class DocumentSourceLookUp;
+class DocumentSourceMatch;
+class DocumentSourceMerge;
+class DocumentSourceMergeCursors;
+class DocumentSourceOperationMetrics;
+class DocumentSourceOut;
+class DocumentSourcePlanCacheStats;
+class DocumentSourceQueue;
+class DocumentSourceRedact;
+class DocumentSourceSample;
+class DocumentSourceSampleFromRandomCursor;
+class DocumentSourceSequentialDocumentCache;
+class DocumentSourceSingleDocumentTransformation;
+class DocumentSourceSkip;
+class DocumentSourceSort;
+class DocumentSourceTeeConsumer;
+class DocumentSourceUnionWith;
+class DocumentSourceUnwind;
+
+/**
+ * Visitor pattern for pipeline document sources.
+ *
+ * This code is not responsible for traversing the tree, only for performing the double-dispatch.
+ *
+ * If the visitor doesn't intend to modify the tree, then the template argument 'IsConst' should be
+ * set to 'true'. In this case all 'visit()' methods will take a const pointer to a visiting node.
+ */
+template <bool IsConst = false>
+class DocumentSourceVisitor {
+public:
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceBucketAuto> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceCollStats> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceCurrentOp> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceCursor> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceExchange> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceFacet> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceGeoNear> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceGeoNearCursor> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceGraphLookUp> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceGroup> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceIndexStats> source) = 0;
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, DocumentSourceInternalInhibitOptimization> source) = 0;
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, DocumentSourceInternalShardFilter> source) = 0;
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, DocumentSourceInternalSplitPipeline> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceLimit> source) = 0;
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, DocumentSourceListCachedAndActiveUsers> source) = 0;
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, DocumentSourceListLocalSessions> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceListSessions> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceLookUp> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceMatch> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceMerge> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceMergeCursors> source) = 0;
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, DocumentSourceOperationMetrics> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceOut> source) = 0;
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, DocumentSourcePlanCacheStats> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceQueue> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceRedact> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceSample> source) = 0;
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, DocumentSourceSampleFromRandomCursor> source) = 0;
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, DocumentSourceSequentialDocumentCache> source) = 0;
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, DocumentSourceSingleDocumentTransformation> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceSkip> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceSort> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceTeeConsumer> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceUnionWith> source) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceUnwind> source) = 0;
+};
+
+using DocumentSourceMutableVisitor = DocumentSourceVisitor<false>;
+using DocumentSourceConstVisitor = DocumentSourceVisitor<true>;
+
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/visitors/document_source_walker.cpp b/src/mongo/db/pipeline/visitors/document_source_walker.cpp
new file mode 100644
index 00000000000..b0ea004cae9
--- /dev/null
+++ b/src/mongo/db/pipeline/visitors/document_source_walker.cpp
@@ -0,0 +1,144 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/pipeline/visitors/document_source_walker.h"
+
+#include "mongo/base/error_codes.h"
+#include "mongo/db/pipeline/document_source_bucket_auto.h"
+#include "mongo/db/pipeline/document_source_coll_stats.h"
+#include "mongo/db/pipeline/document_source_current_op.h"
+#include "mongo/db/pipeline/document_source_cursor.h"
+#include "mongo/db/pipeline/document_source_exchange.h"
+#include "mongo/db/pipeline/document_source_facet.h"
+#include "mongo/db/pipeline/document_source_geo_near.h"
+#include "mongo/db/pipeline/document_source_geo_near_cursor.h"
+#include "mongo/db/pipeline/document_source_graph_lookup.h"
+#include "mongo/db/pipeline/document_source_group.h"
+#include "mongo/db/pipeline/document_source_index_stats.h"
+#include "mongo/db/pipeline/document_source_internal_inhibit_optimization.h"
+#include "mongo/db/pipeline/document_source_internal_shard_filter.h"
+#include "mongo/db/pipeline/document_source_internal_split_pipeline.h"
+#include "mongo/db/pipeline/document_source_limit.h"
+#include "mongo/db/pipeline/document_source_list_cached_and_active_users.h"
+#include "mongo/db/pipeline/document_source_list_local_sessions.h"
+#include "mongo/db/pipeline/document_source_list_sessions.h"
+#include "mongo/db/pipeline/document_source_lookup.h"
+#include "mongo/db/pipeline/document_source_match.h"
+#include "mongo/db/pipeline/document_source_merge.h"
+#include "mongo/db/pipeline/document_source_operation_metrics.h"
+#include "mongo/db/pipeline/document_source_out.h"
+#include "mongo/db/pipeline/document_source_plan_cache_stats.h"
+#include "mongo/db/pipeline/document_source_queue.h"
+#include "mongo/db/pipeline/document_source_redact.h"
+#include "mongo/db/pipeline/document_source_sample.h"
+#include "mongo/db/pipeline/document_source_sample_from_random_cursor.h"
+#include "mongo/db/pipeline/document_source_sequential_document_cache.h"
+#include "mongo/db/pipeline/document_source_single_document_transformation.h"
+#include "mongo/db/pipeline/document_source_skip.h"
+#include "mongo/db/pipeline/document_source_sort.h"
+#include "mongo/db/pipeline/document_source_tee_consumer.h"
+#include "mongo/db/pipeline/document_source_union_with.h"
+#include "mongo/db/pipeline/document_source_unwind.h"
+#include "mongo/s/query/document_source_merge_cursors.h"
+
+namespace mongo {
+
+template <class T>
+bool DocumentSourceWalker::visitHelper(const DocumentSource* source) {
+ const T* concrete = dynamic_cast<const T*>(source);
+ if (concrete == nullptr) {
+ return false;
+ }
+
+ _postVisitor->visit(concrete);
+ return true;
+}
+
+void DocumentSourceWalker::walk(const Pipeline& pipeline) {
+ const Pipeline::SourceContainer& sources = pipeline.getSources();
+
+ if (_postVisitor != nullptr) {
+ for (auto it = sources.begin(); it != sources.end(); it++) {
+ // TODO: use acceptVisitor method when DocumentSources get ability to visit.
+ // source->acceptVisitor(*_preVisitor);
+ //
+ // For now, however, we use a crutch walker which performs a series of dynamic casts.
+ // Some types are commented out because of dependency issues (e.g. not in pipeline
+ // target but in query_exec target)
+ const DocumentSource* ds = it->get();
+ const bool visited = visitHelper<DocumentSourceBucketAuto>(ds) ||
+ visitHelper<DocumentSourceBucketAuto>(ds) ||
+ visitHelper<DocumentSourceCollStats>(ds) ||
+ visitHelper<DocumentSourceCurrentOp>(ds) ||
+ // TODO: uncomment after fixing dependency
+ // visitHelper<DocumentSourceCursor>(ds) ||
+ visitHelper<DocumentSourceExchange>(ds) || visitHelper<DocumentSourceFacet>(ds) ||
+ visitHelper<DocumentSourceGeoNear>(ds) ||
+
+ // TODO: uncomment after fixing dependency
+ //! visitHelper<DocumentSourceGeoNearCursor>(ds) ||
+ visitHelper<DocumentSourceGraphLookUp>(ds) ||
+ visitHelper<DocumentSourceGroup>(ds) || visitHelper<DocumentSourceIndexStats>(ds) ||
+ visitHelper<DocumentSourceInternalInhibitOptimization>(ds) ||
+ visitHelper<DocumentSourceInternalShardFilter>(ds) ||
+ visitHelper<DocumentSourceInternalSplitPipeline>(ds) ||
+ visitHelper<DocumentSourceLimit>(ds) ||
+ visitHelper<DocumentSourceListCachedAndActiveUsers>(ds) ||
+ visitHelper<DocumentSourceListLocalSessions>(ds) ||
+ visitHelper<DocumentSourceListSessions>(ds) ||
+ visitHelper<DocumentSourceLookUp>(ds) || visitHelper<DocumentSourceMatch>(ds) ||
+ visitHelper<DocumentSourceMerge>(ds) ||
+ // TODO: uncomment after fixing dependency
+ // visitHelper<DocumentSourceMergeCursors>(ds) ||
+ visitHelper<DocumentSourceOperationMetrics>(ds) ||
+ visitHelper<DocumentSourceOut>(ds) ||
+ visitHelper<DocumentSourcePlanCacheStats>(ds) ||
+ visitHelper<DocumentSourceQueue>(ds) || visitHelper<DocumentSourceRedact>(ds) ||
+ visitHelper<DocumentSourceSample>(ds) ||
+ visitHelper<DocumentSourceSampleFromRandomCursor>(ds) ||
+ visitHelper<DocumentSourceSequentialDocumentCache>(ds) ||
+ visitHelper<DocumentSourceSingleDocumentTransformation>(ds) ||
+ visitHelper<DocumentSourceSkip>(ds) || visitHelper<DocumentSourceSort>(ds) ||
+ visitHelper<DocumentSourceTeeConsumer>(ds) ||
+ visitHelper<DocumentSourceUnionWith>(ds) || visitHelper<DocumentSourceUnwind>(ds)
+ // TODO: uncomment after fixing dependency
+ //&& visitHelper<DocumentSourceUpdateOnAddShard>(ds)
+ ;
+
+ if (!visited) {
+ uasserted(ErrorCodes::InternalErrorNotSupported,
+ str::stream() << "Stage is not supported: " << ds->getSourceName());
+ }
+ }
+ }
+
+ // TODO: reverse for pre-visitor
+}
+
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/visitors/document_source_walker.h b/src/mongo/db/pipeline/visitors/document_source_walker.h
new file mode 100644
index 00000000000..cb4a844de4d
--- /dev/null
+++ b/src/mongo/db/pipeline/visitors/document_source_walker.h
@@ -0,0 +1,58 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/pipeline/document_source.h"
+#include "mongo/db/pipeline/pipeline.h"
+#include "mongo/db/pipeline/visitors/document_source_visitor.h"
+
+namespace mongo {
+
+/**
+ * A document source walker.
+ * TODO: SERVER-62657. Implement a hash-table based resolution instead of sequential dynamic casts.
+ */
+class DocumentSourceWalker final {
+public:
+ DocumentSourceWalker(DocumentSourceConstVisitor* preVisitor,
+ DocumentSourceConstVisitor* postVisitor)
+ : _preVisitor{preVisitor}, _postVisitor{postVisitor} {}
+
+ void walk(const Pipeline& pipeline);
+
+private:
+ template <class T>
+ bool visitHelper(const DocumentSource* source);
+
+ DocumentSourceConstVisitor* _preVisitor;
+ DocumentSourceConstVisitor* _postVisitor;
+};
+
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/visitors/transformer_interface_visitor.h b/src/mongo/db/pipeline/visitors/transformer_interface_visitor.h
new file mode 100644
index 00000000000..e81075e8e9f
--- /dev/null
+++ b/src/mongo/db/pipeline/visitors/transformer_interface_visitor.h
@@ -0,0 +1,66 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/tree_walker.h"
+
+namespace mongo {
+
+namespace projection_executor {
+class AddFieldsProjectionExecutor;
+class ExclusionProjectionExecutor;
+class InclusionProjectionExecutor;
+} // namespace projection_executor
+class GroupFromFirstDocumentTransformation;
+class ReplaceRootTransformation;
+
+/**
+ * Visitor pattern for Transformer Interface instances
+ */
+template <bool IsConst = false>
+class TransformerInterfaceVisitor {
+public:
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, projection_executor::AddFieldsProjectionExecutor>
+ visitor) = 0;
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, projection_executor::ExclusionProjectionExecutor>
+ visitor) = 0;
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, projection_executor::InclusionProjectionExecutor>
+ visitor) = 0;
+ virtual void visit(
+ tree_walker::MaybeConstPtr<IsConst, GroupFromFirstDocumentTransformation> visitor) = 0;
+ virtual void visit(tree_walker::MaybeConstPtr<IsConst, ReplaceRootTransformation> visitor) = 0;
+};
+
+using TransformerInterfaceMutableVisitor = TransformerInterfaceVisitor<false>;
+using TransformerInterfaceConstVisitor = TransformerInterfaceVisitor<true>;
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/visitors/transformer_interface_walker.cpp b/src/mongo/db/pipeline/visitors/transformer_interface_walker.cpp
new file mode 100644
index 00000000000..11555f764b9
--- /dev/null
+++ b/src/mongo/db/pipeline/visitors/transformer_interface_walker.cpp
@@ -0,0 +1,72 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/pipeline/visitors/transformer_interface_walker.h"
+#include "mongo/db/exec/add_fields_projection_executor.h"
+#include "mongo/db/exec/exclusion_projection_executor.h"
+#include "mongo/db/exec/inclusion_projection_executor.h"
+#include "mongo/db/pipeline/document_source_group.h"
+#include "mongo/db/pipeline/document_source_replace_root.h"
+
+namespace mongo {
+
+TransformerInterfaceWalker::TransformerInterfaceWalker(TransformerInterfaceConstVisitor* visitor)
+ : _visitor(visitor) {}
+
+void TransformerInterfaceWalker::walk(const TransformerInterface* transformer) {
+ switch (transformer->getType()) {
+ case TransformerInterface::TransformerType::kExclusionProjection:
+ _visitor->visit(
+ static_cast<const projection_executor::ExclusionProjectionExecutor*>(transformer));
+ break;
+
+ case TransformerInterface::TransformerType::kInclusionProjection:
+ _visitor->visit(
+ static_cast<const projection_executor::InclusionProjectionExecutor*>(transformer));
+ break;
+
+ case TransformerInterface::TransformerType::kComputedProjection:
+ _visitor->visit(
+ static_cast<const projection_executor::AddFieldsProjectionExecutor*>(transformer));
+ break;
+
+ case TransformerInterface::TransformerType::kReplaceRoot:
+ _visitor->visit(static_cast<const ReplaceRootTransformation*>(transformer));
+ break;
+
+ case TransformerInterface::TransformerType::kGroupFromFirstDocument:
+ _visitor->visit(static_cast<const GroupFromFirstDocumentTransformation*>(transformer));
+ break;
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+}
+
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/visitors/transformer_interface_walker.h b/src/mongo/db/pipeline/visitors/transformer_interface_walker.h
new file mode 100644
index 00000000000..32c8248e00e
--- /dev/null
+++ b/src/mongo/db/pipeline/visitors/transformer_interface_walker.h
@@ -0,0 +1,48 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+
+#include "mongo/db/pipeline/transformer_interface.h"
+#include "mongo/db/pipeline/visitors/transformer_interface_visitor.h"
+
+namespace mongo {
+
+class TransformerInterfaceWalker final {
+public:
+ TransformerInterfaceWalker(TransformerInterfaceConstVisitor* visitor);
+
+ void walk(const TransformerInterface* transformer);
+
+private:
+ TransformerInterfaceConstVisitor* _visitor;
+};
+
+} // namespace mongo
diff --git a/src/mongo/db/query/SConscript b/src/mongo/db/query/SConscript
index 67b992fe02a..77d2fabc948 100644
--- a/src/mongo/db/query/SConscript
+++ b/src/mongo/db/query/SConscript
@@ -6,8 +6,10 @@ env = env.Clone()
env.SConscript(
dirs=[
+ "ce",
"collation",
"datetime",
+ 'optimizer',
],
exports=[
'env'
diff --git a/src/mongo/db/query/ce/SConscript b/src/mongo/db/query/ce/SConscript
new file mode 100644
index 00000000000..8ab2ca62f51
--- /dev/null
+++ b/src/mongo/db/query/ce/SConscript
@@ -0,0 +1,16 @@
+# -*- mode: python -*-
+
+Import("env")
+
+env = env.Clone()
+
+env.Library(
+ target="query_ce",
+ source=[
+ 'ce_sampling.cpp',
+ ],
+ LIBDEPS_PRIVATE=[
+ '$BUILD_DIR/mongo/db/exec/sbe/query_sbe_abt',
+ '$BUILD_DIR/mongo/db/query/optimizer/optimizer',
+ ]
+) \ No newline at end of file
diff --git a/src/mongo/db/query/ce/ce_sampling.cpp b/src/mongo/db/query/ce/ce_sampling.cpp
new file mode 100644
index 00000000000..e95b836bd57
--- /dev/null
+++ b/src/mongo/db/query/ce/ce_sampling.cpp
@@ -0,0 +1,290 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/ce/ce_sampling.h"
+
+#include "mongo/db/exec/sbe/abt/abt_lower.h"
+#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
+#include "mongo/db/query/optimizer/explain.h"
+#include "mongo/db/query/optimizer/utils/abt_hash.h"
+#include "mongo/db/query/optimizer/utils/memo_utils.h"
+
+namespace mongo::optimizer::cascades {
+
+using namespace properties;
+
+class SamplingPlanExtractor {
+public:
+ SamplingPlanExtractor(const Memo& memo, const size_t sampleSize)
+ : _memo(memo), _sampleSize(sampleSize) {}
+
+ void transport(ABT& n, const MemoLogicalDelegatorNode& node) {
+ n = extract(_memo.getGroup(node.getGroupId())._logicalNodes.at(0));
+ }
+
+ void transport(ABT& n, const ScanNode& /*node*/, ABT& /*binder*/) {
+ // We will lower the scan node in a sampling context here.
+ // TODO: for now just return the documents in random order.
+ n = make<LimitSkipNode>(LimitSkipRequirement(_sampleSize, 0), std::move(n));
+ }
+
+ void transport(ABT& n, const FilterNode& /*node*/, ABT& childResult, ABT& /*exprResult*/) {
+ // Skip over filters.
+ n = childResult;
+ }
+
+ void transport(ABT& /*n*/,
+ const EvaluationNode& /*node*/,
+ ABT& /*childResult*/,
+ ABT& /*exprResult*/) {
+ // Keep Eval nodes.
+ }
+
+ void transport(ABT& n, const SargableNode& /*node*/, ABT& childResult, ABT& refs, ABT& binds) {
+ // Skip over sargable nodes.
+ n = childResult;
+ }
+
+ template <typename T, typename... Ts>
+ void transport(ABT& /*n*/, const T& /*node*/, Ts&&...) {
+ if constexpr (std::is_base_of_v<Node, T>) {
+ uasserted(6624242, "Should not be seeing other types of nodes here.");
+ }
+ }
+
+ ABT extract(ABT node) {
+ algebra::transport<true>(node, *this);
+ return node;
+ }
+
+private:
+ const Memo& _memo;
+ const size_t _sampleSize;
+};
+
+class CESamplingTransportImpl {
+ static constexpr size_t kMaxSampleSize = 1000;
+
+public:
+ CESamplingTransportImpl(OperationContext* opCtx,
+ OptPhaseManager& phaseManager,
+ const int64_t numRecords)
+ : _opCtx(opCtx),
+ _phaseManager(phaseManager),
+ _heuristicCE(),
+ _sampleSize(std::min<int64_t>(numRecords, kMaxSampleSize)) {}
+
+ CEType transport(const ABT& n,
+ const FilterNode& node,
+ const Memo& memo,
+ const LogicalProps& logicalProps,
+ CEType childResult,
+ CEType /*exprResult*/) {
+ if (!hasProperty<IndexingAvailability>(logicalProps)) {
+ return _heuristicCE.deriveCE(memo, logicalProps, n.ref());
+ }
+
+ SamplingPlanExtractor planExtractor(memo, _sampleSize);
+ // Create a plan with all eval nodes so far and the filter last.
+ ABT abtTree = make<FilterNode>(node.getFilter(), planExtractor.extract(n));
+
+ return estimateFilterCE(memo, logicalProps, n, std::move(abtTree), childResult);
+ }
+
+ CEType transport(const ABT& n,
+ const SargableNode& node,
+ const Memo& memo,
+ const LogicalProps& logicalProps,
+ CEType childResult,
+ CEType /*bindResult*/,
+ CEType /*refsResult*/) {
+ if (!hasProperty<IndexingAvailability>(logicalProps)) {
+ return _heuristicCE.deriveCE(memo, logicalProps, n.ref());
+ }
+
+ SamplingPlanExtractor planExtractor(memo, _sampleSize);
+ ABT extracted = planExtractor.extract(n);
+
+ // Estimate individual requirements separately by potentially re-using cached results.
+ // Here we assume that each requirement is independent.
+ // TODO: consider estimating together the entire set of requirements (but caching!)
+ CEType result = childResult;
+ for (const auto& [key, req] : node.getReqMap()) {
+ if (!isIntervalReqFullyOpenDNF(req.getIntervals())) {
+ ABT lowered = extracted;
+ lowerPartialSchemaRequirement(key, req, lowered);
+ uassert(6624243, "Expected a filter node", lowered.is<FilterNode>());
+ result = estimateFilterCE(memo, logicalProps, n, std::move(lowered), result);
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Other ABT types.
+ */
+ template <typename T, typename... Ts>
+ CEType transport(const ABT& n,
+ const T& /*node*/,
+ const Memo& memo,
+ const LogicalProps& logicalProps,
+ Ts&&...) {
+ if (canBeLogicalNode<T>()) {
+ return _heuristicCE.deriveCE(memo, logicalProps, n.ref());
+ }
+ return 0.0;
+ }
+
+ CEType derive(const Memo& memo,
+ const properties::LogicalProps& logicalProps,
+ const ABT::reference_type logicalNodeRef) {
+ return algebra::transport<true>(logicalNodeRef, *this, memo, logicalProps);
+ }
+
+private:
+ CEType estimateFilterCE(const Memo& memo,
+ const LogicalProps& logicalProps,
+ const ABT& n,
+ ABT abtTree,
+ CEType childResult) {
+ auto it = _selectivityCacheMap.find(abtTree);
+ if (it != _selectivityCacheMap.cend()) {
+ // Cache hit.
+ return it->second * childResult;
+ }
+
+ const auto [success, selectivity] = estimateSelectivity(abtTree);
+ if (!success) {
+ return _heuristicCE.deriveCE(memo, logicalProps, n.ref());
+ }
+
+ _selectivityCacheMap.emplace(std::move(abtTree), selectivity);
+ std::cerr << "Sampling sel.: " << selectivity << "\n";
+ return selectivity * childResult;
+ }
+
+ std::pair<bool, SelectivityType> estimateSelectivity(ABT abtTree) {
+ // Add a group by to count number of documents.
+ const ProjectionName sampleSumProjection = "sum";
+ abtTree =
+ make<GroupByNode>(ProjectionNameVector{},
+ ProjectionNameVector{sampleSumProjection},
+ makeSeq(make<FunctionCall>("$sum", makeSeq(Constant::int64(1)))),
+ std::move(abtTree));
+ abtTree = make<RootNode>(
+ properties::ProjectionRequirement{ProjectionNameVector{sampleSumProjection}},
+ std::move(abtTree));
+
+ std::cerr << "********* Sampling ABT *********\n";
+ std::cerr << ExplainGenerator::explainV2(abtTree);
+ std::cerr << "********* Sampling ABT *********\n";
+
+ if (!_phaseManager.optimize(abtTree)) {
+ return {false, {}};
+ }
+
+ auto env = VariableEnvironment::build(abtTree);
+ SlotVarMap slotMap;
+ sbe::value::SlotIdGenerator ids;
+ SBENodeLowering g{env,
+ slotMap,
+ ids,
+ _phaseManager.getMetadata(),
+ _phaseManager.getNodeToGroupPropsMap(),
+ _phaseManager.getRIDProjections(),
+ true /*randomScan*/};
+ auto sbePlan = g.optimize(abtTree);
+
+ // TODO: return errors instead of exceptions?
+ uassert(6624244, "Lowering failed", sbePlan != nullptr);
+ uassert(6624245, "Invalid slot map size", slotMap.size() == 1);
+
+ sbePlan->attachToOperationContext(_opCtx);
+ sbe::CompileCtx ctx(std::make_unique<sbe::RuntimeEnvironment>());
+ sbePlan->prepare(ctx);
+
+ std::vector<sbe::value::SlotAccessor*> accessors;
+ for (auto& [name, slot] : slotMap) {
+ accessors.emplace_back(sbePlan->getAccessor(ctx, slot));
+ }
+
+ sbePlan->open(false);
+ ON_BLOCK_EXIT([&] { sbePlan->close(); });
+
+ while (sbePlan->getNext() != sbe::PlanState::IS_EOF) {
+ const auto [tag, value] = accessors.at(0)->getViewOfValue();
+ if (tag == sbe::value::TypeTags::NumberInt64) {
+ // TODO: check if we get exactly one result from the groupby?
+ return {true, static_cast<double>(value) / _sampleSize};
+ }
+ return {false, {}};
+ };
+
+ // If nothing passes the filter, estimate 0.0 selectivity. HashGroup will return 0 results.
+ return {true, 0.0};
+ }
+
+ struct NodeRefHash {
+ size_t operator()(const ABT& node) const {
+ return ABTHashGenerator::generate(node);
+ }
+ };
+
+ struct NodeRefCompare {
+ bool operator()(const ABT& left, const ABT& right) const {
+ return left == right;
+ }
+ };
+
+ // Cache a logical node reference to computed selectivity. Used for Filter and Sargable nodes.
+ opt::unordered_map<ABT, SelectivityType, NodeRefHash, NodeRefCompare> _selectivityCacheMap;
+
+ // We don't own those.
+ OperationContext* _opCtx;
+ OptPhaseManager& _phaseManager;
+
+ HeuristicCE _heuristicCE;
+ const int64_t _sampleSize;
+};
+
+CESamplingTransport::CESamplingTransport(OperationContext* opCtx,
+ OptPhaseManager& phaseManager,
+ const int64_t numRecords)
+ : _impl(std::make_unique<CESamplingTransportImpl>(opCtx, phaseManager, numRecords)) {}
+
+CESamplingTransport::~CESamplingTransport() {}
+
+CEType CESamplingTransport::deriveCE(const Memo& memo,
+ const LogicalProps& logicalProps,
+ const ABT::reference_type logicalNodeRef) const {
+ return _impl->derive(memo, logicalProps, logicalNodeRef);
+}
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/ce/ce_sampling.h b/src/mongo/db/query/ce/ce_sampling.h
new file mode 100644
index 00000000000..68eada866a1
--- /dev/null
+++ b/src/mongo/db/query/ce/ce_sampling.h
@@ -0,0 +1,52 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/cascades/interfaces.h"
+#include "mongo/db/query/optimizer/opt_phase_manager.h"
+
+namespace mongo::optimizer::cascades {
+
+class CESamplingTransportImpl;
+
+class CESamplingTransport : public CEInterface {
+public:
+ CESamplingTransport(OperationContext* opCtx, OptPhaseManager& phaseManager, int64_t numRecords);
+ ~CESamplingTransport();
+
+ CEType deriveCE(const Memo& memo,
+ const properties::LogicalProps& logicalProps,
+ ABT::reference_type logicalNodeRef) const final;
+
+private:
+ std::unique_ptr<CESamplingTransportImpl> _impl;
+};
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/get_executor.cpp b/src/mongo/db/query/get_executor.cpp
index a1a0716d29e..cd61dfefe7c 100644
--- a/src/mongo/db/query/get_executor.cpp
+++ b/src/mongo/db/query/get_executor.cpp
@@ -1318,6 +1318,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getSlotBasedExe
std::move(cq),
std::move(solutions[0]),
std::move(roots[0]),
+ {},
mainColl,
plannerOptions,
std::move(nss),
diff --git a/src/mongo/db/query/optimizer/SConscript b/src/mongo/db/query/optimizer/SConscript
new file mode 100644
index 00000000000..8843d549d8e
--- /dev/null
+++ b/src/mongo/db/query/optimizer/SConscript
@@ -0,0 +1,76 @@
+# -*- mode: python -*-
+
+Import("env")
+
+env = env.Clone()
+
+env.SConscript(
+ dirs=[
+ "algebra",
+ ],
+ exports=[
+ 'env'
+ ],
+)
+
+env.Library(
+ target="optimizer",
+ source=[
+ "cascades/ce_heuristic.cpp",
+ "cascades/ce_hinted.cpp",
+ "cascades/cost_derivation.cpp",
+ "cascades/enforcers.cpp",
+ "cascades/implementers.cpp",
+ "cascades/logical_props_derivation.cpp",
+ "cascades/logical_rewriter.cpp",
+ "cascades/memo.cpp",
+ "cascades/physical_rewriter.cpp",
+ "cascades/rewrite_queues.cpp",
+ "defs.cpp",
+ "explain.cpp",
+ "index_bounds.cpp",
+ "metadata.cpp",
+ "node.cpp",
+ "opt_phase_manager.cpp",
+ "props.cpp",
+ "reference_tracker.cpp",
+ "rewrites/const_eval.cpp",
+ "rewrites/path.cpp",
+ "rewrites/path_lower.cpp",
+ "syntax/expr.cpp",
+ "utils/abt_hash.cpp",
+ "utils/interval_utils.cpp",
+ "utils/memo_utils.cpp",
+ "utils/utils.cpp"
+ ],
+ LIBDEPS=[
+ "$BUILD_DIR/mongo/db/exec/sbe/query_sbe_values",
+ ],
+)
+
+env.Library(
+ target="unit_test_utils",
+ source=[
+ "utils/unit_test_utils.cpp",
+ ],
+ LIBDEPS=[
+ "$BUILD_DIR/mongo/db/pipeline/pipeline",
+ "$BUILD_DIR/mongo/db/query/query_test_service_context",
+ "$BUILD_DIR/mongo/unittest/unittest",
+ ],
+)
+
+env.CppUnitTest(
+ target='optimizer_test',
+ source=[
+ "logical_rewriter_optimizer_test.cpp",
+ "optimizer_test.cpp",
+ "physical_rewriter_optimizer_test.cpp",
+ "rewrites/path_optimizer_test.cpp",
+ "interval_intersection_test.cpp",
+ ],
+ LIBDEPS=[
+ "optimizer",
+ "unit_test_utils",
+ ]
+)
diff --git a/src/mongo/db/query/optimizer/algebra/SConscript b/src/mongo/db/query/optimizer/algebra/SConscript
new file mode 100644
index 00000000000..0d2a48c24d3
--- /dev/null
+++ b/src/mongo/db/query/optimizer/algebra/SConscript
@@ -0,0 +1,15 @@
+# -*- mode: python -*-
+
+Import("env")
+
+env = env.Clone()
+
+env.CppUnitTest(
+ target='algebra_test',
+ source=[
+ 'algebra_test.cpp',
+ ],
+ LIBDEPS=[
+
+ ]
+)
diff --git a/src/mongo/db/query/optimizer/algebra/algebra_test.cpp b/src/mongo/db/query/optimizer/algebra/algebra_test.cpp
new file mode 100644
index 00000000000..48e668a6e32
--- /dev/null
+++ b/src/mongo/db/query/optimizer/algebra/algebra_test.cpp
@@ -0,0 +1,570 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/algebra/operator.h"
+#include "mongo/db/query/optimizer/algebra/polyvalue.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo::optimizer::algebra {
+
+namespace {
+
+class Leaf;
+class BinaryNode;
+class NaryNode;
+class AtLeastBinaryNode;
+using Tree = PolyValue<Leaf, BinaryNode, NaryNode, AtLeastBinaryNode>;
+
+/**
+ * A leaf in the tree. Just contains data - in this case a double.
+ */
+class Leaf : public OpSpecificArity<Tree, Leaf, 0> {
+public:
+ Leaf(double x) : x(x) {}
+ double x;
+};
+
+/**
+ * An inner node in the tree with exactly two children.
+ */
+class BinaryNode : public OpSpecificArity<Tree, BinaryNode, 2> {
+public:
+ BinaryNode(Tree left, Tree right)
+ : OpSpecificArity<Tree, BinaryNode, 2>(std::move(left), std::move(right)) {}
+};
+
+/**
+ * An inner node in the tree with any number of children, zero or greater.
+ */
+class NaryNode : public OpSpecificDynamicArity<Tree, NaryNode, 0> {
+public:
+ NaryNode(std::vector<Tree> children)
+ : OpSpecificDynamicArity<Tree, NaryNode, 0>(std::move(children)) {}
+};
+
+/**
+ * An inner node in the tree with 2 or more nodes.
+ */
+class AtLeastBinaryNode : public OpSpecificDynamicArity<Tree, AtLeastBinaryNode, 2> {
+public:
+ /**
+ * Notice the required number of nodes are given as separate arguments from the vector.
+ */
+ AtLeastBinaryNode(std::vector<Tree> children, Tree left, Tree right)
+ : OpSpecificDynamicArity<Tree, AtLeastBinaryNode, 2>(
+ std::move(children), std::move(left), std::move(right)) {}
+};
+
+/**
+ * A visitor of the tree with methods to visit each kind of node.
+ *
+ * This is a very basic visitor to just demonstrate the transport() API - all it does is sum up
+ * doubles in the leaf nodes of the tree.
+ *
+ * Notice that each kind of node did not need to fill out some boilerplate "visit()" method or
+ * anything like that. The PolyValue templating magic took care of all the boilerplate for us, and
+ * the operator classes (e.g. OpSpecificArity) exposes the tree structure and children.
+ */
+class NodeTransporter {
+public:
+ double transport(Leaf& leaf) {
+ return leaf.x;
+ }
+ double transport(BinaryNode& node, double child0, double child1) {
+ return child0 + child1;
+ }
+ double transport(NaryNode& node, std::vector<double> children) {
+ return std::accumulate(children.begin(), children.end(), 0.0);
+ }
+ double transport(AtLeastBinaryNode& node,
+ std::vector<double> children,
+ double child0,
+ double child1) {
+ return child0 + child1 + std::accumulate(children.begin(), children.end(), 0.0);
+ }
+};
+
+/**
+ * A visitor of the tree with methods to visit each kind of node. This visitor also takes a
+ * reference to the Tree itself. Unused here, this reference can be used to mutate or replace the
+ * node itself while the walking takes place.
+ */
+class TreeTransporter {
+public:
+ double transport(Tree& tree, Leaf& leaf) {
+ return leaf.x;
+ }
+ double transport(Tree& tree, BinaryNode& node, double child0, double child1) {
+ return child0 + child1;
+ }
+ double transport(Tree& tree, NaryNode& node, std::vector<double> children) {
+ return std::accumulate(children.begin(), children.end(), 0.0);
+ }
+ double transport(Tree& tree,
+ AtLeastBinaryNode& node,
+ std::vector<double> children,
+ double child0,
+ double child1) {
+ return child0 + child1 + std::accumulate(children.begin(), children.end(), 0.0);
+ }
+};
+
+TEST(PolyValueTest, SumTransportFixedArity) {
+ NodeTransporter nodeTransporter;
+ TreeTransporter treeTransporter;
+ {
+ Tree simple = Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), Tree::make<Leaf>(1.0));
+ // Notice the template parameter true or false matches whether the walker expects to have a
+ // Tree& parameter first in the transport implementations.
+ double result = transport<false>(simple, nodeTransporter);
+ ASSERT_EQ(result, 3.0);
+ // This 'true' template means we expect the 'Tree&' argument to come first in all the
+ // 'transport()' implementations.
+ result = transport<true>(simple, treeTransporter);
+ ASSERT_EQ(result, 3.0);
+ }
+
+ {
+ Tree deeper = Tree::make<BinaryNode>(
+ Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), Tree::make<Leaf>(1.0)),
+ Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), Tree::make<Leaf>(1.0)));
+ double result = transport<false>(deeper, nodeTransporter);
+ ASSERT_EQ(result, 6.0);
+ result = transport<true>(deeper, treeTransporter);
+ ASSERT_EQ(result, 6.0);
+ }
+}
+
+/**
+ * Prove out that the walking/visiting can hit the variadic NaryNode.
+ */
+TEST(PolyValueTest, SumTransportVariadic) {
+ NodeTransporter nodeTransporter;
+ TreeTransporter treeTransporter;
+ Tree naryDemoTree = Tree::make<NaryNode>(
+ std::vector<Tree>{Tree::make<Leaf>(6.0),
+ Tree::make<Leaf>(5.0),
+ Tree::make<NaryNode>(std::vector<Tree>{
+ Tree::make<Leaf>(4.0), Tree::make<Leaf>(3.0), Tree::make<Leaf>(2.0)}),
+ Tree::make<Leaf>(1.0)});
+
+ double result = transport<false>(naryDemoTree, nodeTransporter);
+ ASSERT_EQ(result, 21.0);
+ result = transport<true>(naryDemoTree, treeTransporter);
+ ASSERT_EQ(result, 21.0);
+}
+
+TEST(PolyValueTest, SumTransportAtLeast2Children) {
+ NodeTransporter nodeTransporter;
+ TreeTransporter treeTransporter;
+ Tree demoTree = Tree::make<AtLeastBinaryNode>(
+ std::vector<Tree>{Tree::make<Leaf>(7.0), Tree::make<Leaf>(6.0)},
+ Tree::make<Leaf>(5.0),
+ Tree::make<AtLeastBinaryNode>(
+ std::vector<Tree>{Tree::make<Leaf>(4.0), Tree::make<Leaf>(3.0)},
+ Tree::make<Leaf>(2.0),
+ Tree::make<Leaf>(1.0)));
+ double result = transport<false>(demoTree, nodeTransporter);
+ ASSERT_EQ(result, 28.0);
+ result = transport<true>(demoTree, treeTransporter);
+ ASSERT_EQ(result, 28.0);
+}
+
+/**
+ * A visitor of the tree like those above but which takes const references so is forbidden from
+ * modifying the tree or nodes.
+ *
+ * This visitor creates a copy of the tree but with the values at the leaves doubled.
+ */
+class ConstTransporterCopyAndDouble {
+public:
+ Tree transport(const Leaf& leaf) {
+ return Tree::make<Leaf>(2 * leaf.x);
+ }
+ Tree transport(const BinaryNode& node, Tree child0, Tree child1) {
+ return Tree::make<BinaryNode>(std::move(child0), std::move(child1));
+ }
+ Tree transport(const NaryNode& node, std::vector<Tree> children) {
+ return Tree::make<NaryNode>(std::move(children));
+ }
+ Tree transport(const AtLeastBinaryNode& node,
+ std::vector<Tree> children,
+ Tree child0,
+ Tree child1) {
+ return Tree::make<AtLeastBinaryNode>(
+ std::move(children), std::move(child0), std::move(child1));
+ }
+
+ // Add all the same walkers with the optional 'tree' argument. Note this is also const.
+ Tree transport(const Tree& tree, const Leaf& leaf) {
+ return Tree::make<Leaf>(2 * leaf.x);
+ }
+ Tree transport(const Tree& tree, const BinaryNode& node, Tree child0, Tree child1) {
+ return Tree::make<BinaryNode>(std::move(child0), std::move(child1));
+ }
+ Tree transport(const Tree& tree, const NaryNode& node, std::vector<Tree> children) {
+ return Tree::make<NaryNode>(std::move(children));
+ }
+ Tree transport(const Tree& tree,
+ const AtLeastBinaryNode& node,
+ std::vector<Tree> children,
+ Tree child0,
+ Tree child1) {
+ return Tree::make<AtLeastBinaryNode>(
+ std::move(children), std::move(child0), std::move(child1));
+ }
+};
+
+TEST(PolyValueTest, CopyAndDoubleTreeConst) {
+ // Test that we can create a copy of a tree and walk with a const transporter to provide extra
+ // proof that it's actually a deep copy.
+ ConstTransporterCopyAndDouble transporter;
+ {
+ const Tree simple = Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), Tree::make<Leaf>(1.0));
+ // Notice 'simple' is const.
+ Tree result = transport<false>(simple, transporter);
+ BinaryNode* newRoot = result.cast<BinaryNode>();
+ ASSERT(newRoot);
+ Leaf* newLeafLeft = newRoot->get<0>().cast<Leaf>();
+ ASSERT(newLeafLeft);
+ ASSERT_EQ(newLeafLeft->x, 4.0);
+
+ Leaf* newLeafRight = newRoot->get<1>().cast<Leaf>();
+ ASSERT(newLeafRight);
+ ASSERT_EQ(newLeafRight->x, 2.0);
+ }
+ {
+ // Do the same test but walk with the tree reference (pass 'true' to transport).
+ const Tree simple = Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), Tree::make<Leaf>(1.0));
+ // Notice 'simple' is const.
+ Tree result = transport<true>(simple, transporter);
+ BinaryNode* newRoot = result.cast<BinaryNode>();
+ ASSERT(newRoot);
+ Leaf* newLeafLeft = newRoot->get<0>().cast<Leaf>();
+ ASSERT(newLeafLeft);
+ ASSERT_EQ(newLeafLeft->x, 4.0);
+
+ Leaf* newLeafRight = newRoot->get<1>().cast<Leaf>();
+ ASSERT(newLeafRight);
+ ASSERT_EQ(newLeafRight->x, 2.0);
+ }
+}
+
+/**
+ * A walker which accumulates all nodes into a std::set to demonstrate which nodes are visited.
+ *
+ * The order of the visitation is not guaranteed, except that we visit "bottom-up" so leaves must
+ * happen before parents. This much must be true since the API to visit a node depends on the
+ * results of its children being pre-computed.
+ */
+class AccumulateToSetTransporter {
+public:
+ std::set<double> transport(Leaf& leaf) {
+ return {leaf.x};
+ }
+
+ std::set<double> transport(BinaryNode& node,
+ std::set<double> visitedChild0,
+ std::set<double> visitedChild1) {
+ // 'visistedChild0' and 'visitedChild1' represent the accumulated results of their visited
+ // numbers. Here we just merge the two.
+ std::set<double> merged;
+ std::merge(visitedChild0.begin(),
+ visitedChild0.end(),
+ visitedChild1.begin(),
+ visitedChild1.end(),
+ std::inserter(merged, merged.begin()));
+ return merged;
+ }
+
+ std::set<double> transport(NaryNode& node, std::vector<std::set<double>> childrenVisitedSets) {
+ return std::accumulate(childrenVisitedSets.begin(),
+ childrenVisitedSets.end(),
+ std::set<double>{},
+ [](auto&& visited1, auto&& visited2) {
+ std::set<double> merged;
+ std::merge(visited1.begin(),
+ visited1.end(),
+ visited2.begin(),
+ visited2.end(),
+ std::inserter(merged, merged.begin()));
+ return merged;
+ });
+ }
+
+ std::set<double> transport(AtLeastBinaryNode& node,
+ std::vector<std::set<double>> childrenVisitedSets,
+ std::set<double> visitedChild0,
+ std::set<double> visitedChild1) {
+ std::set<double> merged;
+ std::merge(visitedChild0.begin(),
+ visitedChild0.end(),
+ visitedChild1.begin(),
+ visitedChild1.end(),
+ std::inserter(merged, merged.begin()));
+
+ return std::accumulate(childrenVisitedSets.begin(),
+ childrenVisitedSets.end(),
+ merged,
+ [](auto&& visited1, auto&& visited2) {
+ std::set<double> merged;
+ std::merge(visited1.begin(),
+ visited1.end(),
+ visited2.begin(),
+ visited2.end(),
+ std::inserter(merged, merged.begin()));
+ return merged;
+ });
+ }
+};
+
+/**
+ * Here we see a test which walks all the various types of nodes at once, and in this case
+ * accumulates into a std::set any visited leaves.
+ */
+TEST(PolyValueTest, AccumulateAllDoubles) {
+ AccumulateToSetTransporter nodeTransporter;
+
+ {
+ Tree simple = Tree::make<AtLeastBinaryNode>(
+ std::vector<Tree>{Tree::make<Leaf>(4.0)},
+ Tree::make<Leaf>(3.0),
+ Tree::make<BinaryNode>(Tree::make<Leaf>(2.0),
+ Tree::make<NaryNode>(std::vector<Tree>{Tree::make<Leaf>(1.0)})));
+ std::set<double> result = transport<false>(simple, nodeTransporter);
+ ASSERT_EQ(result.size(), 4UL);
+ ASSERT_EQ(result.count(1.0), 1UL);
+ ASSERT_EQ(result.count(2.0), 1UL);
+ ASSERT_EQ(result.count(3.0), 1UL);
+ ASSERT_EQ(result.count(4.0), 1UL);
+ }
+ {
+ Tree complex = Tree::make<AtLeastBinaryNode>(
+ std::vector<Tree>{Tree::make<Leaf>(1.0), Tree::make<Leaf>(2.0)},
+ Tree::make<Leaf>(3.0),
+ Tree::make<BinaryNode>(
+ Tree::make<Leaf>(4.0),
+ Tree::make<NaryNode>(std::vector<Tree>{
+ Tree::make<Leaf>(5.0),
+ Tree::make<BinaryNode>(Tree::make<Leaf>(6.0), Tree::make<Leaf>(7.0))})));
+ std::set<double> result = transport<false>(complex, nodeTransporter);
+ ASSERT_EQ(result.size(), 7UL);
+ ASSERT_EQ(result.count(1.0), 1UL);
+ ASSERT_EQ(result.count(2.0), 1UL);
+ ASSERT_EQ(result.count(3.0), 1UL);
+ ASSERT_EQ(result.count(4.0), 1UL);
+ ASSERT_EQ(result.count(5.0), 1UL);
+ ASSERT_EQ(result.count(6.0), 1UL);
+ ASSERT_EQ(result.count(7.0), 1UL);
+ }
+}
+
+
+/**
+ * A walker which accepts an extra 'multiplier' argument to each transport call.
+ */
+class NodeTransporterWithExtraArg {
+public:
+ double transport(Leaf& leaf, double multiplier) {
+ return leaf.x * multiplier;
+ }
+ double transport(BinaryNode& node, double multiplier, double child0, double child1) {
+ return child0 +
+ child1; // No need to apply multiplier here, would be applied in the children already.
+ }
+ double transport(NaryNode& node, double multiplier, std::vector<double> children) {
+ return std::accumulate(children.begin(), children.end(), 0);
+ }
+ double transport(AtLeastBinaryNode& node,
+ double multiplier,
+ std::vector<double> children,
+ double child0,
+ double child1) {
+ return child0 + child1 + std::accumulate(children.begin(), children.end(), 0);
+ }
+};
+
+TEST(PolyValueTest, TransporterWithAnExtrArgument) {
+ NodeTransporterWithExtraArg nodeTransporter;
+
+ Tree simple = Tree::make<AtLeastBinaryNode>(
+ std::vector<Tree>{Tree::make<Leaf>(4.0)},
+ Tree::make<Leaf>(3.0),
+ Tree::make<BinaryNode>(Tree::make<Leaf>(2.0),
+ Tree::make<NaryNode>(std::vector<Tree>{Tree::make<Leaf>(1.0)})));
+ double result = transport<false>(simple, nodeTransporter, 2.0);
+ ASSERT_EQ(result, 20.0);
+}
+
+/**
+ * A simple walker which trackes whether it has seen a zero. While the task is simple, this walker
+ * demosntrates:
+ * - A walker with state attached ('iHaveSeenAZero'). Note it could be done without tracking state
+ * also.
+ * - The capability of 'transport' to return void.
+ * - You can add a templated 'transport()' to avoid needing to fill in each and every instantiation
+ * for the PolyValue.
+ */
+class TemplatedNodeTransporterWithContext {
+public:
+ bool iHaveSeenAZero = false;
+
+ void transport(Leaf& leaf) {
+ if (leaf.x == 0.0) {
+ iHaveSeenAZero = true;
+ }
+ }
+
+ /**
+ * Template to handle all other cases - we don't care or need to do anything here, so we knock
+ * out all the other required implementations at once with this template.
+ */
+ template <typename T, typename... Args>
+ void transport(T&& node, Args&&... args) {
+ return;
+ }
+};
+
+TEST(PolyValueTest, TransporterTrackingState) {
+ TemplatedNodeTransporterWithContext templatedNodeTransporter;
+
+ Tree noZero = Tree::make<AtLeastBinaryNode>(
+ std::vector<Tree>{Tree::make<Leaf>(4.0)},
+ Tree::make<Leaf>(3.0),
+ Tree::make<BinaryNode>(Tree::make<Leaf>(2.0),
+ Tree::make<NaryNode>(std::vector<Tree>{Tree::make<Leaf>(1.0)})));
+ transport<false>(noZero, templatedNodeTransporter);
+ ASSERT_EQ(templatedNodeTransporter.iHaveSeenAZero, false);
+
+ Tree yesZero = Tree::make<AtLeastBinaryNode>(
+ std::vector<Tree>{Tree::make<Leaf>(3.0)},
+ Tree::make<Leaf>(2.0),
+ Tree::make<BinaryNode>(Tree::make<Leaf>(1.0),
+ Tree::make<NaryNode>(std::vector<Tree>{Tree::make<Leaf>(0.0)})));
+ transport<false>(yesZero, templatedNodeTransporter);
+ ASSERT_EQ(templatedNodeTransporter.iHaveSeenAZero, true);
+}
+
+/**
+ * A walker demonstrating the 'prepare()' API which tracks the depth and weights things deeper in
+ * the tree at factors of 10 higher. So the top level is worth 1x, second level 10x, third level
+ * 100x, etc.
+ */
+class NodeTransporterTrackingDepth {
+ int _depthMultiplier = 1;
+
+public:
+ double transport(Leaf& leaf) {
+ return leaf.x * _depthMultiplier;
+ }
+
+ void prepare(Leaf&) {
+ // Noop. Just here to prevent from yet another 10x multiplication if we were to fall into
+ // the generic 'prepare()'.
+ }
+
+ /**
+ * 'prepare()' is called as we descend the tree before we walk/visit the children.
+ */
+ template <typename T, typename... Args>
+ void prepare(T&& node, Args&&... args) {
+ _depthMultiplier *= 10;
+ }
+
+ double transport(BinaryNode& node, double child0, double child1) {
+ _depthMultiplier /= 10;
+ return child0 + child1;
+ }
+ double transport(NaryNode& node, std::vector<double> children) {
+ _depthMultiplier /= 10;
+ return std::accumulate(children.begin(), children.end(), 0);
+ }
+ double transport(AtLeastBinaryNode& node,
+ std::vector<double> children,
+ double child0,
+ double child1) {
+ _depthMultiplier /= 10;
+ return child0 + child1 + std::accumulate(children.begin(), children.end(), 0);
+ }
+};
+
+TEST(PolyValueTest, TransporterUsingPrepare) {
+ NodeTransporterTrackingDepth nodeTransporter;
+
+ Tree demoTree = Tree::make<AtLeastBinaryNode>(
+ std::vector<Tree>{Tree::make<Leaf>(4.0)},
+ Tree::make<Leaf>(3.0),
+ Tree::make<BinaryNode>(Tree::make<Leaf>(2.0),
+ Tree::make<NaryNode>(std::vector<Tree>{Tree::make<Leaf>(1.0)})));
+ const double result = transport<false>(demoTree, nodeTransporter);
+ /*
+ demoTree
+ 1x level: root
+ / | \
+ 10x level: 4 3 binary
+ / \
+ 100x level: 2 nary
+ \
+ 1000x level: 1
+ */
+ ASSERT_EQ(result, 1270.0);
+}
+
+class NodeWalkerIsLeaf {
+public:
+ bool walk(Leaf& leaf) {
+ return true;
+ }
+
+ bool walk(BinaryNode& node, Tree& leftChild, Tree& rightChild) {
+ return false;
+ }
+
+ bool walk(AtLeastBinaryNode& node,
+ std::vector<Tree>& extraChildren,
+ Tree& leftChild,
+ Tree& rightChild) {
+ return false;
+ }
+
+ bool walk(NaryNode& node, std::vector<Tree>& children) {
+ return false;
+ }
+};
+
+TEST(PolyValueTest, WalkerBasic) {
+ NodeWalkerIsLeaf walker;
+ auto tree = Tree::make<BinaryNode>(Tree::make<Leaf>(1.0), Tree::make<Leaf>(2.0));
+ ASSERT(!walk<false>(tree, walker));
+ ASSERT(walk<false>(tree.cast<BinaryNode>()->get<0>(), walker));
+ ASSERT(walk<false>(tree.cast<BinaryNode>()->get<1>(), walker));
+}
+
+} // namespace
+} // namespace mongo::optimizer::algebra
diff --git a/src/mongo/db/query/optimizer/algebra/operator.h b/src/mongo/db/query/optimizer/algebra/operator.h
new file mode 100644
index 00000000000..fb6dbc4d474
--- /dev/null
+++ b/src/mongo/db/query/optimizer/algebra/operator.h
@@ -0,0 +1,341 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <vector>
+
+#include "mongo/db/query/optimizer/algebra/polyvalue.h"
+
+namespace mongo::optimizer {
+namespace algebra {
+
+template <typename T, int S>
+struct OpNodeStorage {
+ T _nodes[S];
+
+ template <typename... Ts>
+ OpNodeStorage(Ts&&... vals) : _nodes{std::forward<Ts>(vals)...} {}
+};
+
+template <typename T>
+struct OpNodeStorage<T, 0> {};
+
+/*=====-----
+ *
+ * Arity of operator can be:
+ * 1. statically known - A, A, A, ...
+ * 2. dynamic prefix with optional statically know - vector<A>, A, A, A, ...
+ *
+ * Denotations map A to some B.
+ * So static arity <A,A,A> is mapped to <B,B,B>.
+ * Similarly, arity <vector<A>,A> is mapped to <vector<B>,B>
+ *
+ * There is a wrinkle when B is a reference (if allowed at all)
+ * Arity <vector<A>, A, A> is mapped to <vector<B>&, B&, B&> - note that the reference is lifted
+ * outside of the vector.
+ *
+ */
+template <typename Slot, typename Derived, int Arity>
+class OpSpecificArity : public OpNodeStorage<Slot, Arity> {
+ using Base = OpNodeStorage<Slot, Arity>;
+
+public:
+ template <typename... Ts>
+ OpSpecificArity(Ts&&... vals) : Base({std::forward<Ts>(vals)...}) {
+ static_assert(sizeof...(Ts) == Arity, "constructor paramaters do not match");
+ }
+
+ template <int I, std::enable_if_t<(I >= 0 && I < Arity), int> = 0>
+ auto& get() noexcept {
+ return this->_nodes[I];
+ }
+
+ template <int I, std::enable_if_t<(I >= 0 && I < Arity), int> = 0>
+ const auto& get() const noexcept {
+ return this->_nodes[I];
+ }
+};
+/*=====-----
+ *
+ * Operator with dynamic arity
+ *
+ */
+template <typename Slot, typename Derived, int Arity>
+class OpSpecificDynamicArity : public OpSpecificArity<Slot, Derived, Arity> {
+ using Base = OpSpecificArity<Slot, Derived, Arity>;
+
+ std::vector<Slot> _dyNodes;
+
+public:
+ template <typename... Ts>
+ OpSpecificDynamicArity(std::vector<Slot>&& nodes, Ts&&... vals)
+ : Base({std::forward<Ts>(vals)...}), _dyNodes(std::move(nodes)) {}
+
+ auto& nodes() {
+ return _dyNodes;
+ }
+ const auto& nodes() const {
+ return _dyNodes;
+ }
+};
+
+/*=====-----
+ *
+ * Semantic transport interface
+ *
+ */
+namespace detail {
+template <typename D, typename T, typename... Args>
+using call_prepare_t =
+ decltype(std::declval<D>().prepare(std::declval<T&>(), std::declval<Args>()...));
+
+template <typename N, typename D, typename T, typename... Args>
+using call_prepare_slot_t = decltype(
+ std::declval<D>().prepare(std::declval<N&>(), std::declval<T&>(), std::declval<Args>()...));
+
+template <typename Void, template <class...> class Op, class... Args>
+struct has_prepare : std::false_type {};
+
+template <template <class...> class Op, class... Args>
+struct has_prepare<std::void_t<Op<Args...>>, Op, Args...> : std::true_type {};
+
+template <bool withSlot, typename N, typename D, typename T, typename... Args>
+inline constexpr auto has_prepare_v =
+ std::conditional_t<withSlot,
+ has_prepare<void, call_prepare_slot_t, N, D, T, Args...>,
+ has_prepare<void, call_prepare_t, D, T, Args...>>::value;
+
+template <typename Slot, typename Derived, int Arity>
+inline constexpr int get_arity(const OpSpecificArity<Slot, Derived, Arity>*) {
+ return Arity;
+}
+
+template <typename Slot, typename Derived, int Arity>
+inline constexpr bool is_dynamic(const OpSpecificArity<Slot, Derived, Arity>*) {
+ return false;
+}
+
+template <typename Slot, typename Derived, int Arity>
+inline constexpr bool is_dynamic(const OpSpecificDynamicArity<Slot, Derived, Arity>*) {
+ return true;
+}
+
+template <typename T>
+using OpConcreteType = typename std::remove_reference_t<T>::template get_t<0>;
+} // namespace detail
+
+template <typename D, bool withSlot>
+class OpTransporter {
+ D& _domain;
+
+ template <typename T, bool B, typename... Args>
+ struct Deducer {};
+ template <typename T, typename... Args>
+ struct Deducer<T, true, Args...> {
+ using type =
+ decltype(std::declval<D>().transport(std::declval<T>(),
+ std::declval<detail::OpConcreteType<T>&>(),
+ std::declval<Args>()...));
+ };
+ template <typename T, typename... Args>
+ struct Deducer<T, false, Args...> {
+ using type = decltype(std::declval<D>().transport(
+ std::declval<detail::OpConcreteType<T>&>(), std::declval<Args>()...));
+ };
+ template <typename T, typename... Args>
+ using deduced_t = typename Deducer<T, withSlot, Args...>::type;
+
+ template <typename N, typename T, typename... Ts>
+ auto transformStep(N&& slot, T&& op, Ts&&... args) {
+ if constexpr (withSlot) {
+ return _domain.transport(
+ std::forward<N>(slot), std::forward<T>(op), std::forward<Ts>(args)...);
+ } else {
+ return _domain.transport(std::forward<T>(op), std::forward<Ts>(args)...);
+ }
+ }
+
+ template <typename N, typename T, typename... Args, size_t... I>
+ auto transportUnpack(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) {
+ return transformStep(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::forward<Args>(args)...,
+ op.template get<I>().visit(*this, std::forward<Args>(args)...)...);
+ }
+ template <typename N, typename T, typename... Args, size_t... I>
+ auto transportDynamicUnpack(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) {
+ std::vector<decltype(slot.visit(*this, std::forward<Args>(args)...))> v;
+ for (auto& node : op.nodes()) {
+ v.emplace_back(node.visit(*this, std::forward<Args>(args)...));
+ }
+ return transformStep(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::forward<Args>(args)...,
+ std::move(v),
+ op.template get<I>().visit(*this, std::forward<Args>(args)...)...);
+ }
+ template <typename N, typename T, typename... Args, size_t... I>
+ void transportUnpackVoid(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) {
+ (op.template get<I>().visit(*this, std::forward<Args>(args)...), ...);
+ return transformStep(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::forward<Args>(args)...,
+ op.template get<I>()...);
+ }
+ template <typename N, typename T, typename... Args, size_t... I>
+ void transportDynamicUnpackVoid(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) {
+ for (auto& node : op.nodes()) {
+ node.visit(*this, std::forward<Args>(args)...);
+ }
+ (op.template get<I>().visit(*this, std::forward<Args>(args)...), ...);
+ return transformStep(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::forward<Args>(args)...,
+ op.nodes(),
+ op.template get<I>()...);
+ }
+
+public:
+ OpTransporter(D& domain) : _domain(domain) {}
+
+ template <typename N, typename T, typename... Args, typename R = deduced_t<N, Args...>>
+ R operator()(N&& slot, T&& op, Args&&... args) {
+ // N is either `PolyValue<Ts...>&` or `const PolyValue<Ts...>&` i.e. reference
+ // T is either `A&` or `const A&` where A is one of Ts
+ using type = std::remove_reference_t<T>;
+
+ constexpr int arity = detail::get_arity(static_cast<type*>(nullptr));
+ constexpr bool is_dynamic = detail::is_dynamic(static_cast<type*>(nullptr));
+
+ if constexpr (detail::has_prepare_v<withSlot, N, D, type, Args...>) {
+ if constexpr (withSlot) {
+ _domain.prepare(
+ std::forward<N>(slot), std::forward<T>(op), std::forward<Args>(args)...);
+ } else {
+ _domain.prepare(std::forward<T>(op), std::forward<Args>(args)...);
+ }
+ }
+
+ if constexpr (is_dynamic) {
+ if constexpr (std::is_same_v<R, void>) {
+ return transportDynamicUnpackVoid(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::make_index_sequence<arity>{},
+ std::forward<Args>(args)...);
+ } else {
+ return transportDynamicUnpack(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::make_index_sequence<arity>{},
+ std::forward<Args>(args)...);
+ }
+ } else {
+ if constexpr (std::is_same_v<R, void>) {
+ return transportUnpackVoid(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::make_index_sequence<arity>{},
+ std::forward<Args>(args)...);
+ } else {
+ return transportUnpack(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::make_index_sequence<arity>{},
+ std::forward<Args>(args)...);
+ }
+ }
+ }
+};
+
+template <typename D, bool withSlot>
+class OpWalker {
+ D& _domain;
+
+ template <typename N, typename T, typename... Ts>
+ auto walkStep(N&& slot, T&& op, Ts&&... args) {
+ if constexpr (withSlot) {
+ return _domain.walk(
+ std::forward<N>(slot), std::forward<T>(op), std::forward<Ts>(args)...);
+ } else {
+ return _domain.walk(std::forward<T>(op), std::forward<Ts>(args)...);
+ }
+ }
+
+ template <typename N, typename T, typename... Args, size_t... I>
+ auto walkUnpack(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) {
+ return walkStep(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::forward<Args>(args)...,
+ op.template get<I>()...);
+ }
+ template <typename N, typename T, typename... Args, size_t... I>
+ auto walkDynamicUnpack(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) {
+ return walkStep(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::forward<Args>(args)...,
+ op.nodes(),
+ op.template get<I>()...);
+ }
+
+public:
+ OpWalker(D& domain) : _domain(domain) {}
+
+ template <typename N, typename T, typename... Args>
+ auto operator()(N&& slot, T&& op, Args&&... args) {
+ // N is either `PolyValue<Ts...>&` or `const PolyValue<Ts...>&` i.e. reference
+ // T is either `A&` or `const A&` where A is one of Ts
+ using type = std::remove_reference_t<T>;
+
+ constexpr int arity = detail::get_arity(static_cast<type*>(nullptr));
+ constexpr bool is_dynamic = detail::is_dynamic(static_cast<type*>(nullptr));
+
+ if constexpr (is_dynamic) {
+ return walkDynamicUnpack(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::make_index_sequence<arity>{},
+ std::forward<Args>(args)...);
+ } else {
+ return walkUnpack(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::make_index_sequence<arity>{},
+ std::forward<Args>(args)...);
+ }
+ }
+};
+
+template <bool withSlot = false, typename D, typename N, typename... Args>
+auto transport(N&& node, D& domain, Args&&... args) {
+ return node.visit(OpTransporter<D, withSlot>{domain}, std::forward<Args>(args)...);
+}
+
+template <bool withSlot = false, typename D, typename N, typename... Args>
+auto walk(N&& node, D& domain, Args&&... args) {
+ return node.visit(OpWalker<D, withSlot>{domain}, std::forward<Args>(args)...);
+}
+
+} // namespace algebra
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/algebra/polyvalue.h b/src/mongo/db/query/optimizer/algebra/polyvalue.h
new file mode 100644
index 00000000000..63f5965c50c
--- /dev/null
+++ b/src/mongo/db/query/optimizer/algebra/polyvalue.h
@@ -0,0 +1,541 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <array>
+#include <stdexcept>
+#include <type_traits>
+
+namespace mongo::optimizer {
+namespace algebra {
+namespace detail {
+
+template <typename T, typename... Args>
+inline constexpr bool is_one_of_v = std::disjunction_v<std::is_same<T, Args>...>;
+
+template <typename T, typename... Args>
+inline constexpr bool is_one_of_f() {
+ return is_one_of_v<T, Args...>;
+}
+
+template <typename... Args>
+struct is_unique_t : std::true_type {};
+
+template <typename H, typename... T>
+struct is_unique_t<H, T...>
+ : std::bool_constant<!is_one_of_f<H, T...>() && is_unique_t<T...>::value> {};
+
+template <typename... Args>
+inline constexpr bool is_unique_v = is_unique_t<Args...>::value;
+
+// Given the type T find its index in Ts
+template <typename T, typename... Ts>
+static inline constexpr int find_index() {
+ static_assert(detail::is_unique_v<Ts...>, "Types must be unique");
+ constexpr bool matchVector[] = {std::is_same<T, Ts>::value...};
+
+ for (int index = 0; index < static_cast<int>(sizeof...(Ts)); ++index) {
+ if (matchVector[index]) {
+ return index;
+ }
+ }
+
+ return -1;
+}
+
+template <int N, typename T, typename... Ts>
+struct get_type_by_index_impl {
+ using type = typename get_type_by_index_impl<N - 1, Ts...>::type;
+};
+template <typename T, typename... Ts>
+struct get_type_by_index_impl<0, T, Ts...> {
+ using type = T;
+};
+
+// Given the index I return the type from Ts
+template <int I, typename... Ts>
+using get_type_by_index = typename get_type_by_index_impl<I, Ts...>::type;
+
+} // namespace detail
+
+/*=====-----
+ *
+ * The overload trick to construct visitors from lambdas.
+ *
+ */
+template <class... Ts>
+struct overload : Ts... {
+ using Ts::operator()...;
+};
+template <class... Ts>
+overload(Ts...)->overload<Ts...>;
+
+/*=====-----
+ *
+ * Forward declarations
+ *
+ */
+template <typename... Ts>
+class PolyValue;
+
+template <typename T, typename... Ts>
+class ControlBlockVTable;
+
+/*=====-----
+ *
+ * The base control block that PolyValue holds.
+ *
+ * It does not contain anything else by the runtime tag.
+ *
+ */
+template <typename... Ts>
+class ControlBlock {
+ const int _tag;
+
+protected:
+ ControlBlock(int tag) noexcept : _tag(tag) {}
+
+public:
+ auto getRuntimeTag() const noexcept {
+ return _tag;
+ }
+};
+
+/*=====-----
+ *
+ * The concrete control block VTable generator.
+ *
+ * It must be empty ad PolyValue derives from the generators
+ * and we want EBO to kick in.
+ *
+ */
+template <typename T, typename... Ts>
+class ControlBlockVTable {
+protected:
+ static constexpr int _staticTag = detail::find_index<T, Ts...>();
+ static_assert(_staticTag != -1, "Type must be on the list");
+
+ using AbstractType = ControlBlock<Ts...>;
+
+ /*=====-----
+ *
+ * The concrete control block for every type T of Ts.
+ *
+ * It derives from the ControlBlock. All methods are private and only
+ * the friend class ControlBlockVTable can call them.
+ *
+ */
+ class ConcreteType : public AbstractType {
+ T _t;
+
+ public:
+ template <typename... Args>
+ ConcreteType(Args&&... args) : AbstractType(_staticTag), _t(std::forward<Args>(args)...) {}
+
+ const T* getPtr() const noexcept {
+ return &_t;
+ }
+
+ T* getPtr() noexcept {
+ return &_t;
+ }
+ };
+
+ static constexpr auto concrete(AbstractType* block) noexcept {
+ return static_cast<ConcreteType*>(block);
+ }
+
+ static constexpr auto concrete(const AbstractType* block) noexcept {
+ return static_cast<const ConcreteType*>(block);
+ }
+
+public:
+ template <typename... Args>
+ static AbstractType* make(Args&&... args) {
+ return new ConcreteType(std::forward<Args>(args)...);
+ }
+
+ static AbstractType* clone(const AbstractType* block) {
+ return new ConcreteType(*concrete(block));
+ }
+
+ static void destroy(AbstractType* block) noexcept {
+ delete concrete(block);
+ }
+
+ static bool compareEq(AbstractType* blockLhs, AbstractType* blockRhs) noexcept {
+ if (blockLhs->getRuntimeTag() == blockRhs->getRuntimeTag()) {
+ return *castConst<T>(blockLhs) == *castConst<T>(blockRhs);
+ }
+ return false;
+ }
+
+ template <typename U>
+ static constexpr bool is_v = std::is_base_of_v<U, T>;
+
+ template <typename U>
+ static U* cast(AbstractType* block) noexcept {
+ if constexpr (is_v<U>) {
+ return static_cast<U*>(concrete(block)->getPtr());
+ } else {
+ // gcc bug 81676
+ (void)block;
+ return nullptr;
+ }
+ }
+
+ template <typename U>
+ static const U* castConst(const AbstractType* block) noexcept {
+ if constexpr (is_v<U>) {
+ return static_cast<const U*>(concrete(block)->getPtr());
+ } else {
+ // gcc bug 81676
+ (void)block;
+ return nullptr;
+ }
+ }
+
+ template <typename V, typename N, typename... Args>
+ static auto visit(V&& v, N& holder, AbstractType* block, Args&&... args) {
+ return v(holder, *cast<T>(block), std::forward<Args>(args)...);
+ }
+
+ template <typename V, typename N, typename... Args>
+ static auto visitConst(V&& v, const N& holder, const AbstractType* block, Args&&... args) {
+ return v(holder, *castConst<T>(block), std::forward<Args>(args)...);
+ }
+};
+
+/*=====-----
+ *
+ * This is a variation on variant and polymorphic value theme.
+ *
+ * A tag based dispatch
+ *
+ * Supported operations:
+ * - construction
+ * - destruction
+ * - clone a = b;
+ * - cast a.cast<T>()
+ * - multi-method cast to common base a.cast<B>()
+ * - multi-method visit
+ */
+template <typename... Ts>
+class PolyValue : private ControlBlockVTable<Ts, Ts...>... {
+public:
+ using key_type = int;
+
+private:
+ static_assert(detail::is_unique_v<Ts...>, "Types must be unique");
+ static_assert(std::conjunction_v<std::is_empty<ControlBlockVTable<Ts, Ts...>>...>,
+ "VTable base classes must be empty");
+
+ ControlBlock<Ts...>* _object{nullptr};
+
+ PolyValue(ControlBlock<Ts...>* object) noexcept : _object(object) {}
+
+ auto tag() const noexcept {
+ return _object->getRuntimeTag();
+ }
+
+ static void check(const ControlBlock<Ts...>* object) {
+ if (!object) {
+ throw std::logic_error("PolyValue is empty");
+ }
+ }
+
+ static void destroy(ControlBlock<Ts...>* object) noexcept {
+ static constexpr std::array destroyTbl = {&ControlBlockVTable<Ts, Ts...>::destroy...};
+
+ destroyTbl[object->getRuntimeTag()](object);
+ }
+
+ template <typename T>
+ static T* cast(ControlBlock<Ts...>* object) {
+ check(object);
+ static constexpr std::array castTbl = {&ControlBlockVTable<Ts, Ts...>::template cast<T>...};
+ return castTbl[object->getRuntimeTag()](object);
+ }
+
+ template <typename T>
+ static const T* castConst(ControlBlock<Ts...>* object) {
+ check(object);
+ static constexpr std::array castTbl = {
+ &ControlBlockVTable<Ts, Ts...>::template castConst<T>...};
+ return castTbl[object->getRuntimeTag()](object);
+ }
+
+ template <typename T>
+ static bool is(ControlBlock<Ts...>* object) {
+ check(object);
+ static constexpr std::array isTbl = {ControlBlockVTable<Ts, Ts...>::template is_v<T>...};
+ return isTbl[object->getRuntimeTag()];
+ }
+
+ class CompareHelper {
+ ControlBlock<Ts...>* _object{nullptr};
+
+ auto tag() const noexcept {
+ return _object->getRuntimeTag();
+ }
+
+ public:
+ CompareHelper() = default;
+ CompareHelper(ControlBlock<Ts...>* object) : _object(object) {}
+
+ bool operator==(const CompareHelper& rhs) const noexcept {
+ static constexpr std::array cmp = {ControlBlockVTable<Ts, Ts...>::compareEq...};
+ return cmp[tag()](_object, rhs._object);
+ }
+ };
+
+ class Reference {
+ ControlBlock<Ts...>* _object{nullptr};
+
+ auto tag() const noexcept {
+ return _object->getRuntimeTag();
+ }
+
+ public:
+ Reference() = default;
+ Reference(ControlBlock<Ts...>* object) : _object(object) {}
+
+ template <int I>
+ using get_t = detail::get_type_by_index<I, Ts...>;
+
+ key_type tagOf() const {
+ check(_object);
+
+ return tag();
+ }
+
+
+ template <typename V, typename... Args>
+ auto visit(V&& v, Args&&... args) {
+ // unfortunately gcc rejects much nicer code, clang and msvc accept
+ // static constexpr std::array visitTbl = { &ControlBlockVTable<Ts, Ts...>::template
+ // visit<V>... };
+
+ using FunPtrType = decltype(
+ &ControlBlockVTable<get_t<0>, Ts...>::template visit<V, Reference, Args...>);
+ static constexpr FunPtrType visitTbl[] = {
+ &ControlBlockVTable<Ts, Ts...>::template visit<V, Reference, Args...>...};
+
+ check(_object);
+ return visitTbl[tag()](std::forward<V>(v), *this, _object, std::forward<Args>(args)...);
+ }
+
+ template <typename V, typename... Args>
+ auto visit(V&& v, Args&&... args) const {
+ // unfortunately gcc rejects much nicer code, clang and msvc accept
+ // static constexpr std::array visitTbl = { &ControlBlockVTable<Ts, Ts...>::template
+ // visitConst<V>... };
+
+ using FunPtrType = decltype(
+ &ControlBlockVTable<get_t<0>, Ts...>::template visitConst<V, Reference, Args...>);
+ static constexpr FunPtrType visitTbl[] = {
+ &ControlBlockVTable<Ts, Ts...>::template visitConst<V, Reference, Args...>...};
+
+ check(_object);
+ return visitTbl[tag()](std::forward<V>(v), *this, _object, std::forward<Args>(args)...);
+ }
+
+ template <typename T>
+ T* cast() {
+ return PolyValue<Ts...>::template cast<T>(_object);
+ }
+
+ template <typename T>
+ const T* cast() const {
+ return PolyValue<Ts...>::template castConst<T>(_object);
+ }
+
+ template <typename T>
+ bool is() const {
+ return PolyValue<Ts...>::template is<T>(_object);
+ }
+
+ bool empty() const noexcept {
+ return !_object;
+ }
+
+ void swap(Reference& other) noexcept {
+ std::swap(other._object, _object);
+ }
+
+ // Compare references, not the objects themselves.
+ bool operator==(const Reference& rhs) const noexcept {
+ return _object == rhs._object;
+ }
+
+ bool operator==(const PolyValue& rhs) const noexcept {
+ return rhs == (*this);
+ }
+
+ auto hash() const noexcept {
+ return std::hash<const void*>{}(_object);
+ }
+
+ auto follow() const {
+ return CompareHelper(_object);
+ }
+
+ friend class PolyValue;
+ };
+
+public:
+ using reference_type = Reference;
+
+ template <typename T>
+ static constexpr key_type tagOf() {
+ return ControlBlockVTable<T, Ts...>::_staticTag;
+ }
+
+ key_type tagOf() const {
+ check(_object);
+
+ return tag();
+ }
+
+ PolyValue() = delete;
+
+ PolyValue(const PolyValue& other) {
+ static constexpr std::array cloneTbl = {&ControlBlockVTable<Ts, Ts...>::clone...};
+ if (other._object) {
+ _object = cloneTbl[other.tag()](other._object);
+ }
+ }
+
+ PolyValue(const Reference& other) {
+ static constexpr std::array cloneTbl = {&ControlBlockVTable<Ts, Ts...>::clone...};
+ if (other._object) {
+ _object = cloneTbl[other.tag()](other._object);
+ }
+ }
+
+ PolyValue(PolyValue&& other) noexcept {
+ swap(other);
+ }
+
+ ~PolyValue() noexcept {
+ if (_object) {
+ destroy(_object);
+ }
+ }
+
+ PolyValue& operator=(PolyValue other) noexcept {
+ swap(other);
+ return *this;
+ }
+
+ template <typename T, typename... Args>
+ static PolyValue make(Args&&... args) {
+ return PolyValue{ControlBlockVTable<T, Ts...>::make(std::forward<Args>(args)...)};
+ }
+
+ template <int I>
+ using get_t = detail::get_type_by_index<I, Ts...>;
+
+ template <typename V, typename... Args>
+ auto visit(V&& v, Args&&... args) {
+ // unfortunately gcc rejects much nicer code, clang and msvc accept
+ // static constexpr std::array visitTbl = { &ControlBlockVTable<Ts, Ts...>::template
+ // visit<V>... };
+
+ using FunPtrType =
+ decltype(&ControlBlockVTable<get_t<0>, Ts...>::template visit<V, PolyValue, Args...>);
+ static constexpr FunPtrType visitTbl[] = {
+ &ControlBlockVTable<Ts, Ts...>::template visit<V, PolyValue, Args...>...};
+
+ check(_object);
+ return visitTbl[tag()](std::forward<V>(v), *this, _object, std::forward<Args>(args)...);
+ }
+
+ template <typename V, typename... Args>
+ auto visit(V&& v, Args&&... args) const {
+ // unfortunately gcc rejects much nicer code, clang and msvc accept
+ // static constexpr std::array visitTbl = { &ControlBlockVTable<Ts, Ts...>::template
+ // visitConst<V>... };
+
+ using FunPtrType = decltype(
+ &ControlBlockVTable<get_t<0>, Ts...>::template visitConst<V, PolyValue, Args...>);
+ static constexpr FunPtrType visitTbl[] = {
+ &ControlBlockVTable<Ts, Ts...>::template visitConst<V, PolyValue, Args...>...};
+
+ check(_object);
+ return visitTbl[tag()](std::forward<V>(v), *this, _object, std::forward<Args>(args)...);
+ }
+
+ template <typename T>
+ T* cast() {
+ return cast<T>(_object);
+ }
+
+ template <typename T>
+ const T* cast() const {
+ return castConst<T>(_object);
+ }
+
+ template <typename T>
+ bool is() const {
+ return is<T>(_object);
+ }
+
+ bool empty() const noexcept {
+ return !_object;
+ }
+
+ void swap(PolyValue& other) noexcept {
+ std::swap(other._object, _object);
+ }
+
+ bool operator==(const PolyValue& rhs) const noexcept {
+ static constexpr std::array cmp = {ControlBlockVTable<Ts, Ts...>::compareEq...};
+ return cmp[tag()](_object, rhs._object);
+ }
+
+ bool operator==(const Reference& rhs) const noexcept {
+ static constexpr std::array cmp = {ControlBlockVTable<Ts, Ts...>::compareEq...};
+ return cmp[tag()](_object, rhs._object);
+ }
+
+ auto ref() {
+ check(_object);
+ return Reference(_object);
+ }
+
+ auto ref() const {
+ check(_object);
+ return Reference(_object);
+ }
+};
+
+} // namespace algebra
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/bool_expression.h b/src/mongo/db/query/optimizer/bool_expression.h
new file mode 100644
index 00000000000..bf00f907504
--- /dev/null
+++ b/src/mongo/db/query/optimizer/bool_expression.h
@@ -0,0 +1,140 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+
+#pragma once
+
+#include "boost/optional.hpp"
+#include <vector>
+
+#include "mongo/db/query/optimizer/algebra/operator.h"
+#include "mongo/db/query/optimizer/algebra/polyvalue.h"
+#include "mongo/util/assert_util.h"
+
+namespace mongo::optimizer {
+
+/**
+ * Represents a generic boolean expression with arbitrarily nested conjunctions and disjunction
+ * elements.
+ */
+template <class T>
+struct BoolExpr {
+ class Atom;
+ class Conjunction;
+ class Disjunction;
+
+ using Node = algebra::PolyValue<Atom, Conjunction, Disjunction>;
+ using NodeVector = std::vector<Node>;
+
+
+ class Atom final : public algebra::OpSpecificArity<Node, Atom, 0> {
+ using Base = algebra::OpSpecificArity<Node, Atom, 0>;
+
+ public:
+ Atom(T expr) : Base(), _expr(std::move(expr)) {}
+
+ bool operator==(const Atom& other) const {
+ return _expr == other._expr;
+ }
+
+ const T& getExpr() const {
+ return _expr;
+ }
+ T& getExpr() {
+ return _expr;
+ }
+
+ private:
+ T _expr;
+ };
+
+ class Conjunction final : public algebra::OpSpecificDynamicArity<Node, Conjunction, 0> {
+ using Base = algebra::OpSpecificDynamicArity<Node, Conjunction, 0>;
+
+ public:
+ Conjunction(NodeVector children) : Base(std::move(children)) {
+ uassert(6624351, "Must have at least one child", !Base::nodes().empty());
+ }
+
+ bool operator==(const Conjunction& other) const {
+ return Base::nodes() == other.nodes();
+ }
+ };
+
+ class Disjunction final : public algebra::OpSpecificDynamicArity<Node, Disjunction, 0> {
+ using Base = algebra::OpSpecificDynamicArity<Node, Disjunction, 0>;
+
+ public:
+ Disjunction(NodeVector children) : Base(std::move(children)) {
+ uassert(6624301, "Must have at least one child", !Base::nodes().empty());
+ }
+
+ bool operator==(const Disjunction& other) const {
+ return Base::nodes() == other.nodes();
+ }
+ };
+
+
+ /**
+ * Utility functions.
+ */
+ template <typename T1, typename... Args>
+ static auto make(Args&&... args) {
+ return Node::template make<T1>(std::forward<Args>(args)...);
+ }
+
+ template <typename... Args>
+ static auto makeSeq(Args&&... args) {
+ NodeVector seq;
+ (seq.emplace_back(std::forward<Args>(args)), ...);
+ return seq;
+ }
+
+ template <typename... Args>
+ static Node makeSingularDNF(Args&&... args) {
+ return make<Disjunction>(
+ makeSeq(make<Conjunction>(makeSeq(make<Atom>(T{std::forward<Args>(args)...})))));
+ }
+
+ static boost::optional<const T&> getSingularDNF(const Node& n) {
+ if (auto disjunction = n.template cast<Disjunction>();
+ disjunction != nullptr && disjunction->nodes().size() == 1) {
+ if (auto conjunction = disjunction->nodes().front().template cast<Conjunction>();
+ conjunction != nullptr && conjunction->nodes().size() == 1) {
+ if (auto atom = conjunction->nodes().front().template cast<Atom>();
+ atom != nullptr) {
+ return {atom->getExpr()};
+ }
+ }
+ }
+ return {};
+ }
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/cascades/ce_heuristic.cpp b/src/mongo/db/query/optimizer/cascades/ce_heuristic.cpp
new file mode 100644
index 00000000000..263fab109a5
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/ce_heuristic.cpp
@@ -0,0 +1,192 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
+#include "mongo/db/query/optimizer/utils/memo_utils.h"
+
+namespace mongo::optimizer::cascades {
+
+using namespace properties;
+
+class CEHeuristicTransport {
+public:
+ CEType transport(const ScanNode& node, CEType /*bindResult*/) {
+ // Default cardinality estimate.
+ const CEType metadataCE = _memo.getMetadata()._scanDefs.at(node.getScanDefName()).getCE();
+ return (metadataCE < 0.0) ? 1000.00 : metadataCE;
+ }
+
+ CEType transport(const ValueScanNode& node, CEType /*bindResult*/) {
+ return node.getArraySize();
+ }
+
+ CEType transport(const MemoLogicalDelegatorNode& node) {
+ return getPropertyConst<CardinalityEstimate>(
+ _memo.getGroup(node.getGroupId())._logicalProperties)
+ .getEstimate();
+ }
+
+ CEType transport(const FilterNode& node, CEType childResult, CEType /*exprResult*/) {
+ if (node.getFilter() == Constant::boolean(true)) {
+ // Trivially true filter.
+ return childResult;
+ } else if (node.getFilter() == Constant::boolean(false)) {
+ // Trivially false filter.
+ return 0.0;
+ } else {
+ // Estimate filter selectivity at 0.1.
+ return 0.1 * childResult;
+ }
+ }
+
+ CEType transport(const EvaluationNode& node, CEType childResult, CEType /*exprResult*/) {
+ // Evaluations do not change cardinality.
+ return childResult;
+ }
+
+ CEType transport(const SargableNode& node,
+ CEType /*childResult*/,
+ CEType /*bindsResult*/,
+ CEType /*refsResult*/) {
+ ABT lowered = node.getChild();
+ for (const auto& [key, req] : node.getReqMap()) {
+ // TODO: consider issuing one filter node per interval.
+ lowerPartialSchemaRequirement(key, req, lowered);
+ }
+ return algebra::transport<false>(lowered, *this);
+ }
+
+ CEType transport(const RIDIntersectNode& node,
+ CEType /*leftChildResult*/,
+ CEType /*rightChildResult*/) {
+ // CE for the group should already be derived via the underlying Filter or Evaluation
+ // logical nodes.
+ uasserted(6624038, "Should not be necessary to derive CE for RIDIntersectNode");
+ }
+
+ CEType transport(const BinaryJoinNode& node,
+ CEType /*leftChildResult*/,
+ CEType /*rightChildResult*/,
+ CEType /*exprResult*/) {
+ uasserted(6624039, "CE derivation not implemented.");
+ }
+
+ CEType transport(const UnionNode& node,
+ std::vector<CEType> childResults,
+ CEType /*bindResult*/,
+ CEType /*refsResult*/) {
+ // Combine the CE of each child.
+ CEType result = 0;
+ for (auto&& child : childResults) {
+ result += child;
+ }
+ return result;
+ }
+
+ CEType transport(const GroupByNode& node,
+ CEType childResult,
+ CEType /*bindAggResult*/,
+ CEType /*refsAggResult*/,
+ CEType /*bindGbResult*/,
+ CEType /*refsGbResult*/) {
+ // TODO: estimate number of groups.
+ switch (node.getType()) {
+ case GroupNodeType::Complete:
+ return 0.01 * childResult;
+
+ // Global and Local selectivity should multiply to Complete selectivity.
+ case GroupNodeType::Global:
+ return 0.5 * childResult;
+ case GroupNodeType::Local:
+ return 0.02 * childResult;
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+ }
+
+ CEType transport(const UnwindNode& node,
+ CEType childResult,
+ CEType /*bindResult*/,
+ CEType /*refsResult*/) {
+ // Estimate unwind selectivity at 10.0
+ return 10.0 * childResult;
+ }
+
+ CEType transport(const CollationNode& node, CEType childResult, CEType /*refsResult*/) {
+ // Collations do not change cardinality.
+ return childResult;
+ }
+
+ CEType transport(const LimitSkipNode& node, CEType childResult) {
+ const auto limit = node.getProperty().getLimit();
+ if (limit < childResult) {
+ return limit;
+ }
+ return childResult;
+ }
+
+ CEType transport(const ExchangeNode& node, CEType childResult, CEType /*refsResult*/) {
+ // Exchanges do not change cardinality.
+ return childResult;
+ }
+
+ CEType transport(const RootNode& node, CEType childResult, CEType /*refsResult*/) {
+ // Root node does not change cardinality.
+ return childResult;
+ }
+
+ /**
+ * Other ABT types.
+ */
+ template <typename T, typename... Ts>
+ CEType transport(const T& /*node*/, Ts&&...) {
+ static_assert(!canBeLogicalNode<T>(), "Logical node must implement its CE derivation.");
+ return 0.0;
+ }
+
+ static CEType derive(const Memo& memo, const ABT::reference_type logicalNodeRef) {
+ CEHeuristicTransport instance(memo);
+ return algebra::transport<false>(logicalNodeRef, instance);
+ }
+
+private:
+ CEHeuristicTransport(const Memo& memo) : _memo(memo) {}
+
+ // We don't own this.
+ const Memo& _memo;
+};
+
+CEType HeuristicCE::deriveCE(const Memo& memo,
+ const LogicalProps& /*logicalProps*/,
+ const ABT::reference_type logicalNodeRef) const {
+ return CEHeuristicTransport::derive(memo, logicalNodeRef);
+}
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/ce_heuristic.h b/src/mongo/db/query/optimizer/cascades/ce_heuristic.h
new file mode 100644
index 00000000000..40547b90867
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/ce_heuristic.h
@@ -0,0 +1,48 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/cascades/interfaces.h"
+
+namespace mongo::optimizer::cascades {
+
+/**
+ * Default cardinality estimation in the absence of statistics.
+ * Relies purely on heuristics.
+ * We currently do not use logical properties for heuristic ce.
+ */
+class HeuristicCE : public CEInterface {
+public:
+ CEType deriveCE(const Memo& memo,
+ const properties::LogicalProps& /*logicalProps*/,
+ ABT::reference_type logicalNodeRef) const override final;
+};
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/ce_hinted.cpp b/src/mongo/db/query/optimizer/cascades/ce_hinted.cpp
new file mode 100644
index 00000000000..ed3f4ce9335
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/ce_hinted.cpp
@@ -0,0 +1,96 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/cascades/ce_hinted.h"
+#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
+
+namespace mongo::optimizer::cascades {
+
+using namespace properties;
+
+class CEHintedTransport {
+public:
+ CEType transport(const ABT& n,
+ const SargableNode& node,
+ const Memo& memo,
+ const LogicalProps& logicalProps,
+ CEType childResult,
+ CEType /*bindsResult*/,
+ CEType /*refsResult*/) {
+ CEType result = childResult;
+ for (const auto& [key, req] : node.getReqMap()) {
+ if (!isIntervalReqFullyOpenDNF(req.getIntervals())) {
+ auto it = _hints.find(key);
+ if (it != _hints.cend()) {
+ // Assume independence.
+ result *= it->second;
+ }
+ }
+ }
+
+ return result;
+ }
+
+ template <typename T, typename... Ts>
+ CEType transport(const ABT& n,
+ const T& /*node*/,
+ const Memo& memo,
+ const LogicalProps& logicalProps,
+ Ts&&...) {
+ if (canBeLogicalNode<T>()) {
+ return _heuristicCE.deriveCE(memo, logicalProps, n.ref());
+ }
+ return 0.0;
+ }
+
+ static CEType derive(const Memo& memo,
+ const PartialSchemaSelHints& hints,
+ const LogicalProps& logicalProps,
+ const ABT::reference_type logicalNodeRef) {
+ CEHintedTransport instance(memo, hints);
+ return algebra::transport<true>(logicalNodeRef, instance, memo, logicalProps);
+ }
+
+private:
+ CEHintedTransport(const Memo& memo, const PartialSchemaKeyCE& hints)
+ : _heuristicCE(), _hints(hints) {}
+
+ HeuristicCE _heuristicCE;
+
+ // We don't own this.
+ const PartialSchemaSelHints& _hints;
+};
+
+CEType HintedCE::deriveCE(const Memo& memo,
+ const LogicalProps& logicalProps,
+ const ABT::reference_type logicalNodeRef) const {
+ return CEHintedTransport::derive(memo, _hints, logicalProps, logicalNodeRef);
+}
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/ce_hinted.h b/src/mongo/db/query/optimizer/cascades/ce_hinted.h
new file mode 100644
index 00000000000..422bfcdfa73
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/ce_hinted.h
@@ -0,0 +1,56 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/cascades/interfaces.h"
+
+namespace mongo::optimizer::cascades {
+
+using PartialSchemaSelHints =
+ std::map<PartialSchemaKey, SelectivityType, PartialSchemaKeyLessComparator>;
+
+/**
+ * Estimation based on hints. The hints are organized in a PartialSchemaKeyCE structure.
+ * SargableNodes are estimated based on the matching PartialSchemaKeys.
+ */
+class HintedCE : public CEInterface {
+public:
+ HintedCE(PartialSchemaSelHints hints) : _hints(std::move(hints)) {}
+
+ CEType deriveCE(const Memo& memo,
+ const properties::LogicalProps& logicalProps,
+ ABT::reference_type logicalNodeRef) const override final;
+
+private:
+ // Selectivity hints per PartialSchemaKey.
+ PartialSchemaSelHints _hints;
+};
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/cost_derivation.cpp b/src/mongo/db/query/optimizer/cascades/cost_derivation.cpp
new file mode 100644
index 00000000000..20da31e0d80
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/cost_derivation.cpp
@@ -0,0 +1,429 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/cascades/cost_derivation.h"
+#include "mongo/db/query/optimizer/defs.h"
+
+namespace mongo::optimizer::cascades {
+
+using namespace properties;
+
+struct CostAndCEInternal {
+ CostAndCEInternal(double cost, CEType ce) : _cost(cost), _ce(ce) {
+ uassert(8423334, "Invalid cost.", !std::isnan(cost) && cost >= 0.0);
+ uassert(8423332, "Invalid cardinality", std::isfinite(ce) && ce >= 0.0);
+ }
+ double _cost;
+ CEType _ce;
+};
+
+class CostDerivation {
+ // These cost should reflect estimated aggregated execution time in milliseconds.
+ static constexpr double ms = 1.0e-3;
+
+ // Startup cost of an operator. This is the minimal cost of an operator since it is
+ // present even if it doesn't process any input.
+ // TODO: calibrate the cost individually for each operator
+ static constexpr double kStartupCost = 0.000001;
+
+ // TODO: collection scan should depend on the width of the doc.
+ // TODO: the actual measured cost is (0.4 * ms), however we increase it here because currently
+ // it is not possible to estimate the cost of a collection scan vs a full index scan.
+ static constexpr double kScanIncrementalCost = 0.6 * ms;
+
+ // TODO: cost(N fields) ~ (0.55 + 0.025 * N)
+ static constexpr double kIndexScanIncrementalCost = 0.5 * ms;
+
+ // TODO: cost(N fields) ~ 0.7 + 0.19 * N
+ static constexpr double kSeekCost = 2.0 * ms;
+
+ // TODO: take the expression into account.
+ // cost(N conditions) = 0.2 + N * ???
+ static constexpr double kFilterIncrementalCost = 0.2 * ms;
+ // TODO: the cost of projection depends on number of fields: cost(N fields) ~ 0.1 + 0.2 * N
+ static constexpr double kEvalIncrementalCost = 2.0 * ms;
+
+ // TODO: cost(N fields) ~ 0.04 + 0.03*(N^2)
+ static constexpr double kGroupByIncrementalCost = 0.07 * ms;
+ static constexpr double kUnwindIncrementalCost = 0.03 * ms; // TODO: not yet calibrated
+ // TODO: not yet calibrated, should be at least as expensive as a filter
+ static constexpr double kBinaryJoinIncrementalCost = 0.2 * ms;
+ static constexpr double kHashJoinIncrementalCost = 0.05 * ms; // TODO: not yet calibrated
+ static constexpr double kMergeJoinIncrementalCost = 0.02 * ms; // TODO: not yet calibrated
+
+ static constexpr double kUniqueIncrementalCost = 0.7 * ms;
+
+ // TODO: implement collation cost that depends on number and size of sorted fields
+ // Based on a mix of int and str(64) fields:
+ // 1 sort field: sort_cost(N) = 1.0/10 * N * log(N)
+ // 5 sort fields: sort_cost(N) = 2.5/10 * N * log(N)
+ // 10 sort fields: sort_cost(N) = 3.0/10 * N * log(N)
+ // field_cost_coeff(F) ~ 0.75 + 0.2 * F
+ static constexpr double kCollationIncrementalCost = 2.5 * ms; // 5 fields avg
+ static constexpr double kCollationWithLimitIncrementalCost =
+ 1.0 * ms; // TODO: not yet calibrated
+
+ static constexpr double kUnionIncrementalCost = 0.02 * ms;
+
+ static constexpr double kExchangeIncrementalCost = 0.1 * ms; // TODO: not yet calibrated
+
+public:
+ CostAndCEInternal operator()(const ABT& /*n*/, const PhysicalScanNode& /*node*/) {
+ // Default estimate for scan.
+ const double collectionScanCost =
+ kStartupCost + kScanIncrementalCost * _cardinalityEstimate;
+ return {collectionScanCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const CoScanNode& /*node*/) {
+ // Assumed to be free.
+ return {kStartupCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const IndexScanNode& node) {
+ const double indexScanCost =
+ kStartupCost + kIndexScanIncrementalCost * _cardinalityEstimate;
+ return {indexScanCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const SeekNode& /*node*/) {
+ // SeekNode should deliver one result via cardinality estimate override.
+ // TODO: consider using node.getProjectionMap()._fieldProjections.size() to make the cost
+ // dependent on the size of the projection
+ const double seekCost = kStartupCost + kSeekCost * _cardinalityEstimate;
+ return {seekCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const MemoLogicalDelegatorNode& node) {
+ const LogicalProps& childLogicalProps =
+ _memo.getGroup(node.getGroupId())._logicalProperties;
+ // Notice that unlike all physical nodes, this logical node takes it cardinality directly
+ // from the memo group logical property, igrnoring _cardinalityEstimate.
+ CEType baseCE = getPropertyConst<CardinalityEstimate>(childLogicalProps).getEstimate();
+
+ if (hasProperty<IndexingRequirement>(_physProps)) {
+ const auto& indexingReq = getPropertyConst<IndexingRequirement>(_physProps);
+ if (indexingReq.getIndexReqTarget() == IndexReqTarget::Seek) {
+ // If we are performing a seek, normalize against the scan group cardinality.
+ const GroupIdType scanGroupId =
+ getPropertyConst<IndexingAvailability>(childLogicalProps).getScanGroupId();
+ if (scanGroupId == node.getGroupId()) {
+ baseCE = 1.0;
+ } else {
+ const CEType scanGroupCE = getPropertyConst<CardinalityEstimate>(
+ _memo.getGroup(scanGroupId)._logicalProperties)
+ .getEstimate();
+ if (scanGroupCE > 0.0) {
+ baseCE /= scanGroupCE;
+ }
+ }
+ }
+ }
+
+ return {0.0, getAdjustedCE(baseCE, _physProps)};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const MemoPhysicalDelegatorNode& /*node*/) {
+ uasserted(6624040, "Should not be costing physical delegator nodes.");
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const FilterNode& node) {
+ CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
+ double filterCost = childResult._cost;
+ if (!node.getFilter().is<Constant>() && !node.getFilter().is<Variable>()) {
+ // Non-trivial filter.
+ filterCost += kStartupCost + kFilterIncrementalCost * childResult._ce;
+ }
+ return {filterCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const EvaluationNode& node) {
+ CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
+ double evalCost = childResult._cost;
+ if (!node.getProjection().is<Constant>() && !node.getProjection().is<Variable>()) {
+ // Non-trivial projection.
+ evalCost += kStartupCost + kEvalIncrementalCost * childResult._ce;
+ }
+ return {evalCost, childResult._ce};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const BinaryJoinNode& node) {
+ CostAndCEInternal leftChildResult = deriveChild(node.getLeftChild(), 0);
+ CostAndCEInternal rightChildResult = deriveChild(node.getRightChild(), 1);
+ const double joinCost = kStartupCost +
+ kBinaryJoinIncrementalCost * (leftChildResult._ce + rightChildResult._ce) +
+ leftChildResult._cost + rightChildResult._cost;
+ return {joinCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const HashJoinNode& node) {
+ CostAndCEInternal leftChildResult = deriveChild(node.getLeftChild(), 0);
+ CostAndCEInternal rightChildResult = deriveChild(node.getRightChild(), 1);
+
+ // TODO: distinguish build side and probe side.
+ const double hashJoinCost = kStartupCost +
+ kHashJoinIncrementalCost * (leftChildResult._ce + rightChildResult._ce) +
+ leftChildResult._cost + rightChildResult._cost;
+ return {hashJoinCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const MergeJoinNode& node) {
+ CostAndCEInternal leftChildResult = deriveChild(node.getLeftChild(), 0);
+ CostAndCEInternal rightChildResult = deriveChild(node.getRightChild(), 1);
+
+ const double mergeJoinCost = kStartupCost +
+ kMergeJoinIncrementalCost * (leftChildResult._ce + rightChildResult._ce) +
+ leftChildResult._cost + rightChildResult._cost;
+
+ return {mergeJoinCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const UnionNode& node) {
+ const ABTVector& children = node.nodes();
+ // UnionNode with one child is optimized away before lowering, therefore
+ // its cost is the cost of its child.
+ if (children.size() == 1) {
+ CostAndCEInternal childResult = deriveChild(children[0], 0);
+ return {childResult._cost, _cardinalityEstimate};
+ }
+
+ double totalCost = kStartupCost;
+ // The cost is the sum of the costs of its children and the cost to union each child.
+ for (size_t childIdx = 0; childIdx < children.size(); childIdx++) {
+ CostAndCEInternal childResult = deriveChild(children[childIdx], childIdx);
+ const double childCost =
+ childResult._cost + (childIdx > 0 ? kUnionIncrementalCost * childResult._ce : 0);
+ totalCost += childCost;
+ }
+ return {totalCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const GroupByNode& node) {
+ CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
+ double groupByCost = kStartupCost;
+
+ // TODO: for now pretend global group by is free.
+ if (node.getType() == GroupNodeType::Global) {
+ groupByCost += childResult._cost;
+ } else {
+ // TODO: consider RepetitionEstimate since this is a stateful operation.
+ groupByCost += kGroupByIncrementalCost * childResult._ce + childResult._cost;
+ }
+ return {groupByCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const UnwindNode& node) {
+ CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
+ // Unwind probably depends mostly on its output size.
+ const double unwindCost = kUnwindIncrementalCost * _cardinalityEstimate + childResult._cost;
+ return {unwindCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const UniqueNode& node) {
+ CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
+ const double uniqueCost =
+ kStartupCost + kUniqueIncrementalCost * childResult._ce + childResult._cost;
+ return {uniqueCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const CollationNode& node) {
+ CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
+ // TODO: consider RepetitionEstimate since this is a stateful operation.
+
+ double logFactor = childResult._ce;
+ double incrConst = kCollationIncrementalCost;
+ if (hasProperty<LimitSkipRequirement>(_physProps)) {
+ if (auto limit = getPropertyConst<LimitSkipRequirement>(_physProps).getAbsoluteLimit();
+ limit < logFactor) {
+ logFactor = limit;
+ incrConst = kCollationWithLimitIncrementalCost;
+ }
+ }
+
+ // Notice that log2(x) < 0 for any x < 1, and log2(1) = 0. Generally it makes sense that
+ // there is no cost to sort 1 document, so the only cost left is the startup cost.
+ const double sortCost = kStartupCost + childResult._cost +
+ ((logFactor <= 1.0)
+ ? 0.0
+ // TODO: The cost formula below is based on 1 field, mix of int and str. Instead we
+ // have to take into account the number and size of sorted fields.
+ : incrConst * childResult._ce * std::log2(logFactor));
+ return {sortCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const LimitSkipNode& node) {
+ // Assumed to be free.
+ CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
+ const double limitCost = kStartupCost + childResult._cost;
+ return {limitCost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const ExchangeNode& node) {
+ CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
+ double localCost = kStartupCost + kExchangeIncrementalCost * _cardinalityEstimate;
+
+ switch (node.getProperty().getDistributionAndProjections()._type) {
+ case DistributionType::Replicated:
+ localCost *= 2.0;
+ break;
+
+ case DistributionType::HashPartitioning:
+ case DistributionType::RangePartitioning:
+ localCost *= 1.1;
+ break;
+
+ default:
+ break;
+ }
+
+ return {localCost + childResult._cost, _cardinalityEstimate};
+ }
+
+ CostAndCEInternal operator()(const ABT& /*n*/, const RootNode& node) {
+ return deriveChild(node.getChild(), 0);
+ }
+
+ /**
+ * Other ABT types.
+ */
+ template <typename T, typename... Ts>
+ CostAndCEInternal operator()(const ABT& /*n*/, const T& /*node*/, Ts&&...) {
+ static_assert(!canBePhysicalNode<T>(), "Physical node must implement its cost derivation.");
+ return {0.0, 0.0};
+ }
+
+ static CostAndCEInternal derive(const Memo& memo,
+ const PhysProps& physProps,
+ const ABT::reference_type physNodeRef,
+ const ChildPropsType& childProps,
+ const NodeCEMap& nodeCEMap) {
+ CostAndCEInternal result =
+ deriveInternal(memo, physProps, physNodeRef, childProps, nodeCEMap);
+
+ switch (getPropertyConst<DistributionRequirement>(physProps)
+ .getDistributionAndProjections()
+ ._type) {
+ case DistributionType::Centralized:
+ case DistributionType::Replicated:
+ break;
+
+ case DistributionType::RoundRobin:
+ case DistributionType::HashPartitioning:
+ case DistributionType::RangePartitioning:
+ case DistributionType::UnknownPartitioning:
+ result._cost /= memo.getMetadata()._numberOfPartitions;
+ break;
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+
+ return result;
+ }
+
+private:
+ CostDerivation(const Memo& memo,
+ const CEType ce,
+ const PhysProps& physProps,
+ const ChildPropsType& childProps,
+ const NodeCEMap& nodeCEMap)
+ : _memo(memo),
+ _physProps(physProps),
+ _cardinalityEstimate(getAdjustedCE(ce, _physProps)),
+ _childProps(childProps),
+ _nodeCEMap(nodeCEMap) {}
+
+ static CostAndCEInternal deriveInternal(const Memo& memo,
+ const PhysProps& physProps,
+ const ABT::reference_type physNodeRef,
+ const ChildPropsType& childProps,
+ const NodeCEMap& nodeCEMap) {
+ auto it = nodeCEMap.find(physNodeRef.cast<Node>());
+ bool found = (it != nodeCEMap.cend());
+ uassert(8423330,
+ "Only MemoLogicalDelegatorNode can be missing from nodeCEMap.",
+ found || physNodeRef.is<MemoLogicalDelegatorNode>());
+ const CEType ce = (found ? it->second : 0.0);
+
+ CostDerivation instance(memo, ce, physProps, childProps, nodeCEMap);
+ CostAndCEInternal costCEestimates = physNodeRef.visit(instance);
+ return costCEestimates;
+ }
+
+ CostAndCEInternal deriveChild(const ABT& child, const size_t childIndex) {
+ PhysProps physProps = _childProps.empty() ? _physProps : _childProps.at(childIndex).second;
+ return deriveInternal(_memo, physProps, child.ref(), {}, _nodeCEMap);
+ }
+
+ static CEType getAdjustedCE(CEType baseCE, const PhysProps& physProps) {
+ CEType result = baseCE;
+
+ // First: correct for un-enforced limit.
+ if (hasProperty<LimitSkipRequirement>(physProps)) {
+ const auto limit = getPropertyConst<LimitSkipRequirement>(physProps).getAbsoluteLimit();
+ if (result > limit) {
+ result = limit;
+ }
+ }
+
+ // Second: correct for enforced limit.
+ if (hasProperty<LimitEstimate>(physProps)) {
+ const auto limit = getPropertyConst<LimitEstimate>(physProps).getEstimate();
+ if (result > limit) {
+ result = limit;
+ }
+ }
+
+ // Third: correct for repetition.
+ if (hasProperty<RepetitionEstimate>(physProps)) {
+ result *= getPropertyConst<RepetitionEstimate>(physProps).getEstimate();
+ }
+
+ return result;
+ }
+
+ // We don't own this.
+ const Memo& _memo;
+ const PhysProps& _physProps;
+ const CEType _cardinalityEstimate;
+ const ChildPropsType& _childProps;
+ const NodeCEMap& _nodeCEMap;
+};
+
+CostAndCE DefaultCosting::deriveCost(const Memo& memo,
+ const PhysProps& physProps,
+ const ABT::reference_type physNodeRef,
+ const ChildPropsType& childProps,
+ const NodeCEMap& nodeCEMap) const {
+ const CostAndCEInternal result =
+ CostDerivation::derive(memo, physProps, physNodeRef, childProps, nodeCEMap);
+ return {CostType::fromDouble(result._cost), result._ce};
+}
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/cost_derivation.h b/src/mongo/db/query/optimizer/cascades/cost_derivation.h
new file mode 100644
index 00000000000..2d85db5f63f
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/cost_derivation.h
@@ -0,0 +1,49 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/cascades/interfaces.h"
+#include "mongo/db/query/optimizer/cascades/memo.h"
+
+namespace mongo::optimizer::cascades {
+
+/**
+ * Default costing for physical nodes with logical delegator (not-yet-optimized) inputs.
+ */
+class DefaultCosting : public CostingInterface {
+public:
+ CostAndCE deriveCost(const Memo& memo,
+ const properties::PhysProps& physProps,
+ ABT::reference_type physNodeRef,
+ const ChildPropsType& childProps,
+ const NodeCEMap& nodeCEMap) const override final;
+};
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/enforcers.cpp b/src/mongo/db/query/optimizer/cascades/enforcers.cpp
new file mode 100644
index 00000000000..ba5413126bf
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/enforcers.cpp
@@ -0,0 +1,269 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/cascades/enforcers.h"
+
+#include "mongo/db/query/optimizer/utils/memo_utils.h"
+
+namespace mongo::optimizer::cascades {
+
+using namespace properties;
+
+// Maximum Limit to consider to implement in the sort stage (via a min-heap internally).
+static constexpr int64_t kMaxLimitForSort = 100;
+
+static bool isDistributionCentralizedOrReplicated(const PhysProps& physProps) {
+ switch (getPropertyConst<DistributionRequirement>(physProps)
+ .getDistributionAndProjections()
+ ._type) {
+ case DistributionType::Centralized:
+ case DistributionType::Replicated:
+ return true;
+
+ default:
+ return false;
+ }
+}
+
+/**
+ * Checks if we are not trying to satisfy using the entire collection. We are either aiming for a
+ * covered index, or for a seek.
+ */
+static bool hasIncompleteScanIndexingRequirement(const PhysProps& physProps) {
+ return hasProperty<IndexingRequirement>(physProps) &&
+ getPropertyConst<IndexingRequirement>(physProps).getIndexReqTarget() !=
+ IndexReqTarget::Complete;
+}
+
+class PropEnforcerVisitor {
+public:
+ PropEnforcerVisitor(const GroupIdType groupId,
+ const Metadata& metadata,
+ PrefixId& prefixId,
+ PhysRewriteQueue& queue,
+ const PhysProps& physProps,
+ const LogicalProps& logicalProps)
+ : _groupId(groupId),
+ _metadata(metadata),
+ _prefixId(prefixId),
+ _queue(queue),
+ _physProps(physProps),
+ _logicalProps(logicalProps) {}
+
+ void operator()(const PhysProperty&, const CollationRequirement& prop) {
+ if (hasIncompleteScanIndexingRequirement(_physProps)) {
+ // If we have indexing requirements, we do not enforce collation separately.
+ // It will be satisfied as part of the index collation.
+ return;
+ }
+
+ PhysProps childProps = _physProps;
+ removeProperty<CollationRequirement>(childProps);
+ addProjectionsToProperties(childProps, prop.getAffectedProjectionNames());
+
+ // TODO: also remove RepetitionEstimate if the subtree does not use bound variables.
+ removeProperty<LimitEstimate>(childProps);
+
+ if (hasProperty<LimitSkipRequirement>(_physProps)) {
+ const auto& limitSkipReq = getPropertyConst<LimitSkipRequirement>(_physProps);
+ if (prop.hasClusteredOp() || limitSkipReq.getSkip() != 0 ||
+ limitSkipReq.getLimit() > kMaxLimitForSort) {
+ // We cannot enforce collation+skip or collation+large limit.
+ return;
+ }
+
+ // We can satisfy both collation and limit-skip requirement. During lowering, physical
+ // properties will indicate presence of limit skip, and thus we set the limit on the sbe
+ // stage.
+ removeProperty<LimitSkipRequirement>(childProps);
+ }
+
+ ABT enforcer = make<CollationNode>(prop, make<MemoLogicalDelegatorNode>(_groupId));
+ optimizeChild<CollationNode>(
+ _queue, kDefaultPriority, std::move(enforcer), std::move(childProps));
+ }
+
+ void operator()(const PhysProperty&, const LimitSkipRequirement& prop) {
+ if (hasIncompleteScanIndexingRequirement(_physProps)) {
+ // If we have indexing requirements, we do not enforce limit skip.
+ return;
+ }
+ if (!isDistributionCentralizedOrReplicated(_physProps)) {
+ // Can only enforce limit-skip under centralized or replicated distribution.
+ return;
+ }
+
+ PhysProps childProps = _physProps;
+ removeProperty<LimitSkipRequirement>(childProps);
+ setPropertyOverwrite<LimitEstimate>(
+ childProps, LimitEstimate{static_cast<CEType>(prop.getAbsoluteLimit())});
+
+ ABT enforcer = make<LimitSkipNode>(prop, make<MemoLogicalDelegatorNode>(_groupId));
+ optimizeChild<LimitSkipNode>(
+ _queue, kDefaultPriority, std::move(enforcer), std::move(childProps));
+ }
+
+ void operator()(const PhysProperty&, const DistributionRequirement& prop) {
+ if (!_metadata.isParallelExecution()) {
+ // We're running in serial mode.
+ return;
+ }
+ if (prop.getDisableExchanges()) {
+ // We cannot change distributions.
+ return;
+ }
+ if (hasProperty<IndexingRequirement>(_physProps) &&
+ getPropertyConst<IndexingRequirement>(_physProps).getIndexReqTarget() ==
+ IndexReqTarget::Seek) {
+ // Cannot change distributions while under Seek requirement.
+ return;
+ }
+
+ if (prop.getDistributionAndProjections()._type == DistributionType::UnknownPartitioning) {
+ // Cannot exchange into unknown partitioning.
+ return;
+ }
+
+ // TODO: consider hash partition on RID if under IndexingAvailability.
+
+ const bool hasCollation = hasProperty<CollationRequirement>(_physProps);
+ if (hasCollation) {
+ // For now we cannot enforce if we have collation requirement.
+ // TODO: try enforcing into partitioning distributions which form prefixes over the
+ // collation, with ordered exchange.
+ return;
+ }
+
+ const auto& distributions =
+ getPropertyConst<DistributionAvailability>(_logicalProps).getDistributionSet();
+ for (const auto& distribution : distributions) {
+ if (distribution == prop.getDistributionAndProjections()) {
+ // Same distribution.
+ continue;
+ }
+ if (distribution._type == DistributionType::Replicated) {
+ // Cannot switch "away" from replicated distribution.
+ continue;
+ }
+
+ PhysProps childProps = _physProps;
+ setPropertyOverwrite<DistributionRequirement>(childProps, distribution);
+
+ addProjectionsToProperties(childProps, distribution._projectionNames);
+ getProperty<DistributionRequirement>(childProps).setDisableExchanges(true);
+
+ ABT enforcer = make<ExchangeNode>(prop, make<MemoLogicalDelegatorNode>(_groupId));
+ optimizeChild<ExchangeNode>(
+ _queue, kDefaultPriority, std::move(enforcer), std::move(childProps));
+ }
+ }
+
+ void operator()(const PhysProperty&, const ProjectionRequirement& prop) {
+ const ProjectionNameSet& availableProjections =
+ getPropertyConst<ProjectionAvailability>(_logicalProps).getProjections();
+
+ // Verify we can satisfy the required projections using the logical projections.
+ for (const ProjectionName& projectionName : prop.getProjections().getVector()) {
+ if (availableProjections.find(projectionName) == availableProjections.cend()) {
+ uasserted(6624100, "Cannot satisfy all projections");
+ }
+ }
+ }
+
+ void operator()(const PhysProperty&, const IndexingRequirement& prop) {
+ if (prop.getIndexReqTarget() != IndexReqTarget::Complete) {
+ return;
+ }
+
+ uassert(6624101,
+ "IndexingRequirement without indexing availability",
+ hasProperty<IndexingAvailability>(_logicalProps));
+ const IndexingAvailability& indexingAvailability =
+ getPropertyConst<IndexingAvailability>(_logicalProps);
+
+ // TODO: consider left outer joins. We can propagate rid from the outer side.
+ if (_metadata._scanDefs.at(indexingAvailability.getScanDefName()).getIndexDefs().empty()) {
+ // No indexes on the collection.
+ return;
+ }
+
+ const ProjectionNameOrderPreservingSet& requiredProjections =
+ getPropertyConst<ProjectionRequirement>(_physProps).getProjections();
+ const ProjectionName& scanProjection = indexingAvailability.getScanProjection();
+ const bool requiresScanProjection = requiredProjections.find(scanProjection).second;
+
+ if (!requiresScanProjection) {
+ // Try indexScanOnly (covered index) if we do not require scan projection.
+ PhysProps newProps = _physProps;
+ setPropertyOverwrite<IndexingRequirement>(newProps,
+ {IndexReqTarget::Index,
+ prop.getNeedsRID(),
+ prop.getDedupRID(),
+ prop.getSatisfiedPartialIndexesGroupId()});
+
+ optimizeUnderNewProperties(_queue,
+ kDefaultPriority,
+ make<MemoLogicalDelegatorNode>(_groupId),
+ std::move(newProps));
+ }
+ }
+
+ void operator()(const PhysProperty&, const RepetitionEstimate& prop) {
+ // Noop. We do not currently enforce this property. It only affects costing.
+ // TODO: consider materializing the subtree if we estimate a lot of repetitions.
+ }
+
+ void operator()(const PhysProperty&, const LimitEstimate& prop) {
+ // Noop. We do not currently enforce this property. It only affects costing.
+ }
+
+private:
+ const GroupIdType _groupId;
+
+ // We don't own any of those.
+ const Metadata& _metadata;
+ PrefixId& _prefixId;
+ PhysRewriteQueue& _queue;
+ const PhysProps& _physProps;
+ const LogicalProps& _logicalProps;
+};
+
+void addEnforcers(const GroupIdType groupId,
+ const Metadata& metadata,
+ PrefixId& prefixId,
+ PhysRewriteQueue& queue,
+ const PhysProps& physProps,
+ const LogicalProps& logicalProps) {
+ PropEnforcerVisitor visitor(groupId, metadata, prefixId, queue, physProps, logicalProps);
+ for (const auto& entry : physProps) {
+ entry.second.visit(visitor);
+ }
+}
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/enforcers.h b/src/mongo/db/query/optimizer/cascades/enforcers.h
new file mode 100644
index 00000000000..6ad1ace8bce
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/enforcers.h
@@ -0,0 +1,47 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/cascades/rewrite_queues.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer::cascades {
+
+/**
+ * Adds property enforcement rules for particular group and physical properties.
+ */
+void addEnforcers(GroupIdType groupId,
+ const Metadata& metadata,
+ PrefixId& prefixId,
+ PhysRewriteQueue& queue,
+ const properties::PhysProps& physProps,
+ const properties::LogicalProps& logicalProps);
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/implementers.cpp b/src/mongo/db/query/optimizer/cascades/implementers.cpp
new file mode 100644
index 00000000000..54dc07eca53
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/implementers.cpp
@@ -0,0 +1,1441 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/cascades/implementers.h"
+
+#include "mongo/db/query/optimizer/utils/memo_utils.h"
+
+namespace mongo::optimizer::cascades {
+using namespace properties;
+
+template <class P, class T>
+static bool propertyAffectsProjections(const PhysProps& props, const T& projections) {
+ if (!hasProperty<P>(props)) {
+ return false;
+ }
+
+ const ProjectionNameSet& propProjections =
+ getPropertyConst<P>(props).getAffectedProjectionNames();
+ for (const ProjectionName& projectionName : projections) {
+ if (propProjections.find(projectionName) != propProjections.cend()) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+template <class P>
+static bool propertyAffectsProjection(const PhysProps& props,
+ const ProjectionName& projectionName) {
+ return propertyAffectsProjections<P>(props, ProjectionNameVector{projectionName});
+}
+
+/**
+ * Implement physical nodes based on existing logical nodes.
+ */
+class ImplementationVisitor {
+public:
+ void operator()(const ABT& /*n*/, const ScanNode& node) {
+ if (hasProperty<LimitSkipRequirement>(_physProps)) {
+ // Cannot satisfy limit-skip.
+ return;
+ }
+ if (hasProperty<CollationRequirement>(_physProps)) {
+ // Regular scan cannot satisfy any collation requirement.
+ // TODO: consider rid?
+ return;
+ }
+
+ const auto& indexReq = getPropertyConst<IndexingRequirement>(_physProps);
+ const IndexReqTarget indexReqTarget = indexReq.getIndexReqTarget();
+ switch (indexReqTarget) {
+ case IndexReqTarget::Index:
+ // At this point cannot only satisfy index-only.
+ return;
+
+ case IndexReqTarget::Seek:
+ if (_hints._disableIndexes == DisableIndexOptions::DisableAll) {
+ return;
+ }
+ uassert(6624102, "RID projection is required for seek", indexReq.getNeedsRID());
+ // Fall through to code below.
+ break;
+
+ case IndexReqTarget::Complete:
+ if (_hints._disableScan) {
+ return;
+ }
+ // Fall through to code below.
+ break;
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+
+ // Handle complete indexing requirement.
+ bool canUseParallelScan = false;
+ if (!distributionsCompatible(
+ indexReqTarget,
+ _memo.getMetadata()._scanDefs.at(node.getScanDefName()).getDistributionAndPaths(),
+ node.getProjectionName(),
+ _logicalProps,
+ {},
+ canUseParallelScan)) {
+ return;
+ }
+
+ FieldProjectionMap fieldProjectionMap;
+ for (const ProjectionName& required :
+ getPropertyConst<ProjectionRequirement>(_physProps).getAffectedProjectionNames()) {
+ if (required == node.getProjectionName()) {
+ fieldProjectionMap._rootProjection = node.getProjectionName();
+ } else {
+ // Regular scan node can satisfy only using its root projection (not fields).
+ return;
+ }
+ }
+
+ const ProjectionName& ridProjName = _ridProjections.at(
+ getPropertyConst<IndexingAvailability>(_logicalProps).getScanDefName());
+ if (indexReqTarget == IndexReqTarget::Seek) {
+ NodeCEMap nodeCEMap;
+
+ ABT physicalSeek =
+ make<SeekNode>(ridProjName, std::move(fieldProjectionMap), node.getScanDefName());
+ // If optimizing a Seek, override CE to 1.0.
+ nodeCEMap.emplace(physicalSeek.cast<Node>(), 1.0);
+
+ ABT limitSkip =
+ make<LimitSkipNode>(LimitSkipRequirement{1, 0}, std::move(physicalSeek));
+ nodeCEMap.emplace(limitSkip.cast<Node>(), 1.0);
+
+ optimizeChildrenNoAssert(
+ _queue, kDefaultPriority, std::move(limitSkip), {}, std::move(nodeCEMap));
+ } else {
+ if (indexReq.getNeedsRID()) {
+ fieldProjectionMap._ridProjection = ridProjName;
+ }
+ ABT physicalScan = make<PhysicalScanNode>(
+ std::move(fieldProjectionMap), node.getScanDefName(), canUseParallelScan);
+ optimizeChild<PhysicalScanNode>(_queue, kDefaultPriority, std::move(physicalScan));
+ }
+ }
+
+ void operator()(const ABT& n, const ValueScanNode& node) {
+ if (hasProperty<LimitSkipRequirement>(_physProps)) {
+ // Cannot satisfy limit-skip.
+ return;
+ }
+ if (hasProperty<CollationRequirement>(_physProps)) {
+ // Cannot satisfy any collation requirement.
+ return;
+ }
+
+ NodeCEMap nodeCEMap;
+ ABT physNode = make<CoScanNode>();
+ const auto& requiredProjections =
+ getPropertyConst<ProjectionRequirement>(_physProps).getProjections();
+
+ if (node.getArraySize() == 0) {
+ nodeCEMap.emplace(physNode.cast<Node>(), 0.0);
+
+ physNode =
+ make<LimitSkipNode>(properties::LimitSkipRequirement{0, 0}, std::move(physNode));
+ nodeCEMap.emplace(physNode.cast<Node>(), 0.0);
+
+ for (const ProjectionName& projectionName : requiredProjections.getVector()) {
+ physNode =
+ make<EvaluationNode>(projectionName, Constant::nothing(), std::move(physNode));
+ nodeCEMap.emplace(physNode.cast<Node>(), 0.0);
+ }
+ } else {
+ nodeCEMap.emplace(physNode.cast<Node>(), 1.0);
+
+ physNode =
+ make<LimitSkipNode>(properties::LimitSkipRequirement{1, 0}, std::move(physNode));
+ nodeCEMap.emplace(physNode.cast<Node>(), 1.0);
+
+ const ProjectionName valueScanProj = _prefixId.getNextId("valueScan");
+ physNode =
+ make<EvaluationNode>(valueScanProj, node.getValueArray(), std::move(physNode));
+ nodeCEMap.emplace(physNode.cast<Node>(), 1.0);
+
+ // Unwind the combined array constant and pick an element for each required projection
+ // in sequence.
+ physNode = make<UnwindNode>(valueScanProj,
+ _prefixId.getNextId("valueScanPid"),
+ false /*retainNonArrays*/,
+ std::move(physNode));
+ nodeCEMap.emplace(physNode.cast<Node>(), node.getArraySize());
+
+ /**
+ * Iterate over the bound projections here as opposed to the required projections, since
+ * the array elements are ordered accordingly.
+ */
+ const ProjectionNameVector& boundProjNames = node.binder().names();
+ for (size_t i = 0; i < boundProjNames.size(); i++) {
+ const ProjectionName& boundProjName = boundProjNames.at(i);
+ if (requiredProjections.find(boundProjName).second) {
+ physNode = make<EvaluationNode>(
+ boundProjName,
+ make<FunctionCall>(
+ "getElement",
+ makeSeq(make<Variable>(valueScanProj), Constant::int32(i))),
+ std::move(physNode));
+ nodeCEMap.emplace(physNode.cast<Node>(), node.getArraySize());
+ }
+ }
+ }
+
+ optimizeChildrenNoAssert(
+ _queue, kDefaultPriority, std::move(physNode), {}, std::move(nodeCEMap));
+ }
+
+ void operator()(const ABT& /*n*/, const MemoLogicalDelegatorNode& /*node*/) {
+ uasserted(6624041,
+ "Must not have logical delegator nodes in the list of the logical nodes");
+ }
+
+ void operator()(const ABT& n, const FilterNode& node) {
+ if (hasProperty<LimitSkipRequirement>(_physProps)) {
+ // We cannot satisfy here.
+ return;
+ }
+
+ VariableNameSetType references = collectVariableReferences(n);
+ if (checkIntroducesScanProjectionUnderIndexOnly(references)) {
+ // Reject if under indexing requirements and now we introduce dependence on scan
+ // projection.
+ return;
+ }
+
+ PhysProps newProps = _physProps;
+ // Add projections we depend on to the requirement.
+ addProjectionsToProperties(newProps, std::move(references));
+ getProperty<DistributionRequirement>(newProps).setDisableExchanges(true);
+
+ ABT physicalFilter = n;
+ optimizeChild<FilterNode>(
+ _queue, kDefaultPriority, std::move(physicalFilter), std::move(newProps));
+ }
+
+ void operator()(const ABT& n, const EvaluationNode& node) {
+ const ProjectionName& projectionName = node.getProjectionName();
+
+ if (const auto* varPtr = node.getProjection().cast<Variable>(); varPtr != nullptr) {
+ // Special case of evaluation node: rebinds to a different variable.
+ const ProjectionName& newProjName = varPtr->name();
+ PhysProps newProps = _physProps;
+
+ {
+ // Update required projections.
+ auto& reqProjections =
+ getProperty<ProjectionRequirement>(newProps).getProjections();
+ reqProjections.erase(projectionName);
+ reqProjections.emplace_back(newProjName);
+ }
+
+ if (hasProperty<CollationRequirement>(newProps)) {
+ // Update the collation specification to use the input variable.
+ auto& collationReq = getProperty<CollationRequirement>(newProps);
+ for (auto& [projName, op] : collationReq.getCollationSpec()) {
+ if (projName == projectionName) {
+ projName = newProjName;
+ }
+ }
+ }
+
+ {
+ // Update the distribution specification to use the input variable;
+ auto& distribReq = getProperty<DistributionRequirement>(newProps);
+ for (auto& projName : distribReq.getDistributionAndProjections()._projectionNames) {
+ if (projName == projectionName) {
+ projName = newProjName;
+ }
+ }
+ }
+
+ ABT physicalEval = n;
+ optimizeChild<EvaluationNode>(
+ _queue, kDefaultPriority, std::move(physicalEval), std::move(newProps));
+ return;
+ }
+
+ if (propertyAffectsProjection<DistributionRequirement>(_physProps, projectionName)) {
+ // We cannot satisfy distribution on the projection we output.
+ return;
+ }
+ if (propertyAffectsProjection<CollationRequirement>(_physProps, projectionName)) {
+ // In general, we cannot satisfy collation on the projection we output.
+ // TODO consider x = y+1, we can propagate the collation requirement from x to y.
+ return;
+ }
+ if (!propertyAffectsProjection<ProjectionRequirement>(_physProps, projectionName)) {
+ // We do not require the projection. Do not place a physical evaluation node and
+ // continue optimizing the child.
+ optimizeUnderNewProperties(_queue, kDefaultPriority, node.getChild(), _physProps);
+ return;
+ }
+
+ // Remove our projection from requirement, and add projections we depend on to the
+ // requirement.
+ PhysProps newProps = _physProps;
+
+ VariableNameSetType references = collectVariableReferences(n);
+ if (checkIntroducesScanProjectionUnderIndexOnly(references)) {
+ // Reject if under indexing requirements and now we introduce dependence on scan
+ // projection.
+ return;
+ }
+
+ addRemoveProjectionsToProperties(
+ newProps, std::move(references), ProjectionNameVector{projectionName});
+ getProperty<DistributionRequirement>(newProps).setDisableExchanges(true);
+
+ ABT physicalEval = n;
+ optimizeChild<EvaluationNode>(
+ _queue, kDefaultPriority, std::move(physicalEval), std::move(newProps));
+ }
+
+ void operator()(const ABT& n, const SargableNode& node) {
+ if (hasProperty<LimitSkipRequirement>(_physProps)) {
+ // Cannot satisfy limit-skip.
+ return;
+ }
+
+ const IndexingAvailability& indexingAvailability =
+ getPropertyConst<IndexingAvailability>(_logicalProps);
+ if (node.getChild().cast<MemoLogicalDelegatorNode>()->getGroupId() !=
+ indexingAvailability.getScanGroupId()) {
+ // To optimize a sargable predicate, we must have the scan group as a child.
+ return;
+ }
+
+ const std::string& scanDefName = indexingAvailability.getScanDefName();
+ const auto& scanDef = _memo.getMetadata()._scanDefs.at(scanDefName);
+
+
+ // We do not check indexDefs to be empty here. We want to allow evaluations to be covered
+ // via a physical scan even in the absence of indexes.
+
+ const IndexingRequirement& requirements = getPropertyConst<IndexingRequirement>(_physProps);
+ const IndexReqTarget indexReqTarget = requirements.getIndexReqTarget();
+ switch (indexReqTarget) {
+ case IndexReqTarget::Complete:
+ if (_hints._disableScan) {
+ return;
+ }
+ break;
+
+ case IndexReqTarget::Index:
+ case IndexReqTarget::Seek:
+ if (_hints._disableIndexes == DisableIndexOptions::DisableAll) {
+ return;
+ }
+ break;
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+ const auto& satisfiedPartialIndexes =
+ getPropertyConst<IndexingAvailability>(
+ _memo.getGroup(requirements.getSatisfiedPartialIndexesGroupId())._logicalProperties)
+ .getSatisfiedPartialIndexes();
+
+ const bool needsRID = requirements.getNeedsRID();
+ const ProjectionName& ridProjName = _ridProjections.at(scanDefName);
+
+ const ProjectionName& scanProjectionName = indexingAvailability.getScanProjection();
+ const GroupIdType scanGroupId = indexingAvailability.getScanGroupId();
+ const LogicalProps& scanLogicalProps = _memo.getGroup(scanGroupId)._logicalProperties;
+ const CEType scanGroupCE =
+ getPropertyConst<CardinalityEstimate>(scanLogicalProps).getEstimate();
+
+ const PartialSchemaRequirements& reqMap = node.getReqMap();
+
+ if (indexReqTarget != IndexReqTarget::Index &&
+ hasProperty<CollationRequirement>(_physProps)) {
+ // PhysicalScan or Seek cannot satisfy any collation requirement.
+ // TODO: consider rid?
+ return;
+ }
+
+ for (const auto& [key, req] : reqMap) {
+ if (key.emptyPath()) {
+ // We cannot satisfy without a field.
+ return;
+ }
+ if (key._projectionName != scanProjectionName) {
+ // We can only satisfy partial schema requirements using our root projection.
+ return;
+ }
+ }
+
+ const auto& requiredProjections =
+ getPropertyConst<ProjectionRequirement>(_physProps).getProjections();
+ bool requiresRootProjection = false;
+ {
+ auto projectionsLeftToSatisfy = requiredProjections;
+ if (indexReqTarget != IndexReqTarget::Index) {
+ // Deliver root projection if required.
+ requiresRootProjection = projectionsLeftToSatisfy.erase(scanProjectionName);
+ }
+
+ for (const auto& entry : reqMap) {
+ if (entry.second.hasBoundProjectionName()) {
+ // Project field only if it required.
+ const ProjectionName& projectionName = entry.second.getBoundProjectionName();
+ projectionsLeftToSatisfy.erase(projectionName);
+ }
+ }
+ if (!projectionsLeftToSatisfy.getVector().empty()) {
+ // Unknown projections remain. Reject.
+ return;
+ }
+ }
+
+ const auto& ceProperty = getPropertyConst<CardinalityEstimate>(_logicalProps);
+ const CEType currentCE = ceProperty.getEstimate();
+ const PartialSchemaKeyCE& partialSchemaKeyCEMap = ceProperty.getPartialSchemaKeyCEMap();
+
+ ProjectionRenames projectionRenames;
+ if (indexReqTarget == IndexReqTarget::Index) {
+ ProjectionCollationSpec requiredCollation;
+ if (hasProperty<CollationRequirement>(_physProps)) {
+ requiredCollation =
+ getPropertyConst<CollationRequirement>(_physProps).getCollationSpec();
+ }
+
+ // Consider all candidate indexes, and check if they satisfy the collation and
+ // distribution requirements.
+ for (const auto& [indexDefName, candidateIndexEntry] : node.getCandidateIndexMap()) {
+ const auto& indexDef = scanDef.getIndexDefs().at(indexDefName);
+ if (!indexDef.getPartialReqMap().empty() &&
+ (_hints._disableIndexes == DisableIndexOptions::DisablePartialOnly ||
+ satisfiedPartialIndexes.count(indexDefName) == 0)) {
+ // Consider only indexes for which we satisfy partial requirements.
+ continue;
+ }
+
+ {
+ bool canUseParallelScanUnused = false;
+ if (!distributionsCompatible(IndexReqTarget::Index,
+ indexDef.getDistributionAndPaths(),
+ scanProjectionName,
+ scanLogicalProps,
+ reqMap,
+ canUseParallelScanUnused)) {
+ return;
+ }
+ }
+
+ const auto availableDirections = indexSatisfiesCollation(
+ indexDef.getCollationSpec(), candidateIndexEntry, requiredCollation);
+ if (!availableDirections._forward && !availableDirections._backward) {
+ // Failed to satisfy collation.
+ continue;
+ }
+
+ uassert(6624103,
+ "Either forward or backward direction must be available.",
+ availableDirections._forward || availableDirections._backward);
+
+ auto indexProjectionMap = candidateIndexEntry._fieldProjectionMap;
+ indexProjectionMap._ridProjection = needsRID ? ridProjName : "";
+
+ {
+ // Remove unused projections from the field projection map.
+ auto& fieldProjMap = indexProjectionMap._fieldProjections;
+ for (auto it = fieldProjMap.begin(); it != fieldProjMap.end();) {
+ const ProjectionName& projName = it->second;
+ if (!requiredProjections.find(projName).second &&
+ candidateIndexEntry._residualRequirementsTempProjections.count(
+ projName) == 0) {
+ fieldProjMap.erase(it++);
+ } else {
+ it++;
+ }
+ }
+ }
+
+ CEType indexCE = currentCE;
+ ResidualRequirements residualRequirements;
+ if (!candidateIndexEntry._residualRequirements.empty()) {
+ SelectivityType residualSelectivity = 1.0;
+ SelectivityType currentSelectivity = currentCE / scanGroupCE;
+
+ for (const auto& [residualKey, residualReq] :
+ candidateIndexEntry._residualRequirements) {
+ const auto& queryKey = candidateIndexEntry._residualKeyMap.at(residualKey);
+ const CEType ce = partialSchemaKeyCEMap.at(queryKey);
+ if (scanGroupCE > 0.0) {
+ residualSelectivity *= ce / scanGroupCE;
+ }
+ residualRequirements.emplace_back(residualKey, residualReq, ce);
+ }
+ if (residualSelectivity > 0.0) {
+ indexCE = scanGroupCE * (currentSelectivity / residualSelectivity);
+ }
+ }
+
+ const auto& intervals = candidateIndexEntry._intervals;
+ ABT physNode = make<Blackhole>();
+ NodeCEMap nodeCEMap;
+
+ // TODO: consider pre-computing as part of the candidateIndexes structure.
+ const auto singularInterval = MultiKeyIntervalReqExpr::getSingularDNF(intervals);
+ const bool needsUniqueStage = singularInterval &&
+ !areMultiKeyIntervalsEqualities(*singularInterval) && indexDef.isMultiKey() &&
+ requirements.getDedupRID();
+
+ if (singularInterval) {
+ physNode =
+ make<IndexScanNode>(std::move(indexProjectionMap),
+ IndexSpecification{scanDefName,
+ indexDefName,
+ *singularInterval,
+ !availableDirections._forward});
+ nodeCEMap.emplace(physNode.cast<Node>(), indexCE);
+ } else {
+ physNode = lowerIntervals(_prefixId,
+ ridProjName,
+ std::move(indexProjectionMap),
+ scanDefName,
+ indexDefName,
+ intervals,
+ !availableDirections._forward,
+ indexCE,
+ scanGroupCE,
+ nodeCEMap);
+ }
+
+ applyProjectionRenames(projectionRenames, physNode, [&](const ABT& node) {
+ nodeCEMap.emplace(node.cast<Node>(), indexCE);
+ });
+
+ lowerPartialSchemaRequirements(
+ indexCE, scanGroupCE, residualRequirements, physNode, nodeCEMap);
+
+ if (needsUniqueStage) {
+ // Insert unique stage if we need to, after the residual requirements.
+ physNode =
+ make<UniqueNode>(ProjectionNameVector{ridProjName}, std::move(physNode));
+ nodeCEMap.emplace(physNode.cast<Node>(), currentCE);
+ }
+
+ optimizeChildrenNoAssert(
+ _queue, kDefaultPriority, std::move(physNode), {}, std::move(nodeCEMap));
+ }
+ } else {
+ bool canUseParallelScan = false;
+ if (!distributionsCompatible(indexReqTarget,
+ scanDef.getDistributionAndPaths(),
+ scanProjectionName,
+ scanLogicalProps,
+ reqMap,
+ canUseParallelScan)) {
+ return;
+ }
+
+ FieldProjectionMap fieldProjectionMap;
+ if (indexReqTarget == IndexReqTarget::Complete && needsRID) {
+ fieldProjectionMap._ridProjection = ridProjName;
+ }
+
+ ResidualRequirements residualRequirements;
+ computePhysicalScanParams(_prefixId,
+ reqMap,
+ partialSchemaKeyCEMap,
+ requiredProjections,
+ residualRequirements,
+ projectionRenames,
+ fieldProjectionMap,
+ requiresRootProjection);
+ if (requiresRootProjection) {
+ fieldProjectionMap._rootProjection = scanProjectionName;
+ }
+
+ NodeCEMap nodeCEMap;
+ ABT physNode = make<Blackhole>();
+ CEType baseCE = 0.0;
+
+ if (indexReqTarget == IndexReqTarget::Complete) {
+ baseCE = scanGroupCE;
+
+ // Return a physical scan with field map.
+ physNode = make<PhysicalScanNode>(
+ std::move(fieldProjectionMap), scanDefName, canUseParallelScan);
+ nodeCEMap.emplace(physNode.cast<Node>(), baseCE);
+ } else {
+ baseCE = 1.0;
+
+ // Try Seek with Limit 1.
+ physNode = make<SeekNode>(ridProjName, std::move(fieldProjectionMap), scanDefName);
+ nodeCEMap.emplace(physNode.cast<Node>(), baseCE);
+
+ physNode = make<LimitSkipNode>(LimitSkipRequirement{1, 0}, std::move(physNode));
+ nodeCEMap.emplace(physNode.cast<Node>(), baseCE);
+ }
+
+ applyProjectionRenames(std::move(projectionRenames), physNode, [&](const ABT& node) {
+ nodeCEMap.emplace(node.cast<Node>(), baseCE);
+ });
+
+ lowerPartialSchemaRequirements(
+ baseCE, scanGroupCE, residualRequirements, physNode, nodeCEMap);
+ optimizeChildrenNoAssert(
+ _queue, kDefaultPriority, std::move(physNode), {}, std::move(nodeCEMap));
+ }
+ }
+
+ void operator()(const ABT& /*n*/, const RIDIntersectNode& node) {
+ const auto& indexingAvailability = getPropertyConst<IndexingAvailability>(_logicalProps);
+ const std::string& scanDefName = indexingAvailability.getScanDefName();
+ {
+ const auto& scanDef = _memo.getMetadata()._scanDefs.at(scanDefName);
+ if (scanDef.getIndexDefs().empty()) {
+ // Reject if we do not have any indexes.
+ return;
+ }
+ }
+
+ if (hasProperty<LimitSkipRequirement>(_physProps)) {
+ // Cannot satisfy limit-skip.
+ return;
+ }
+
+ const IndexingRequirement& requirements = getPropertyConst<IndexingRequirement>(_physProps);
+ const bool dedupRID = requirements.getDedupRID();
+ const IndexReqTarget indexReqTarget = requirements.getIndexReqTarget();
+ if (indexReqTarget == IndexReqTarget::Seek) {
+ return;
+ }
+ const bool isIndex = indexReqTarget == IndexReqTarget::Index;
+ if (isIndex && (!node.hasLeftIntervals() || !node.hasRightIntervals())) {
+ // We need to have proper intervals on both sides.
+ return;
+ }
+
+ const auto& distribRequirement = getPropertyConst<DistributionRequirement>(_physProps);
+ const auto& distrAndProjections = distribRequirement.getDistributionAndProjections();
+ if (isIndex) {
+ switch (distrAndProjections._type) {
+ case DistributionType::UnknownPartitioning:
+ case DistributionType::RoundRobin:
+ // Cannot satisfy unknown or round-robin distributions.
+ return;
+
+ default:
+ break;
+ }
+ }
+
+ const GroupIdType leftGroupId =
+ node.getLeftChild().cast<MemoLogicalDelegatorNode>()->getGroupId();
+ const GroupIdType rightGroupId =
+ node.getRightChild().cast<MemoLogicalDelegatorNode>()->getGroupId();
+
+ const LogicalProps& leftLogicalProps = _memo.getGroup(leftGroupId)._logicalProperties;
+ const LogicalProps& rightLogicalProps = _memo.getGroup(rightGroupId)._logicalProperties;
+
+ const CEType intersectedCE =
+ getPropertyConst<CardinalityEstimate>(_logicalProps).getEstimate();
+ const CEType leftCE = getPropertyConst<CardinalityEstimate>(leftLogicalProps).getEstimate();
+ const CEType rightCE =
+ getPropertyConst<CardinalityEstimate>(rightLogicalProps).getEstimate();
+ const ProjectionNameSet& leftProjections =
+ getPropertyConst<ProjectionAvailability>(leftLogicalProps).getProjections();
+ const ProjectionNameSet& rightProjections =
+ getPropertyConst<ProjectionAvailability>(rightLogicalProps).getProjections();
+
+ // Split required projections between inner and outer side.
+ ProjectionNameOrderPreservingSet leftChildProjections;
+ ProjectionNameOrderPreservingSet rightChildProjections;
+ for (const ProjectionName& projectionName :
+ getPropertyConst<ProjectionRequirement>(_physProps).getProjections().getVector()) {
+ if (projectionName != node.getScanProjectionName() &&
+ leftProjections.count(projectionName) > 0) {
+ leftChildProjections.emplace_back(projectionName);
+ } else if (rightProjections.count(projectionName) > 0) {
+ if (isIndex && projectionName == node.getScanProjectionName()) {
+ return;
+ }
+ rightChildProjections.emplace_back(projectionName);
+ } else {
+ uasserted(6624104,
+ "Required projection must appear in either the left or the right child "
+ "projections");
+ return;
+ }
+ }
+
+ ProjectionCollationSpec collationSpec;
+ if (hasProperty<CollationRequirement>(_physProps)) {
+ collationSpec = getPropertyConst<CollationRequirement>(_physProps).getCollationSpec();
+ }
+
+ // Split collation between inner and outer side.
+ const CollationSplitResult& collationLeftRightSplit =
+ splitCollationSpec(collationSpec, leftProjections, rightProjections);
+ const CollationSplitResult& collationRightLeftSplit =
+ splitCollationSpec(collationSpec, rightProjections, leftProjections);
+
+ // We are propagating the distribution requirements to both sides.
+ PhysProps leftPhysProps = _physProps;
+ PhysProps rightPhysProps = _physProps;
+
+ // Specifically do not propagate limit-skip.
+ // TODO: handle similarly to physical join.
+ removeProperty<LimitSkipRequirement>(leftPhysProps);
+ removeProperty<LimitSkipRequirement>(rightPhysProps);
+
+ getProperty<DistributionRequirement>(leftPhysProps).setDisableExchanges(false);
+ getProperty<DistributionRequirement>(rightPhysProps).setDisableExchanges(false);
+
+ const ProjectionName& ridProjName = _ridProjections.at(
+ getPropertyConst<IndexingAvailability>(_logicalProps).getScanDefName());
+ setPropertyOverwrite<IndexingRequirement>(
+ leftPhysProps,
+ {IndexReqTarget::Index,
+ true /*needRID*/,
+ !isIndex && dedupRID,
+ requirements.getSatisfiedPartialIndexesGroupId()});
+ setPropertyOverwrite<IndexingRequirement>(
+ rightPhysProps,
+ {isIndex ? IndexReqTarget::Index : IndexReqTarget::Seek,
+ true /*needRID*/,
+ !isIndex && dedupRID,
+ requirements.getSatisfiedPartialIndexesGroupId()});
+
+ setPropertyOverwrite<ProjectionRequirement>(leftPhysProps, std::move(leftChildProjections));
+ setPropertyOverwrite<ProjectionRequirement>(rightPhysProps,
+ std::move(rightChildProjections));
+
+ if (!isIndex) {
+ // Add repeated execution property to inner side.
+ CEType estimatedRepetitions = hasProperty<RepetitionEstimate>(_physProps)
+ ? getPropertyConst<RepetitionEstimate>(_physProps).getEstimate()
+ : 1.0;
+ estimatedRepetitions *=
+ getPropertyConst<CardinalityEstimate>(leftLogicalProps).getEstimate();
+ setPropertyOverwrite<RepetitionEstimate>(rightPhysProps,
+ RepetitionEstimate{estimatedRepetitions});
+ }
+
+ const auto& optimizeFn = std::bind(&ImplementationVisitor::optimizeRIDIntersect,
+ this,
+ isIndex,
+ dedupRID,
+ indexingAvailability.getPossiblyEqPredsOnly(),
+ std::cref(ridProjName),
+ std::cref(collationLeftRightSplit),
+ std::cref(collationRightLeftSplit),
+ intersectedCE,
+ leftCE,
+ rightCE,
+ std::cref(leftPhysProps),
+ std::cref(rightPhysProps),
+ std::cref(node.getLeftChild()),
+ std::cref(node.getRightChild()));
+
+ // Always optimize under same distributions on left and on right.
+ optimizeFn();
+
+ if (isIndex) {
+ switch (distrAndProjections._type) {
+ case DistributionType::HashPartitioning:
+ case DistributionType::RangePartitioning: {
+ // Specifically for index intersection, try propagating the requirement on one
+ // side and replicating the other.
+
+ const auto& leftDistributions =
+ getPropertyConst<DistributionAvailability>(leftLogicalProps)
+ .getDistributionSet();
+ const auto& rightDistributions =
+ getPropertyConst<DistributionAvailability>(rightLogicalProps)
+ .getDistributionSet();
+
+ if (leftDistributions.count(distrAndProjections) > 0) {
+ setPropertyOverwrite<DistributionRequirement>(leftPhysProps,
+ distribRequirement);
+ setPropertyOverwrite<DistributionRequirement>(
+ rightPhysProps, DistributionRequirement{DistributionType::Replicated});
+ optimizeFn();
+ }
+
+ if (rightDistributions.count(distrAndProjections) > 0) {
+ setPropertyOverwrite<DistributionRequirement>(
+ leftPhysProps, DistributionRequirement{DistributionType::Replicated});
+ setPropertyOverwrite<DistributionRequirement>(rightPhysProps,
+ distribRequirement);
+ optimizeFn();
+ }
+ break;
+ }
+
+ default:
+ break;
+ }
+ }
+ }
+
+ void operator()(const ABT& /*n*/, const BinaryJoinNode& node) {
+ // TODO: optimize binary joins
+ uasserted(6624105, "not implemented");
+ }
+
+ void operator()(const ABT& n, const UnionNode& node) {
+ if (hasProperty<LimitSkipRequirement>(_physProps)) {
+ // We cannot satisfy limit-skip requirements.
+ return;
+ }
+ if (hasProperty<CollationRequirement>(_physProps)) {
+ // In general we cannot satisfy collation requirements.
+ // TODO: This may be possible with a merge sort type of node.
+ return;
+ }
+
+ // Only need to propagate the required projection set.
+ ABT physicalUnion = make<UnionNode>(
+ getPropertyConst<ProjectionRequirement>(_physProps).getProjections().getVector(),
+ node.nodes());
+
+ // Optimize each child under the same physical properties.
+ ChildPropsType childProps;
+ for (auto& child : physicalUnion.cast<UnionNode>()->nodes()) {
+ PhysProps newProps = _physProps;
+ childProps.emplace_back(&child, std::move(newProps));
+ }
+
+ optimizeChildren<UnionNode>(
+ _queue, kDefaultPriority, std::move(physicalUnion), std::move(childProps));
+ }
+
+ void operator()(const ABT& n, const GroupByNode& node) {
+ if (hasProperty<LimitSkipRequirement>(_physProps)) {
+ // We cannot satisfy limit-skip requirements.
+ // TODO: consider an optimization where we keep track of at most "limit" groups.
+ return;
+ }
+ if (hasProperty<CollationRequirement>(_physProps)) {
+ // In general we cannot satisfy collation requirements.
+ // TODO: consider stream group-by.
+ return;
+ }
+
+ if (propertyAffectsProjections<DistributionRequirement>(
+ _physProps, node.getAggregationProjectionNames())) {
+ // We cannot satisfy distribution on the aggregations.
+ return;
+ }
+
+ const ProjectionNameVector& groupByProjections = node.getGroupByProjectionNames();
+
+ const bool isLocal = node.getType() == GroupNodeType::Local;
+ if (!isLocal) {
+ // We are constrained in terms of distribution only if we are a global or complete agg.
+
+ const auto& distribAndProjections =
+ getPropertyConst<DistributionRequirement>(_physProps)
+ .getDistributionAndProjections();
+ switch (distribAndProjections._type) {
+ case DistributionType::UnknownPartitioning:
+ case DistributionType::RoundRobin:
+ // Cannot satisfy unknown or round-robin partitioning.
+ return;
+
+ case DistributionType::HashPartitioning: {
+ ProjectionNameSet groupByProjectionSet;
+ for (const ProjectionName& projectionName : groupByProjections) {
+ groupByProjectionSet.insert(projectionName);
+ }
+ for (const ProjectionName& projectionName :
+ distribAndProjections._projectionNames) {
+ if (groupByProjectionSet.count(projectionName) == 0) {
+ // We can only be partitioned on projections on which we group.
+ return;
+ }
+ }
+ break;
+ }
+
+ case DistributionType::RangePartitioning:
+ if (distribAndProjections._projectionNames != groupByProjections) {
+ // For range partitioning we need to be partitioned exactly in the same
+ // order as our group-by projections.
+ return;
+ }
+ break;
+
+ default:
+ break;
+ }
+ }
+
+ PhysProps newProps = _physProps;
+
+ // TODO: remove RepetitionEstimate if the subtree does not use bound variables.
+ // TODO: this is not the case for stream group-by.
+
+ // Specifically do not propagate limit-skip.
+ removeProperty<LimitSkipRequirement>(newProps);
+
+ getProperty<DistributionRequirement>(newProps).setDisableExchanges(isLocal);
+
+ // Iterate over the aggregation expressions and only add those required.
+ ABTVector aggregationProjections;
+ ProjectionNameVector aggregationProjectionNames;
+ VariableNameSetType projectionsToAdd;
+ for (const ProjectionName& groupByProjName : groupByProjections) {
+ projectionsToAdd.insert(groupByProjName);
+ }
+
+ const auto& requiredProjections =
+ getPropertyConst<ProjectionRequirement>(_physProps).getProjections();
+ for (size_t aggIndex = 0; aggIndex < node.getAggregationExpressions().size(); aggIndex++) {
+ const ProjectionName& aggProjectionName =
+ node.getAggregationProjectionNames().at(aggIndex);
+
+ if (requiredProjections.find(aggProjectionName).second) {
+ // We require this agg expression.
+ aggregationProjectionNames.push_back(aggProjectionName);
+ const ABT& aggExpr = node.getAggregationExpressions().at(aggIndex);
+ aggregationProjections.push_back(aggExpr);
+
+ for (const auto& var : VariableEnvironment::getVariables(aggExpr)._variables) {
+ // Add all references this expression requires.
+ projectionsToAdd.insert(var->name());
+ }
+ }
+ }
+
+ addRemoveProjectionsToProperties(
+ newProps, projectionsToAdd, node.getAggregationProjectionNames());
+
+ ABT physicalGroupBy = make<GroupByNode>(groupByProjections,
+ std::move(aggregationProjectionNames),
+ std::move(aggregationProjections),
+ node.getType(),
+ node.getChild());
+ optimizeChild<GroupByNode>(
+ _queue, kDefaultPriority, std::move(physicalGroupBy), std::move(newProps));
+ }
+
+ void operator()(const ABT& n, const UnwindNode& node) {
+ const ProjectionName& pidProjectionName = node.getPIDProjectionName();
+ const ProjectionNameVector& projectionNames = {(node.getProjectionName()),
+ pidProjectionName};
+
+ if (propertyAffectsProjections<DistributionRequirement>(_physProps, projectionNames)) {
+ // We cannot satisfy distribution on the unwound output, or pid.
+ return;
+ }
+ if (propertyAffectsProjections<CollationRequirement>(_physProps, projectionNames)) {
+ // We cannot satisfy collation on the output.
+ return;
+ }
+ if (hasProperty<LimitSkipRequirement>(_physProps)) {
+ // Cannot satisfy limit-skip.
+ return;
+ }
+
+ PhysProps newProps = _physProps;
+ addRemoveProjectionsToProperties(
+ newProps, collectVariableReferences(n), ProjectionNameVector{pidProjectionName});
+
+ // Specifically do not propagate limit-skip.
+ removeProperty<LimitSkipRequirement>(newProps);
+ // Keep collation property if given it does not affect output.
+
+ getProperty<DistributionRequirement>(newProps).setDisableExchanges(false);
+
+ ABT physicalUnwind = n;
+ optimizeChild<UnwindNode>(
+ _queue, kDefaultPriority, std::move(physicalUnwind), std::move(newProps));
+ }
+
+ void operator()(const ABT& /*n*/, const CollationNode& node) {
+ if (getPropertyConst<DistributionRequirement>(_physProps)
+ .getDistributionAndProjections()
+ ._type != DistributionType::Centralized) {
+ // We can only pick up collation under centralized (but we can enforce under any
+ // distribution).
+ return;
+ }
+
+ optimizeSimplePropertyNode<CollationNode, CollationRequirement>(node);
+ }
+
+ void operator()(const ABT& /*n*/, const LimitSkipNode& node) {
+ // We can pick-up limit-skip under any distribution (but enforce under centralized or
+ // replicated).
+
+ PhysProps newProps = _physProps;
+ LimitSkipRequirement newProp = node.getProperty();
+
+ removeProperty<LimitEstimate>(newProps);
+
+ if (hasProperty<LimitSkipRequirement>(_physProps)) {
+ const LimitSkipRequirement& required =
+ getPropertyConst<LimitSkipRequirement>(_physProps);
+ LimitSkipRequirement merged(required.getLimit(), required.getSkip());
+
+ combineLimitSkipProperties(merged, newProp);
+ // Continue with new unenforced requirement.
+ newProp = std::move(merged);
+ }
+
+ setPropertyOverwrite<LimitSkipRequirement>(newProps, std::move(newProp));
+ getProperty<DistributionRequirement>(newProps).setDisableExchanges(false);
+
+ optimizeUnderNewProperties(_queue, kDefaultPriority, node.getChild(), std::move(newProps));
+ }
+
+ void operator()(const ABT& /*n*/, const ExchangeNode& node) {
+ optimizeSimplePropertyNode<ExchangeNode, DistributionRequirement>(node);
+ }
+
+ void operator()(const ABT& n, const RootNode& node) {
+ PhysProps newProps = _physProps;
+ setPropertyOverwrite<ProjectionRequirement>(newProps, node.getProperty());
+ getProperty<DistributionRequirement>(newProps).setDisableExchanges(false);
+
+ ABT rootNode = n;
+ optimizeChild<RootNode>(_queue, kDefaultPriority, std::move(rootNode), std::move(newProps));
+ }
+
+ template <typename T>
+ void operator()(const ABT& /*n*/, const T& /*node*/) {
+ static_assert(!canBeLogicalNode<T>(), "Logical node must implement its visitor.");
+ }
+
+ ImplementationVisitor(const Memo& memo,
+ const QueryHints& hints,
+ const opt::unordered_map<std::string, ProjectionName>& ridProjections,
+ PrefixId& prefixId,
+ PhysRewriteQueue& queue,
+ const PhysProps& physProps,
+ const LogicalProps& logicalProps)
+ : _memo(memo),
+ _hints(hints),
+ _ridProjections(ridProjections),
+ _prefixId(prefixId),
+ _queue(queue),
+ _physProps(physProps),
+ _logicalProps(logicalProps) {}
+
+private:
+ template <class NodeType, class PropType>
+ void optimizeSimplePropertyNode(const NodeType& node) {
+ const PropType& nodeProp = node.getProperty();
+ PhysProps newProps = _physProps;
+ setPropertyOverwrite<PropType>(newProps, nodeProp);
+
+ getProperty<DistributionRequirement>(newProps).setDisableExchanges(false);
+ optimizeUnderNewProperties(_queue, kDefaultPriority, node.getChild(), std::move(newProps));
+ }
+
+ struct IndexAvailableDirections {
+ // Keep track if we can match against forward or backward direction.
+ bool _forward = true;
+ bool _backward = true;
+ };
+
+ IndexAvailableDirections indexSatisfiesCollation(
+ const IndexCollationSpec& indexCollationSpec,
+ const CandidateIndexEntry& candidateIndexEntry,
+ const ProjectionCollationSpec& requiredCollationSpec) {
+ if (requiredCollationSpec.empty()) {
+ return {true, true};
+ }
+
+ IndexAvailableDirections result;
+ size_t collationSpecIndex = 0;
+ bool indexSuitable = true;
+ const auto& fieldProjections = candidateIndexEntry._fieldProjectionMap._fieldProjections;
+
+ // Verify the index is compatible with our collation requirement, and can deliver the right
+ // order of paths.
+ for (size_t indexField = 0; indexField < indexCollationSpec.size(); indexField++) {
+ const bool needsCollation = candidateIndexEntry._fieldsToCollate.count(indexField) > 0;
+
+ auto it = fieldProjections.find(encodeIndexKeyName(indexField));
+ if (it == fieldProjections.cend()) {
+ // No bound projection for this index field.
+ if (needsCollation) {
+ // We cannot satisfy the rest of the collation requirements.
+ indexSuitable = false;
+ break;
+ }
+ continue;
+ }
+ const ProjectionName& projName = it->second;
+
+ if (!needsCollation) {
+ // We do not need to collate this field because of equality.
+ if (requiredCollationSpec.at(collationSpecIndex).first == projName) {
+ // We can satisfy the next collation requirement independent of collation op.
+ if (++collationSpecIndex >= requiredCollationSpec.size()) {
+ break;
+ }
+ }
+ continue;
+ }
+
+ // Check if we can satisfy the next collation requirement.
+ const auto& collationEntry = requiredCollationSpec.at(collationSpecIndex);
+ if (collationEntry.first != projName) {
+ indexSuitable = false;
+ break;
+ }
+
+ const auto& indexCollationEntry = indexCollationSpec.at(indexField);
+ if (result._forward &&
+ !collationOpsCompatible(indexCollationEntry._op, collationEntry.second)) {
+ result._forward = false;
+ }
+ if (result._backward &&
+ !collationOpsCompatible(reverseCollationOp(indexCollationEntry._op),
+ collationEntry.second)) {
+ result._backward = false;
+ }
+ if (!result._forward && !result._backward) {
+ indexSuitable = false;
+ break;
+ }
+ if (++collationSpecIndex >= requiredCollationSpec.size()) {
+ break;
+ }
+ }
+
+ if (!indexSuitable || collationSpecIndex < requiredCollationSpec.size()) {
+ return {false, false};
+ }
+ return result;
+ }
+
+ /**
+ * Check if we are under index-only requirements and expression introduces dependency on scan
+ * projection.
+ */
+ bool checkIntroducesScanProjectionUnderIndexOnly(const VariableNameSetType& references) {
+ return hasProperty<IndexingAvailability>(_logicalProps) &&
+ getPropertyConst<IndexingRequirement>(_physProps).getIndexReqTarget() ==
+ IndexReqTarget::Index &&
+ references.find(
+ getPropertyConst<IndexingAvailability>(_logicalProps).getScanProjection()) !=
+ references.cend();
+ }
+
+ bool distributionsCompatible(const IndexReqTarget target,
+ const DistributionAndPaths& distributionAndPaths,
+ const ProjectionName& scanProjection,
+ const LogicalProps& scanLogicalProps,
+ const PartialSchemaRequirements& reqMap,
+ bool& canUseParallelScan) {
+ const DistributionRequirement& required =
+ getPropertyConst<DistributionRequirement>(_physProps);
+ const auto& distribAndProjections = required.getDistributionAndProjections();
+
+ const auto& scanDistributions =
+ getPropertyConst<DistributionAvailability>(scanLogicalProps).getDistributionSet();
+
+ switch (distribAndProjections._type) {
+ case DistributionType::Centralized:
+ return scanDistributions.count({DistributionType::Centralized}) > 0 ||
+ scanDistributions.count({DistributionType::Replicated}) > 0;
+
+ case DistributionType::Replicated:
+ return scanDistributions.count({DistributionType::Replicated}) > 0;
+
+ case DistributionType::RoundRobin:
+ if (target == IndexReqTarget::Seek) {
+ // We can satisfy Seek with RoundRobin if we can scan the collection in
+ // parallel.
+ return scanDistributions.count({DistributionType::UnknownPartitioning}) > 0;
+ }
+
+ // TODO: Are two round robin distributions compatible?
+ return false;
+
+ case DistributionType::UnknownPartitioning:
+ if (target == IndexReqTarget::Index) {
+ // We cannot satisfy unknown partitioning with an index as (unlike parallel
+ // collection scan) we currently cannot perform a parallel index scan.
+ return false;
+ }
+
+ if (scanDistributions.count({DistributionType::UnknownPartitioning}) > 0) {
+ canUseParallelScan = true;
+ return true;
+ }
+ return false;
+
+ case DistributionType::HashPartitioning:
+ case DistributionType::RangePartitioning: {
+ if (distribAndProjections._type != distributionAndPaths._type) {
+ return false;
+ }
+
+ size_t distributionPartitionIndex = 0;
+ const ProjectionNameVector& requiredProjections =
+ distribAndProjections._projectionNames;
+
+ for (const ABT& partitioningPath : distributionAndPaths._paths) {
+ auto it = reqMap.find(PartialSchemaKey{scanProjection, partitioningPath});
+ if (it == reqMap.cend()) {
+ return false;
+ }
+
+ if (it->second.getBoundProjectionName() !=
+ requiredProjections.at(distributionPartitionIndex)) {
+ return false;
+ }
+ distributionPartitionIndex++;
+ }
+
+ return distributionPartitionIndex == requiredProjections.size();
+ }
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+ }
+
+ void setCollationForRIDIntersect(const CollationSplitResult& collationSplit,
+ PhysProps& leftPhysProps,
+ PhysProps& rightPhysProps) {
+ if (collationSplit._leftCollation.empty()) {
+ removeProperty<CollationRequirement>(leftPhysProps);
+ } else {
+ setPropertyOverwrite<CollationRequirement>(leftPhysProps,
+ collationSplit._leftCollation);
+ }
+
+ if (collationSplit._rightCollation.empty()) {
+ removeProperty<CollationRequirement>(rightPhysProps);
+ } else {
+ setPropertyOverwrite<CollationRequirement>(rightPhysProps,
+ collationSplit._rightCollation);
+ }
+ }
+
+ void optimizeRIDIntersect(const bool isIndex,
+ const bool dedupRID,
+ const bool useMergeJoin,
+ const ProjectionName& ridProjectionName,
+ const CollationSplitResult& collationLeftRightSplit,
+ const CollationSplitResult& collationRightLeftSplit,
+ const CEType intersectedCE,
+ const CEType leftCE,
+ const CEType rightCE,
+ const PhysProps& leftPhysProps,
+ const PhysProps& rightPhysProps,
+ const ABT& leftChild,
+ const ABT& rightChild) {
+ if (isIndex && collationRightLeftSplit._validSplit &&
+ (!collationLeftRightSplit._validSplit || leftCE > rightCE)) {
+ // Need to reverse the left and right side as the left collation split is not valid, or
+ // to use the larger CE as the other side.
+ optimizeRIDIntersect(true /*isIndex*/,
+ dedupRID,
+ useMergeJoin,
+ ridProjectionName,
+ collationRightLeftSplit,
+ {},
+ intersectedCE,
+ rightCE,
+ leftCE,
+ rightPhysProps,
+ leftPhysProps,
+ rightChild,
+ leftChild);
+ return;
+ }
+ if (!collationLeftRightSplit._validSplit) {
+ return;
+ }
+
+ if (isIndex) {
+ if (useMergeJoin && !_hints._disableMergeJoinRIDIntersect) {
+ // Try a merge join on RID since both of our children only have equality
+ // predicates.
+
+ NodeCEMap nodeCEMap;
+ ChildPropsType childProps;
+ PhysProps leftPhysPropsLocal = leftPhysProps;
+ PhysProps rightPhysPropsLocal = rightPhysProps;
+
+ setCollationForRIDIntersect(
+ collationLeftRightSplit, leftPhysPropsLocal, rightPhysPropsLocal);
+ if (dedupRID) {
+ getProperty<IndexingRequirement>(leftPhysPropsLocal)
+ .setDedupRID(true /*dedupRID*/);
+ getProperty<IndexingRequirement>(rightPhysPropsLocal)
+ .setDedupRID(true /*dedupRID*/);
+ }
+
+ ABT physNode = lowerRIDIntersectMergeJoin(_prefixId,
+ ridProjectionName,
+ intersectedCE,
+ leftCE,
+ rightCE,
+ leftPhysPropsLocal,
+ rightPhysPropsLocal,
+ leftChild,
+ rightChild,
+ nodeCEMap,
+ childProps);
+ optimizeChildrenNoAssert(_queue,
+ kDefaultPriority,
+ std::move(physNode),
+ std::move(childProps),
+ std::move(nodeCEMap));
+ } else {
+ if (!_hints._disableHashJoinRIDIntersect) {
+ // Try a HashJoin. Propagate dedupRID on left and right indexing
+ // requirements.
+
+ NodeCEMap nodeCEMap;
+ ChildPropsType childProps;
+ PhysProps leftPhysPropsLocal = leftPhysProps;
+ PhysProps rightPhysPropsLocal = rightPhysProps;
+
+ setCollationForRIDIntersect(
+ collationLeftRightSplit, leftPhysPropsLocal, rightPhysPropsLocal);
+ if (dedupRID) {
+ getProperty<IndexingRequirement>(leftPhysPropsLocal)
+ .setDedupRID(true /*dedupRID*/);
+ getProperty<IndexingRequirement>(rightPhysPropsLocal)
+ .setDedupRID(true /*dedupRID*/);
+ }
+
+ ABT physNode = lowerRIDIntersectHashJoin(_prefixId,
+ ridProjectionName,
+ intersectedCE,
+ leftCE,
+ rightCE,
+ leftPhysPropsLocal,
+ rightPhysPropsLocal,
+ leftChild,
+ rightChild,
+ nodeCEMap,
+ childProps);
+ optimizeChildrenNoAssert(_queue,
+ kDefaultPriority,
+ std::move(physNode),
+ std::move(childProps),
+ std::move(nodeCEMap));
+ }
+
+ // We can only attempt this strategy if we have no collation requirements.
+ if (!_hints._disableGroupByAndUnionRIDIntersect && dedupRID &&
+ collationLeftRightSplit._leftCollation.empty() &&
+ collationLeftRightSplit._rightCollation.empty()) {
+ // Try a Union+GroupBy. left and right indexing requirements are already
+ // initialized to not dedup.
+
+ NodeCEMap nodeCEMap;
+ ChildPropsType childProps;
+ PhysProps leftPhysPropsLocal = leftPhysProps;
+ PhysProps rightPhysPropsLocal = rightPhysProps;
+
+ setCollationForRIDIntersect(
+ collationLeftRightSplit, leftPhysPropsLocal, rightPhysPropsLocal);
+
+ ABT physNode = lowerRIDIntersectGroupBy(_prefixId,
+ ridProjectionName,
+ intersectedCE,
+ leftCE,
+ rightCE,
+ _physProps,
+ leftPhysPropsLocal,
+ rightPhysPropsLocal,
+ leftChild,
+ rightChild,
+ nodeCEMap,
+ childProps);
+ optimizeChildrenNoAssert(_queue,
+ kDefaultPriority,
+ std::move(physNode),
+ std::move(childProps),
+ std::move(nodeCEMap));
+ }
+ }
+ } else {
+ ABT physicalJoin = make<BinaryJoinNode>(JoinType::Inner,
+ ProjectionNameSet{ridProjectionName},
+ Constant::boolean(true),
+ leftChild,
+ rightChild);
+
+ PhysProps leftPhysPropsLocal = leftPhysProps;
+ PhysProps rightPhysPropsLocal = rightPhysProps;
+ setCollationForRIDIntersect(
+ collationLeftRightSplit, leftPhysPropsLocal, rightPhysPropsLocal);
+
+ optimizeChildren<BinaryJoinNode>(_queue,
+ kDefaultPriority,
+ std::move(physicalJoin),
+ std::move(leftPhysPropsLocal),
+ std::move(rightPhysPropsLocal));
+ }
+ }
+
+ // We don't own any of those;
+ const Memo& _memo;
+ const QueryHints& _hints;
+ const opt::unordered_map<std::string, ProjectionName>& _ridProjections;
+ PrefixId& _prefixId;
+ PhysRewriteQueue& _queue;
+ const PhysProps& _physProps;
+ const LogicalProps& _logicalProps;
+};
+
+void addImplementers(const Memo& memo,
+ const QueryHints& hints,
+ const opt::unordered_map<std::string, ProjectionName>& ridProjections,
+ PrefixId& prefixId,
+ PhysOptimizationResult& bestResult,
+ const properties::LogicalProps& logicalProps,
+ const OrderPreservingABTSet& logicalNodes) {
+ ImplementationVisitor visitor(memo,
+ hints,
+ ridProjections,
+ prefixId,
+ bestResult._queue,
+ bestResult._physProps,
+ logicalProps);
+ while (bestResult._lastImplementedNodePos < logicalNodes.size()) {
+ logicalNodes.at(bestResult._lastImplementedNodePos++).visit(visitor);
+ }
+}
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/implementers.h b/src/mongo/db/query/optimizer/cascades/implementers.h
new file mode 100644
index 00000000000..c26bbbbdf69
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/implementers.h
@@ -0,0 +1,50 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/cascades/memo.h"
+#include "mongo/db/query/optimizer/cascades/rewrite_queues.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer::cascades {
+
+/**
+ * Adds operator implementation rules for particular group and physical properties. Only add
+ * rewrites for newly added logical nodes
+ */
+void addImplementers(const Memo& memo,
+ const QueryHints& hints,
+ const opt::unordered_map<std::string, ProjectionName>& ridProjections,
+ PrefixId& prefixId,
+ PhysOptimizationResult& bestResult,
+ const properties::LogicalProps& logicalProps,
+ const OrderPreservingABTSet& logicalNodes);
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/interfaces.h b/src/mongo/db/query/optimizer/cascades/interfaces.h
new file mode 100644
index 00000000000..2a09720bed5
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/interfaces.h
@@ -0,0 +1,81 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/node_defs.h"
+#include "mongo/db/query/optimizer/props.h"
+
+namespace mongo::optimizer::cascades {
+
+class Memo;
+
+/**
+ * Interface for deriving properties. Typically we supply memo, but may derive without a memo as
+ * long as metadata is provided.
+ */
+class LogicalPropsInterface {
+public:
+ virtual ~LogicalPropsInterface() = default;
+
+ using NodePropsMap = opt::unordered_map<const Node*, properties::LogicalProps>;
+
+ virtual properties::LogicalProps deriveProps(const Metadata& metadata,
+ ABT::reference_type nodeRef,
+ NodePropsMap* nodePropsMap = nullptr,
+ const Memo* memo = nullptr,
+ GroupIdType groupId = -1) const = 0;
+};
+
+/**
+ * Interface for deriving CE for a newly added logical node in a new memo group.
+ */
+class CEInterface {
+public:
+ virtual ~CEInterface() = default;
+
+ virtual CEType deriveCE(const Memo& memo,
+ const properties::LogicalProps& logicalProps,
+ ABT::reference_type logicalNodeRef) const = 0;
+};
+
+/**
+ * Interface for deriving costs and adjusted CE (based on physical props) for a physical node.
+ */
+class CostingInterface {
+public:
+ virtual ~CostingInterface() = default;
+ virtual CostAndCE deriveCost(const Memo& memo,
+ const properties::PhysProps& physProps,
+ ABT::reference_type physNodeRef,
+ const ChildPropsType& childProps,
+ const NodeCEMap& nodeCEMap) const = 0;
+};
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp b/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp
new file mode 100644
index 00000000000..4f2a655a363
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp
@@ -0,0 +1,494 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/cascades/logical_props_derivation.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer::cascades {
+
+using namespace properties;
+
+static void populateInitialDistributions(const DistributionAndPaths& distributionAndPaths,
+ const bool isMultiPartition,
+ DistributionSet& distributions) {
+ switch (distributionAndPaths._type) {
+ case DistributionType::Centralized:
+ distributions.insert({DistributionType::Centralized});
+ break;
+
+ case DistributionType::Replicated:
+ uassert(6624106, "Invalid distribution specification", isMultiPartition);
+
+ distributions.insert({DistributionType::Centralized});
+ distributions.insert({DistributionType::Replicated});
+ break;
+
+ case DistributionType::HashPartitioning:
+ case DistributionType::RangePartitioning:
+ case DistributionType::UnknownPartitioning:
+ uassert(6624107, "Invalid distribution specification", isMultiPartition);
+
+ distributions.insert({DistributionType::UnknownPartitioning});
+ break;
+
+ default:
+ uasserted(6624108, "Invalid collection distribution");
+ }
+}
+
+static void populateDistributionPaths(const PartialSchemaRequirements& req,
+ const ProjectionName& scanProjectionName,
+ const DistributionAndPaths& distributionAndPaths,
+ DistributionSet& distributions) {
+ switch (distributionAndPaths._type) {
+ case DistributionType::HashPartitioning:
+ case DistributionType::RangePartitioning: {
+ ProjectionNameVector distributionProjections;
+
+ for (const ABT& path : distributionAndPaths._paths) {
+ auto it = req.find(PartialSchemaKey{scanProjectionName, path});
+ if (it == req.cend()) {
+ break;
+ }
+ if (it->second.hasBoundProjectionName()) {
+ distributionProjections.push_back(it->second.getBoundProjectionName());
+ }
+ }
+
+ if (distributionProjections.size() == distributionAndPaths._paths.size()) {
+ distributions.emplace(distributionAndPaths._type,
+ std::move(distributionProjections));
+ }
+ }
+
+ default:
+ break;
+ }
+}
+
+static bool computePossiblyEqPredsOnly(const PartialSchemaRequirements& reqMap) {
+ PartialSchemaRequirements equalitiesReqMap;
+ PartialSchemaRequirements fullyOpenReqMap;
+
+ for (const auto& [key, req] : reqMap) {
+ const auto& intervals = req.getIntervals();
+ if (auto singularInterval = IntervalReqExpr::getSingularDNF(intervals)) {
+ if (singularInterval->isFullyOpen()) {
+ fullyOpenReqMap.emplace(key, req);
+ } else if (singularInterval->isEquality()) {
+ equalitiesReqMap.emplace(key, req);
+ } else {
+ // Encountered a non-equality and not-fully-open interval.
+ return false;
+ }
+ } else {
+ // Encountered a non-trivial interval.
+ return false;
+ }
+ }
+
+ PartialSchemaKeySet resultKeySet;
+ PartialSchemaRequirement req_unused;
+ for (const auto& [key, req] : fullyOpenReqMap) {
+ findMatchingSchemaRequirement(
+ key, equalitiesReqMap, resultKeySet, req_unused, false /*setIntervalsAndBoundProj*/);
+ if (resultKeySet.empty()) {
+ // No possible match for fully open requirement.
+ return false;
+ }
+ }
+
+ return true;
+}
+
+class DeriveLogicalProperties {
+public:
+ LogicalProps transport(const ScanNode& node, LogicalProps /*bindResult*/) {
+ DistributionSet distributions;
+
+ const auto& scanDef = _metadata._scanDefs.at(node.getScanDefName());
+ populateInitialDistributions(
+ scanDef.getDistributionAndPaths(), _metadata.isParallelExecution(), distributions);
+ for (const auto& entry : scanDef.getIndexDefs()) {
+ populateInitialDistributions(entry.second.getDistributionAndPaths(),
+ _metadata.isParallelExecution(),
+ distributions);
+ }
+
+ return maybeUpdateNodePropsMap(
+ node,
+ makeLogicalProps(IndexingAvailability(_groupId,
+ node.getProjectionName(),
+ node.getScanDefName(),
+ true /*possiblyEqPredsOnly*/,
+ {} /*satisfiedPartialIndexes*/),
+ CollectionAvailability({node.getScanDefName()}),
+ DistributionAvailability(std::move(distributions))));
+ }
+
+ LogicalProps transport(const ValueScanNode& node, LogicalProps /*bindResult*/) {
+ // We do not originate indexing availability, and have empty collection availability with
+ // Centralized + Replicated distribution availability. During physical optimization we
+ // accept optimization under any distribution.
+ LogicalProps result =
+ makeLogicalProps(CollectionAvailability{{}}, DistributionAvailability{{}});
+ addCentralizedAndRoundRobinDistributions(result);
+ return maybeUpdateNodePropsMap(node, std::move(result));
+ }
+
+ LogicalProps transport(const MemoLogicalDelegatorNode& node) {
+ uassert(6624109, "Uninitialized memo", _memo != nullptr);
+ return maybeUpdateNodePropsMap(node, _memo->getGroup(node.getGroupId())._logicalProperties);
+ }
+
+ LogicalProps transport(const FilterNode& node,
+ LogicalProps childResult,
+ LogicalProps /*exprResult*/) {
+ // Propagate indexing, collection, and distribution availabilities.
+ LogicalProps result = std::move(childResult);
+ if (hasProperty<IndexingAvailability>(result)) {
+ getProperty<IndexingAvailability>(result).setPossiblyEqPredsOnly(false);
+ }
+ addCentralizedAndRoundRobinDistributions(result);
+ return maybeUpdateNodePropsMap(node, std::move(result));
+ }
+
+ LogicalProps transport(const EvaluationNode& node,
+ LogicalProps childResult,
+ LogicalProps /*exprResult*/) {
+ // We are specifically not adding the node's projection to ProjectionAvailability here.
+ // The logical properties already contains projection availability which is derived first
+ // when the memo group is created.
+ LogicalProps result = std::move(childResult);
+ if (hasProperty<IndexingAvailability>(result)) {
+ getProperty<IndexingAvailability>(result).setPossiblyEqPredsOnly(false);
+ }
+ addCentralizedAndRoundRobinDistributions(result);
+ return maybeUpdateNodePropsMap(node, std::move(result));
+ }
+
+ LogicalProps transport(const SargableNode& node,
+ LogicalProps childResult,
+ LogicalProps /*bindsResult*/,
+ LogicalProps /*refsResult*/) {
+ LogicalProps result = std::move(childResult);
+
+ auto& indexingAvailability = getProperty<IndexingAvailability>(result);
+ const ProjectionName& scanProjectionName = indexingAvailability.getScanProjection();
+ const std::string& scanDefName = indexingAvailability.getScanDefName();
+ const auto& scanDef = _metadata._scanDefs.at(scanDefName);
+
+ auto& distributions = getProperty<DistributionAvailability>(result).getDistributionSet();
+ addCentralizedAndRoundRobinDistributions(result);
+
+ populateDistributionPaths(
+ node.getReqMap(), scanProjectionName, scanDef.getDistributionAndPaths(), distributions);
+ for (const auto& entry : scanDef.getIndexDefs()) {
+ populateDistributionPaths(node.getReqMap(),
+ scanProjectionName,
+ entry.second.getDistributionAndPaths(),
+ distributions);
+ }
+
+ if (indexingAvailability.getPossiblyEqPredsOnly()) {
+ indexingAvailability.setPossiblyEqPredsOnly(
+ computePossiblyEqPredsOnly(node.getReqMap()));
+ }
+
+ auto& satisfiedPartialIndexes =
+ getProperty<IndexingAvailability>(result).getSatisfiedPartialIndexes();
+ for (const auto& [indexDefName, indexDef] : scanDef.getIndexDefs()) {
+ if (!indexDef.getPartialReqMap().empty()) {
+ auto intersection = node.getReqMap();
+ // We specifically ignore projectionRenames here.
+ ProjectionRenames projectionRenames_unused;
+ if (intersectPartialSchemaReq(
+ intersection, indexDef.getPartialReqMap(), projectionRenames_unused) &&
+ intersection == node.getReqMap()) {
+ satisfiedPartialIndexes.insert(indexDefName);
+ }
+ }
+ }
+
+ return maybeUpdateNodePropsMap(node, std::move(result));
+ }
+
+ LogicalProps transport(const RIDIntersectNode& node,
+ LogicalProps /*leftChildResult*/,
+ LogicalProps /*rightChildResult*/) {
+ // Properties for the group should already be derived via the underlying Filter or
+ // Evaluation logical nodes.
+ uasserted(6624042, "Should not be necessary to derive properties for RIDIntersectNode");
+ }
+
+ LogicalProps transport(const BinaryJoinNode& node,
+ LogicalProps /*leftChildResult*/,
+ LogicalProps /*rightChildResult*/,
+ LogicalProps /*exprResult*/) {
+ // TODO: remove indexing availability property when implemented.
+ // TODO: combine scan defs from all children for CollectionAvailability.
+ uasserted(6624043, "Logical property derivation not implemented.");
+ }
+
+ LogicalProps transport(const UnionNode& node,
+ std::vector<LogicalProps> childResults,
+ LogicalProps bindResult,
+ LogicalProps refsResult) {
+ uassert(6624044, "Unexpected empty child results for union node", !childResults.empty());
+
+ // We are specifically not adding the node's projection to ProjectionAvailability here.
+ // The logical properties already contains projection availability which is derived first
+ // when the memo group is created.
+ LogicalProps result = std::move(childResults[0]);
+ auto& mergedScanDefs = getProperty<CollectionAvailability>(result).getScanDefSet();
+ auto& mergedDistributionSet =
+ getProperty<DistributionAvailability>(result).getDistributionSet();
+ for (size_t childIdx = 1; childIdx < childResults.size(); childIdx++) {
+ auto childScanDefs =
+ getProperty<CollectionAvailability>(childResults[childIdx]).getScanDefSet();
+ mergedScanDefs.merge(std::move(childScanDefs));
+
+ // Only keep the distribution properties which are common across all children
+ // distributions.
+ const auto& childDistributionSet =
+ getProperty<DistributionAvailability>(childResults[childIdx]).getDistributionSet();
+
+ for (auto it = mergedDistributionSet.begin(); it != mergedDistributionSet.end(); it++) {
+ if (childDistributionSet.find(*it) == childDistributionSet.end()) {
+ mergedDistributionSet.erase(it);
+ }
+ }
+ }
+
+ // Verify that there is at least one common distribution available.
+ uassert(6624045, "No common distributions for union", !mergedDistributionSet.empty());
+
+ removeProperty<IndexingAvailability>(result);
+ return maybeUpdateNodePropsMap(node, std::move(result));
+ }
+
+ LogicalProps transport(const GroupByNode& node,
+ LogicalProps childResult,
+ LogicalProps /*bindAggResult*/,
+ LogicalProps /*refsAggResult*/,
+ LogicalProps /*bindGbResult*/,
+ LogicalProps /*refsGbResult*/) {
+ LogicalProps result = std::move(childResult);
+ removeProperty<IndexingAvailability>(result);
+
+ auto& distributions = getProperty<DistributionAvailability>(result).getDistributionSet();
+ addCentralizedAndRoundRobinDistributions<false /*addRoundRobin*/>(distributions);
+
+ if (_metadata.isParallelExecution() && node.getType() != GroupNodeType::Local) {
+ distributions.erase({DistributionType::UnknownPartitioning});
+ distributions.erase({DistributionType::RoundRobin});
+
+ // We propagate hash and range partitioning only if we are global agg.
+ const ProjectionNameVector& groupByProjections = node.getGroupByProjectionNames();
+ if (!groupByProjections.empty()) {
+ DistributionRequirement allowedRangePartitioning{
+ {DistributionType::RangePartitioning, groupByProjections}};
+ for (auto it = distributions.begin(); it != distributions.end();) {
+ switch (it->_type) {
+ case DistributionType::HashPartitioning:
+ // Erase all hash partition distributions. New ones will be generated
+ // after.
+ distributions.erase(it++);
+ break;
+
+ case DistributionType::RangePartitioning:
+ // Retain only the range partition which contains the group by
+ // projections in the node order.
+ if (*it == allowedRangePartitioning.getDistributionAndProjections()) {
+ it++;
+ } else {
+ distributions.erase(it++);
+ }
+
+ default:
+ it++;
+ break;
+ }
+ }
+
+ // Generate hash distributions using the power set of group-by projections.
+ for (size_t mask = 1; mask < (1ull << groupByProjections.size()); mask++) {
+ ProjectionNameVector projectionNames;
+ for (size_t index = 0; index < groupByProjections.size(); index++) {
+ if ((mask & (1ull << index)) != 0) {
+ projectionNames.push_back(groupByProjections.at(index));
+ }
+ }
+ distributions.emplace(DistributionType::HashPartitioning,
+ std::move(projectionNames));
+ }
+ }
+ }
+
+ return maybeUpdateNodePropsMap(node, std::move(result));
+ }
+
+ LogicalProps transport(const UnwindNode& node,
+ LogicalProps childResult,
+ LogicalProps /*bindResult*/,
+ LogicalProps /*refsResult*/) {
+ LogicalProps result = std::move(childResult);
+ removeProperty<IndexingAvailability>(result);
+
+ const ProjectionName& unwoundProjectionName = node.getProjectionName();
+ auto& distributions = getProperty<DistributionAvailability>(result).getDistributionSet();
+ addCentralizedAndRoundRobinDistributions(distributions);
+
+ if (_metadata.isParallelExecution()) {
+ for (auto it = distributions.begin(); it != distributions.end();) {
+ switch (it->_type) {
+ case DistributionType::HashPartitioning:
+ case DistributionType::RangePartitioning: {
+ // Erase partitioned distributions which contain the projection to unwind.
+ bool containsProjection = false;
+ for (const ProjectionName& projectionName : it->_projectionNames) {
+ if (projectionName == unwoundProjectionName) {
+ containsProjection = true;
+ break;
+ }
+ }
+ if (containsProjection) {
+ distributions.erase(it);
+ }
+ it++;
+ break;
+ }
+
+ default:
+ it++;
+ break;
+ }
+ }
+ }
+
+ return maybeUpdateNodePropsMap(node, std::move(result));
+ }
+
+ LogicalProps transport(const CollationNode& node,
+ LogicalProps childResult,
+ LogicalProps /*refsResult*/) {
+ LogicalProps result = std::move(childResult);
+ // We propagate indexing availability.
+
+ addCentralizedAndRoundRobinDistributions<false /*addRoundRobin*/>(result);
+ return maybeUpdateNodePropsMap(node, std::move(result));
+ }
+
+ LogicalProps transport(const LimitSkipNode& node, LogicalProps childResult) {
+ LogicalProps result = std::move(childResult);
+ removeProperty<IndexingAvailability>(result);
+ addCentralizedAndRoundRobinDistributions<false /*addRoundRobin*/>(result);
+ return maybeUpdateNodePropsMap(node, std::move(result));
+ }
+
+ LogicalProps transport(const ExchangeNode& node,
+ LogicalProps childResult,
+ LogicalProps /*refsResult*/) {
+ LogicalProps result = std::move(childResult);
+ removeProperty<IndexingAvailability>(result);
+ addCentralizedAndRoundRobinDistributions<false /*addRoundRobin*/>(result);
+ return maybeUpdateNodePropsMap(node, std::move(result));
+ }
+
+ LogicalProps transport(const RootNode& node,
+ LogicalProps childResult,
+ LogicalProps /*refsResult*/) {
+ return maybeUpdateNodePropsMap(node, std::move(childResult));
+ }
+
+ /**
+ * Other ABT types.
+ */
+ template <typename T, typename... Ts>
+ LogicalProps transport(const T& /*node*/, Ts&&...) {
+ static_assert(!canBeLogicalNode<T>(),
+ "Logical node must implement its logical property derivation.");
+ return {};
+ }
+
+ static LogicalProps derive(const Metadata& metadata,
+ const ABT::reference_type nodeRef,
+ LogicalPropsInterface::NodePropsMap* nodePropsMap,
+ const Memo* memo,
+ const GroupIdType groupId) {
+ DeriveLogicalProperties instance(memo, metadata, groupId, nodePropsMap);
+ return algebra::transport<false>(nodeRef, instance);
+ }
+
+private:
+ DeriveLogicalProperties(const Memo* memo,
+ const Metadata& metadata,
+ const GroupIdType groupId,
+ LogicalPropsInterface::NodePropsMap* nodePropsMap)
+ : _groupId(groupId), _memo(memo), _metadata(metadata), _nodePropsMap(nodePropsMap) {}
+
+ template <bool addRoundRobin = true>
+ void addCentralizedAndRoundRobinDistributions(DistributionSet& distributions) {
+ distributions.emplace(DistributionType::Centralized);
+ if (addRoundRobin && _metadata.isParallelExecution()) {
+ distributions.emplace(DistributionType::RoundRobin);
+ }
+ }
+
+ template <bool addRoundRobin = true>
+ void addCentralizedAndRoundRobinDistributions(LogicalProps& properties) {
+ addCentralizedAndRoundRobinDistributions<addRoundRobin>(
+ getProperty<DistributionAvailability>(properties).getDistributionSet());
+ }
+
+ LogicalProps maybeUpdateNodePropsMap(const Node& node, LogicalProps props) {
+ if (_nodePropsMap != nullptr) {
+ _nodePropsMap->emplace(&node, props);
+ }
+ return props;
+ }
+
+ const GroupIdType _groupId;
+
+ // We don't own any of those.
+ const Memo* _memo;
+ const Metadata& _metadata;
+ LogicalPropsInterface::NodePropsMap* _nodePropsMap;
+};
+
+properties::LogicalProps DefaultLogicalPropsDerivation::deriveProps(
+ const Metadata& metadata,
+ const ABT::reference_type nodeRef,
+ LogicalPropsInterface::NodePropsMap* nodePropsMap,
+ const Memo* memo,
+ const GroupIdType groupId) const {
+ return DeriveLogicalProperties::derive(metadata, nodeRef, nodePropsMap, memo, groupId);
+}
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/logical_props_derivation.h b/src/mongo/db/query/optimizer/cascades/logical_props_derivation.h
new file mode 100644
index 00000000000..d8b4d8ec349
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/logical_props_derivation.h
@@ -0,0 +1,52 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/cascades/interfaces.h"
+#include "mongo/db/query/optimizer/cascades/memo.h"
+
+namespace mongo::optimizer::cascades {
+
+
+/**
+ * Logical property derivation framework.
+ * If nodePropsMap is supplied, populate per-node properties.
+ */
+class DefaultLogicalPropsDerivation : public LogicalPropsInterface {
+public:
+ properties::LogicalProps deriveProps(const Metadata& metadata,
+ ABT::reference_type nodeRef,
+ NodePropsMap* nodePropsMap = nullptr,
+ const Memo* memo = nullptr,
+ GroupIdType groupId = -1) const override final;
+};
+
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp b/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp
new file mode 100644
index 00000000000..b4a52153736
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp
@@ -0,0 +1,1288 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/cascades/logical_rewriter.h"
+
+
+namespace mongo::optimizer::cascades {
+
+LogicalRewriter::RewriteSet LogicalRewriter::_explorationSet = {
+ {LogicalRewriteType::GroupByExplore, 1},
+ {LogicalRewriteType::SargableSplit, 2},
+ {LogicalRewriteType::FilterRIDIntersectReorder, 2},
+ {LogicalRewriteType::EvaluationRIDIntersectReorder, 2}};
+
+LogicalRewriter::RewriteSet LogicalRewriter::_substitutionSet = {
+ {LogicalRewriteType::FilterEvaluationReorder, 1},
+ {LogicalRewriteType::FilterCollationReorder, 1},
+ {LogicalRewriteType::EvaluationCollationReorder, 1},
+ {LogicalRewriteType::EvaluationLimitSkipReorder, 1},
+
+ {LogicalRewriteType::FilterGroupByReorder, 1},
+ {LogicalRewriteType::GroupCollationReorder, 1},
+
+ {LogicalRewriteType::FilterUnwindReorder, 1},
+ {LogicalRewriteType::EvaluationUnwindReorder, 1},
+ {LogicalRewriteType::UnwindCollationReorder, 1},
+
+ {LogicalRewriteType::FilterExchangeReorder, 1},
+ {LogicalRewriteType::ExchangeEvaluationReorder, 1},
+
+ {LogicalRewriteType::FilterUnionReorder, 1},
+
+ {LogicalRewriteType::CollationMerge, 1},
+ {LogicalRewriteType::LimitSkipMerge, 1},
+
+ {LogicalRewriteType::SargableFilterReorder, 1},
+ {LogicalRewriteType::SargableEvaluationReorder, 1},
+
+ {LogicalRewriteType::FilterValueScanPropagate, 1},
+ {LogicalRewriteType::EvaluationValueScanPropagate, 1},
+ {LogicalRewriteType::SargableValueScanPropagate, 1},
+ {LogicalRewriteType::CollationValueScanPropagate, 1},
+ {LogicalRewriteType::LimitSkipValueScanPropagate, 1},
+ {LogicalRewriteType::ExchangeValueScanPropagate, 1},
+
+ {LogicalRewriteType::LimitSkipSubstitute, 1},
+
+ {LogicalRewriteType::FilterSubstitute, 2},
+ {LogicalRewriteType::EvaluationSubstitute, 2},
+ {LogicalRewriteType::SargableMerge, 2}};
+
+LogicalRewriter::LogicalRewriter(Memo& memo, PrefixId& prefixId, RewriteSet rewriteSet)
+ : _activeRewriteSet(std::move(rewriteSet)), _groupsPending(), _memo(memo), _prefixId(prefixId) {
+ initializeRewrites();
+
+ if (_activeRewriteSet.count(LogicalRewriteType::SargableSplit) > 0) {
+ // If we are performing SargableSplit exploration rewrite, populate helper map.
+ for (const auto& [scanDefName, scanDef] : _memo.getMetadata()._scanDefs) {
+ for (const auto& [indexDefName, indexDef] : scanDef.getIndexDefs()) {
+ for (const IndexCollationEntry& entry : indexDef.getCollationSpec()) {
+ if (auto pathPtr = entry._path.cast<PathGet>(); pathPtr != nullptr) {
+ _indexFieldPrefixMap[scanDefName].insert(pathPtr->name());
+ }
+ }
+ }
+ }
+ }
+}
+
+GroupIdType LogicalRewriter::addRootNode(const ABT& node) {
+ return addNode(node, -1, false /*addExistingNodeWithNewChild*/).first;
+}
+
+std::pair<GroupIdType, NodeIdSet> LogicalRewriter::addNode(const ABT& node,
+ const GroupIdType targetGroupId,
+ const bool addExistingNodeWithNewChild) {
+ NodeIdSet insertNodeIds;
+
+ Memo::NodeTargetGroupMap targetGroupMap;
+ if (targetGroupId >= 0) {
+ targetGroupMap = {{node.ref(), targetGroupId}};
+ }
+
+ const GroupIdType resultGroupId = _memo.integrate(
+ node, std::move(targetGroupMap), insertNodeIds, addExistingNodeWithNewChild);
+
+ uassert(6624046,
+ "Result group is not the same as target group",
+ targetGroupId < 0 || targetGroupId == resultGroupId);
+
+ for (const MemoLogicalNodeId& nodeMemoId : insertNodeIds) {
+ if (addExistingNodeWithNewChild && nodeMemoId._groupId == targetGroupId) {
+ continue;
+ }
+
+ for (const auto [type, priority] : _activeRewriteSet) {
+ auto& groupQueue = _memo.getGroup(nodeMemoId._groupId)._logicalRewriteQueue;
+ groupQueue.push(std::make_unique<LogicalRewriteEntry>(priority, type, nodeMemoId));
+
+ _groupsPending.insert(nodeMemoId._groupId);
+ }
+ }
+
+ return {resultGroupId, std::move(insertNodeIds)};
+}
+
+void LogicalRewriter::clearGroup(const GroupIdType groupId) {
+ _memo.clearLogicalNodes(groupId);
+}
+
+class RewriteContext {
+public:
+ RewriteContext(LogicalRewriter& rewriter,
+ const MemoLogicalNodeId aboveNodeId,
+ const MemoLogicalNodeId belowNodeId)
+ : RewriteContext(rewriter, aboveNodeId, true /*hasBelowNodeId*/, belowNodeId){};
+
+ RewriteContext(LogicalRewriter& rewriter, const MemoLogicalNodeId aboveNodeId)
+ : RewriteContext(rewriter, aboveNodeId, false /*hasBelowNodeId*/, {}){};
+
+ std::pair<GroupIdType, NodeIdSet> addNode(const ABT& node,
+ const bool substitute,
+ const bool addExistingNodeWithNewChild = false) {
+ if (substitute) {
+ uassert(6624110, "Cannot substitute twice", !_hasSubstituted);
+ _hasSubstituted = true;
+
+ _rewriter.clearGroup(_aboveNodeId._groupId);
+ if (_hasBelowNodeId) {
+ _rewriter.clearGroup(_belowNodeId._groupId);
+ }
+ }
+ return _rewriter.addNode(node, _aboveNodeId._groupId, addExistingNodeWithNewChild);
+ }
+
+ Memo& getMemo() const {
+ return _rewriter._memo;
+ }
+
+ const Metadata& getMetadata() const {
+ return _rewriter._memo.getMetadata();
+ }
+
+ PrefixId& getPrefixId() const {
+ return _rewriter._prefixId;
+ }
+
+ auto& getIndexFieldPrefixMap() const {
+ return _rewriter._indexFieldPrefixMap;
+ }
+
+ const properties::LogicalProps& getAboveLogicalProps() const {
+ return getMemo().getGroup(_aboveNodeId._groupId)._logicalProperties;
+ }
+
+ bool hasSubstituted() const {
+ return _hasSubstituted;
+ }
+
+ MemoLogicalNodeId getAboveNodeId() const {
+ return _aboveNodeId;
+ }
+
+ auto& getSargableSplitCountMap() const {
+ return _rewriter._sargableSplitCountMap;
+ }
+
+private:
+ RewriteContext(LogicalRewriter& rewriter,
+ const MemoLogicalNodeId aboveNodeId,
+ const bool hasBelowNodeId,
+ const MemoLogicalNodeId belowNodeId)
+ : _aboveNodeId(aboveNodeId),
+ _hasBelowNodeId(hasBelowNodeId),
+ _belowNodeId(belowNodeId),
+ _rewriter(rewriter),
+ _hasSubstituted(false){};
+
+ const MemoLogicalNodeId _aboveNodeId;
+ const bool _hasBelowNodeId;
+ const MemoLogicalNodeId _belowNodeId;
+
+ // We don't own this.
+ LogicalRewriter& _rewriter;
+
+ bool _hasSubstituted;
+};
+
+struct ReorderDependencies {
+ bool _hasNodeRef = false;
+ bool _hasChildRef = false;
+ bool _hasNodeAndChildRef = false;
+};
+
+template <class NodeType>
+struct DefaultChildAccessor {
+ const ABT& operator()(const ABT& node) const {
+ return node.cast<NodeType>()->getChild();
+ }
+
+ ABT& operator()(ABT& node) const {
+ return node.cast<NodeType>()->getChild();
+ }
+};
+
+template <class NodeType>
+struct LeftChildAccessor {
+ const ABT& operator()(const ABT& node) const {
+ return node.cast<NodeType>()->getLeftChild();
+ }
+
+ ABT& operator()(ABT& node) const {
+ return node.cast<NodeType>()->getLeftChild();
+ }
+};
+
+template <class NodeType>
+struct RightChildAccessor {
+ const ABT& operator()(const ABT& node) const {
+ return node.cast<NodeType>()->getRightChild();
+ }
+
+ ABT& operator()(ABT& node) const {
+ return node.cast<NodeType>()->getRightChild();
+ }
+};
+
+template <class AboveType,
+ class BelowType,
+ template <class> class BelowChildAccessor = DefaultChildAccessor>
+ReorderDependencies computeDependencies(ABT::reference_type aboveNodeRef,
+ ABT::reference_type belowNodeRef,
+ RewriteContext& ctx) {
+ // Get variables from above node and check if they are bound at below node, or at below node's
+ // child.
+ const auto aboveNodeVarNames = collectVariableReferences(aboveNodeRef);
+
+ ABT belowNode = belowNodeRef;
+ VariableEnvironment env = VariableEnvironment::build(belowNode, &ctx.getMemo());
+ const DefinitionsMap belowNodeDefs = env.hasDefinitions(belowNode.ref())
+ ? env.getDefinitions(belowNode.ref())
+ : DefinitionsMap{};
+ ABT::reference_type belowChild = BelowChildAccessor<BelowType>()(belowNode).ref();
+ const DefinitionsMap belowChildNodeDefs =
+ env.hasDefinitions(belowChild) ? env.getDefinitions(belowChild) : DefinitionsMap{};
+
+ ReorderDependencies dependencies;
+ for (const std::string& varName : aboveNodeVarNames) {
+ auto it = belowNodeDefs.find(varName);
+ // Variable is exclusively defined in the below node.
+ const bool refersToNode = it != belowNodeDefs.cend() && it->second.definedBy == belowNode;
+ // Variable is defined in the belowNode's child subtree.
+ const bool refersToChild = belowChildNodeDefs.find(varName) != belowChildNodeDefs.cend();
+
+ if (refersToNode) {
+ if (refersToChild) {
+ dependencies._hasNodeAndChildRef = true;
+ } else {
+ dependencies._hasNodeRef = true;
+ }
+ } else if (refersToChild) {
+ dependencies._hasChildRef = true;
+ } else {
+ // Lambda variable. Ignore.
+ }
+ }
+
+ return dependencies;
+}
+
+static ABT createEmptyValueScanNode(const RewriteContext& ctx) {
+ using namespace properties;
+
+ const ProjectionNameSet& projNameSet =
+ getPropertyConst<ProjectionAvailability>(ctx.getAboveLogicalProps()).getProjections();
+ ProjectionNameVector projNameVector;
+ projNameVector.insert(projNameVector.begin(), projNameSet.cbegin(), projNameSet.cend());
+ return make<ValueScanNode>(std::move(projNameVector));
+}
+
+static void addEmptyValueScanNode(RewriteContext& ctx) {
+ ABT newNode = createEmptyValueScanNode(ctx);
+ ctx.addNode(newNode, true /*substitute*/);
+}
+
+static void defaultPropagateEmptyValueScanNode(const ABT& n, RewriteContext& ctx) {
+ if (n.cast<ValueScanNode>()->getArraySize() == 0) {
+ addEmptyValueScanNode(ctx);
+ }
+}
+
+template <class AboveType,
+ class BelowType,
+ template <class> class AboveChildAccessor = DefaultChildAccessor,
+ template <class> class BelowChildAccessor = DefaultChildAccessor,
+ bool substitute = true>
+void defaultReorder(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) {
+ ABT newParent = belowNode;
+ ABT newChild = aboveNode;
+
+ std::swap(BelowChildAccessor<BelowType>()(newParent),
+ AboveChildAccessor<AboveType>()(newChild));
+ BelowChildAccessor<BelowType>()(newParent) = std::move(newChild);
+
+ ctx.addNode(newParent, substitute);
+}
+
+template <class AboveType, class BelowType>
+void defaultReorderWithDependenceCheck(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) {
+ const ReorderDependencies dependencies =
+ computeDependencies<AboveType, BelowType>(aboveNode, belowNode, ctx);
+ if (dependencies._hasNodeRef) {
+ // Above node refers to a variable bound by below node.
+ return;
+ }
+
+ defaultReorder<AboveType, BelowType>(aboveNode, belowNode, ctx);
+}
+
+template <class AboveType, class BelowType>
+struct SubstituteReorder {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ defaultReorderWithDependenceCheck<AboveType, BelowType>(aboveNode, belowNode, ctx);
+ }
+};
+
+template <>
+struct SubstituteReorder<FilterNode, FilterNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ defaultReorder<FilterNode, FilterNode>(aboveNode, belowNode, ctx);
+ }
+};
+
+template <>
+struct SubstituteReorder<FilterNode, UnionNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ ABT newParent = belowNode;
+
+ for (auto& childOfChild : newParent.cast<UnionNode>()->nodes()) {
+ ABT aboveCopy = aboveNode;
+ std::swap(aboveCopy.cast<FilterNode>()->getChild(), childOfChild);
+ std::swap(childOfChild, aboveCopy);
+ }
+
+ ctx.addNode(newParent, true /*substitute*/);
+ }
+};
+
+template <class AboveType>
+void unwindBelowReorder(ABT::reference_type aboveNode,
+ ABT::reference_type unwindNode,
+ RewriteContext& ctx) {
+ const ReorderDependencies dependencies =
+ computeDependencies<AboveType, UnwindNode>(aboveNode, unwindNode, ctx);
+ if (dependencies._hasNodeRef || dependencies._hasNodeAndChildRef) {
+ // Above node refers to projection being unwound. Reject rewrite.
+ return;
+ }
+
+ defaultReorder<AboveType, UnwindNode>(aboveNode, unwindNode, ctx);
+}
+
+template <>
+struct SubstituteReorder<FilterNode, UnwindNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ unwindBelowReorder<FilterNode>(aboveNode, belowNode, ctx);
+ }
+};
+
+template <>
+struct SubstituteReorder<EvaluationNode, UnwindNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ unwindBelowReorder<EvaluationNode>(aboveNode, belowNode, ctx);
+ }
+};
+
+template <>
+struct SubstituteReorder<UnwindNode, CollationNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ const ProjectionNameSet& collationProjections =
+ belowNode.cast<CollationNode>()->getProperty().getAffectedProjectionNames();
+ if (collationProjections.find(aboveNode.cast<UnwindNode>()->getProjectionName()) !=
+ collationProjections.cend()) {
+ // A projection being affected by the collation is being unwound. Reject rewrite.
+ return;
+ }
+
+ defaultReorder<UnwindNode, CollationNode>(aboveNode, belowNode, ctx);
+ }
+};
+
+template <>
+struct SubstituteReorder<FilterNode, ValueScanNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ defaultPropagateEmptyValueScanNode(belowNode, ctx);
+ }
+};
+
+template <>
+struct SubstituteReorder<EvaluationNode, ValueScanNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ defaultPropagateEmptyValueScanNode(belowNode, ctx);
+ }
+};
+
+template <>
+struct SubstituteReorder<SargableNode, ValueScanNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ defaultPropagateEmptyValueScanNode(belowNode, ctx);
+ }
+};
+
+template <>
+struct SubstituteReorder<CollationNode, ValueScanNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ defaultPropagateEmptyValueScanNode(belowNode, ctx);
+ }
+};
+
+template <>
+struct SubstituteReorder<LimitSkipNode, ValueScanNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ defaultPropagateEmptyValueScanNode(belowNode, ctx);
+ }
+};
+
+template <>
+struct SubstituteReorder<ExchangeNode, ValueScanNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ defaultPropagateEmptyValueScanNode(belowNode, ctx);
+ }
+};
+
+template <class AboveType, class BelowType>
+struct SubstituteMerge {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) = delete;
+};
+
+template <>
+struct SubstituteMerge<CollationNode, CollationNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ ABT newRoot = aboveNode;
+ // Retain above property.
+ newRoot.cast<CollationNode>()->getChild() = belowNode.cast<CollationNode>()->getChild();
+
+ ctx.addNode(newRoot, true /*substitute*/);
+ }
+};
+
+template <>
+struct SubstituteMerge<LimitSkipNode, LimitSkipNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ using namespace properties;
+
+ ABT newRoot = aboveNode;
+ LimitSkipNode& aboveCollationNode = *newRoot.cast<LimitSkipNode>();
+ const LimitSkipNode& belowCollationNode = *belowNode.cast<LimitSkipNode>();
+
+ aboveCollationNode.getChild() = belowCollationNode.getChild();
+ combineLimitSkipProperties(aboveCollationNode.getProperty(),
+ belowCollationNode.getProperty());
+
+ ctx.addNode(newRoot, true /*substitute*/);
+ }
+};
+
+static boost::optional<ABT> mergeSargableNodes(
+ const properties::IndexingAvailability& indexingAvailability,
+ const SargableNode& aboveNode,
+ const SargableNode& belowNode,
+ RewriteContext& ctx) {
+ if (indexingAvailability.getScanGroupId() !=
+ belowNode.getChild().cast<MemoLogicalDelegatorNode>()->getGroupId()) {
+ // Do not merge if child is not another Sargable node, or the child's child is not a
+ // ScanNode.
+ return {};
+ }
+
+ PartialSchemaRequirements mergedReqs = belowNode.getReqMap();
+ ProjectionRenames projectionRenames;
+ if (!intersectPartialSchemaReq(mergedReqs, aboveNode.getReqMap(), projectionRenames)) {
+ return {};
+ }
+ if (mergedReqs.size() > LogicalRewriter::kMaxPartialSchemaReqCount) {
+ return {};
+ }
+
+ const ScanDefinition& scanDef =
+ ctx.getMetadata()._scanDefs.at(indexingAvailability.getScanDefName());
+ bool hasEmptyInterval = false;
+ auto candidateIndexMap = computeCandidateIndexMap(ctx.getPrefixId(),
+ indexingAvailability.getScanProjection(),
+ mergedReqs,
+ scanDef,
+ hasEmptyInterval);
+
+ if (hasEmptyInterval) {
+ return createEmptyValueScanNode(ctx);
+ }
+
+ ABT result = make<SargableNode>(std::move(mergedReqs),
+ std::move(candidateIndexMap),
+ IndexReqTarget::Complete,
+ belowNode.getChild());
+ applyProjectionRenames(std::move(projectionRenames), result);
+ return result;
+}
+
+template <>
+struct SubstituteMerge<SargableNode, SargableNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ using namespace properties;
+ const auto& result =
+ mergeSargableNodes(getPropertyConst<IndexingAvailability>(ctx.getAboveLogicalProps()),
+ *aboveNode.cast<SargableNode>(),
+ *belowNode.cast<SargableNode>(),
+ ctx);
+ if (result) {
+ ctx.addNode(*result, true /*substitute*/);
+ }
+ }
+};
+
+template <class Type>
+struct SubstituteConvert {
+ void operator()(ABT::reference_type nodeRef, RewriteContext& ctx) = delete;
+};
+
+template <>
+struct SubstituteConvert<LimitSkipNode> {
+ void operator()(ABT::reference_type node, RewriteContext& ctx) {
+ if (node.cast<LimitSkipNode>()->getProperty().getLimit() == 0) {
+ addEmptyValueScanNode(ctx);
+ }
+ }
+};
+
+static void addElemMatchAndSargableNode(const ABT& node, ABT sargableNode, RewriteContext& ctx) {
+ ABT newNode = node;
+ newNode.cast<FilterNode>()->getChild() = std::move(sargableNode);
+ ctx.addNode(newNode, false /*substitute*/, true /*addExistingNodeWithNewChild*/);
+}
+
+void convertFilterToSargableNode(ABT::reference_type node,
+ const FilterNode& filterNode,
+ RewriteContext& ctx) {
+ using namespace properties;
+
+ const LogicalProps& props = ctx.getAboveLogicalProps();
+ if (!hasProperty<IndexingAvailability>(props)) {
+ // Can only convert to sargable node if we have indexing availability.
+ return;
+ }
+
+ const auto& indexingAvailability = getPropertyConst<IndexingAvailability>(props);
+ const ScanDefinition& scanDef =
+ ctx.getMetadata()._scanDefs.at(indexingAvailability.getScanDefName());
+ if (!scanDef.exists()) {
+ // Do not attempt to optimize for non-existing collections.
+ return;
+ }
+
+ PartialSchemaReqConversion conversion = convertExprToPartialSchemaReq(filterNode.getFilter());
+ if (!conversion._success) {
+ return;
+ }
+ if (conversion._hasEmptyInterval) {
+ addEmptyValueScanNode(ctx);
+ return;
+ }
+
+ for (const auto& entry : conversion._reqMap) {
+ uassert(6624111,
+ "Filter partial schema requirement must contain a variable name.",
+ !entry.first._projectionName.empty());
+ uassert(6624112,
+ "Filter partial schema requirement cannot bind.",
+ !entry.second.hasBoundProjectionName());
+ uassert(6624113,
+ "Filter partial schema requirement must have a range.",
+ !isIntervalReqFullyOpenDNF(entry.second.getIntervals()));
+ }
+
+ bool hasEmptyInterval = false;
+ auto candidateIndexMap = computeCandidateIndexMap(ctx.getPrefixId(),
+ indexingAvailability.getScanProjection(),
+ conversion._reqMap,
+ scanDef,
+ hasEmptyInterval);
+
+ if (hasEmptyInterval) {
+ addEmptyValueScanNode(ctx);
+ } else {
+ ABT sargableNode = make<SargableNode>(std::move(conversion._reqMap),
+ std::move(candidateIndexMap),
+ IndexReqTarget::Complete,
+ filterNode.getChild());
+
+ ctx.addNode(sargableNode, true /*substitute*/);
+ }
+}
+
+static ABT appendFieldPath(const FieldPathType& fieldPath, ABT input) {
+ for (size_t index = fieldPath.size(); index-- > 0;) {
+ input = make<PathGet>(fieldPath.at(index), std::move(input));
+ }
+ return input;
+}
+
+template <>
+struct SubstituteConvert<FilterNode> {
+ void operator()(ABT::reference_type node, RewriteContext& ctx) {
+ const FilterNode& filterNode = *node.cast<FilterNode>();
+
+ // Sub-rewrite: attempt to de-compose filter. If we have a path with a prefix of PathGet's
+ // followed by a PathComposeM, then split into two filter nodes at the composition and
+ // retain the prefix for each.
+ // TODO: consider using a standalone rewrite.
+ if (auto evalFilter = filterNode.getFilter().cast<EvalFilter>(); evalFilter != nullptr) {
+ ABT::reference_type pathRef = evalFilter->getPath().ref();
+ FieldPathType fieldPath;
+ for (;;) {
+ if (auto newPath = pathRef.cast<PathGet>(); newPath != nullptr) {
+ fieldPath.push_back(newPath->name());
+ pathRef = newPath->getPath().ref();
+ } else {
+ break;
+ }
+ }
+
+ if (auto composition = pathRef.cast<PathComposeM>(); composition != nullptr) {
+ // Remove the path composition and insert two filter nodes.
+ ABT filterNode1 = make<FilterNode>(
+ make<EvalFilter>(appendFieldPath(fieldPath, composition->getPath1()),
+ evalFilter->getInput()),
+ filterNode.getChild());
+ ABT filterNode2 = make<FilterNode>(
+ make<EvalFilter>(appendFieldPath(fieldPath, composition->getPath2()),
+ evalFilter->getInput()),
+ std::move(filterNode1));
+
+ ctx.addNode(filterNode2, true /*substitute*/);
+ return;
+ }
+ }
+
+ convertFilterToSargableNode(node, filterNode, ctx);
+ }
+};
+
+template <>
+struct SubstituteConvert<EvaluationNode> {
+ void operator()(ABT::reference_type node, RewriteContext& ctx) {
+ using namespace properties;
+
+ const LogicalProps props = ctx.getAboveLogicalProps();
+ if (!hasProperty<IndexingAvailability>(props)) {
+ // Can only convert to sargable node if we have indexing availability.
+ return;
+ }
+
+ const auto& indexingAvailability = getPropertyConst<IndexingAvailability>(props);
+ const ProjectionName& scanProjName = indexingAvailability.getScanProjection();
+
+ const ScanDefinition& scanDef =
+ ctx.getMetadata()._scanDefs.at(indexingAvailability.getScanDefName());
+ if (!scanDef.exists()) {
+ // Do not attempt to optimize for non-existing collections.
+ return;
+ }
+
+ const EvaluationNode& evalNode = *node.cast<EvaluationNode>();
+
+ // Sub-rewrite: attempt to convert Keep to a chain of individual evaluations.
+ // TODO: consider using a standalone rewrite.
+ if (auto evalPathPtr = evalNode.getProjection().cast<EvalPath>(); evalPathPtr != nullptr) {
+ if (auto inputPtr = evalPathPtr->getInput().cast<Variable>();
+ inputPtr != nullptr && inputPtr->name() == scanProjName) {
+ if (auto pathKeepPtr = evalPathPtr->getPath().cast<PathKeep>();
+ pathKeepPtr != nullptr) {
+ // Optimization. If we are retaining fields on the root level, generate
+ // EvalNodes with the intention of converting later to a SargableNode after
+ // reordering, in order to be able to cover the fields using a physical scan or
+ // index.
+
+ ABT result = evalNode.getChild();
+ ABT keepPath = make<PathIdentity>();
+
+ std::set<std::string> orderedSet;
+ for (const std::string& fieldName : pathKeepPtr->getNames()) {
+ orderedSet.insert(fieldName);
+ }
+ for (const std::string& fieldName : orderedSet) {
+ ProjectionName projName = ctx.getPrefixId().getNextId("fieldProj");
+ result = make<EvaluationNode>(
+ projName,
+ make<EvalPath>(make<PathGet>(fieldName, make<PathIdentity>()),
+ evalPathPtr->getInput()),
+ std::move(result));
+
+ maybeComposePath(keepPath,
+ make<PathField>(fieldName,
+ make<PathConstant>(
+ make<Variable>(std::move(projName)))));
+ }
+
+ result = make<EvaluationNode>(
+ evalNode.getProjectionName(),
+ make<EvalPath>(std::move(keepPath), Constant::emptyObject()),
+ std::move(result));
+ ctx.addNode(result, true /*substitute*/);
+ return;
+ }
+ }
+ }
+
+ // We still want to extract sargable nodes from EvalNode to use for PhysicalScans.
+ PartialSchemaReqConversion conversion =
+ convertExprToPartialSchemaReq(evalNode.getProjection());
+ if (!conversion._success || conversion._reqMap.size() != 1) {
+ // For evaluation nodes we expect to create a single entry.
+ return;
+ }
+ if (conversion._hasEmptyInterval) {
+ addEmptyValueScanNode(ctx);
+ return;
+ }
+
+ for (auto& entry : conversion._reqMap) {
+ PartialSchemaRequirement& req = entry.second;
+ req.setBoundProjectionName(evalNode.getProjectionName());
+
+ uassert(6624114,
+ "Eval partial schema requirement must contain a variable name.",
+ !entry.first._projectionName.empty());
+ uassert(6624115,
+ "Eval partial schema requirement cannot have a range",
+ isIntervalReqFullyOpenDNF(req.getIntervals()));
+ }
+
+ bool hasEmptyInterval = false;
+ auto candidateIndexMap = computeCandidateIndexMap(
+ ctx.getPrefixId(), scanProjName, conversion._reqMap, scanDef, hasEmptyInterval);
+
+ if (hasEmptyInterval) {
+ addEmptyValueScanNode(ctx);
+ } else {
+ ABT newNode = make<SargableNode>(std::move(conversion._reqMap),
+ std::move(candidateIndexMap),
+ IndexReqTarget::Complete,
+ evalNode.getChild());
+ ctx.addNode(newNode, true /*substitute*/);
+ }
+ }
+};
+
+static void lowerSargableNode(const SargableNode& node, RewriteContext& ctx) {
+ ABT n = node.getChild();
+ const auto reqMap = node.getReqMap();
+ for (const auto& [key, req] : reqMap) {
+ lowerPartialSchemaRequirement(key, req, n);
+ }
+ ctx.addNode(n, true /*clear*/);
+}
+
+template <class Type>
+struct ExploreConvert {
+ void operator()(ABT::reference_type nodeRef, RewriteContext& ctx) = delete;
+};
+
+template <>
+struct ExploreConvert<SargableNode> {
+ void operator()(ABT::reference_type node, RewriteContext& ctx) {
+ using namespace properties;
+
+ const SargableNode& sargableNode = *node.cast<SargableNode>();
+ const IndexReqTarget target = sargableNode.getTarget();
+ if (target == IndexReqTarget::Seek) {
+ return;
+ }
+
+ const LogicalProps& props = ctx.getAboveLogicalProps();
+ const auto& indexingAvailability = getPropertyConst<IndexingAvailability>(props);
+ const GroupIdType scanGroupId = indexingAvailability.getScanGroupId();
+ if (sargableNode.getChild().cast<MemoLogicalDelegatorNode>()->getGroupId() != scanGroupId) {
+ lowerSargableNode(sargableNode, ctx);
+ return;
+ }
+
+ const std::string& scanDefName = indexingAvailability.getScanDefName();
+ const ScanDefinition& scanDef = ctx.getMetadata()._scanDefs.at(scanDefName);
+ const size_t indexCount = scanDef.getIndexDefs().size();
+ if (indexCount == 0) {
+ // Do not insert RIDIntersect if we do not have indexes available.
+ return;
+ }
+
+ const auto aboveNodeId = ctx.getAboveNodeId();
+ auto& sargableSplitCountMap = ctx.getSargableSplitCountMap();
+ const size_t splitCount = sargableSplitCountMap[aboveNodeId];
+ if ((1ull << splitCount) >
+ roundUpToNextPow2(indexCount, LogicalRewriter::kMaxSargableNodeSplitCount)) {
+ // We cannot split this node further.
+ return;
+ }
+
+ const ProjectionName& scanProjectionName = indexingAvailability.getScanProjection();
+ if (collectVariableReferences(node) != VariableNameSetType{scanProjectionName}) {
+ // Rewrite not applicable if we refer projections other than the scan projection.
+ return;
+ }
+
+ const bool isIndex = target == IndexReqTarget::Index;
+
+ const auto& indexFieldPrefixMap = ctx.getIndexFieldPrefixMap();
+ const auto indexFieldPrefixMapIt =
+ isIndex ? indexFieldPrefixMap.cend() : indexFieldPrefixMap.find(scanDefName);
+ const bool indexFieldMapHasScanDef = indexFieldPrefixMapIt != indexFieldPrefixMap.cend();
+
+ const auto& reqMap = sargableNode.getReqMap();
+ const size_t reqSize = reqMap.size();
+ const size_t highMask = isIndex ? (1ull << (reqSize - 1)) : (1ull << reqSize);
+ for (size_t mask = 1; mask < highMask; mask++) {
+ PartialSchemaRequirements leftReqs;
+ PartialSchemaRequirements rightReqs;
+ bool hasFieldCoverage = true;
+ bool hasLeftIntervals = false;
+ bool hasRightIntervals = false;
+
+ size_t index = 0;
+ for (const auto& [key, req] : reqMap) {
+ const bool fullyOpenInterval = isIntervalReqFullyOpenDNF(req.getIntervals());
+
+ if (((1ull << index) & mask) != 0) {
+ leftReqs.emplace(key, req);
+
+ if (!fullyOpenInterval) {
+ hasLeftIntervals = true;
+ }
+ if (indexFieldMapHasScanDef) {
+ if (auto pathPtr = key._path.cast<PathGet>(); pathPtr != nullptr &&
+ indexFieldPrefixMapIt->second.count(pathPtr->name()) == 0) {
+ // We have found a left requirement which cannot be covered with an
+ // index.
+ hasFieldCoverage = false;
+ break;
+ }
+ }
+ } else {
+ rightReqs.emplace(key, req);
+
+ if (!fullyOpenInterval) {
+ hasRightIntervals = true;
+ }
+ }
+ index++;
+ }
+
+ if (isIndex && (!hasLeftIntervals || !hasRightIntervals)) {
+ // Reject. Must have at least one proper interval on either side.
+ continue;
+ }
+ if (!hasFieldCoverage) {
+ // Reject rewrite. No suitable indexes.
+ continue;
+ }
+
+ bool hasEmptyLeftInterval = false;
+ auto leftCandidateIndexMap = computeCandidateIndexMap(
+ ctx.getPrefixId(), scanProjectionName, leftReqs, scanDef, hasEmptyLeftInterval);
+ if (isIndex && leftCandidateIndexMap.empty()) {
+ // Reject rewrite.
+ continue;
+ }
+
+ bool hasEmptyRightInterval = false;
+ auto rightCandidateIndexMap = computeCandidateIndexMap(
+ ctx.getPrefixId(), scanProjectionName, rightReqs, scanDef, hasEmptyRightInterval);
+ if (isIndex && rightCandidateIndexMap.empty()) {
+ // With empty candidate map, reject only if we cannot implement as Seek.
+ continue;
+ }
+ uassert(6624116,
+ "Empty intervals should already be rewritten to empty ValueScan nodes",
+ !hasEmptyLeftInterval && !hasEmptyRightInterval);
+
+ ABT scanDelegator = make<MemoLogicalDelegatorNode>(scanGroupId);
+ ABT leftChild = make<SargableNode>(std::move(leftReqs),
+ std::move(leftCandidateIndexMap),
+ IndexReqTarget::Index,
+ scanDelegator);
+ ABT rightChild = rightReqs.empty()
+ ? scanDelegator
+ : make<SargableNode>(std::move(rightReqs),
+ std::move(rightCandidateIndexMap),
+ isIndex ? IndexReqTarget::Index : IndexReqTarget::Seek,
+ scanDelegator);
+
+ ABT newRoot = make<RIDIntersectNode>(scanProjectionName,
+ hasLeftIntervals,
+ hasRightIntervals,
+ std::move(leftChild),
+ std::move(rightChild));
+
+ const auto& result = ctx.addNode(newRoot, false /*substitute*/);
+ for (const MemoLogicalNodeId nodeId : result.second) {
+ if (!(nodeId == aboveNodeId)) {
+ sargableSplitCountMap[nodeId] = splitCount + 1;
+ }
+ }
+ }
+ }
+};
+
+template <>
+struct ExploreConvert<GroupByNode> {
+ void operator()(ABT::reference_type node, RewriteContext& ctx) {
+ const GroupByNode& groupByNode = *node.cast<GroupByNode>();
+ if (groupByNode.getType() != GroupNodeType::Complete) {
+ return;
+ }
+
+ ProjectionNameVector preaggVariableNames;
+ ABTVector preaggExpressions;
+
+ const ABTVector& aggExpressions = groupByNode.getAggregationExpressions();
+ for (const ABT& expr : aggExpressions) {
+ const FunctionCall* aggPtr = expr.cast<FunctionCall>();
+ if (aggPtr == nullptr) {
+ return;
+ }
+
+ // In order to be able to pre-aggregate for now we expect a simple aggregate like
+ // SUM(x).
+ const auto& aggFnName = aggPtr->name();
+ if (aggFnName != "$sum" && aggFnName != "$min" && aggFnName != "$max") {
+ // TODO: allow more functions.
+ return;
+ }
+ uassert(6624117, "Invalid argument count", aggPtr->nodes().size() == 1);
+
+ preaggVariableNames.push_back(ctx.getPrefixId().getNextId("preagg"));
+ preaggExpressions.emplace_back(
+ make<FunctionCall>(aggFnName, makeSeq(make<Variable>(preaggVariableNames.back()))));
+ }
+
+ ABT localGroupBy = make<GroupByNode>(groupByNode.getGroupByProjectionNames(),
+ std::move(preaggVariableNames),
+ aggExpressions,
+ GroupNodeType::Local,
+ groupByNode.getChild());
+
+ ABT newRoot = make<GroupByNode>(groupByNode.getGroupByProjectionNames(),
+ groupByNode.getAggregationProjectionNames(),
+ std::move(preaggExpressions),
+ GroupNodeType::Global,
+ std::move(localGroupBy));
+
+ ctx.addNode(newRoot, false /*substitute*/);
+ }
+};
+
+template <class AboveType, class BelowType>
+struct ExploreReorder {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const = delete;
+};
+
+template <class AboveNode>
+void reorderAgainstRIDIntersectNode(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) {
+ const ReorderDependencies leftDeps =
+ computeDependencies<AboveNode, RIDIntersectNode, LeftChildAccessor>(
+ aboveNode, belowNode, ctx);
+ uassert(6624118, "RIDIntersect cannot bind projections", !leftDeps._hasNodeRef);
+ const bool hasLeftRef = leftDeps._hasChildRef;
+
+ const ReorderDependencies rightDeps =
+ computeDependencies<AboveNode, RIDIntersectNode, RightChildAccessor>(
+ aboveNode, belowNode, ctx);
+ uassert(6624119, "RIDIntersect cannot bind projections", !rightDeps._hasNodeRef);
+ const bool hasRightRef = rightDeps._hasChildRef;
+
+ if (hasLeftRef == hasRightRef) {
+ // Both left and right reorderings available means that we refer to both left and right
+ // sides.
+ return;
+ }
+
+ const RIDIntersectNode& node = *belowNode.cast<RIDIntersectNode>();
+ if (node.hasLeftIntervals() && hasLeftRef) {
+ defaultReorder<AboveNode,
+ RIDIntersectNode,
+ DefaultChildAccessor,
+ LeftChildAccessor,
+ false /*substitute*/>(aboveNode, belowNode, ctx);
+ }
+ if (node.hasRightIntervals() && hasRightRef) {
+ defaultReorder<AboveNode,
+ RIDIntersectNode,
+ DefaultChildAccessor,
+ RightChildAccessor,
+ false /*substitute*/>(aboveNode, belowNode, ctx);
+ }
+};
+
+template <>
+struct ExploreReorder<FilterNode, RIDIntersectNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ reorderAgainstRIDIntersectNode<FilterNode>(aboveNode, belowNode, ctx);
+ }
+};
+
+template <>
+struct ExploreReorder<EvaluationNode, RIDIntersectNode> {
+ void operator()(ABT::reference_type aboveNode,
+ ABT::reference_type belowNode,
+ RewriteContext& ctx) const {
+ reorderAgainstRIDIntersectNode<EvaluationNode>(aboveNode, belowNode, ctx);
+ }
+};
+
+void LogicalRewriter::registerRewrite(const LogicalRewriteType rewriteType, RewriteFn fn) {
+ if (_activeRewriteSet.find(rewriteType) != _activeRewriteSet.cend()) {
+ _rewriteMap.emplace(rewriteType, fn);
+ }
+}
+
+void LogicalRewriter::initializeRewrites() {
+ registerRewrite(
+ LogicalRewriteType::FilterEvaluationReorder,
+ &LogicalRewriter::bindAboveBelow<FilterNode, EvaluationNode, SubstituteReorder>);
+ registerRewrite(LogicalRewriteType::FilterCollationReorder,
+ &LogicalRewriter::bindAboveBelow<FilterNode, CollationNode, SubstituteReorder>);
+ registerRewrite(
+ LogicalRewriteType::EvaluationCollationReorder,
+ &LogicalRewriter::bindAboveBelow<EvaluationNode, CollationNode, SubstituteReorder>);
+ registerRewrite(
+ LogicalRewriteType::EvaluationLimitSkipReorder,
+ &LogicalRewriter::bindAboveBelow<EvaluationNode, LimitSkipNode, SubstituteReorder>);
+ registerRewrite(LogicalRewriteType::FilterGroupByReorder,
+ &LogicalRewriter::bindAboveBelow<FilterNode, GroupByNode, SubstituteReorder>);
+ registerRewrite(
+ LogicalRewriteType::GroupCollationReorder,
+ &LogicalRewriter::bindAboveBelow<GroupByNode, CollationNode, SubstituteReorder>);
+ registerRewrite(LogicalRewriteType::FilterUnwindReorder,
+ &LogicalRewriter::bindAboveBelow<FilterNode, UnwindNode, SubstituteReorder>);
+ registerRewrite(
+ LogicalRewriteType::EvaluationUnwindReorder,
+ &LogicalRewriter::bindAboveBelow<EvaluationNode, UnwindNode, SubstituteReorder>);
+ registerRewrite(LogicalRewriteType::UnwindCollationReorder,
+ &LogicalRewriter::bindAboveBelow<UnwindNode, CollationNode, SubstituteReorder>);
+
+ registerRewrite(LogicalRewriteType::FilterExchangeReorder,
+ &LogicalRewriter::bindAboveBelow<FilterNode, ExchangeNode, SubstituteReorder>);
+ registerRewrite(
+ LogicalRewriteType::ExchangeEvaluationReorder,
+ &LogicalRewriter::bindAboveBelow<ExchangeNode, EvaluationNode, SubstituteReorder>);
+
+ registerRewrite(LogicalRewriteType::FilterUnionReorder,
+ &LogicalRewriter::bindAboveBelow<FilterNode, UnionNode, SubstituteReorder>);
+
+ registerRewrite(
+ LogicalRewriteType::CollationMerge,
+ &LogicalRewriter::bindAboveBelow<CollationNode, CollationNode, SubstituteMerge>);
+ registerRewrite(
+ LogicalRewriteType::LimitSkipMerge,
+ &LogicalRewriter::bindAboveBelow<LimitSkipNode, LimitSkipNode, SubstituteMerge>);
+
+ registerRewrite(LogicalRewriteType::SargableFilterReorder,
+ &LogicalRewriter::bindAboveBelow<SargableNode, FilterNode, SubstituteReorder>);
+ registerRewrite(
+ LogicalRewriteType::SargableEvaluationReorder,
+ &LogicalRewriter::bindAboveBelow<SargableNode, EvaluationNode, SubstituteReorder>);
+
+ registerRewrite(LogicalRewriteType::LimitSkipSubstitute,
+ &LogicalRewriter::bindSingleNode<LimitSkipNode, SubstituteConvert>);
+
+ registerRewrite(LogicalRewriteType::SargableMerge,
+ &LogicalRewriter::bindAboveBelow<SargableNode, SargableNode, SubstituteMerge>);
+ registerRewrite(LogicalRewriteType::FilterSubstitute,
+ &LogicalRewriter::bindSingleNode<FilterNode, SubstituteConvert>);
+ registerRewrite(LogicalRewriteType::EvaluationSubstitute,
+ &LogicalRewriter::bindSingleNode<EvaluationNode, SubstituteConvert>);
+
+ registerRewrite(LogicalRewriteType::FilterValueScanPropagate,
+ &LogicalRewriter::bindAboveBelow<FilterNode, ValueScanNode, SubstituteReorder>);
+ registerRewrite(
+ LogicalRewriteType::EvaluationValueScanPropagate,
+ &LogicalRewriter::bindAboveBelow<EvaluationNode, ValueScanNode, SubstituteReorder>);
+ registerRewrite(
+ LogicalRewriteType::SargableValueScanPropagate,
+ &LogicalRewriter::bindAboveBelow<SargableNode, ValueScanNode, SubstituteReorder>);
+ registerRewrite(
+ LogicalRewriteType::CollationValueScanPropagate,
+ &LogicalRewriter::bindAboveBelow<CollationNode, ValueScanNode, SubstituteReorder>);
+ registerRewrite(
+ LogicalRewriteType::LimitSkipValueScanPropagate,
+ &LogicalRewriter::bindAboveBelow<LimitSkipNode, ValueScanNode, SubstituteReorder>);
+ registerRewrite(
+ LogicalRewriteType::ExchangeValueScanPropagate,
+ &LogicalRewriter::bindAboveBelow<ExchangeNode, ValueScanNode, SubstituteReorder>);
+
+ registerRewrite(LogicalRewriteType::GroupByExplore,
+ &LogicalRewriter::bindSingleNode<GroupByNode, ExploreConvert>);
+ registerRewrite(LogicalRewriteType::SargableSplit,
+ &LogicalRewriter::bindSingleNode<SargableNode, ExploreConvert>);
+
+ registerRewrite(LogicalRewriteType::FilterRIDIntersectReorder,
+ &LogicalRewriter::bindAboveBelow<FilterNode, RIDIntersectNode, ExploreReorder>);
+ registerRewrite(
+ LogicalRewriteType::EvaluationRIDIntersectReorder,
+ &LogicalRewriter::bindAboveBelow<EvaluationNode, RIDIntersectNode, ExploreReorder>);
+}
+
+bool LogicalRewriter::rewriteToFixPoint() {
+ int iterationCount = 0;
+
+ while (!_groupsPending.empty()) {
+ iterationCount++;
+ if (_memo.getDebugInfo().exceedsIterationLimit(iterationCount)) {
+ // Iteration limit exceeded.
+ return false;
+ }
+
+ const GroupIdType groupId = *_groupsPending.begin();
+ rewriteGroup(groupId);
+ _groupsPending.erase(groupId);
+ }
+
+ return true;
+}
+
+void LogicalRewriter::rewriteGroup(const GroupIdType groupId) {
+ auto& queue = _memo.getGroup(groupId)._logicalRewriteQueue;
+ while (!queue.empty()) {
+ LogicalRewriteEntry rewriteEntry = std::move(*queue.top());
+ // TODO: check if rewriteEntry is different than previous (remove duplicates).
+ queue.pop();
+
+ _rewriteMap.at(rewriteEntry._type)(this, rewriteEntry._nodeId);
+ }
+}
+
+template <class AboveType, class BelowType, template <class, class> class R>
+void LogicalRewriter::bindAboveBelow(const MemoLogicalNodeId nodeMemoId) {
+ // Get a reference to the node instead of the node itself.
+ // Rewrites insert into the memo and can move it.
+ ABT::reference_type node = _memo.getNode(nodeMemoId);
+ const GroupIdType currentGroupId = nodeMemoId._groupId;
+
+ if (node.is<AboveType>()) {
+ // Try to bind as parent.
+ const GroupIdType targetGroupId = node.cast<AboveType>()
+ ->getChild()
+ .template cast<MemoLogicalDelegatorNode>()
+ ->getGroupId();
+
+ for (size_t i = 0; i < _memo.getGroup(targetGroupId)._logicalNodes.size(); i++) {
+ const MemoLogicalNodeId targetNodeId{targetGroupId, i};
+ auto targetNode = _memo.getNode(targetNodeId);
+ if (targetNode.is<BelowType>()) {
+ RewriteContext ctx(*this, nodeMemoId, targetNodeId);
+ R<AboveType, BelowType>()(node, targetNode, ctx);
+ if (ctx.hasSubstituted()) {
+ return;
+ }
+ }
+ }
+ }
+
+ if (node.is<BelowType>()) {
+ // Try to bind as child.
+ NodeIdSet usageNodeIdSet;
+ {
+ const auto& inputGroupsToNodeId = _memo.getInputGroupsToNodeIdMap();
+ auto it = inputGroupsToNodeId.find({currentGroupId});
+ if (it != inputGroupsToNodeId.cend()) {
+ usageNodeIdSet = it->second;
+ }
+ }
+
+ for (const MemoLogicalNodeId& parentNodeId : usageNodeIdSet) {
+ auto targetNode = _memo.getNode(parentNodeId);
+ if (targetNode.is<AboveType>()) {
+ uassert(6624047,
+ "Parent child groupId mismatch (usage map index incorrect?)",
+ targetNode.cast<AboveType>()
+ ->getChild()
+ .template cast<MemoLogicalDelegatorNode>()
+ ->getGroupId() == currentGroupId);
+
+ RewriteContext ctx(*this, parentNodeId, nodeMemoId);
+ R<AboveType, BelowType>()(targetNode, node, ctx);
+ if (ctx.hasSubstituted()) {
+ return;
+ }
+ }
+ }
+ }
+}
+
+template <class Type, template <class> class R>
+void LogicalRewriter::bindSingleNode(const MemoLogicalNodeId nodeMemoId) {
+ // Get a reference to the node instead of the node itself.
+ // Rewrites insert into the memo and can move it.
+ ABT::reference_type node = _memo.getNode(nodeMemoId);
+ if (node.is<Type>()) {
+ RewriteContext ctx(*this, nodeMemoId);
+ R<Type>()(node, ctx);
+ }
+}
+
+const LogicalRewriter::RewriteSet& LogicalRewriter::getExplorationSet() {
+ return _explorationSet;
+}
+
+const LogicalRewriter::RewriteSet& LogicalRewriter::getSubstitutionSet() {
+ return _substitutionSet;
+}
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/logical_rewriter.h b/src/mongo/db/query/optimizer/cascades/logical_rewriter.h
new file mode 100644
index 00000000000..642fdae1842
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/logical_rewriter.h
@@ -0,0 +1,132 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <queue>
+
+#include "mongo/db/query/optimizer/cascades/logical_rewriter_rules.h"
+#include "mongo/db/query/optimizer/cascades/memo.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer::cascades {
+
+class LogicalRewriter {
+ friend class RewriteContext;
+
+public:
+ /**
+ * Maximum size of PartialSchemaRequirements for a SargableNode.
+ * This limits the number of splits for index intersection.
+ */
+ static constexpr size_t kMaxPartialSchemaReqCount = 10;
+
+ /*
+ * How many times are we allowed to split a sargable node to facilitate index intersection.
+ * Results in at most 2^N index intersections.
+ */
+ static constexpr size_t kMaxSargableNodeSplitCount = 2;
+
+ /**
+ * Map of rewrite type to rewrite priority
+ */
+ using RewriteSet = opt::unordered_map<LogicalRewriteType, double>;
+
+ LogicalRewriter(Memo& memo, PrefixId& prefixId, RewriteSet rewriteSet);
+
+ LogicalRewriter() = delete;
+ LogicalRewriter(const LogicalRewriter& other) = delete;
+ LogicalRewriter(LogicalRewriter&& other) = default;
+
+ GroupIdType addRootNode(const ABT& node);
+ std::pair<GroupIdType, NodeIdSet> addNode(const ABT& node,
+ GroupIdType targetGroupId,
+ bool addExistingNodeWithNewChild);
+ void clearGroup(GroupIdType groupId);
+
+ /**
+ * Performs logical rewrites across all groups until a fix point is reached.
+ * Use this method to perform "standalone" rewrites.
+ */
+ bool rewriteToFixPoint();
+
+ /**
+ * Performs rewrites only for a particular group. Use this method to perform rewrites driven by
+ * top-down optimization.
+ */
+ void rewriteGroup(GroupIdType groupId);
+
+ static const RewriteSet& getExplorationSet();
+ static const RewriteSet& getSubstitutionSet();
+
+private:
+ using RewriteFn =
+ std::function<void(LogicalRewriter* rewriter, const MemoLogicalNodeId nodeId)>;
+ using RewriteFnMap = opt::unordered_map<LogicalRewriteType, RewriteFn>;
+
+ /**
+ * Attempts to perform a reordering rewrite specified by the R template argument.
+ */
+ template <class AboveType, class BelowType, template <class, class> class R>
+ void bindAboveBelow(MemoLogicalNodeId nodeMemoId);
+
+ /**
+ * Attempts to perform a simple rewrite specified by the R template argument.
+ */
+ template <class Type, template <class> class R>
+ void bindSingleNode(MemoLogicalNodeId nodeMemoId);
+
+ void registerRewrite(LogicalRewriteType rewriteType, RewriteFn fn);
+ void initializeRewrites();
+
+ static RewriteSet _explorationSet;
+ static RewriteSet _substitutionSet;
+
+ const RewriteSet _activeRewriteSet;
+
+ // For standalone logical rewrite phase, keeps track of which groups still have rewrites
+ // pending.
+ std::set<int> _groupsPending;
+
+ // We don't own those:
+ Memo& _memo;
+ PrefixId& _prefixId;
+
+ RewriteFnMap _rewriteMap;
+
+ // Contains the set of top-level index fields for a given scanDef. For example "a.b" is encoded
+ // as "a". This is used to constrain the possible splits of a sargable node.
+ opt::unordered_map<std::string, opt::unordered_set<FieldNameType>> _indexFieldPrefixMap;
+
+ // Track number of times a SargableNode at a given position in the memo has been split.
+ opt::unordered_map<MemoLogicalNodeId, size_t, NodeIdHash> _sargableSplitCountMap;
+};
+
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/logical_rewriter_rules.h b/src/mongo/db/query/optimizer/cascades/logical_rewriter_rules.h
new file mode 100644
index 00000000000..ab562c685c8
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/logical_rewriter_rules.h
@@ -0,0 +1,89 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/defs.h"
+
+namespace mongo::optimizer::cascades {
+
+#define LOGICALREWRITER_NAMES(F) \
+ /* "Linear" reordering rewrites. */ \
+ F(FilterEvaluationReorder) \
+ F(FilterCollationReorder) \
+ F(EvaluationCollationReorder) \
+ F(EvaluationLimitSkipReorder) \
+ \
+ F(FilterGroupByReorder) \
+ F(GroupCollationReorder) \
+ \
+ F(FilterUnwindReorder) \
+ F(EvaluationUnwindReorder) \
+ F(UnwindCollationReorder) \
+ \
+ F(FilterExchangeReorder) \
+ F(ExchangeEvaluationReorder) \
+ \
+ F(FilterUnionReorder) \
+ \
+ /* Merging rewrites. */ \
+ F(CollationMerge) \
+ F(LimitSkipMerge) \
+ F(SargableMerge) \
+ \
+ /* Local-global optimization for GroupBy */ \
+ F(GroupByExplore) \
+ \
+ F(SargableFilterReorder) \
+ F(SargableEvaluationReorder) \
+ \
+ /* Propagate ValueScan nodes*/ \
+ F(FilterValueScanPropagate) \
+ F(EvaluationValueScanPropagate) \
+ F(SargableValueScanPropagate) \
+ F(CollationValueScanPropagate) \
+ F(LimitSkipValueScanPropagate) \
+ F(ExchangeValueScanPropagate) \
+ \
+ F(LimitSkipSubstitute) \
+ \
+ /* Convert filter and evaluation nodes into sargable nodes */ \
+ F(FilterSubstitute) \
+ F(EvaluationSubstitute) \
+ F(SargableSplit) \
+ F(FilterRIDIntersectReorder) \
+ F(EvaluationRIDIntersectReorder)
+
+MAKE_PRINTABLE_ENUM(LogicalRewriteType, LOGICALREWRITER_NAMES);
+MAKE_PRINTABLE_ENUM_STRING_ARRAY(LogicalRewriterTypeEnum,
+ LogicalRewriterType,
+ LOGICALREWRITER_NAMES);
+#undef LOGICALREWRITER_NAMES
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/memo.cpp b/src/mongo/db/query/optimizer/cascades/memo.cpp
new file mode 100644
index 00000000000..d39dc7ccbc5
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/memo.cpp
@@ -0,0 +1,794 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <set>
+
+#include "mongo/db/query/optimizer/cascades/memo.h"
+#include "mongo/db/query/optimizer/explain.h"
+#include "mongo/db/query/optimizer/utils/abt_hash.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer::cascades {
+
+size_t MemoNodeRefHash::operator()(const ABT::reference_type& nodeRef) const {
+ // Compare delegator as well.
+ return ABTHashGenerator::generate(nodeRef);
+}
+
+bool MemoNodeRefCompare::operator()(const ABT::reference_type& left,
+ const ABT::reference_type& right) const {
+ // Deep comparison.
+ return left.follow() == right.follow();
+}
+
+ABT::reference_type OrderPreservingABTSet::at(const size_t index) const {
+ return _vector.at(index).ref();
+}
+
+std::pair<size_t, bool> OrderPreservingABTSet::emplace_back(ABT node) {
+ auto [index, found] = find(node.ref());
+ if (found) {
+ return {index, false};
+ }
+
+ const size_t id = _vector.size();
+ _vector.emplace_back(std::move(node));
+ _map.emplace(_vector.back().ref(), id);
+ return {id, true};
+}
+
+std::pair<size_t, bool> OrderPreservingABTSet::find(ABT::reference_type node) const {
+ auto it = _map.find(node);
+ if (it == _map.end()) {
+ return {0, false};
+ }
+
+ return {it->second, true};
+}
+
+void OrderPreservingABTSet::clear() {
+ _map.clear();
+ _vector.clear();
+}
+
+size_t OrderPreservingABTSet::size() const {
+ return _vector.size();
+}
+
+const ABTVector& OrderPreservingABTSet::getVector() const {
+ return _vector;
+}
+
+PhysRewriteEntry::PhysRewriteEntry(const double priority,
+ ABT node,
+ std::vector<std::pair<ABT*, properties::PhysProps>> childProps,
+ NodeCEMap nodeCEMap)
+ : _priority(priority),
+ _node(std::move(node)),
+ _childProps(std::move(childProps)),
+ _nodeCEMap(std::move(nodeCEMap)) {}
+
+PhysOptimizationResult::PhysOptimizationResult()
+ : PhysOptimizationResult(0, {}, CostType::kInfinity) {}
+
+PhysOptimizationResult::PhysOptimizationResult(size_t index,
+ properties::PhysProps physProps,
+ CostType costLimit)
+ : _index(index),
+ _physProps(std::move(physProps)),
+ _costLimit(std::move(costLimit)),
+ _nodeInfo(),
+ _lastImplementedNodePos(0),
+ _queue() {}
+
+bool PhysOptimizationResult::isOptimized() const {
+ return _queue.empty();
+}
+
+void PhysOptimizationResult::raiseCostLimit(CostType costLimit) {
+ _costLimit = costLimit;
+ // Allow for re-optimization under the higher cost limit.
+ _lastImplementedNodePos = 0;
+}
+
+bool PhysRewriteEntryComparator::operator()(const std::unique_ptr<PhysRewriteEntry>& x,
+ const std::unique_ptr<PhysRewriteEntry>& y) const {
+ // Lower numerical priority is considered last (and thus de-queued first).
+ return x->_priority > y->_priority;
+}
+
+static ABT createBinderMap(const properties::LogicalProps& logicalProperties) {
+ const properties::ProjectionAvailability& projSet =
+ properties::getPropertyConst<properties::ProjectionAvailability>(logicalProperties);
+
+ ProjectionNameVector projectionVector;
+ ABTVector expressions;
+
+ ProjectionNameOrderedSet ordered = convertToOrderedSet(projSet.getProjections());
+ for (const ProjectionName& projection : ordered) {
+ projectionVector.push_back(projection);
+ expressions.emplace_back(make<Source>());
+ }
+
+ return make<ExpressionBinder>(std::move(projectionVector), std::move(expressions));
+}
+
+Group::Group(ProjectionNameSet projections)
+ : _logicalNodes(),
+ _logicalProperties(
+ properties::makeLogicalProps(properties::ProjectionAvailability(std::move(projections)))),
+ _binder(createBinderMap(_logicalProperties)),
+ _logicalRewriteQueue(),
+ _physicalNodes() {}
+
+const ExpressionBinder& Group::binder() const {
+ auto pointer = _binder.cast<ExpressionBinder>();
+ uassert(6624048, "Invalid binder type", pointer);
+
+ return *pointer;
+}
+
+PhysOptimizationResult& PhysNodes::addOptimizationResult(properties::PhysProps properties,
+ CostType costLimit) {
+ const size_t index = _physicalNodes.size();
+ _physPropsToPhysNodeMap.emplace(properties, index);
+ return *_physicalNodes.emplace_back(std::make_unique<PhysOptimizationResult>(
+ index, std::move(properties), std::move(costLimit)));
+}
+
+const PhysOptimizationResult& PhysNodes::at(const size_t index) const {
+ return *_physicalNodes.at(index);
+}
+
+PhysOptimizationResult& PhysNodes::at(const size_t index) {
+ return *_physicalNodes.at(index);
+}
+
+std::pair<size_t, bool> PhysNodes::find(const properties::PhysProps& props) const {
+ auto it = _physPropsToPhysNodeMap.find(props);
+ if (it == _physPropsToPhysNodeMap.cend()) {
+ return {0, false};
+ }
+ return {it->second, true};
+}
+
+const PhysNodes::PhysNodeVector& PhysNodes::getNodes() const {
+ return _physicalNodes;
+}
+
+size_t PhysNodes::PhysPropsHasher::operator()(const properties::PhysProps& physProps) const {
+ return ABTHashGenerator::generateForPhysProps(physProps);
+}
+
+class MemoIntegrator {
+public:
+ explicit MemoIntegrator(Memo& memo,
+ Memo::NodeTargetGroupMap targetGroupMap,
+ NodeIdSet& insertedNodeIds,
+ const bool addExistingNodeWithNewChild)
+ : _memo(memo),
+ _insertedNodeIds(insertedNodeIds),
+ _targetGroupMap(std::move(targetGroupMap)),
+ _addExistingNodeWithNewChild(addExistingNodeWithNewChild) {}
+
+ /**
+ * Nodes
+ */
+ void prepare(const ABT& n, const ScanNode& node, const VariableEnvironment& /*env*/) {
+ // noop
+ }
+
+ GroupIdType transport(const ABT& n,
+ const ScanNode& node,
+ const VariableEnvironment& env,
+ GroupIdType /*binder*/) {
+ return addNodes(n, node, n, env, {});
+ }
+
+ void prepare(const ABT& n, const ValueScanNode& node, const VariableEnvironment& /*env*/) {
+ // noop
+ }
+
+ GroupIdType transport(const ABT& n,
+ const ValueScanNode& node,
+ const VariableEnvironment& env,
+ GroupIdType /*binder*/) {
+ return addNodes(n, node, n, env, {});
+ }
+
+ void prepare(const ABT& n,
+ const MemoLogicalDelegatorNode& node,
+ const VariableEnvironment& /*env*/) {
+ // noop
+ }
+
+ GroupIdType transport(const ABT& /*n*/,
+ const MemoLogicalDelegatorNode& node,
+ const VariableEnvironment& /*env*/) {
+ return node.getGroupId();
+ }
+
+ void prepare(const ABT& n, const FilterNode& node, const VariableEnvironment& /*env*/) {
+ updateTargetGroupMapUnary(n, node);
+ }
+
+ GroupIdType transport(const ABT& n,
+ const FilterNode& node,
+ const VariableEnvironment& env,
+ GroupIdType child,
+ GroupIdType /*binder*/) {
+ return addNode(n, node, env, child);
+ }
+
+ void prepare(const ABT& n, const EvaluationNode& node, const VariableEnvironment& /*env*/) {
+ updateTargetGroupMapUnary(n, node);
+ }
+
+ GroupIdType transport(const ABT& n,
+ const EvaluationNode& node,
+ const VariableEnvironment& env,
+ GroupIdType child,
+ GroupIdType /*binder*/) {
+ return addNode(n, node, env, child);
+ }
+
+ void prepare(const ABT& n, const SargableNode& node, const VariableEnvironment& /*env*/) {
+ updateTargetGroupMapUnary(n, node);
+ }
+
+ GroupIdType transport(const ABT& n,
+ const SargableNode& node,
+ const VariableEnvironment& env,
+ GroupIdType child,
+ GroupIdType /*binder*/,
+ GroupIdType /*references*/) {
+ return addNode(n, node, env, child);
+ }
+
+ void prepare(const ABT& n, const RIDIntersectNode& node, const VariableEnvironment& /*env*/) {
+ // noop.
+ }
+
+ GroupIdType transport(const ABT& n,
+ const RIDIntersectNode& node,
+ const VariableEnvironment& env,
+ GroupIdType leftChild,
+ GroupIdType rightChild) {
+ return addNodes(n, node, env, leftChild, rightChild);
+ }
+
+ void prepare(const ABT& n, const BinaryJoinNode& node, const VariableEnvironment& /*env*/) {
+ updateTargetGroupMapBinary(n, node);
+ }
+
+ GroupIdType transport(const ABT& n,
+ const BinaryJoinNode& node,
+ const VariableEnvironment& env,
+ GroupIdType leftChild,
+ GroupIdType rightChild,
+ GroupIdType /*filter*/) {
+ return addNodes(n, node, env, leftChild, rightChild);
+ }
+
+ void prepare(const ABT& n, const UnionNode& node, const VariableEnvironment& /*env*/) {
+ updateTargetGroupMapNary(n, node);
+ }
+
+ GroupIdType transport(const ABT& n,
+ const UnionNode& node,
+ const VariableEnvironment& env,
+ Memo::GroupIdVector children,
+ GroupIdType /*binder*/,
+ GroupIdType /*refs*/) {
+ return addNodes(n, node, env, std::move(children));
+ }
+
+ void prepare(const ABT& n, const GroupByNode& node, const VariableEnvironment& /*env*/) {
+ updateTargetGroupMapUnary(n, node);
+ }
+
+ GroupIdType transport(const ABT& n,
+ const GroupByNode& node,
+ const VariableEnvironment& env,
+ GroupIdType child,
+ GroupIdType /*binderAgg*/,
+ GroupIdType /*refsAgg*/,
+ GroupIdType /*binderGb*/,
+ GroupIdType /*refsGb*/) {
+ return addNode(n, node, env, child);
+ }
+
+ void prepare(const ABT& n, const UnwindNode& node, const VariableEnvironment& /*env*/) {
+ updateTargetGroupMapUnary(n, node);
+ }
+
+ GroupIdType transport(const ABT& n,
+ const UnwindNode& node,
+ const VariableEnvironment& env,
+ GroupIdType child,
+ GroupIdType /*binder*/,
+ GroupIdType /*refs*/) {
+ return addNode(n, node, env, child);
+ }
+
+ void prepare(const ABT& n, const CollationNode& node, const VariableEnvironment& /*env*/) {
+ updateTargetGroupMapUnary(n, node);
+ }
+
+ GroupIdType transport(const ABT& n,
+ const CollationNode& node,
+ const VariableEnvironment& env,
+ GroupIdType child,
+ GroupIdType /*refs*/) {
+ return addNode(n, node, env, child);
+ }
+
+ void prepare(const ABT& n, const LimitSkipNode& node, const VariableEnvironment& /*env*/) {
+ updateTargetGroupMapUnary(n, node);
+ }
+
+ GroupIdType transport(const ABT& n,
+ const LimitSkipNode& node,
+ const VariableEnvironment& env,
+ GroupIdType child) {
+ return addNode(n, node, env, child);
+ }
+
+ void prepare(const ABT& n, const ExchangeNode& node, const VariableEnvironment& /*env*/) {
+ updateTargetGroupMapUnary(n, node);
+ }
+
+ GroupIdType transport(const ABT& n,
+ const ExchangeNode& node,
+ const VariableEnvironment& env,
+ GroupIdType child,
+ GroupIdType /*refs*/) {
+ return addNode(n, node, env, child);
+ }
+
+ void prepare(const ABT& n, const RootNode& node, const VariableEnvironment& /*env*/) {
+ updateTargetGroupMapUnary(n, node);
+ }
+
+ GroupIdType transport(const ABT& n,
+ const RootNode& node,
+ const VariableEnvironment& env,
+ GroupIdType child,
+ GroupIdType /*refs*/) {
+ return addNode(n, node, env, child);
+ }
+
+ /**
+ * Other ABT types.
+ */
+ template <typename T, typename... Ts>
+ GroupIdType transport(const ABT& /*n*/,
+ const T& /*node*/,
+ const VariableEnvironment& /*env*/,
+ Ts&&...) {
+ static_assert(!canBeLogicalNode<T>(), "Logical node must implement its transport.");
+ return -1;
+ }
+
+ template <typename T, typename... Ts>
+ void prepare(const ABT& n, const T& /*node*/, const VariableEnvironment& /*env*/) {
+ static_assert(!canBeLogicalNode<T>(), "Logical node must implement its prepare.");
+ }
+
+ GroupIdType integrate(const ABT& n) {
+ return algebra::transport<true>(n, *this, VariableEnvironment::build(n, &_memo));
+ }
+
+private:
+ GroupIdType addNodes(const ABT& n,
+ const Node& node,
+ ABT forMemo,
+ const VariableEnvironment& env,
+ Memo::GroupIdVector childGroupIds) {
+ auto it = _targetGroupMap.find(n.ref());
+ const GroupIdType targetGroupId = (it == _targetGroupMap.cend()) ? -1 : it->second;
+ const auto result = _memo.addNode(std::move(childGroupIds),
+ env.getProjections(&node),
+ targetGroupId,
+ _insertedNodeIds,
+ std::move(forMemo));
+ return result._groupId;
+ }
+
+ template <class T, typename... Args>
+ GroupIdType addNodes(const ABT& n,
+ const T& node,
+ const VariableEnvironment& env,
+ Memo::GroupIdVector childGroupIds) {
+ ABT forMemo = n;
+ auto& childNodes = forMemo.template cast<T>()->nodes();
+ for (size_t i = 0; i < childNodes.size(); i++) {
+ const GroupIdType childGroupId = childGroupIds.at(i);
+ uassert(6624121, "Invalid child group", childGroupId >= 0);
+ childNodes.at(i) = make<MemoLogicalDelegatorNode>(childGroupId);
+ }
+
+ return addNodes(n, node, std::move(forMemo), env, std::move(childGroupIds));
+ }
+
+ template <class T>
+ GroupIdType addNode(const ABT& n,
+ const T& node,
+ const VariableEnvironment& env,
+ GroupIdType childGroupId) {
+ ABT forMemo = n;
+ uassert(6624122, "Invalid child group", childGroupId >= 0);
+ forMemo.cast<T>()->getChild() = make<MemoLogicalDelegatorNode>(childGroupId);
+ return addNodes(n, node, std::move(forMemo), env, {childGroupId});
+ }
+
+ template <class T>
+ GroupIdType addNodes(const ABT& n,
+ const T& node,
+ const VariableEnvironment& env,
+ GroupIdType leftGroupId,
+ GroupIdType rightGroupId) {
+ ABT forMemo = n;
+ uassert(6624123, "Invalid left child group", leftGroupId >= 0);
+ uassert(6624124, "Invalid right child group", rightGroupId >= 0);
+
+ forMemo.cast<T>()->getLeftChild() = make<MemoLogicalDelegatorNode>(leftGroupId);
+ forMemo.cast<T>()->getRightChild() = make<MemoLogicalDelegatorNode>(rightGroupId);
+ return addNodes(n, node, std::move(forMemo), env, {leftGroupId, rightGroupId});
+ }
+
+ template <class T>
+ ABT::reference_type findExistingNodeFromTargetGroupMap(const ABT& n, const T& node) {
+ auto it = _targetGroupMap.find(n.ref());
+ if (it == _targetGroupMap.cend()) {
+ return nullptr;
+ }
+ const auto [index, found] = _memo.findNodeInGroup(it->second, n.ref());
+ if (!found) {
+ return nullptr;
+ }
+
+ ABT::reference_type result = _memo.getNode({it->second, index});
+ uassert(6624049, "Node type in memo does not match target type", result.is<T>());
+ return result;
+ }
+
+ void updateTargetGroupRefs(
+ const std::vector<std::pair<ABT::reference_type, GroupIdType>>& childGroups) {
+ for (auto [childRef, targetGroupId] : childGroups) {
+ auto it = _targetGroupMap.find(childRef);
+ if (it == _targetGroupMap.cend()) {
+ _targetGroupMap.emplace(childRef, targetGroupId);
+ } else if (it->second != targetGroupId) {
+ uasserted(6624050, "Incompatible target groups for parent and child");
+ }
+ }
+ }
+
+ template <class T>
+ void updateTargetGroupMapUnary(const ABT& n, const T& node) {
+ if (_addExistingNodeWithNewChild) {
+ return;
+ }
+
+ ABT::reference_type existing = findExistingNodeFromTargetGroupMap(n, node);
+ if (!existing.empty()) {
+ const GroupIdType targetGroupId = existing.cast<T>()
+ ->getChild()
+ .template cast<MemoLogicalDelegatorNode>()
+ ->getGroupId();
+ updateTargetGroupRefs({{node.getChild().ref(), targetGroupId}});
+ }
+ }
+
+ template <class T>
+ void updateTargetGroupMapNary(const ABT& n, const T& node) {
+ ABT::reference_type existing = findExistingNodeFromTargetGroupMap(n, node);
+ if (!existing.empty()) {
+ const ABTVector& existingChildren = existing.cast<T>()->nodes();
+ const ABTVector& targetChildren = node.nodes();
+ uassert(6624051,
+ "Different number of children between existing and target node",
+ existingChildren.size() == targetChildren.size());
+
+ std::vector<std::pair<ABT::reference_type, GroupIdType>> childGroups;
+ for (size_t i = 0; i < existingChildren.size(); i++) {
+ const ABT& existingChild = existingChildren.at(i);
+ const ABT& targetChild = targetChildren.at(i);
+ childGroups.emplace_back(
+ targetChild.ref(),
+ existingChild.cast<MemoLogicalDelegatorNode>()->getGroupId());
+ }
+ updateTargetGroupRefs(childGroups);
+ }
+ }
+
+ template <class T>
+ void updateTargetGroupMapBinary(const ABT& n, const T& node) {
+ ABT::reference_type existing = findExistingNodeFromTargetGroupMap(n, node);
+ if (existing.empty()) {
+ return;
+ }
+
+ const T& existingNode = *existing.cast<T>();
+ const GroupIdType leftGroupId =
+ existingNode.getLeftChild().template cast<MemoLogicalDelegatorNode>()->getGroupId();
+ const GroupIdType rightGroupId =
+ existingNode.getRightChild().template cast<MemoLogicalDelegatorNode>()->getGroupId();
+ updateTargetGroupRefs(
+ {{node.getLeftChild().ref(), leftGroupId}, {node.getRightChild().ref(), rightGroupId}});
+ }
+
+ /**
+ * We do not own any of these.
+ */
+ Memo& _memo;
+ NodeIdSet& _insertedNodeIds;
+
+ /**
+ * We own this.
+ */
+ Memo::NodeTargetGroupMap _targetGroupMap;
+
+ // If set we enable modification of target group based on existing nodes. In practical terms, we
+ // would not assume that if F(x) = F(y) then x = y. This is currently used in conjunction with
+ // $elemMatch rewrite (PathTraverse over PathCompose).
+ bool _addExistingNodeWithNewChild;
+};
+
+size_t Memo::GroupIdVectorHash::operator()(const Memo::GroupIdVector& v) const {
+ size_t result = 17;
+ for (const GroupIdType id : v) {
+ updateHash(result, std::hash<GroupIdType>()(id));
+ }
+ return result;
+}
+
+size_t Memo::NodeTargetGroupHash::operator()(const ABT::reference_type& nodeRef) const {
+ return std::hash<const Node*>()(nodeRef.cast<Node>());
+}
+
+Memo::Memo(DebugInfo debugInfo,
+ const Metadata& metadata,
+ std::unique_ptr<LogicalPropsInterface> logicalPropsDerivation,
+ std::unique_ptr<CEInterface> ceDerivation)
+ : _groups(),
+ _inputGroupsToNodeIdMap(),
+ _nodeIdToInputGroupsMap(),
+ _metadata(metadata),
+ _logicalPropsDerivation(std::move(logicalPropsDerivation)),
+ _ceDerivation(std::move(ceDerivation)),
+ _debugInfo(std::move(debugInfo)),
+ _stats() {
+ uassert(6624125, "Empty logical properties derivation", _logicalPropsDerivation.get());
+ uassert(6624126, "Empty CE derivation", _ceDerivation.get());
+}
+
+const Group& Memo::getGroup(const GroupIdType groupId) const {
+ return *_groups.at(groupId);
+}
+
+Group& Memo::getGroup(const GroupIdType groupId) {
+ return *_groups.at(groupId);
+}
+
+std::pair<size_t, bool> Memo::findNodeInGroup(GroupIdType groupId, ABT::reference_type node) const {
+ return getGroup(groupId)._logicalNodes.find(node);
+}
+
+GroupIdType Memo::addGroup(ProjectionNameSet projections) {
+ _groups.emplace_back(std::make_unique<Group>(std::move(projections)));
+ return _groups.size() - 1;
+}
+
+std::pair<MemoLogicalNodeId, bool> Memo::addNode(GroupIdType groupId, ABT n) {
+ uassert(6624052, "Attempting to insert a physical node", !n.is<PhysicalNode>());
+ uassert(6624053,
+ "Attempting to insert a logical delegator node",
+ !n.is<MemoLogicalDelegatorNode>());
+
+ OrderPreservingABTSet& nodes = _groups.at(groupId)->_logicalNodes;
+ auto [index, inserted] = nodes.emplace_back(std::move(n));
+ return {{groupId, index}, inserted};
+}
+
+ABT::reference_type Memo::getNode(const MemoLogicalNodeId nodeMemoId) const {
+ return getGroup(nodeMemoId._groupId)._logicalNodes.at(nodeMemoId._index);
+}
+
+std::pair<MemoLogicalNodeId, bool> Memo::findNode(const GroupIdVector& groups, const ABT& node) {
+ const auto it = _inputGroupsToNodeIdMap.find(groups);
+ if (it != _inputGroupsToNodeIdMap.cend()) {
+ for (const MemoLogicalNodeId& nodeMemoId : it->second) {
+ if (getNode(nodeMemoId) == node) {
+ return {nodeMemoId, true};
+ }
+ }
+ }
+ return {{0, 0}, false};
+}
+
+void Memo::estimateCE(const GroupIdType groupId) {
+ // If inserted into a new group, derive logical properties, and cardinality estimation
+ // for the new group.
+ Group& group = getGroup(groupId);
+ properties::LogicalProps& props = group._logicalProperties;
+
+ const ABT::reference_type nodeRef = group._logicalNodes.at(0);
+ properties::LogicalProps logicalProps =
+ _logicalPropsDerivation->deriveProps(_metadata, nodeRef, nullptr, this, groupId);
+ props.merge(logicalProps);
+
+ const CEType estimate = _ceDerivation->deriveCE(*this, props, nodeRef);
+ auto ceProp = properties::CardinalityEstimate(estimate);
+
+ if (auto sargablePtr = nodeRef.cast<SargableNode>(); sargablePtr != nullptr) {
+ auto& partialSchemaKeyCEMap = ceProp.getPartialSchemaKeyCEMap();
+ for (const auto& [key, req] : sargablePtr->getReqMap()) {
+ ABT singularReq = make<SargableNode>(PartialSchemaRequirements{{key, req}},
+ CandidateIndexMap{},
+ sargablePtr->getTarget(),
+ sargablePtr->getChild());
+ const CEType singularEst = _ceDerivation->deriveCE(*this, props, singularReq.ref());
+ partialSchemaKeyCEMap.emplace(key, singularEst);
+ }
+ }
+
+ properties::setPropertyOverwrite(props, std::move(ceProp));
+ if (_debugInfo.hasDebugLevel(2)) {
+ std::cout << "Group " << groupId << ": "
+ << ExplainGenerator::explainLogicalProps("Logical properties", props);
+ }
+}
+
+MemoLogicalNodeId Memo::addNode(GroupIdVector groupVector,
+ ProjectionNameSet projections,
+ const GroupIdType targetGroupId,
+ NodeIdSet& insertedNodeIds,
+ ABT n) {
+ for (const GroupIdType groupId : groupVector) {
+ // Invalid tree: node is its own child.
+ uassert(6624127, "Target group appears inside group vector", groupId != targetGroupId);
+ }
+
+ auto [existingId, foundNode] = findNode(groupVector, n);
+
+ if (foundNode) {
+ uassert(6624054,
+ "Found node outside target group",
+ targetGroupId < 0 || targetGroupId == existingId._groupId);
+ return existingId;
+ }
+
+ const bool noTargetGroup = targetGroupId < 0;
+ // Only for debugging.
+ ProjectionNameSet projectionsCopy;
+ if (!noTargetGroup && _debugInfo.isDebugMode()) {
+ projectionsCopy = projections;
+ }
+
+ // Current node is not in the memo. Insert unchanged.
+ const GroupIdType groupId = noTargetGroup ? addGroup(std::move(projections)) : targetGroupId;
+ auto [newId, inserted] = addNode(groupId, std::move(n));
+ if (inserted || noTargetGroup) {
+ insertedNodeIds.insert(newId);
+ _inputGroupsToNodeIdMap[groupVector].insert(newId);
+ _nodeIdToInputGroupsMap[newId] = groupVector;
+
+ if (noTargetGroup) {
+ estimateCE(groupId);
+ } else if (_debugInfo.isDebugMode()) {
+ const Group& group = getGroup(groupId);
+ // If inserted into an existing group, verify we deliver all expected projections.
+ for (const ProjectionName& groupProjection : group.binder().names()) {
+ uassert(6624055,
+ "Node does not project all specified group projections",
+ projectionsCopy.find(groupProjection) != projectionsCopy.cend());
+ }
+
+ // TODO: possibly verify cardinality estimation
+ }
+ }
+
+ return newId;
+}
+
+GroupIdType Memo::integrate(const ABT& node,
+ NodeTargetGroupMap targetGroupMap,
+ NodeIdSet& insertedNodeIds,
+ const bool addExistingNodeWithNewChild) {
+ _stats._numIntegrations++;
+ MemoIntegrator integrator(
+ *this, std::move(targetGroupMap), insertedNodeIds, addExistingNodeWithNewChild);
+ return integrator.integrate(node);
+}
+
+size_t Memo::getGroupCount() const {
+ return _groups.size();
+}
+
+void Memo::clearLogicalNodes(const GroupIdType groupId) {
+ auto& group = getGroup(groupId);
+ auto& logicalNodes = group._logicalNodes;
+
+ for (size_t index = 0; index < logicalNodes.size(); index++) {
+ const MemoLogicalNodeId nodeId{groupId, index};
+ const auto& groupVector = _nodeIdToInputGroupsMap.at(nodeId);
+ _inputGroupsToNodeIdMap.at(groupVector).erase(nodeId);
+ _nodeIdToInputGroupsMap.erase(nodeId);
+ }
+
+ logicalNodes.clear();
+ group._logicalRewriteQueue = {};
+}
+
+const Memo::InputGroupsToNodeIdMap& Memo::getInputGroupsToNodeIdMap() const {
+ return _inputGroupsToNodeIdMap;
+}
+
+const DebugInfo& Memo::getDebugInfo() const {
+ return _debugInfo;
+}
+
+void Memo::clear() {
+ _stats = {};
+ _groups.clear();
+ _inputGroupsToNodeIdMap.clear();
+ _nodeIdToInputGroupsMap.clear();
+}
+
+const Memo::Stats& Memo::getStats() const {
+ return _stats;
+}
+
+size_t Memo::getLogicalNodeCount() const {
+ size_t result = 0;
+ for (const auto& group : _groups) {
+ result += group->_logicalNodes.size();
+ }
+ return result;
+}
+
+size_t Memo::getPhysicalNodeCount() const {
+ size_t result = 0;
+ for (const auto& group : _groups) {
+ result += group->_physicalNodes.getNodes().size();
+ }
+ return result;
+}
+
+const Metadata& Memo::getMetadata() const {
+ return _metadata;
+}
+
+const CEInterface& Memo::getCEDerivation() const {
+ return *_ceDerivation;
+}
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/memo.h b/src/mongo/db/query/optimizer/cascades/memo.h
new file mode 100644
index 00000000000..4d0a943ac56
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/memo.h
@@ -0,0 +1,248 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <map>
+#include <set>
+#include <unordered_map>
+#include <vector>
+
+#include "mongo/db/query/optimizer/cascades/interfaces.h"
+#include "mongo/db/query/optimizer/cascades/rewrite_queues.h"
+#include "mongo/db/query/optimizer/reference_tracker.h"
+
+namespace mongo::optimizer::cascades {
+
+struct MemoNodeRefHash {
+ size_t operator()(const ABT::reference_type& nodeRef) const;
+};
+
+struct MemoNodeRefCompare {
+ bool operator()(const ABT::reference_type& left, const ABT::reference_type& right) const;
+};
+
+class OrderPreservingABTSet {
+public:
+ OrderPreservingABTSet() = default;
+ OrderPreservingABTSet(const OrderPreservingABTSet&) = delete;
+ OrderPreservingABTSet(OrderPreservingABTSet&&) = default;
+
+ ABT::reference_type at(size_t index) const;
+ std::pair<size_t, bool> emplace_back(ABT node);
+ std::pair<size_t, bool> find(ABT::reference_type node) const;
+
+ void clear();
+
+ size_t size() const;
+ const ABTVector& getVector() const;
+
+private:
+ opt::unordered_map<ABT::reference_type, size_t, MemoNodeRefHash, MemoNodeRefCompare> _map;
+ ABTVector _vector;
+};
+
+struct PhysNodeInfo {
+ ABT _node;
+
+ // Total cost for the entire subtree.
+ CostType _cost;
+
+ // Operator cost (without including the subtree).
+ CostType _localCost;
+
+ // For display purposes, adjusted cardinality based on physical properties (e.g. Repetition and
+ // Limit-Skip).
+ CEType _adjustedCE;
+};
+
+struct PhysOptimizationResult {
+ PhysOptimizationResult();
+ PhysOptimizationResult(size_t index, properties::PhysProps physProps, CostType costLimit);
+
+ bool isOptimized() const;
+ void raiseCostLimit(CostType costLimit);
+
+ const size_t _index;
+ const properties::PhysProps _physProps;
+
+ CostType _costLimit;
+ // If set, we have successfully optimized.
+ boost::optional<PhysNodeInfo> _nodeInfo;
+ // Rejected physical plans.
+ std::vector<PhysNodeInfo> _rejectedNodeInfo;
+
+ // Index of last logical node in our group we implemented.
+ size_t _lastImplementedNodePos;
+
+ PhysRewriteQueue _queue;
+};
+
+struct PhysNodes {
+ using PhysNodeVector = std::vector<std::unique_ptr<PhysOptimizationResult>>;
+
+ PhysNodes() = default;
+
+ PhysOptimizationResult& addOptimizationResult(properties::PhysProps properties,
+ CostType costLimit);
+
+ const PhysOptimizationResult& at(size_t index) const;
+ PhysOptimizationResult& at(size_t index);
+
+ std::pair<size_t, bool> find(const properties::PhysProps& props) const;
+
+ const PhysNodeVector& getNodes() const;
+
+private:
+ PhysNodeVector _physicalNodes;
+
+ struct PhysPropsHasher {
+ size_t operator()(const properties::PhysProps& physProps) const;
+ };
+
+ // Used to speed up lookups into the winner's circle using physical properties.
+ opt::unordered_map<properties::PhysProps, size_t, PhysPropsHasher> _physPropsToPhysNodeMap;
+};
+
+struct Group {
+ explicit Group(ProjectionNameSet projections);
+
+ Group(const Group&) = delete;
+ Group(Group&&) = default;
+
+ const ExpressionBinder& binder() const;
+
+ // Associated logical nodes.
+ OrderPreservingABTSet _logicalNodes;
+ // Group logical properties.
+ properties::LogicalProps _logicalProperties;
+ ABT _binder;
+
+ LogicalRewriteQueue _logicalRewriteQueue;
+
+ // Best physical plan for given physical properties: aka "Winner's circle".
+ PhysNodes _physicalNodes;
+};
+
+class Memo {
+ friend class PhysicalRewriter;
+
+public:
+ using GroupIdVector = std::vector<GroupIdType>;
+
+ struct Stats {
+ // Number of calls to integrate()
+ size_t _numIntegrations = 0;
+ // Number of recursive physical optimization calls.
+ size_t _physPlanExplorationCount = 0;
+ // Number of checks to winner's circle.
+ size_t _physMemoCheckCount = 0;
+ };
+
+ struct GroupIdVectorHash {
+ size_t operator()(const GroupIdVector& v) const;
+ };
+ using InputGroupsToNodeIdMap = opt::unordered_map<GroupIdVector, NodeIdSet, GroupIdVectorHash>;
+
+ /**
+ * Inverse map.
+ */
+ using NodeIdToInputGroupsMap = opt::unordered_map<MemoLogicalNodeId, GroupIdVector, NodeIdHash>;
+
+ struct NodeTargetGroupHash {
+ size_t operator()(const ABT::reference_type& nodeRef) const;
+ };
+ using NodeTargetGroupMap =
+ opt::unordered_map<ABT::reference_type, GroupIdType, NodeTargetGroupHash>;
+
+ Memo(DebugInfo debugInfo,
+ const Metadata& metadata,
+ std::unique_ptr<LogicalPropsInterface> logicalPropsDerivation,
+ std::unique_ptr<CEInterface> ceDerivation);
+
+ const Group& getGroup(GroupIdType groupId) const;
+ Group& getGroup(GroupIdType groupId);
+ size_t getGroupCount() const;
+
+ std::pair<size_t, bool> findNodeInGroup(GroupIdType groupId, ABT::reference_type node) const;
+
+ ABT::reference_type getNode(MemoLogicalNodeId nodeMemoId) const;
+
+ void estimateCE(GroupIdType groupId);
+
+ MemoLogicalNodeId addNode(GroupIdVector groupVector,
+ ProjectionNameSet projections,
+ GroupIdType targetGroupId,
+ NodeIdSet& insertedNodeIds,
+ ABT n);
+
+ GroupIdType integrate(const ABT& node,
+ NodeTargetGroupMap targetGroupMap,
+ NodeIdSet& insertedNodeIds,
+ bool addExistingNodeWithNewChild = false);
+
+ void clearLogicalNodes(GroupIdType groupId);
+
+ const InputGroupsToNodeIdMap& getInputGroupsToNodeIdMap() const;
+
+ const DebugInfo& getDebugInfo() const;
+
+ void clear();
+
+ const Stats& getStats() const;
+ size_t getLogicalNodeCount() const;
+ size_t getPhysicalNodeCount() const;
+
+ const Metadata& getMetadata() const;
+ const CEInterface& getCEDerivation() const;
+
+private:
+ GroupIdType addGroup(ProjectionNameSet projections);
+
+ std::pair<MemoLogicalNodeId, bool> addNode(GroupIdType groupId, ABT n);
+
+ std::pair<MemoLogicalNodeId, bool> findNode(const GroupIdVector& groups, const ABT& node);
+
+ std::vector<std::unique_ptr<Group>> _groups;
+
+ // Used to find nodes using particular groups as inputs.
+ InputGroupsToNodeIdMap _inputGroupsToNodeIdMap;
+
+ NodeIdToInputGroupsMap _nodeIdToInputGroupsMap;
+
+ const Metadata& _metadata;
+ std::unique_ptr<LogicalPropsInterface> _logicalPropsDerivation;
+ std::unique_ptr<CEInterface> _ceDerivation;
+
+ const DebugInfo _debugInfo;
+
+ Stats _stats;
+};
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp b/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp
new file mode 100644
index 00000000000..fb3d9185220
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp
@@ -0,0 +1,407 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/cascades/physical_rewriter.h"
+
+#include "mongo/db/query/optimizer/cascades/enforcers.h"
+#include "mongo/db/query/optimizer/cascades/implementers.h"
+#include "mongo/db/query/optimizer/explain.h"
+#include "mongo/db/query/optimizer/utils/memo_utils.h"
+
+namespace mongo::optimizer::cascades {
+
+using namespace properties;
+
+/**
+ * Helper class used to check if two physical property sets are compatible by testing each
+ * constituent property for compatibility. This is used to check if a winner's circle entry can be
+ * reused.
+ */
+class PropCompatibleVisitor {
+public:
+ bool operator()(const PhysProperty&, const CollationRequirement& requiredProp) {
+ return collationsCompatible(
+ getPropertyConst<CollationRequirement>(_availableProps).getCollationSpec(),
+ requiredProp.getCollationSpec());
+ }
+
+ bool operator()(const PhysProperty&, const LimitSkipRequirement& requiredProp) {
+ const auto& available = getPropertyConst<LimitSkipRequirement>(_availableProps);
+ return available.getSkip() >= requiredProp.getSkip() &&
+ available.getAbsoluteLimit() <= requiredProp.getAbsoluteLimit();
+ }
+
+ bool operator()(const PhysProperty&, const ProjectionRequirement& requiredProp) {
+ const auto& availableProjections =
+ getPropertyConst<ProjectionRequirement>(_availableProps).getProjections();
+ // Do we have a projection superset (not necessarily strict superset)?
+ for (const ProjectionName& projectionName : requiredProp.getProjections().getVector()) {
+ if (!availableProjections.find(projectionName).second) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ bool operator()(const PhysProperty&, const IndexingRequirement& requiredProp) {
+ const auto& available = getPropertyConst<IndexingRequirement>(_availableProps);
+ return available.getIndexReqTarget() == requiredProp.getIndexReqTarget() &&
+ available.getNeedsRID() == requiredProp.getNeedsRID() &&
+ (available.getDedupRID() || !requiredProp.getDedupRID()) &&
+ available.getSatisfiedPartialIndexesGroupId() ==
+ requiredProp.getSatisfiedPartialIndexesGroupId();
+ }
+
+ template <class T>
+ bool operator()(const PhysProperty&, const T& requiredProp) {
+ return getPropertyConst<T>(_availableProps) == requiredProp;
+ }
+
+ static bool propertiesCompatible(const PhysProps& requiredProps,
+ const PhysProps& availableProps) {
+ if (requiredProps.size() != availableProps.size()) {
+ return false;
+ }
+
+ PropCompatibleVisitor visitor(availableProps);
+ for (const auto& [key, prop] : requiredProps) {
+ if (availableProps.find(key) == availableProps.cend() || !prop.visit(visitor)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+private:
+ PropCompatibleVisitor(const PhysProps& availableProp) : _availableProps(availableProp) {}
+
+ // We don't own this.
+ const PhysProps& _availableProps;
+};
+
+PhysicalRewriter::PhysicalRewriter(
+ Memo& memo,
+ const QueryHints& hints,
+ const opt::unordered_map<std::string, ProjectionName>& ridProjections,
+ const CostingInterface& costDerivation,
+ std::unique_ptr<LogicalRewriter>& logicalRewriter)
+ : _memo(memo),
+ _costDerivation(costDerivation),
+ _hints(hints),
+ _ridProjections(ridProjections),
+ _logicalRewriter(logicalRewriter) {}
+
+static void printCandidateInfo(const ABT& node,
+ const GroupIdType groupId,
+ const CostType nodeCost,
+ const ChildPropsType& childProps,
+ const PhysOptimizationResult& bestResult) {
+ std::cout
+ << "group: " << groupId << ", id: " << bestResult._index
+ << ", nodeCost: " << nodeCost.toString() << ", best cost: "
+ << (bestResult._nodeInfo ? bestResult._nodeInfo->_cost : CostType::kInfinity).toString()
+ << "\n";
+ std::cout << ExplainGenerator::explainPhysProps("Physical properties", bestResult._physProps)
+ << "\n";
+ std::cout << "Node: \n" << ExplainGenerator::explainV2(node) << "\n";
+
+ for (const auto& childProp : childProps) {
+ std::cout << ExplainGenerator::explainPhysProps("Child properties", childProp.second);
+ }
+}
+
+void PhysicalRewriter::costAndRetainBestNode(ABT node,
+ ChildPropsType childProps,
+ NodeCEMap nodeCEMap,
+ const GroupIdType groupId,
+ const PrefixId& prefixId,
+ PhysOptimizationResult& bestResult) {
+ const CostAndCE nodeCostAndCE =
+ _costDerivation.deriveCost(_memo, bestResult._physProps, node.ref(), childProps, nodeCEMap);
+ const CostType nodeCost = nodeCostAndCE._cost;
+ uassert(6624056, "Must get non-infinity cost for physical node.", !nodeCost.isInfinite());
+
+ if (_memo.getDebugInfo().hasDebugLevel(3)) {
+ std::cout << "Requesting optimization\n";
+ printCandidateInfo(node, groupId, nodeCost, childProps, bestResult);
+ }
+
+ const CostType childCostLimit =
+ bestResult._nodeInfo ? bestResult._nodeInfo->_cost : bestResult._costLimit;
+ const auto [success, cost] = optimizeChildren(nodeCost, childProps, prefixId, childCostLimit);
+ const bool improvement =
+ success && (!bestResult._nodeInfo || cost < bestResult._nodeInfo->_cost);
+
+ if (_memo.getDebugInfo().hasDebugLevel(3)) {
+ std::cout << (success ? (improvement ? "Improved" : "Did not improve")
+ : "Failed optimizing")
+ << "\n";
+ printCandidateInfo(node, groupId, nodeCost, childProps, bestResult);
+ }
+
+ PhysNodeInfo candidateNodeInfo{
+ unwrapConstFilter(std::move(node)), cost, nodeCost, nodeCostAndCE._ce};
+ const bool keepRejectedPlans = _hints._keepRejectedPlans;
+ if (improvement) {
+ if (keepRejectedPlans && bestResult._nodeInfo) {
+ bestResult._rejectedNodeInfo.push_back(std::move(*bestResult._nodeInfo));
+ }
+ bestResult._nodeInfo = std::move(candidateNodeInfo);
+ } else if (keepRejectedPlans) {
+ bestResult._rejectedNodeInfo.push_back(std::move(candidateNodeInfo));
+ }
+}
+
+/**
+ * Convert nodes from logical to physical memo delegators.
+ * Performs branch-and-bound search.
+ */
+std::pair<bool, CostType> PhysicalRewriter::optimizeChildren(const CostType nodeCost,
+ ChildPropsType childProps,
+ const PrefixId& prefixId,
+ const CostType costLimit) {
+ const bool disableBranchAndBound = _hints._disableBranchAndBound;
+
+ CostType totalCost = nodeCost;
+ if (costLimit < totalCost && !disableBranchAndBound) {
+ return {false, CostType::kInfinity};
+ }
+
+ for (auto& [node, props] : childProps) {
+ const GroupIdType groupId = node->cast<MemoLogicalDelegatorNode>()->getGroupId();
+
+ const CostType childCostLimit =
+ disableBranchAndBound ? CostType::kInfinity : (costLimit - totalCost);
+ auto optGroupResult = optimizeGroup(groupId, std::move(props), prefixId, childCostLimit);
+ if (!optGroupResult._success) {
+ return {false, CostType::kInfinity};
+ }
+
+ totalCost += optGroupResult._cost;
+ if (costLimit < totalCost && !disableBranchAndBound) {
+ return {false, CostType::kInfinity};
+ }
+
+ ABT optimizedChild =
+ make<MemoPhysicalDelegatorNode>(MemoPhysicalNodeId{groupId, optGroupResult._index});
+ std::swap(*node, optimizedChild);
+ }
+
+ return {true, totalCost};
+}
+
+PhysicalRewriter::OptimizeGroupResult::OptimizeGroupResult()
+ : _success(false), _index(0), _cost(CostType::kInfinity) {}
+
+PhysicalRewriter::OptimizeGroupResult::OptimizeGroupResult(const size_t index, const CostType cost)
+ : _success(true), _index(index), _cost(std::move(cost)) {
+ uassert(6624347,
+ "Cannot have successful optimization with infinite cost",
+ _cost < CostType::kInfinity);
+}
+
+PhysicalRewriter::OptimizeGroupResult PhysicalRewriter::optimizeGroup(const GroupIdType groupId,
+ PhysProps physProps,
+ PrefixId prefixId,
+ CostType costLimit) {
+ const size_t localPlanExplorationCount = ++_memo._stats._physPlanExplorationCount;
+ if (_memo.getDebugInfo().hasDebugLevel(2)) {
+ std::cout << "#" << localPlanExplorationCount << " Optimizing group " << groupId
+ << ", cost limit: " << costLimit.toString() << "\n";
+ std::cout << ExplainGenerator::explainPhysProps("Physical properties", physProps) << "\n";
+ }
+
+ Group& group = _memo.getGroup(groupId);
+ const LogicalProps& logicalProps = group._logicalProperties;
+ if (hasProperty<IndexingAvailability>(logicalProps) &&
+ !hasProperty<IndexingRequirement>(physProps)) {
+ // Re-optimize under complete scan indexing requirements.
+ setPropertyOverwrite(
+ physProps,
+ IndexingRequirement{
+ IndexReqTarget::Complete, false /*needRID*/, true /*dedupRID*/, groupId});
+ }
+
+ auto& physicalNodes = group._physicalNodes;
+ // Establish if we have found exact match of the physical properties in the winner's circle.
+ const auto [exactPropsIndex, hasExactProps] = physicalNodes.find(physProps);
+ // If true, we have found compatible (but not equal) props with cost under our cost limit.
+ bool hasCompatibleProps = false;
+
+ if (hasExactProps) {
+ PhysOptimizationResult& physNode = physicalNodes.at(exactPropsIndex);
+ if (!physNode.isOptimized()) {
+ // Currently optimizing under the same properties higher up the stack (recursive loop).
+ return {};
+ }
+ // At this point we have an optimized entry.
+
+ if (!physNode._nodeInfo) {
+ if (physNode._costLimit < costLimit) {
+ physNode.raiseCostLimit(costLimit);
+ // Fall through and continue optimizing.
+ } else {
+ // Previously failed to optimize under less strict cost limit.
+ return {};
+ }
+ } else if (costLimit < physNode._nodeInfo->_cost) {
+ // We have a stricter limit than our previous optimization's cost.
+ return {};
+ } else {
+ // Reuse result under identical properties.
+ if (_memo.getDebugInfo().hasDebugLevel(3)) {
+ std::cout << "Reusing winner's circle entry: group: " << groupId
+ << ", id: " << physNode._index
+ << ", cost: " << physNode._nodeInfo->_cost.toString()
+ << ", limit: " << costLimit.toString() << "\n";
+ std::cout << "Existing props: "
+ << ExplainGenerator::explainPhysProps("existing", physNode._physProps)
+ << "\n";
+ std::cout << "New props: " << ExplainGenerator::explainPhysProps("new", physProps)
+ << "\n";
+ std::cout << "Reused plan: "
+ << ExplainGenerator::explainV2(physNode._nodeInfo->_node) << "\n";
+ }
+ return {physNode._index, physNode._nodeInfo->_cost};
+ }
+ } else {
+ // Check winner's circle for compatible properties.
+ for (const auto& physNode : physicalNodes.getNodes()) {
+ _memo._stats._physMemoCheckCount++;
+
+ if (!physNode->_nodeInfo) {
+ continue;
+ }
+ // At this point we have an optimized entry.
+
+ if (costLimit < physNode->_nodeInfo->_cost) {
+ // Properties are not identical. Continue exploring even if limit was stricter.
+ continue;
+ }
+
+ if (!PropCompatibleVisitor::propertiesCompatible(physProps, physNode->_physProps)) {
+ // We are stricter that what is available.
+ continue;
+ }
+
+ if (physNode->_nodeInfo->_cost < costLimit) {
+ if (_memo.getDebugInfo().hasDebugLevel(3)) {
+ std::cout << "Reducing cost limit: group: " << groupId
+ << ", id: " << physNode->_index
+ << ", cost: " << physNode->_nodeInfo->_cost.toString()
+ << ", limit: " << costLimit.toString() << "\n";
+ std::cout << ExplainGenerator::explainPhysProps("Existing props",
+ physNode->_physProps)
+ << "\n";
+ std::cout << ExplainGenerator::explainPhysProps("New props", physProps) << "\n";
+ }
+
+ // Reduce cost limit result under compatible properties.
+ hasCompatibleProps = true;
+ costLimit = physNode->_nodeInfo->_cost;
+ }
+ }
+ }
+
+ // If found an exact match for properties, re-use entry and continue optimizing under higher
+ // cost limit. Otherwise create with a new entry for the current properties.
+ PhysOptimizationResult& bestResult = hasExactProps
+ ? physicalNodes.at(exactPropsIndex)
+ : physicalNodes.addOptimizationResult(physProps, costLimit);
+
+ // Enforcement rewrites run just once, and are independent of the logical nodes.
+ if (hasProperty<ProjectionRequirement>(bestResult._physProps)) {
+ // Verify properties can be enforced and add enforcers if necessary.
+ addEnforcers(groupId,
+ _memo.getMetadata(),
+ prefixId,
+ bestResult._queue,
+ bestResult._physProps,
+ logicalProps);
+ }
+
+ // Iterate until we perform all logical for the group and physical rewrites for our best plan.
+ const OrderPreservingABTSet& logicalNodes = group._logicalNodes;
+ while (bestResult._lastImplementedNodePos < logicalNodes.size() || !bestResult._queue.empty()) {
+ if (_logicalRewriter) {
+ // Attempt to perform logical rewrites.
+ _logicalRewriter->rewriteGroup(groupId);
+ }
+
+ // Add rewrites to convert logical into physical nodes. Only add rewrites for newly added
+ // logical nodes.
+ addImplementers(
+ _memo, _hints, _ridProjections, prefixId, bestResult, logicalProps, logicalNodes);
+
+ // Perform physical rewrites, use branch-and-bound.
+ while (!bestResult._queue.empty()) {
+ PhysRewriteEntry rewrite = std::move(*bestResult._queue.top());
+ bestResult._queue.pop();
+
+ NodeCEMap nodeCEMap = std::move(rewrite._nodeCEMap);
+ if (nodeCEMap.empty()) {
+ nodeCEMap.emplace(
+ rewrite._node.cast<Node>(),
+ getPropertyConst<CardinalityEstimate>(logicalProps).getEstimate());
+ }
+
+ costAndRetainBestNode(std::move(rewrite._node),
+ std::move(rewrite._childProps),
+ std::move(nodeCEMap),
+ groupId,
+ prefixId,
+ bestResult);
+ }
+ }
+
+ uassert(6624128, "Result is not optimized!", bestResult.isOptimized());
+ if (!bestResult._nodeInfo) {
+ uassert(6624348,
+ "Must optimize successfully if found compatible properties!",
+ !hasCompatibleProps);
+ return {};
+ }
+
+ // We have a successful rewrite.
+ if (_memo.getDebugInfo().hasDebugLevel(2)) {
+ std::cout << "#" << localPlanExplorationCount << " Optimized group: " << groupId
+ << ", id: " << bestResult._index
+ << ", cost: " << bestResult._nodeInfo->_cost.toString() << "\n";
+ std::cout << ExplainGenerator::explainPhysProps("Physical properties",
+ bestResult._physProps)
+ << "\n";
+ std::cout << "Node: \n"
+ << ExplainGenerator::explainV2(
+ bestResult._nodeInfo->_node, false /*displayProperties*/, &_memo);
+ }
+
+ return {bestResult._index, bestResult._nodeInfo->_cost};
+}
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/physical_rewriter.h b/src/mongo/db/query/optimizer/cascades/physical_rewriter.h
new file mode 100644
index 00000000000..1463c4a7858
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/physical_rewriter.h
@@ -0,0 +1,93 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/cascades/logical_rewriter.h"
+#include "mongo/db/query/optimizer/cascades/memo.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer::cascades {
+
+class PhysicalRewriter {
+ friend class PropEnforcerVisitor;
+ friend class ImplementationVisitor;
+
+public:
+ struct OptimizeGroupResult {
+ OptimizeGroupResult();
+ OptimizeGroupResult(size_t index, CostType cost);
+
+ OptimizeGroupResult(const OptimizeGroupResult& other) = default;
+ OptimizeGroupResult(OptimizeGroupResult&& other) = default;
+
+ bool _success;
+ size_t _index;
+ CostType _cost;
+ };
+
+ PhysicalRewriter(Memo& memo,
+ const QueryHints& hints,
+ const opt::unordered_map<std::string, ProjectionName>& ridProjections,
+ const CostingInterface& costDerivation,
+ std::unique_ptr<LogicalRewriter>& logicalRewriter);
+
+ /**
+ * Main entry point for physical optimization.
+ * Optimize a logical plan rooted at a RootNode, and return an index into the winner's circle if
+ * successful.
+ */
+ OptimizeGroupResult optimizeGroup(GroupIdType groupId,
+ properties::PhysProps physProps,
+ PrefixId prefixId,
+ CostType costLimit);
+
+private:
+ void costAndRetainBestNode(ABT node,
+ ChildPropsType childProps,
+ NodeCEMap nodeCEMap,
+ GroupIdType groupId,
+ const PrefixId& prefixId,
+ PhysOptimizationResult& bestResult);
+
+ std::pair<bool, CostType> optimizeChildren(CostType nodeCost,
+ ChildPropsType childProps,
+ const PrefixId& prefixId,
+ CostType costLimit);
+
+ // We don't own any of this.
+ Memo& _memo;
+ const CostingInterface& _costDerivation;
+ const QueryHints& _hints;
+ const opt::unordered_map<std::string, ProjectionName>& _ridProjections;
+ // If set, we'll perform logical rewrites as part of OptimizeGroup().
+ std::unique_ptr<LogicalRewriter>& _logicalRewriter;
+};
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/rewrite_queues.cpp b/src/mongo/db/query/optimizer/cascades/rewrite_queues.cpp
new file mode 100644
index 00000000000..8d0ef6809b7
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/rewrite_queues.cpp
@@ -0,0 +1,76 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/cascades/rewrite_queues.h"
+#include "mongo/db/query/optimizer/utils/memo_utils.h"
+#include <mongo/db/query/optimizer/defs.h>
+
+namespace mongo::optimizer::cascades {
+
+LogicalRewriteEntry::LogicalRewriteEntry(const double priority,
+ LogicalRewriteType type,
+ MemoLogicalNodeId nodeId)
+ : _priority(priority), _type(type), _nodeId(nodeId) {}
+
+bool LogicalRewriteEntryComparator::operator()(
+ const std::unique_ptr<LogicalRewriteEntry>& x,
+ const std::unique_ptr<LogicalRewriteEntry>& y) const {
+ // Lower numerical priority is considered last (and thus de-queued first).
+ if (x->_priority > y->_priority) {
+ return true;
+ } else if (x->_priority < y->_priority) {
+ return false;
+ }
+
+ // Make sure entries in the queue are consistently ordered.
+ if (x->_nodeId._groupId < y->_nodeId._groupId) {
+ return true;
+ } else if (x->_nodeId._groupId > y->_nodeId._groupId) {
+ return false;
+ }
+ return x->_nodeId._index < y->_nodeId._index;
+}
+
+void optimizeChildrenNoAssert(PhysRewriteQueue& queue,
+ const double priority,
+ ABT node,
+ ChildPropsType childProps,
+ NodeCEMap nodeCEMap) {
+ queue.emplace(std::make_unique<PhysRewriteEntry>(
+ priority, std::move(node), std::move(childProps), std::move(nodeCEMap)));
+}
+
+void optimizeUnderNewProperties(cascades::PhysRewriteQueue& queue,
+ const double priority,
+ ABT child,
+ properties::PhysProps props) {
+ optimizeChild<FilterNode>(queue, priority, wrapConstFilter(std::move(child)), std::move(props));
+}
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/rewrite_queues.h b/src/mongo/db/query/optimizer/cascades/rewrite_queues.h
new file mode 100644
index 00000000000..a64b7d000a1
--- /dev/null
+++ b/src/mongo/db/query/optimizer/cascades/rewrite_queues.h
@@ -0,0 +1,149 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <queue>
+
+#include "mongo/db/query/optimizer/cascades/logical_rewriter_rules.h"
+#include "mongo/db/query/optimizer/node_defs.h"
+
+namespace mongo::optimizer::cascades {
+
+/**
+ * Keeps track of candidate physical rewrites.
+ */
+struct LogicalRewriteEntry {
+ LogicalRewriteEntry(double priority, LogicalRewriteType type, MemoLogicalNodeId nodeId);
+
+ LogicalRewriteEntry() = delete;
+ LogicalRewriteEntry(const LogicalRewriteEntry& other) = delete;
+ LogicalRewriteEntry(LogicalRewriteEntry&& other) = default;
+
+ // Numerically lower priority gets applied first.
+ double _priority;
+
+ LogicalRewriteType _type;
+ MemoLogicalNodeId _nodeId;
+};
+
+struct LogicalRewriteEntryComparator {
+ bool operator()(const std::unique_ptr<LogicalRewriteEntry>& x,
+ const std::unique_ptr<LogicalRewriteEntry>& y) const;
+};
+
+using LogicalRewriteQueue = std::priority_queue<std::unique_ptr<LogicalRewriteEntry>,
+ std::vector<std::unique_ptr<LogicalRewriteEntry>>,
+ LogicalRewriteEntryComparator>;
+
+/**
+ * For now all physical rules use the same priority.
+ * TODO: use specific priorities (may depend on node parameters).
+ */
+static constexpr double kDefaultPriority = 10.0;
+
+/**
+ * Keeps track of candidate physical rewrites.
+ */
+struct PhysRewriteEntry {
+ PhysRewriteEntry(double priority, ABT node, ChildPropsType childProps, NodeCEMap nodeCEMap);
+
+ PhysRewriteEntry() = delete;
+ PhysRewriteEntry(const PhysRewriteEntry& other) = delete;
+ PhysRewriteEntry(PhysRewriteEntry&& other) = default;
+
+ // Numerically lower priority gets applied first.
+ double _priority;
+
+ ABT _node;
+ ChildPropsType _childProps;
+
+ NodeCEMap _nodeCEMap;
+};
+
+struct PhysRewriteEntryComparator {
+ bool operator()(const std::unique_ptr<PhysRewriteEntry>& x,
+ const std::unique_ptr<PhysRewriteEntry>& y) const;
+};
+
+using PhysRewriteQueue = std::priority_queue<std::unique_ptr<PhysRewriteEntry>,
+ std::vector<std::unique_ptr<PhysRewriteEntry>>,
+ PhysRewriteEntryComparator>;
+
+void optimizeChildrenNoAssert(PhysRewriteQueue& queue,
+ double priority,
+ ABT node,
+ ChildPropsType childProps,
+ NodeCEMap nodeCEMap = {});
+
+template <class T>
+static void optimizeChildren(PhysRewriteQueue& queue,
+ double priority,
+ ABT node,
+ ChildPropsType childProps) {
+ static_assert(canBePhysicalNode<T>(), "Can only optimize a physical node.");
+ optimizeChildrenNoAssert(queue, priority, std::move(node), std::move(childProps));
+}
+
+template <class T>
+static void optimizeChild(PhysRewriteQueue& queue,
+ double priority,
+ ABT node,
+ properties::PhysProps childProps) {
+ ABT& childRef = node.cast<T>()->getChild();
+ optimizeChildren<T>(
+ queue, priority, std::move(node), ChildPropsType{{&childRef, std::move(childProps)}});
+}
+
+template <class T>
+static void optimizeChild(PhysRewriteQueue& queue, const double priority, ABT node) {
+ optimizeChildren<T>(queue, priority, std::move(node), {});
+}
+
+void optimizeUnderNewProperties(PhysRewriteQueue& queue,
+ double priority,
+ ABT child,
+ properties::PhysProps props);
+
+template <class T>
+static void optimizeChildren(PhysRewriteQueue& queue,
+ double priority,
+ ABT node,
+ properties::PhysProps leftProps,
+ properties::PhysProps rightProps) {
+ ABT& leftChildRef = node.cast<T>()->getLeftChild();
+ ABT& rightChildRef = node.cast<T>()->getRightChild();
+ optimizeChildren<T>(
+ queue,
+ priority,
+ std::move(node),
+ {{&leftChildRef, std::move(leftProps)}, {&rightChildRef, std::move(rightProps)}});
+}
+
+} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/containers.h b/src/mongo/db/query/optimizer/containers.h
new file mode 100644
index 00000000000..5a9c11e9c12
--- /dev/null
+++ b/src/mongo/db/query/optimizer/containers.h
@@ -0,0 +1,81 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "mongo/stdx/trusted_hasher.h"
+#include "mongo/stdx/unordered_map.h"
+#include "mongo/stdx/unordered_set.h"
+
+namespace mongo::optimizer::opt {
+
+namespace {
+enum class ContainerImpl { STD, STDX };
+
+// For debugging, switch between STD and STDX containers.
+static constexpr ContainerImpl kContainerImpl = ContainerImpl::STDX;
+
+template <ContainerImpl>
+struct OptContainers {};
+
+template <>
+struct OptContainers<ContainerImpl::STDX> {
+ template <class K>
+ using Hasher = mongo::DefaultHasher<K>;
+
+ template <class T, class H, typename... Args>
+ using unordered_set = stdx::unordered_set<T, H, Args...>;
+ template <class K, class V, class H, typename... Args>
+ using unordered_map = stdx::unordered_map<K, V, H, Args...>;
+};
+
+template <>
+struct OptContainers<ContainerImpl::STD> {
+ template <class K>
+ using Hasher = std::hash<K>;
+
+ template <class T, class H, typename... Args>
+ using unordered_set = std::unordered_set<T, H, Args...>; // NOLINT
+ template <class K, class V, class H, typename... Args>
+ using unordered_map = std::unordered_map<K, V, H, Args...>; // NOLINT
+};
+
+using ActiveContainers = OptContainers<kContainerImpl>;
+} // namespace
+
+template <class T, class H = ActiveContainers::Hasher<T>, typename... Args>
+using unordered_set = ActiveContainers::unordered_set<T, H, Args...>;
+
+template <class K, class V, class H = ActiveContainers::Hasher<K>, typename... Args>
+using unordered_map = ActiveContainers::unordered_map<K, V, H, Args...>;
+
+} // namespace mongo::optimizer::opt
diff --git a/src/mongo/db/query/optimizer/defs.cpp b/src/mongo/db/query/optimizer/defs.cpp
new file mode 100644
index 00000000000..9949bb07ab6
--- /dev/null
+++ b/src/mongo/db/query/optimizer/defs.cpp
@@ -0,0 +1,257 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/defs.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+#include "mongo/util/assert_util.h"
+
+namespace mongo::optimizer {
+
+static opt::unordered_map<ProjectionName, size_t> createMapFromVector(
+ const ProjectionNameVector& v) {
+ opt::unordered_map<ProjectionName, size_t> result;
+ for (size_t i = 0; i < v.size(); i++) {
+ result.emplace(v.at(i), i);
+ }
+ return result;
+}
+
+ProjectionNameOrderPreservingSet::ProjectionNameOrderPreservingSet(ProjectionNameVector v)
+ : _map(createMapFromVector(v)), _vector(std::move(v)) {}
+
+ProjectionNameOrderPreservingSet::ProjectionNameOrderPreservingSet(
+ const ProjectionNameOrderPreservingSet& other)
+ : _map(other._map), _vector(other._vector) {}
+
+ProjectionNameOrderPreservingSet::ProjectionNameOrderPreservingSet(
+ ProjectionNameOrderPreservingSet&& other) noexcept
+ : _map(std::move(other._map)), _vector(std::move(other._vector)) {}
+
+bool ProjectionNameOrderPreservingSet::operator==(
+ const ProjectionNameOrderPreservingSet& other) const {
+ return _vector == other._vector;
+}
+
+std::pair<size_t, bool> ProjectionNameOrderPreservingSet::emplace_back(
+ ProjectionName projectionName) {
+ auto [index, found] = find(projectionName);
+ if (found) {
+ return {index, false};
+ }
+
+ const size_t id = _vector.size();
+ _vector.emplace_back(std::move(projectionName));
+ _map.emplace(_vector.back(), id);
+ return {id, true};
+}
+
+std::pair<size_t, bool> ProjectionNameOrderPreservingSet::find(
+ const ProjectionName& projectionName) const {
+ auto it = _map.find(projectionName);
+ if (it == _map.end()) {
+ return {0, false};
+ }
+
+ return {it->second, true};
+}
+
+bool ProjectionNameOrderPreservingSet::erase(const ProjectionName& projectionName) {
+ auto [index, found] = find(projectionName);
+ if (!found) {
+ return false;
+ }
+
+ if (index < _vector.size() - 1) {
+ // Repoint map.
+ _map.at(_vector.back()) = index;
+ // Fill gap with last element.
+ _vector.at(index) = std::move(_vector.back());
+ }
+
+ _map.erase(projectionName);
+ _vector.resize(_vector.size() - 1);
+
+ return true;
+}
+
+bool ProjectionNameOrderPreservingSet::isEqualIgnoreOrder(
+ const ProjectionNameOrderPreservingSet& other) const {
+ size_t numMatches = 0;
+ for (const auto& projectionName : _vector) {
+ if (other.find(projectionName).second) {
+ numMatches++;
+ } else {
+ return false;
+ }
+ }
+
+ return numMatches == other._vector.size();
+}
+
+const ProjectionNameVector& ProjectionNameOrderPreservingSet::getVector() const {
+ return _vector;
+}
+
+bool FieldProjectionMap::operator==(const FieldProjectionMap& other) const {
+ return _ridProjection == other._ridProjection && _rootProjection == other._rootProjection &&
+ _fieldProjections == other._fieldProjections;
+}
+
+bool MemoLogicalNodeId::operator==(const MemoLogicalNodeId& other) const {
+ return _groupId == other._groupId && _index == other._index;
+}
+
+size_t NodeIdHash::operator()(const MemoLogicalNodeId& id) const {
+ size_t result = 17;
+ updateHash(result, std::hash<GroupIdType>()(id._groupId));
+ updateHash(result, std::hash<size_t>()(id._index));
+ return result;
+}
+
+bool MemoPhysicalNodeId::operator==(const MemoPhysicalNodeId& other) const {
+ return _groupId == other._groupId && _index == other._index;
+}
+
+DebugInfo DebugInfo::kDefaultForTests =
+ DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, DebugInfo::kIterationLimitForTests);
+DebugInfo DebugInfo::kDefaultForProd = DebugInfo(false, 0, -1);
+
+DebugInfo::DebugInfo(const bool debugMode, const int debugLevel, const int iterationLimit)
+ : _debugMode(debugMode), _debugLevel(debugLevel), _iterationLimit(iterationLimit) {}
+
+bool DebugInfo::isDebugMode() const {
+ return _debugMode;
+}
+
+bool DebugInfo::hasDebugLevel(const int debugLevel) const {
+ return _debugLevel >= debugLevel;
+}
+
+bool DebugInfo::exceedsIterationLimit(const int iterations) const {
+ return _iterationLimit >= 0 && iterations > _iterationLimit;
+}
+
+CostType CostType::kInfinity = CostType(true /*isInfinite*/, 0.0);
+CostType CostType::kZero = CostType(false /*isInfinite*/, 0.0);
+
+CostType::CostType(const bool isInfinite, const double cost)
+ : _isInfinite(isInfinite), _cost(cost) {
+ uassert(6624346, "Cost is negative", _cost >= 0.0);
+}
+
+bool CostType::operator==(const CostType& other) const {
+ return _isInfinite == other._isInfinite && (_isInfinite || _cost == other._cost);
+}
+
+bool CostType::operator!=(const CostType& other) const {
+ return !(*this == other);
+}
+
+bool CostType::operator<(const CostType& other) const {
+ return !_isInfinite && (other._isInfinite || _cost < other._cost);
+}
+
+CostType CostType::operator+(const CostType& other) const {
+ return (_isInfinite || other._isInfinite) ? kInfinity : fromDouble(_cost + other._cost);
+}
+
+CostType CostType::operator-(const CostType& other) const {
+ uassert(6624001, "Cannot subtract an infinite cost", other != kInfinity);
+ return _isInfinite ? kInfinity : fromDouble(_cost - other._cost);
+}
+
+CostType& CostType::operator+=(const CostType& other) {
+ *this = (*this + other);
+ return *this;
+}
+
+CostType CostType::fromDouble(const double cost) {
+ uassert(8423327, "Invalid cost.", !std::isnan(cost) && cost >= 0.0);
+ return CostType(false /*isInfinite*/, cost);
+}
+
+std::string CostType::toString() const {
+ std::ostringstream os;
+ if (_isInfinite) {
+ os << "{Infinite cost}";
+ } else {
+ os << _cost;
+ }
+ return os.str();
+}
+
+double CostType::getCost() const {
+ uassert(6624002, "Attempted to coerce infinite cost to a double", !_isInfinite);
+ return _cost;
+}
+
+bool CostType::isInfinite() const {
+ return _isInfinite;
+}
+
+CollationOp reverseCollationOp(const CollationOp op) {
+ switch (op) {
+ case CollationOp::Ascending:
+ return CollationOp::Descending;
+ case CollationOp::Descending:
+ return CollationOp::Ascending;
+ case CollationOp::Clustered:
+ return CollationOp::Clustered;
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+}
+
+bool collationOpsCompatible(const CollationOp availableOp, const CollationOp requiredOp) {
+ return requiredOp == CollationOp::Clustered || requiredOp == availableOp;
+}
+
+bool collationsCompatible(const ProjectionCollationSpec& available,
+ const ProjectionCollationSpec& required) {
+ // Check if required is more restrictive than available. If yes, reject.
+ if (available.size() < required.size()) {
+ return false;
+ }
+
+ for (size_t i = 0; i < required.size(); i++) {
+ const auto& requiredEntry = required.at(i);
+ const auto& availableEntry = available.at(i);
+
+ if (requiredEntry.first != availableEntry.first ||
+ !collationOpsCompatible(availableEntry.second, requiredEntry.second)) {
+ return false;
+ }
+ }
+
+ // Available is at least as restrictive as required.
+ return true;
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/defs.h b/src/mongo/db/query/optimizer/defs.h
new file mode 100644
index 00000000000..12c153ef4e9
--- /dev/null
+++ b/src/mongo/db/query/optimizer/defs.h
@@ -0,0 +1,254 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <set>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "mongo/db/query/optimizer/containers.h"
+#include "mongo/db/query/optimizer/utils/printable_enum.h"
+
+
+namespace mongo::optimizer {
+
+using FieldNameType = std::string;
+using FieldPathType = std::vector<FieldNameType>;
+
+using CollectionNameType = std::string;
+
+using ProjectionName = std::string;
+using ProjectionNameSet = opt::unordered_set<ProjectionName>;
+using ProjectionNameOrderedSet = std::set<ProjectionName>;
+using ProjectionNameVector = std::vector<ProjectionName>;
+using ProjectionRenames = opt::unordered_map<ProjectionName, ProjectionName>;
+
+class ProjectionNameOrderPreservingSet {
+public:
+ ProjectionNameOrderPreservingSet() = default;
+ ProjectionNameOrderPreservingSet(ProjectionNameVector v);
+
+ ProjectionNameOrderPreservingSet(const ProjectionNameOrderPreservingSet& other);
+ ProjectionNameOrderPreservingSet(ProjectionNameOrderPreservingSet&& other) noexcept;
+
+ bool operator==(const ProjectionNameOrderPreservingSet& other) const;
+
+ std::pair<size_t, bool> emplace_back(ProjectionName projectionName);
+ std::pair<size_t, bool> find(const ProjectionName& projectionName) const;
+ bool erase(const ProjectionName& projectionName);
+
+ bool isEqualIgnoreOrder(const ProjectionNameOrderPreservingSet& other) const;
+
+ const ProjectionNameVector& getVector() const;
+
+private:
+ opt::unordered_map<ProjectionName, size_t> _map;
+ ProjectionNameVector _vector;
+};
+
+#define INDEXREQTARGET_NAMES(F) \
+ F(Index) \
+ F(Seek) \
+ F(Complete)
+
+MAKE_PRINTABLE_ENUM(IndexReqTarget, INDEXREQTARGET_NAMES);
+MAKE_PRINTABLE_ENUM_STRING_ARRAY(IndexReqTargetEnum, IndexReqTarget, INDEXREQTARGET_NAMES);
+#undef INDEXREQTARGET_NAMES
+
+#define DISTRIBUTIONTYPE_NAMES(F) \
+ F(Centralized) \
+ F(Replicated) \
+ F(RoundRobin) \
+ F(HashPartitioning) \
+ F(RangePartitioning) \
+ F(UnknownPartitioning)
+
+MAKE_PRINTABLE_ENUM(DistributionType, DISTRIBUTIONTYPE_NAMES);
+MAKE_PRINTABLE_ENUM_STRING_ARRAY(DistributionTypeEnum, DistributionType, DISTRIBUTIONTYPE_NAMES);
+#undef DISTRIBUTIONTYPE_NAMES
+
+// In case of covering scan, index, or fetch, specify names of bound projections for each field.
+// Also optionally specify if applicable the rid and record (root) projections.
+struct FieldProjectionMap {
+ ProjectionName _ridProjection;
+ ProjectionName _rootProjection;
+ opt::unordered_map<FieldNameType, ProjectionName> _fieldProjections;
+
+ bool operator==(const FieldProjectionMap& other) const;
+};
+
+// Used to generate field names encoding index keys for covered indexes.
+static constexpr const char* kIndexKeyPrefix = "<indexKey>";
+
+/**
+ * Memo-related types.
+ */
+using GroupIdType = int64_t;
+
+// Logical node id.
+struct MemoLogicalNodeId {
+ GroupIdType _groupId;
+ size_t _index;
+
+ bool operator==(const MemoLogicalNodeId& other) const;
+};
+
+struct NodeIdHash {
+ size_t operator()(const MemoLogicalNodeId& id) const;
+};
+using NodeIdSet = opt::unordered_set<MemoLogicalNodeId, NodeIdHash>;
+
+// Physical node id.
+struct MemoPhysicalNodeId {
+ GroupIdType _groupId;
+ size_t _index;
+
+ bool operator==(const MemoPhysicalNodeId& other) const;
+};
+
+class DebugInfo {
+public:
+ static const int kIterationLimitForTests = 10000;
+ static const int kDefaultDebugLevelForTests = 1;
+
+ static DebugInfo kDefaultForTests;
+ static DebugInfo kDefaultForProd;
+
+ DebugInfo(bool debugMode, int debugLevel, int iterationLimit);
+
+ bool isDebugMode() const;
+
+ bool hasDebugLevel(int debugLevel) const;
+
+ bool exceedsIterationLimit(int iterations) const;
+
+private:
+ // Are we in debug mode? Can we do additional logging, etc?
+ const bool _debugMode;
+
+ const int _debugLevel;
+
+ // Maximum number of rewrite iterations.
+ const int _iterationLimit;
+};
+
+using CEType = double;
+using SelectivityType = double;
+
+class CostType {
+public:
+ static CostType kInfinity;
+ static CostType kZero;
+
+ static CostType fromDouble(double cost);
+
+ CostType(const CostType& other) = default;
+ CostType(CostType&& other) = default;
+ CostType& operator=(const CostType& other) = default;
+
+ bool operator==(const CostType& other) const;
+ bool operator!=(const CostType& other) const;
+ bool operator<(const CostType& other) const;
+
+ CostType operator+(const CostType& other) const;
+ CostType operator-(const CostType& other) const;
+ CostType& operator+=(const CostType& other);
+
+ std::string toString() const;
+
+ /**
+ * Returns the cost as a double, or asserts if infinite.
+ */
+ double getCost() const;
+
+ bool isInfinite() const;
+
+private:
+ CostType(bool isInfinite, double cost);
+
+ bool _isInfinite;
+ double _cost;
+};
+
+struct CostAndCE {
+ CostType _cost;
+ CEType _ce;
+};
+
+#define COLLATIONOP_OPNAMES(F) \
+ F(Ascending) \
+ F(Descending) \
+ F(Clustered)
+
+MAKE_PRINTABLE_ENUM(CollationOp, COLLATIONOP_OPNAMES);
+MAKE_PRINTABLE_ENUM_STRING_ARRAY(CollationOpEnum, CollationOp, COLLATIONOP_OPNAMES);
+#undef PATHSYNTAX_OPNAMES
+
+using ProjectionCollationEntry = std::pair<ProjectionName, CollationOp>;
+using ProjectionCollationSpec = std::vector<ProjectionCollationEntry>;
+
+CollationOp reverseCollationOp(CollationOp op);
+
+bool collationOpsCompatible(CollationOp availableOp, CollationOp requiredOp);
+bool collationsCompatible(const ProjectionCollationSpec& available,
+ const ProjectionCollationSpec& required);
+
+enum class DisableIndexOptions {
+ Enabled, // All types of indexes are enabled.
+ DisableAll, // Disable all indexes.
+ DisablePartialOnly // Only disable partial indexes.
+};
+
+struct QueryHints {
+ // Disable full collection scans.
+ bool _disableScan = false;
+
+ // Disable index scans.
+ DisableIndexOptions _disableIndexes = DisableIndexOptions::Enabled;
+
+ // Disable placing a hash-join during RIDIntersect implementation.
+ bool _disableHashJoinRIDIntersect = false;
+
+ // Disable placing a merge-join during RIDIntersect implementation.
+ bool _disableMergeJoinRIDIntersect = false;
+
+ // Disable placing a group-by and union based RIDIntersect implementation.
+ bool _disableGroupByAndUnionRIDIntersect = false;
+
+ // If set keep track of rejected plans in the memo.
+ bool _keepRejectedPlans = false;
+
+ // Disable Cascades branch-and-bound strategy, and fully evaluate all plans. Used in conjunction
+ // with keeping rejected plans.
+ bool _disableBranchAndBound = false;
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/explain.cpp b/src/mongo/db/query/optimizer/explain.cpp
new file mode 100644
index 00000000000..064ffc1c0aa
--- /dev/null
+++ b/src/mongo/db/query/optimizer/explain.cpp
@@ -0,0 +1,2411 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/explain.h"
+
+#include "mongo/db/exec/sbe/values/bson.h"
+#include "mongo/db/query/optimizer/cascades/memo.h"
+#include "mongo/db/query/optimizer/node.h"
+#include "mongo/util/assert_util.h"
+
+namespace mongo::optimizer {
+
+BSONObj ABTPrinter::explainBSON() const {
+ auto [tag, val] = optimizer::ExplainGenerator::explainBSON(
+ _abtTree, true /*displayProperties*/, nullptr /*Memo*/, _nodeToPropsMap);
+ uassert(6624070, "Expected an object", tag == sbe::value::TypeTags::Object);
+ sbe::value::ValueGuard vg(tag, val);
+
+ BSONObjBuilder builder;
+ sbe::bson::convertToBsonObj(builder, sbe::value::getObjectView(val));
+ return builder.done().getOwned();
+}
+
+enum class ExplainVersion { V1, V2, V3, Vmax };
+
+bool constexpr operator<(const ExplainVersion v1, const ExplainVersion v2) {
+ return static_cast<int>(v1) < static_cast<int>(v2);
+}
+bool constexpr operator<=(const ExplainVersion v1, const ExplainVersion v2) {
+ return static_cast<int>(v1) <= static_cast<int>(v2);
+}
+bool constexpr operator>(const ExplainVersion v1, const ExplainVersion v2) {
+ return static_cast<int>(v1) > static_cast<int>(v2);
+}
+bool constexpr operator>=(const ExplainVersion v1, const ExplainVersion v2) {
+ return static_cast<int>(v1) >= static_cast<int>(v2);
+}
+
+static constexpr ExplainVersion kDefaultExplainVersion = ExplainVersion::V1;
+
+enum class CommandType { Indent, Unindent, AddLine };
+
+struct CommandStruct {
+ CommandStruct() = default;
+ CommandStruct(const CommandType type, std::string str) : _type(type), _str(std::move(str)) {}
+
+ CommandType _type;
+ std::string _str;
+};
+
+using CommandVector = std::vector<CommandStruct>;
+
+template <const ExplainVersion version = kDefaultExplainVersion>
+class ExplainPrinterImpl {
+public:
+ ExplainPrinterImpl()
+ : _cmd(),
+ _os(),
+ _osDirty(false),
+ _indentCount(0),
+ _childrenRemaining(0),
+ _cmdInsertPos(-1) {}
+
+ ~ExplainPrinterImpl() {
+ uassert(6624003, "Unmatched indentations", _indentCount == 0);
+ uassert(6624004, "Incorrect child count mark", _childrenRemaining == 0);
+ }
+
+ ExplainPrinterImpl(const ExplainPrinterImpl& other) = delete;
+ ExplainPrinterImpl& operator=(const ExplainPrinterImpl& other) = delete;
+
+ explicit ExplainPrinterImpl(const std::string& initialStr) : ExplainPrinterImpl() {
+ print(initialStr);
+ }
+
+ ExplainPrinterImpl(ExplainPrinterImpl&& other) noexcept
+ : _cmd(std::move(other._cmd)),
+ _os(std::move(other._os)),
+ _osDirty(other._osDirty),
+ _indentCount(other._indentCount),
+ _childrenRemaining(other._childrenRemaining),
+ _cmdInsertPos(other._cmdInsertPos) {}
+
+ template <class T>
+ ExplainPrinterImpl& print(const T& t) {
+ _os << t;
+ _osDirty = true;
+ return *this;
+ }
+
+ /**
+ * Here and below: "other" printer(s) may be siphoned out.
+ */
+ ExplainPrinterImpl& print(ExplainPrinterImpl& other) {
+ return print(other, false /*singleLevel*/);
+ }
+
+ template <class P>
+ ExplainPrinterImpl& printSingleLevel(P& other, const std::string& singleLevelSpacer = " ") {
+ return print(other, true /*singleLevel*/, singleLevelSpacer);
+ }
+
+ ExplainPrinterImpl& printAppend(ExplainPrinterImpl& other) {
+ // Ignore append
+ return print(other);
+ }
+
+ ExplainPrinterImpl& print(std::vector<ExplainPrinterImpl>& other) {
+ for (auto&& element : other) {
+ print(element);
+ }
+ return *this;
+ }
+
+ ExplainPrinterImpl& printAppend(std::vector<ExplainPrinterImpl>& other) {
+ // Ignore append.
+ return print(other);
+ }
+
+ ExplainPrinterImpl& setChildCount(const size_t childCount) {
+ if (version > ExplainVersion::V1) {
+ _childrenRemaining = childCount;
+ indent("");
+ for (int i = 0; i < _childrenRemaining - 1; i++) {
+ indent("|");
+ }
+ }
+ return *this;
+ }
+
+ ExplainPrinterImpl& maybeReverse() {
+ if (version > ExplainVersion::V1) {
+ _cmdInsertPos = _cmd.size();
+ }
+ return *this;
+ }
+
+ ExplainPrinterImpl& fieldName(const std::string& name,
+ const ExplainVersion minVersion = ExplainVersion::V1,
+ const ExplainVersion maxVersion = ExplainVersion::Vmax) {
+ if (minVersion <= version && maxVersion >= version) {
+ print(name);
+ print(": ");
+ }
+ return *this;
+ }
+
+ ExplainPrinterImpl& separator(const std::string& separator) {
+ return print(separator);
+ }
+
+ std::string str() {
+ newLine();
+
+ std::ostringstream os;
+ std::vector<std::string> linePrefix;
+
+ bool firstAddLine = true;
+ for (const auto& cmd : _cmd) {
+ switch (cmd._type) {
+ case CommandType::Indent:
+ linePrefix.push_back(cmd._str);
+ break;
+
+ case CommandType::Unindent: {
+ linePrefix.pop_back();
+ break;
+ }
+
+ case CommandType::AddLine: {
+ for (const std::string& element : linePrefix) {
+ if (!element.empty()) {
+ os << element << ((version == ExplainVersion::V1) ? " " : " ");
+ }
+ }
+ os << cmd._str << "\n";
+
+ firstAddLine = false;
+ break;
+ }
+
+ default: { MONGO_UNREACHABLE; }
+ }
+ }
+
+ return os.str();
+ }
+
+ void newLine() {
+ if (!_osDirty) {
+ return;
+ }
+ const std::string& str = _os.str();
+ _cmd.emplace_back(CommandType::AddLine, str);
+ _os.str("");
+ _os.clear();
+ _osDirty = false;
+ }
+
+ const CommandVector& getCommands() const {
+ return _cmd;
+ }
+
+private:
+ template <class P>
+ ExplainPrinterImpl& print(P& other,
+ const bool singleLevel,
+ const std::string& singleLevelSpacer = " ") {
+ CommandVector toAppend;
+ if (_cmdInsertPos >= 0) {
+ toAppend = CommandVector(_cmd.cbegin() + _cmdInsertPos, _cmd.cend());
+ _cmd.resize(static_cast<size_t>(_cmdInsertPos));
+ }
+
+ const bool hadChildrenRemaining = _childrenRemaining > 0;
+ if (hadChildrenRemaining) {
+ _childrenRemaining--;
+ }
+ other.newLine();
+
+ if (singleLevel) {
+ uassert(6624071, "Unexpected dirty status", _osDirty);
+
+ bool first = true;
+ for (const auto& element : other.getCommands()) {
+ if (element._type == CommandType::AddLine) {
+ if (first) {
+ first = false;
+ } else {
+ _os << singleLevelSpacer;
+ }
+ _os << element._str;
+ }
+ }
+ } else {
+ newLine();
+ if (!hadChildrenRemaining) {
+ indent();
+ }
+ for (const auto& element : other.getCommands()) {
+ _cmd.push_back(element);
+ }
+ unIndent();
+ }
+
+ if (_cmdInsertPos >= 0) {
+ std::copy(toAppend.cbegin(), toAppend.cend(), std::back_inserter(_cmd));
+ }
+
+ return *this;
+ }
+
+ void indent(std::string s = " ") {
+ newLine();
+ _indentCount++;
+ _cmd.emplace_back(CommandType::Indent, std::move(s));
+ }
+
+ void unIndent() {
+ newLine();
+ _indentCount--;
+ _cmd.emplace_back(CommandType::Unindent, "");
+ }
+
+ CommandVector _cmd;
+ std::ostringstream _os;
+ bool _osDirty;
+ int _indentCount;
+ int _childrenRemaining;
+ int _cmdInsertPos;
+};
+
+template <>
+class ExplainPrinterImpl<ExplainVersion::V3> {
+ static constexpr ExplainVersion version = ExplainVersion::V3;
+
+public:
+ ExplainPrinterImpl() {
+ reset();
+ }
+
+ ~ExplainPrinterImpl() {
+ if (_initialized) {
+ releaseValue(_tag, _val);
+ }
+ }
+
+ ExplainPrinterImpl(const ExplainPrinterImpl& other) = delete;
+ ExplainPrinterImpl& operator=(const ExplainPrinterImpl& other) = delete;
+
+ ExplainPrinterImpl(ExplainPrinterImpl&& other) noexcept {
+ _nextFieldName = std::move(other._nextFieldName);
+ _initialized = other._initialized;
+ _tag = other._tag;
+ _val = other._val;
+ _fieldNameSet = std::move(other._fieldNameSet);
+
+ other.reset();
+ }
+
+ explicit ExplainPrinterImpl(const std::string& nodeName) : ExplainPrinterImpl() {
+ fieldName("nodeType").print(nodeName);
+ }
+
+ auto moveValue() {
+ auto result = std::pair<sbe::value::TypeTags, sbe::value::Value>(_tag, _val);
+ reset();
+ return result;
+ }
+
+ ExplainPrinterImpl& print(const bool v) {
+ addValue(sbe::value::TypeTags::Boolean, v);
+ return *this;
+ }
+
+ ExplainPrinterImpl& print(const int64_t v) {
+ addValue(sbe::value::TypeTags::NumberInt64, sbe::value::bitcastFrom<int64_t>(v));
+ return *this;
+ }
+
+ ExplainPrinterImpl& print(const int32_t v) {
+ addValue(sbe::value::TypeTags::NumberInt32, sbe::value::bitcastFrom<int32_t>(v));
+ return *this;
+ }
+
+ ExplainPrinterImpl& print(const size_t v) {
+ addValue(sbe::value::TypeTags::NumberInt64, sbe::value::bitcastFrom<size_t>(v));
+ return *this;
+ }
+
+ ExplainPrinterImpl& print(const double v) {
+ addValue(sbe::value::TypeTags::NumberDouble, sbe::value::bitcastFrom<double>(v));
+ return *this;
+ }
+
+ ExplainPrinterImpl& print(const std::pair<sbe::value::TypeTags, sbe::value::Value> v) {
+ auto [tag, val] = sbe::value::copyValue(v.first, v.second);
+ addValue(tag, val);
+ return *this;
+ }
+
+ ExplainPrinterImpl& print(const std::string& s) {
+ auto [tag, val] = sbe::value::makeNewString(s);
+ addValue(tag, val);
+ return *this;
+ }
+
+ ExplainPrinterImpl& print(const char* s) {
+ return print(static_cast<std::string>(s));
+ }
+
+ /**
+ * Here and below: "other" printer(s) may be siphoned out.
+ */
+ ExplainPrinterImpl& print(ExplainPrinterImpl& other) {
+ return print(other, false /*append*/);
+ }
+
+ ExplainPrinterImpl& printSingleLevel(ExplainPrinterImpl& other,
+ const std::string& /*singleLevelSpacer*/ = " ") {
+ // Ignore single level.
+ return print(other);
+ }
+
+ ExplainPrinterImpl& printAppend(ExplainPrinterImpl& other) {
+ return print(other, true /*append*/);
+ }
+
+ ExplainPrinterImpl& print(std::vector<ExplainPrinterImpl>& other) {
+ return print(other, false /*append*/);
+ }
+
+ ExplainPrinterImpl& printAppend(std::vector<ExplainPrinterImpl>& other) {
+ return print(other, true /*append*/);
+ }
+
+ ExplainPrinterImpl& setChildCount(const size_t /*childCount*/) {
+ // Ignored.
+ return *this;
+ }
+
+ ExplainPrinterImpl& maybeReverse() {
+ // Ignored.
+ return *this;
+ }
+
+ ExplainPrinterImpl& fieldName(const std::string& name,
+ const ExplainVersion minVersion = ExplainVersion::V1,
+ const ExplainVersion maxVersion = ExplainVersion::Vmax) {
+ if (minVersion <= version && maxVersion >= version) {
+ _nextFieldName = name;
+ }
+ return *this;
+ }
+
+ ExplainPrinterImpl& separator(const std::string& /*separator*/) {
+ // Ignored.
+ return *this;
+ }
+
+private:
+ ExplainPrinterImpl& print(ExplainPrinterImpl& other, const bool append) {
+ auto [tag, val] = other.moveValue();
+ addValue(tag, val, append);
+ if (append) {
+ sbe::value::releaseValue(tag, val);
+ }
+ return *this;
+ }
+
+ ExplainPrinterImpl& print(std::vector<ExplainPrinterImpl>& other, const bool append) {
+ auto [tag, val] = sbe::value::makeNewArray();
+ sbe::value::Array* arr = sbe::value::getArrayView(val);
+ for (auto&& element : other) {
+ auto [tag1, val1] = element.moveValue();
+ arr->push_back(tag1, val1);
+ }
+ addValue(tag, val, append);
+ return *this;
+ }
+
+ void addValue(sbe::value::TypeTags tag, sbe::value::Value val, const bool append = false) {
+ if (!_initialized) {
+ _initialized = true;
+ _canAppend = !_nextFieldName.empty();
+ if (_canAppend) {
+ std::tie(_tag, _val) = sbe::value::makeNewObject();
+ } else {
+ _tag = tag;
+ _val = val;
+ return;
+ }
+ }
+
+ if (!_canAppend) {
+ uasserted(6624072, "Cannot append to scalar");
+ return;
+ }
+
+ if (append) {
+ uassert(6624073, "Field name is not empty", _nextFieldName.empty());
+ uassert(6624349,
+ "Other printer does not contain Object",
+ tag == sbe::value::TypeTags::Object);
+ sbe::value::Object* obj = sbe::value::getObjectView(val);
+ for (size_t i = 0; i < obj->size(); i++) {
+ const auto field = obj->getAt(i);
+ auto [fieldTag, fieldVal] = sbe::value::copyValue(field.first, field.second);
+ addField(obj->field(i), fieldTag, fieldVal);
+ }
+ } else {
+ addField(_nextFieldName, tag, val);
+ _nextFieldName.clear();
+ }
+ }
+
+ void addField(const std::string& fieldName, sbe::value::TypeTags tag, sbe::value::Value val) {
+ uassert(6624074, "Field name is empty", !fieldName.empty());
+ uassert(6624075, "Duplicate field name", _fieldNameSet.insert(fieldName).second);
+ sbe::value::getObjectView(_val)->push_back(fieldName, tag, val);
+ }
+
+ void reset() {
+ _nextFieldName.clear();
+ _initialized = false;
+ _canAppend = false;
+ _tag = sbe::value::TypeTags::Nothing;
+ _val = 0;
+ _fieldNameSet.clear();
+ }
+
+ std::string _nextFieldName;
+ bool _initialized;
+ bool _canAppend;
+ sbe::value::TypeTags _tag;
+ sbe::value::Value _val;
+ // For debugging.
+ opt::unordered_set<std::string> _fieldNameSet;
+};
+
+template <const ExplainVersion version = kDefaultExplainVersion>
+class ExplainGeneratorTransporter {
+public:
+ using ExplainPrinter = ExplainPrinterImpl<version>;
+
+ ExplainGeneratorTransporter(bool displayProperties = false,
+ const cascades::Memo* memo = nullptr,
+ const NodeToGroupPropsMap& nodeMap = {})
+ : _displayProperties(displayProperties), _memo(memo), _nodeMap(nodeMap) {
+ uassert(6624005,
+ "Memo must be provided in order to display properties.",
+ !_displayProperties || (_memo != nullptr || version == ExplainVersion::V3));
+ }
+
+ /**
+ * Helper function that appends the logical and physical properties of 'node' nested under a new
+ * field named 'properties'. Only applicable for BSON explain, for other versions this is a
+ * no-op.
+ */
+ void maybePrintProps(ExplainPrinter& nodePrinter, const Node& node) {
+ if (!_displayProperties || version != ExplainVersion::V3 || _nodeMap.empty()) {
+ return;
+ }
+ auto it = _nodeMap.find(&node);
+ uassert(6624006, "Failed to find node properties", it != _nodeMap.end());
+
+ const NodeProps& props = it->second;
+
+ ExplainPrinter logPropPrinter = printLogicalProps("logical", props._logicalProps);
+ ExplainPrinter physPropPrinter = printPhysProps("physical", props._physicalProps);
+
+ ExplainPrinter propsPrinter;
+ propsPrinter.fieldName("cost")
+ .print(props._cost.getCost())
+ .fieldName("localCost")
+ .print(props._localCost.getCost())
+ .fieldName("adjustedCE")
+ .print(props._adjustedCE)
+ .fieldName("planNodeID")
+ .print(props._planNodeId)
+ .fieldName("logicalProperties")
+ .print(logPropPrinter)
+ .fieldName("physicalProperties")
+ .print(physPropPrinter);
+ ExplainPrinter res;
+ res.fieldName("properties").print(propsPrinter);
+ nodePrinter.printAppend(res);
+ }
+
+ static void printBooleanFlag(ExplainPrinter& printer,
+ const std::string& name,
+ const bool flag,
+ const bool addComma = true) {
+ if constexpr (version < ExplainVersion::V3) {
+ if (flag) {
+ if (addComma) {
+ printer.print(", ");
+ }
+ printer.print(name);
+ }
+ } else if constexpr (version == ExplainVersion::V3) {
+ printer.fieldName(name).print(flag);
+ } else {
+ static_assert("Unknown version");
+ }
+ }
+
+ static void printDirectToParentHelper(const bool directToParent,
+ ExplainPrinter& parent,
+ std::function<void(ExplainPrinter& printer)> fn) {
+ if (directToParent) {
+ fn(parent);
+ } else {
+ ExplainPrinter printer;
+ fn(printer);
+ parent.printAppend(printer);
+ }
+ }
+
+ /**
+ * Nodes
+ */
+ ExplainPrinter transport(const References& references, std::vector<ExplainPrinter> inResults) {
+ ExplainPrinter printer;
+ printer.separator("RefBlock: ").printAppend(inResults);
+ return printer;
+ }
+
+ ExplainPrinter transport(const ExpressionBinder& binders,
+ std::vector<ExplainPrinter> inResults) {
+ std::map<std::string, ExplainPrinter> ordered;
+ for (size_t idx = 0; idx < inResults.size(); ++idx) {
+ ordered.emplace(binders.names()[idx], std::move(inResults[idx]));
+ }
+
+ ExplainPrinter printer;
+ printer.separator("BindBlock:");
+
+ for (auto& [name, child] : ordered) {
+ if constexpr (version < ExplainVersion::V3) {
+ ExplainPrinter local;
+ local.print("[").print(name).print("]").print(child);
+ printer.print(local);
+ } else if constexpr (version == ExplainVersion::V3) {
+ printer.separator(" ").fieldName(name).print(child);
+ } else {
+ static_assert("Unknown version");
+ }
+ }
+
+ return printer;
+ }
+
+ static void printFieldProjectionMap(ExplainPrinter& printer, const FieldProjectionMap& map) {
+ std::map<FieldNameType, ProjectionName> ordered;
+ if (!map._ridProjection.empty()) {
+ ordered["<rid>"] = map._ridProjection;
+ }
+ if (!map._rootProjection.empty()) {
+ ordered["<root>"] = map._rootProjection;
+ }
+ for (const auto& entry : map._fieldProjections) {
+ ordered.insert(entry);
+ }
+
+ if constexpr (version < ExplainVersion::V3) {
+ bool first = true;
+ for (const auto& [fieldName, projectionName] : ordered) {
+ if (first) {
+ first = false;
+ } else {
+ printer.print(", ");
+ }
+ printer.print("'").print(fieldName).print("': ").print(projectionName);
+ }
+ } else if constexpr (version == ExplainVersion::V3) {
+ ExplainPrinter local;
+ for (const auto& [fieldName, projectionName] : ordered) {
+ local.fieldName(fieldName).print(projectionName);
+ }
+ printer.fieldName("fieldProjectionMap").print(local);
+ } else {
+ static_assert("Unknown version");
+ }
+ }
+
+ ExplainPrinter transport(const ScanNode& node, ExplainPrinter bindResult) {
+ ExplainPrinter printer("Scan");
+ maybePrintProps(printer, node);
+
+ printer.separator(" [")
+ .fieldName("scanDefName", ExplainVersion::V3)
+ .print(node.getScanDefName())
+ .separator("]")
+ .fieldName("bindings", ExplainVersion::V3)
+ .print(bindResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const PhysicalScanNode& node, ExplainPrinter bindResult) {
+ ExplainPrinter printer("PhysicalScan");
+ maybePrintProps(printer, node);
+
+ printer.separator(" [{");
+ printFieldProjectionMap(printer, node.getFieldProjectionMap());
+ printer.separator("}, ")
+ .fieldName("scanDefName", ExplainVersion::V3)
+ .print(node.getScanDefName());
+
+ printBooleanFlag(printer, "parallel", node.useParallelScan());
+
+ printer.separator("]").fieldName("bindings", ExplainVersion::V3).print(bindResult);
+
+ return printer;
+ }
+
+ ExplainPrinter transport(const ValueScanNode& node, ExplainPrinter bindResult) {
+ ExplainPrinter valuePrinter = generate(node.getValueArray());
+
+ ExplainPrinter printer("ValueScan");
+ maybePrintProps(printer, node);
+ printer.separator(" [")
+ .fieldName("arraySize")
+ .print(node.getArraySize())
+ .separator("]")
+ .fieldName("values", ExplainVersion::V3)
+ .print(valuePrinter)
+ .fieldName("bindings", ExplainVersion::V3)
+ .print(bindResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const CoScanNode& node) {
+ ExplainPrinter printer("CoScan");
+ maybePrintProps(printer, node);
+ printer.separator(" []");
+ return printer;
+ }
+
+ void printInterval(ExplainPrinter& printer, const IntervalRequirement& interval) {
+ const BoundRequirement& lowBound = interval.getLowBound();
+ const BoundRequirement& highBound = interval.getHighBound();
+
+ if constexpr (version < ExplainVersion::V3) {
+ const auto printBoundFn = [](ExplainPrinter& printer, const ABT& bound) {
+ // Since we are printing on a single level, use V1 printer in order to avoid
+ // children being reversed.
+ ExplainGeneratorTransporter<ExplainVersion::V1> gen;
+ auto boundPrinter = gen.generate(bound);
+ printer.printSingleLevel(boundPrinter);
+ };
+
+ printer.print(lowBound.isInclusive() ? "[" : "(");
+ if (lowBound.isInfinite()) {
+ printer.print("-inf");
+ } else {
+ printBoundFn(printer, lowBound.getBound());
+ }
+
+ printer.print(", ");
+ if (highBound.isInfinite()) {
+ printer.print("+inf");
+ } else {
+ printBoundFn(printer, highBound.getBound());
+ }
+
+ printer.print(highBound.isInclusive() ? "]" : ")");
+ } else if constexpr (version == ExplainVersion::V3) {
+ ExplainPrinter lowBoundPrinter;
+ lowBoundPrinter.fieldName("inclusive")
+ .print(lowBound.isInclusive())
+ .fieldName("infinite")
+ .print(lowBound.isInfinite());
+ if (!lowBound.isInfinite()) {
+ ExplainPrinter boundPrinter = generate(lowBound.getBound());
+ lowBoundPrinter.fieldName("bound").print(boundPrinter);
+ }
+
+ ExplainPrinter highBoundPrinter;
+ highBoundPrinter.fieldName("inclusive")
+ .print(highBound.isInclusive())
+ .fieldName("infinite")
+ .print(highBound.isInfinite());
+ if (!highBound.isInfinite()) {
+ ExplainPrinter boundPrinter = generate(highBound.getBound());
+ highBoundPrinter.fieldName("bound").print(boundPrinter);
+ }
+
+ printer.fieldName("lowBound")
+ .print(lowBoundPrinter)
+ .fieldName("highBound")
+ .print(highBoundPrinter);
+ } else {
+ static_assert("Version not implemented");
+ }
+ }
+
+ std::string printInterval(const IntervalRequirement& interval) {
+ ExplainPrinter printer;
+ printInterval(printer, interval);
+ return printer.str();
+ }
+
+ ExplainPrinter printIntervalExpr(const IntervalReqExpr::Node& intervalExpr) {
+ IntervalPrinter<IntervalReqExpr> intervalPrinter(*this);
+ return intervalPrinter.print(intervalExpr);
+ }
+
+ void printInterval(ExplainPrinter& printer, const MultiKeyIntervalRequirement& interval) {
+ if constexpr (version < ExplainVersion::V3) {
+ bool first = true;
+ for (const auto& entry : interval) {
+ if (first) {
+ first = false;
+ } else {
+ printer.print(", ");
+ }
+ printInterval(printer, entry);
+ }
+ } else if constexpr (version == ExplainVersion::V3) {
+ std::vector<ExplainPrinter> printers;
+ for (const auto& entry : interval) {
+ ExplainPrinter local;
+ printInterval(local, entry);
+ printers.push_back(std::move(local));
+ }
+ printer.print(printers);
+ } else {
+ static_assert("Version not implemented");
+ }
+ }
+
+ template <class T>
+ class IntervalPrinter {
+ public:
+ IntervalPrinter(ExplainGeneratorTransporter& instance) : _instance(instance) {}
+
+ ExplainPrinter transport(const typename T::Atom& node) {
+ ExplainPrinter printer;
+ printer.separator("{");
+ _instance.printInterval(printer, node.getExpr());
+ printer.separator("}");
+ return printer;
+ }
+
+ template <bool isConjunction>
+ ExplainPrinter print(std::vector<ExplainPrinter> childResults) {
+ if constexpr (version < ExplainVersion::V3) {
+ ExplainPrinter printer;
+ printer.separator("{");
+
+ bool first = true;
+ for (ExplainPrinter& child : childResults) {
+ if (first) {
+ first = false;
+ } else if constexpr (isConjunction) {
+ printer.print(" ^ ");
+ } else {
+ printer.print(" U ");
+ }
+ printer.print(child);
+ }
+ printer.separator("}");
+
+ return printer;
+ } else if constexpr (version == ExplainVersion::V3) {
+ ExplainPrinter printer;
+ if constexpr (isConjunction) {
+ printer.fieldName("conjunction");
+ } else {
+ printer.fieldName("disjunction");
+ }
+ printer.print(childResults);
+ return printer;
+ } else {
+ static_assert("Version not implemented");
+ }
+ }
+
+ ExplainPrinter transport(const typename T::Conjunction& node,
+ std::vector<ExplainPrinter> childResults) {
+ return print<true /*isConjunction*/>(std::move(childResults));
+ }
+
+ ExplainPrinter transport(const typename T::Disjunction& node,
+ std::vector<ExplainPrinter> childResults) {
+ return print<false /*isConjunction*/>(std::move(childResults));
+ }
+
+ ExplainPrinter print(const typename T::Node& intervals) {
+ return algebra::transport<false>(intervals, *this);
+ }
+
+ private:
+ ExplainGeneratorTransporter& _instance;
+ };
+
+ ExplainPrinter transport(const IndexScanNode& node, ExplainPrinter bindResult) {
+ ExplainPrinter printer("IndexScan");
+ maybePrintProps(printer, node);
+
+ printer.separator(" [{");
+ printFieldProjectionMap(printer, node.getFieldProjectionMap());
+ printer.separator("}, ");
+
+ const auto& spec = node.getIndexSpecification();
+ printer.fieldName("scanDefName")
+ .print(spec.getScanDefName())
+ .separator(", ")
+ .fieldName("indexDefName")
+ .print(spec.getIndexDefName())
+ .separator(", ");
+
+ printer.fieldName("interval").separator("{");
+ printInterval(printer, spec.getInterval());
+ printer.separator("}");
+
+ printBooleanFlag(printer, "reversed", spec.isReverseOrder());
+
+ printer.separator("]").fieldName("bindings", ExplainVersion::V3).print(bindResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const SeekNode& node,
+ ExplainPrinter bindResult,
+ ExplainPrinter refsResult) {
+ ExplainPrinter printer("Seek");
+ maybePrintProps(printer, node);
+
+ printer.separator(" [")
+ .fieldName("ridProjection")
+ .print(node.getRIDProjectionName())
+ .separator(", {");
+ printFieldProjectionMap(printer, node.getFieldProjectionMap());
+ printer.separator("}, ")
+ .fieldName("scanDefName", ExplainVersion::V3)
+ .print(node.getScanDefName())
+ .separator("]")
+ .setChildCount(2)
+ .fieldName("bindings", ExplainVersion::V3)
+ .print(bindResult)
+ .fieldName("references", ExplainVersion::V3)
+ .print(refsResult);
+
+ return printer;
+ }
+
+ ExplainPrinter transport(const MemoLogicalDelegatorNode& node) {
+ ExplainPrinter printer("MemoLogicalDelegator");
+ maybePrintProps(printer, node);
+ printer.separator(" [").fieldName("groupId").print(node.getGroupId()).separator("]");
+ return printer;
+ }
+
+ ExplainPrinter transport(const MemoPhysicalDelegatorNode& node) {
+ const auto id = node.getNodeId();
+
+ if (_displayProperties) {
+ const auto& group = _memo->getGroup(id._groupId);
+ const auto& result = group._physicalNodes.at(id._index);
+ uassert(6624076,
+ "Physical delegator must be pointing to an optimized result.",
+ result._nodeInfo.has_value());
+
+ const auto& nodeInfo = *result._nodeInfo;
+ const ABT& n = nodeInfo._node;
+
+ ExplainPrinter nodePrinter = generate(n);
+ if (n.template is<MemoPhysicalDelegatorNode>()) {
+ // Handle delegation.
+ return nodePrinter;
+ }
+
+ ExplainPrinter logPropPrinter = printLogicalProps("Logical", group._logicalProperties);
+ ExplainPrinter physPropPrinter = printPhysProps("Physical", result._physProps);
+
+ ExplainPrinter printer("Properties");
+ printer.separator(" [")
+ .fieldName("cost")
+ .print(nodeInfo._cost.getCost())
+ .separator(", ")
+ .fieldName("localCost")
+ .print(nodeInfo._localCost.getCost())
+ .separator(", ")
+ .fieldName("adjustedCE")
+ .print(nodeInfo._adjustedCE)
+ .separator("]")
+ .setChildCount(3)
+ .fieldName("logicalProperties", ExplainVersion::V3)
+ .print(logPropPrinter)
+ .fieldName("physicalProperties", ExplainVersion::V3)
+ .print(physPropPrinter)
+ .fieldName("node", ExplainVersion::V3)
+ .print(nodePrinter);
+ return printer;
+ }
+
+ ExplainPrinter printer("MemoPhysicalDelegator");
+ printer.separator(" [")
+ .fieldName("groupId")
+ .print(id._groupId)
+ .separator(", ")
+ .fieldName("index")
+ .print(id._index)
+ .separator("]");
+ return printer;
+ }
+
+ ExplainPrinter transport(const FilterNode& node,
+ ExplainPrinter childResult,
+ ExplainPrinter filterResult) {
+ ExplainPrinter printer("Filter");
+ maybePrintProps(printer, node);
+ printer.separator(" []")
+ .setChildCount(2)
+ .fieldName("filter", ExplainVersion::V3)
+ .print(filterResult)
+ .fieldName("child", ExplainVersion::V3)
+ .print(childResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const EvaluationNode& node,
+ ExplainPrinter childResult,
+ ExplainPrinter projectionResult) {
+ ExplainPrinter printer("Evaluation");
+ maybePrintProps(printer, node);
+ printer.separator(" []")
+ .setChildCount(2)
+ .fieldName("projection", ExplainVersion::V3)
+ .print(projectionResult)
+ .fieldName("child", ExplainVersion::V3)
+ .print(childResult);
+ return printer;
+ }
+
+ void printPartialSchemaReqMap(ExplainPrinter& parent, const PartialSchemaRequirements& reqMap) {
+ std::vector<ExplainPrinter> printers;
+ for (const auto& [key, req] : reqMap) {
+ ExplainPrinter local;
+
+ local.fieldName("refProjection").print(key._projectionName).separator(", ");
+ ExplainPrinter pathPrinter = generate(key._path);
+ local.fieldName("path").separator("'").printSingleLevel(pathPrinter).separator("', ");
+
+ if (req.hasBoundProjectionName()) {
+ local.fieldName("boundProjection")
+ .print(req.getBoundProjectionName())
+ .separator(", ");
+ }
+
+ local.fieldName("intervals");
+ {
+ ExplainPrinter intervals = printIntervalExpr(req.getIntervals());
+ local.printSingleLevel(intervals, "" /*singleLevelSpacer*/);
+ }
+
+ printers.push_back(std::move(local));
+ }
+
+ parent.fieldName("requirementsMap").print(printers);
+ }
+
+ ExplainPrinter transport(const SargableNode& node,
+ ExplainPrinter childResult,
+ ExplainPrinter bindResult,
+ ExplainPrinter refsResult) {
+ ExplainPrinter printer("Sargable");
+ maybePrintProps(printer, node);
+ printer.separator(" [")
+ .fieldName("target", ExplainVersion::V3)
+ .print(IndexReqTargetEnum::toString[static_cast<int>(node.getTarget())])
+ .separator("]")
+ .setChildCount(5);
+
+ if constexpr (version < ExplainVersion::V3) {
+ ExplainPrinter local;
+ printPartialSchemaReqMap(local, node.getReqMap());
+ printer.print(local);
+ } else if constexpr (version == ExplainVersion::V3) {
+ printPartialSchemaReqMap(printer, node.getReqMap());
+ } else {
+ static_assert("Unknown version");
+ }
+
+ std::set<std::string> orderedIndexDefName;
+ for (const auto& entry : node.getCandidateIndexMap()) {
+ orderedIndexDefName.insert(entry.first);
+ }
+
+ std::vector<ExplainPrinter> candidateIndexesPrinters;
+ size_t candidateIndex = 0;
+ for (const auto& indexDefName : orderedIndexDefName) {
+ candidateIndex++;
+ ExplainPrinter local;
+ local.fieldName("candidateId")
+ .print(candidateIndex)
+ .separator(", ")
+ .fieldName("indexDefName", ExplainVersion::V3)
+ .print(indexDefName)
+ .separator(", ");
+
+ const auto& candidateIndexEntry = node.getCandidateIndexMap().at(indexDefName);
+ local.separator("{");
+ printFieldProjectionMap(local, candidateIndexEntry._fieldProjectionMap);
+ local.separator("}, {");
+
+ {
+ std::set<size_t> orderedFields;
+ for (const size_t fieldId : candidateIndexEntry._fieldsToCollate) {
+ orderedFields.insert(fieldId);
+ }
+
+ if constexpr (version < ExplainVersion::V3) {
+ bool first = true;
+ for (const size_t fieldId : orderedFields) {
+ if (first) {
+ first = false;
+ } else {
+ local.print(", ");
+ }
+ local.print(fieldId);
+ }
+ } else if constexpr (version == ExplainVersion::V3) {
+ std::vector<ExplainPrinter> printers;
+ for (const size_t fieldId : orderedFields) {
+ ExplainPrinter local1;
+ local1.print(fieldId);
+ printers.push_back(std::move(local1));
+ }
+ local.fieldName("fieldsToCollate").print(printers);
+ } else {
+ static_assert("Unknown version");
+ }
+ }
+
+ local.separator("}, ").fieldName("intervals", ExplainVersion::V3);
+ {
+ IntervalPrinter<MultiKeyIntervalReqExpr> intervalPrinter(*this);
+ ExplainPrinter intervals = intervalPrinter.print(candidateIndexEntry._intervals);
+ local.printSingleLevel(intervals, "" /*singleLevelSpacer*/);
+ }
+
+ if (!candidateIndexEntry._residualRequirements.empty()) {
+ if constexpr (version < ExplainVersion::V3) {
+ ExplainPrinter residualReqMapPrinter;
+ printPartialSchemaReqMap(residualReqMapPrinter,
+ candidateIndexEntry._residualRequirements);
+ local.print(residualReqMapPrinter);
+ } else if (version == ExplainVersion::V3) {
+ printPartialSchemaReqMap(local, candidateIndexEntry._residualRequirements);
+ } else {
+ static_assert("Unknown version");
+ }
+ }
+
+ if (!candidateIndexEntry._residualKeyMap.empty()) {
+ std::vector<ExplainPrinter> residualKeyMapPrinters;
+ for (const auto& [queryKey, residualKey] : candidateIndexEntry._residualKeyMap) {
+ ExplainPrinter local1;
+
+ ExplainPrinter pathPrinter = generate(queryKey._path);
+ local1.fieldName("queryRefProjection")
+ .print(queryKey._projectionName)
+ .separator(", ")
+ .fieldName("queryPath")
+ .separator("'")
+ .printSingleLevel(pathPrinter)
+ .separator("', ")
+ .fieldName("residualRefProjection")
+ .print(residualKey._projectionName)
+ .separator(", ");
+
+ ExplainPrinter pathPrinter1 = generate(residualKey._path);
+ local1.fieldName("residualPath")
+ .separator("'")
+ .printSingleLevel(pathPrinter1)
+ .separator("'");
+
+ residualKeyMapPrinters.push_back(std::move(local1));
+ }
+
+ local.fieldName("residualKeyMap").print(residualKeyMapPrinters);
+
+ std::vector<ExplainPrinter> projNamePrinters;
+ for (const ProjectionName& projName :
+ candidateIndexEntry._residualRequirementsTempProjections) {
+ ExplainPrinter local1;
+ local1.print(projName);
+ projNamePrinters.push_back(std::move(local1));
+ }
+ local.fieldName("tempProjections").print(projNamePrinters);
+ }
+
+ candidateIndexesPrinters.push_back(std::move(local));
+ }
+ ExplainPrinter candidateIndexesPrinter;
+ candidateIndexesPrinter.fieldName("candidateIndexes").print(candidateIndexesPrinters);
+
+ printer.printAppend(candidateIndexesPrinter)
+ .fieldName("bindings", ExplainVersion::V3)
+ .print(bindResult)
+ .fieldName("references", ExplainVersion::V3)
+ .print(refsResult)
+ .fieldName("child", ExplainVersion::V3)
+ .print(childResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const RIDIntersectNode& node,
+ ExplainPrinter leftChildResult,
+ ExplainPrinter rightChildResult) {
+ ExplainPrinter printer("RIDIntersect");
+ maybePrintProps(printer, node);
+
+ printer.separator(" [")
+ .fieldName("scanProjectionName", ExplainVersion::V3)
+ .print(node.getScanProjectionName());
+ printBooleanFlag(printer, "hasLeftIntervals", node.hasLeftIntervals());
+ printBooleanFlag(printer, "hasRightIntervals", node.hasRightIntervals());
+
+ printer.separator("]")
+ .setChildCount(2)
+ .maybeReverse()
+ .fieldName("leftChild", ExplainVersion::V3)
+ .print(leftChildResult)
+ .fieldName("rightChild", ExplainVersion::V3)
+ .print(rightChildResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const BinaryJoinNode& node,
+ ExplainPrinter leftChildResult,
+ ExplainPrinter rightChildResult,
+ ExplainPrinter filterResult) {
+ ExplainPrinter printer("BinaryJoin");
+ maybePrintProps(printer, node);
+
+ printer.separator(" [")
+ .fieldName("joinType")
+ .print(JoinTypeEnum::toString[static_cast<int>(node.getJoinType())]);
+
+ if constexpr (version < ExplainVersion::V3) {
+ if (!node.getCorrelatedProjectionNames().empty()) {
+ printer.print(", {");
+ bool first = true;
+ for (const ProjectionName& projectionName : node.getCorrelatedProjectionNames()) {
+ if (first) {
+ first = false;
+ } else {
+ printer.print(", ");
+ }
+ printer.print(projectionName);
+ }
+ printer.print("}");
+ }
+ } else if constexpr (version == ExplainVersion::V3) {
+ std::vector<ExplainPrinter> printers;
+ for (const ProjectionName& projectionName : node.getCorrelatedProjectionNames()) {
+ ExplainPrinter local;
+ local.print(projectionName);
+ printers.push_back(std::move(local));
+ }
+ printer.fieldName("correlatedProjections").print(printers);
+ } else {
+ static_assert("Version not implemented");
+ }
+
+ printer.separator("]")
+ .setChildCount(3)
+ .fieldName("expression", ExplainVersion::V3)
+ .print(filterResult)
+ .maybeReverse()
+ .fieldName("leftChild", ExplainVersion::V3)
+ .print(leftChildResult)
+ .fieldName("rightChild", ExplainVersion::V3)
+ .print(rightChildResult);
+ return printer;
+ }
+
+ void printEqualityJoinCondition(ExplainPrinter& printer,
+ const ProjectionNameVector& leftKeys,
+ const ProjectionNameVector& rightKeys) {
+ if constexpr (version < ExplainVersion::V3) {
+ printer.print("Condition");
+ for (size_t i = 0; i < leftKeys.size(); i++) {
+ ExplainPrinter local;
+ local.print(leftKeys.at(i)).print(" = ").print(rightKeys.at(i));
+ printer.print(local);
+ }
+ } else if constexpr (version == ExplainVersion::V3) {
+ std::vector<ExplainPrinter> printers;
+ for (size_t i = 0; i < leftKeys.size(); i++) {
+ ExplainPrinter local;
+ local.fieldName("leftKey")
+ .print(leftKeys.at(i))
+ .fieldName("rightKey")
+ .print(rightKeys.at(i));
+ printers.push_back(std::move(local));
+ }
+ printer.print(printers);
+ } else {
+ static_assert("Version not implemented");
+ }
+ }
+
+ ExplainPrinter transport(const HashJoinNode& node,
+ ExplainPrinter leftChildResult,
+ ExplainPrinter rightChildResult,
+ ExplainPrinter /*refsResult*/) {
+ ExplainPrinter printer("HashJoin");
+ maybePrintProps(printer, node);
+
+ printer.separator(" [")
+ .fieldName("joinType")
+ .print(JoinTypeEnum::toString[static_cast<int>(node.getJoinType())])
+ .separator("]");
+
+ ExplainPrinter joinConditionPrinter;
+ printEqualityJoinCondition(joinConditionPrinter, node.getLeftKeys(), node.getRightKeys());
+
+ printer.setChildCount(3)
+ .fieldName("joinCondition", ExplainVersion::V3)
+ .print(joinConditionPrinter)
+ .maybeReverse()
+ .fieldName("leftChild", ExplainVersion::V3)
+ .print(leftChildResult)
+ .fieldName("rightChild", ExplainVersion::V3)
+ .print(rightChildResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const MergeJoinNode& node,
+ ExplainPrinter leftChildResult,
+ ExplainPrinter rightChildResult,
+ ExplainPrinter /*refsResult*/) {
+ ExplainPrinter printer("MergeJoin");
+ maybePrintProps(printer, node);
+ printer.separator(" []");
+
+ ExplainPrinter joinConditionPrinter;
+ printEqualityJoinCondition(joinConditionPrinter, node.getLeftKeys(), node.getRightKeys());
+
+ ExplainPrinter collationPrinter;
+ if constexpr (version < ExplainVersion::V3) {
+ collationPrinter.print("Collation");
+ for (const CollationOp op : node.getCollation()) {
+ ExplainPrinter local;
+ local.print(CollationOpEnum::toString[static_cast<int>(op)]);
+ collationPrinter.print(local);
+ }
+ } else if constexpr (version == ExplainVersion::V3) {
+ std::vector<ExplainPrinter> printers;
+ for (const CollationOp op : node.getCollation()) {
+ ExplainPrinter local;
+ local.print(CollationOpEnum::toString[static_cast<int>(op)]);
+ printers.push_back(std::move(local));
+ }
+ collationPrinter.print(printers);
+ } else {
+ static_assert("Version not implemented");
+ }
+
+ printer.setChildCount(4)
+ .fieldName("joinCondition", ExplainVersion::V3)
+ .print(joinConditionPrinter)
+ .fieldName("collation", ExplainVersion::V3)
+ .print(collationPrinter)
+ .maybeReverse()
+ .fieldName("leftChild", ExplainVersion::V3)
+ .print(leftChildResult)
+ .fieldName("rightChild", ExplainVersion::V3)
+ .print(rightChildResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const UnionNode& node,
+ std::vector<ExplainPrinter> childResults,
+ ExplainPrinter bindResult,
+ ExplainPrinter /*refsResult*/) {
+ ExplainPrinter printer("Union");
+ maybePrintProps(printer, node);
+ printer.separator(" []")
+ .setChildCount(childResults.size() + 1)
+ .fieldName("bindings", ExplainVersion::V3)
+ .print(bindResult)
+ .maybeReverse()
+ .fieldName("children", ExplainVersion::V3)
+ .print(childResults);
+ return printer;
+ }
+
+ ExplainPrinter transport(const GroupByNode& node,
+ ExplainPrinter childResult,
+ ExplainPrinter bindAggResult,
+ ExplainPrinter refsAggResult,
+ ExplainPrinter bindGbResult,
+ ExplainPrinter refsGbResult) {
+ std::map<ProjectionName, size_t> ordered;
+ const ProjectionNameVector& aggProjectionNames = node.getAggregationProjectionNames();
+ for (size_t i = 0; i < aggProjectionNames.size(); i++) {
+ ordered.emplace(aggProjectionNames.at(i), i);
+ }
+
+ ExplainPrinter printer("GroupBy");
+ maybePrintProps(printer, node);
+ printer.separator(" [");
+ if (version >= ExplainVersion::V3 || node.getType() != GroupNodeType::Complete) {
+ printer.fieldName("type", ExplainVersion::V3)
+ .print(GroupNodeTypeEnum::toString[static_cast<int>(node.getType())]);
+ }
+ printer.separator("]");
+
+ std::vector<ExplainPrinter> aggPrinters;
+ for (const auto& [projectionName, index] : ordered) {
+ ExplainPrinter local;
+ local.separator("[")
+ .fieldName("projectionName", ExplainVersion::V3)
+ .print(projectionName)
+ .separator("]");
+ ExplainPrinter aggExpr = generate(node.getAggregationExpressions().at(index));
+ local.fieldName("aggregation", ExplainVersion::V3).print(aggExpr);
+ aggPrinters.push_back(std::move(local));
+ }
+
+ ExplainPrinter gbPrinter;
+ gbPrinter.fieldName("groupings").print(refsGbResult);
+
+ ExplainPrinter aggPrinter;
+ aggPrinter.fieldName("aggregations").print(aggPrinters);
+
+ printer.setChildCount(3)
+ .printAppend(gbPrinter)
+ .printAppend(aggPrinter)
+ .fieldName("child", ExplainVersion::V3)
+ .print(childResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const UnwindNode& node,
+ ExplainPrinter childResult,
+ ExplainPrinter bindResult,
+ ExplainPrinter refsResult) {
+ ExplainPrinter printer("Unwind");
+ maybePrintProps(printer, node);
+
+ printer.separator(" [");
+ printBooleanFlag(printer, "retainNonArrays", node.getRetainNonArrays(), false /*addComma*/);
+ printer.separator("]");
+
+ printer.setChildCount(2)
+ .fieldName("bind", ExplainVersion::V3)
+ .print(bindResult)
+ .fieldName("child", ExplainVersion::V3)
+ .print(childResult);
+ return printer;
+ }
+
+ static void printCollationProperty(ExplainPrinter& parent,
+ const properties::CollationRequirement& property,
+ const bool directToParent) {
+ std::vector<ExplainPrinter> propPrinters;
+ for (const auto& entry : property.getCollationSpec()) {
+ ExplainPrinter local;
+ local.fieldName("projectionName", ExplainVersion::V3)
+ .print(entry.first)
+ .separator(": ")
+ .fieldName("collationOp", ExplainVersion::V3)
+ .print(CollationOpEnum::toString[static_cast<int>(entry.second)]);
+ propPrinters.push_back(std::move(local));
+ }
+
+ printDirectToParentHelper(directToParent, parent, [&](ExplainPrinter& printer) {
+ printer.fieldName("collation").print(propPrinters);
+ });
+ }
+
+ ExplainPrinter transport(const UniqueNode& node,
+ ExplainPrinter childResult,
+ ExplainPrinter /*refsResult*/) {
+ ExplainPrinter printer("Unique");
+ maybePrintProps(printer, node);
+
+ printer.separator(" []").setChildCount(2);
+ printPropertyProjections(printer, node.getProjections(), false /*directToParent*/);
+ printer.fieldName("child", ExplainVersion::V3).print(childResult);
+
+ return printer;
+ }
+
+ ExplainPrinter transport(const CollationNode& node,
+ ExplainPrinter childResult,
+ ExplainPrinter refsResult) {
+ ExplainPrinter printer("Collation");
+ maybePrintProps(printer, node);
+
+ printer.separator(" []").setChildCount(3);
+ printCollationProperty(printer, node.getProperty(), false /*directToParent*/);
+ printer.fieldName("references", ExplainVersion::V3)
+ .print(refsResult)
+ .fieldName("child", ExplainVersion::V3)
+ .print(childResult);
+
+ return printer;
+ }
+
+ static void printLimitSkipProperty(ExplainPrinter& propPrinter,
+ ExplainPrinter& limitPrinter,
+ ExplainPrinter& skipPrinter,
+ const properties::LimitSkipRequirement& property) {
+ propPrinter.fieldName("propType", ExplainVersion::V3)
+ .print("limitSkip")
+ .separator(":")
+ .printAppend(limitPrinter)
+ .printAppend(skipPrinter);
+ }
+
+ static void printLimitSkipProperty(ExplainPrinter& parent,
+ const properties::LimitSkipRequirement& property,
+ const bool directToParent) {
+ ExplainPrinter limitPrinter;
+ limitPrinter.fieldName("limit");
+ if (property.hasLimit()) {
+ limitPrinter.print(property.getLimit());
+ } else {
+ limitPrinter.print("(none)");
+ }
+
+ ExplainPrinter skipPrinter;
+ skipPrinter.fieldName("skip").print(property.getSkip());
+
+ printDirectToParentHelper(directToParent, parent, [&](ExplainPrinter& printer) {
+ printLimitSkipProperty(printer, limitPrinter, skipPrinter, property);
+ });
+ }
+
+ ExplainPrinter transport(const LimitSkipNode& node, ExplainPrinter childResult) {
+ ExplainPrinter printer("LimitSkip");
+ maybePrintProps(printer, node);
+
+ printer.separator(" []").setChildCount(2);
+ printLimitSkipProperty(printer, node.getProperty(), false /*directToParent*/);
+ printer.fieldName("child", ExplainVersion::V3).print(childResult);
+
+ return printer;
+ }
+
+ static void printPropertyProjections(ExplainPrinter& parent,
+ const ProjectionNameVector& projections,
+ const bool directToParent) {
+ std::vector<ExplainPrinter> printers;
+ for (const ProjectionName& projection : projections) {
+ ExplainPrinter local;
+ local.print(projection);
+ printers.push_back(std::move(local));
+ }
+
+ printDirectToParentHelper(directToParent, parent, [&](ExplainPrinter& printer) {
+ printer.fieldName("projections");
+ if (printers.empty()) {
+ ExplainPrinter dummy;
+ printer.print(dummy);
+ } else {
+ printer.print(printers);
+ }
+ });
+ }
+
+ static void printDistributionProperty(ExplainPrinter& parent,
+ const properties::DistributionRequirement& property,
+ const bool directToParent) {
+ const auto& distribAndProjections = property.getDistributionAndProjections();
+
+ ExplainPrinter typePrinter;
+ typePrinter.fieldName("type").print(
+ DistributionTypeEnum::toString[static_cast<int>(distribAndProjections._type)]);
+
+ printBooleanFlag(typePrinter, "disableExchanges", property.getDisableExchanges());
+
+ const bool hasProjections = !distribAndProjections._projectionNames.empty();
+ ExplainPrinter projectionPrinter;
+ if (hasProjections) {
+ printPropertyProjections(
+ projectionPrinter, distribAndProjections._projectionNames, true /*directToParent*/);
+ typePrinter.printAppend(projectionPrinter);
+ }
+
+ printDirectToParentHelper(directToParent, parent, [&](ExplainPrinter& printer) {
+ printer.fieldName("distribution").print(typePrinter);
+ });
+ }
+
+ static void printProjectionRequirementProperty(
+ ExplainPrinter& parent,
+ const properties::ProjectionRequirement& property,
+ const bool directToParent) {
+ printPropertyProjections(parent, property.getProjections().getVector(), directToParent);
+ }
+
+ ExplainPrinter transport(const ExchangeNode& node,
+ ExplainPrinter childResult,
+ ExplainPrinter refsResult) {
+ ExplainPrinter printer("Exchange");
+ maybePrintProps(printer, node);
+
+ printer.separator(" []").setChildCount(3);
+ printDistributionProperty(printer, node.getProperty(), false /*directToParent*/);
+ printer.fieldName("references", ExplainVersion::V3)
+ .print(refsResult)
+ .fieldName("child", ExplainVersion::V3)
+ .print(childResult);
+
+ return printer;
+ }
+
+ struct LogicalPropPrintVisitor {
+ LogicalPropPrintVisitor(ExplainPrinter& parent) : _parent(parent){};
+
+ void operator()(const properties::LogicalProperty&,
+ const properties::ProjectionAvailability& prop) {
+ ProjectionNameOrderedSet ordered;
+ for (const ProjectionName& projection : prop.getProjections()) {
+ ordered.insert(projection);
+ }
+
+ std::vector<ExplainPrinter> printers;
+ for (const ProjectionName& projection : ordered) {
+ ExplainPrinter local;
+ local.print(projection);
+ printers.push_back(std::move(local));
+ }
+ _parent.fieldName("projections").print(printers);
+ }
+
+ void operator()(const properties::LogicalProperty&,
+ const properties::CardinalityEstimate& prop) {
+ std::vector<ExplainPrinter> fieldPrinters;
+
+ ExplainPrinter cePrinter;
+ cePrinter.fieldName("ce").print(prop.getEstimate());
+ fieldPrinters.push_back(std::move(cePrinter));
+
+ if (!prop.getPartialSchemaKeyCEMap().empty()) {
+ std::vector<ExplainPrinter> reqPrinters;
+ for (const auto& [key, ce] : prop.getPartialSchemaKeyCEMap()) {
+ ExplainGeneratorTransporter<version> gen;
+ ExplainPrinter pathPrinter = gen.generate(key._path);
+
+ ExplainPrinter local;
+ local.fieldName("refProjection")
+ .print(key._projectionName)
+ .separator(", ")
+ .fieldName("path")
+ .separator("'")
+ .printSingleLevel(pathPrinter)
+ .separator("', ")
+ .fieldName("ce")
+ .print(ce);
+ reqPrinters.push_back(std::move(local));
+ }
+ ExplainPrinter requirementsPrinter;
+ requirementsPrinter.fieldName("requirementCEs").print(reqPrinters);
+ fieldPrinters.push_back(std::move(requirementsPrinter));
+ }
+
+ _parent.fieldName("cardinalityEstimate").print(fieldPrinters);
+ }
+
+ void operator()(const properties::LogicalProperty&,
+ const properties::IndexingAvailability& prop) {
+ ExplainPrinter printer;
+ printer.separator("[")
+ .fieldName("groupId")
+ .print(prop.getScanGroupId())
+ .separator(", ")
+ .fieldName("scanProjection")
+ .print(prop.getScanProjection())
+ .separator(", ")
+ .fieldName("scanDefName")
+ .print(prop.getScanDefName());
+ printBooleanFlag(printer, "possiblyEqPredsOnly", prop.getPossiblyEqPredsOnly());
+ printer.separator("]");
+
+ if (!prop.getSatisfiedPartialIndexes().empty()) {
+ std::set<std::string> ordered;
+ for (const auto& indexName : prop.getSatisfiedPartialIndexes()) {
+ ordered.insert(indexName);
+ }
+
+ std::vector<ExplainPrinter> printers;
+ for (const auto& indexName : ordered) {
+ ExplainPrinter local;
+ local.print(indexName);
+ printers.push_back(std::move(local));
+ }
+ printer.fieldName("satisfiedPartialIndexes").print(printers);
+ }
+
+ _parent.fieldName("indexingAvailability").print(printer);
+ }
+
+ void operator()(const properties::LogicalProperty&,
+ const properties::CollectionAvailability& prop) {
+ std::set<std::string> orderedSet;
+ for (const std::string& scanDef : prop.getScanDefSet()) {
+ orderedSet.insert(scanDef);
+ }
+
+ std::vector<ExplainPrinter> printers;
+ for (const std::string& scanDef : orderedSet) {
+ ExplainPrinter local;
+ local.print(scanDef);
+ printers.push_back(std::move(local));
+ }
+ if (printers.empty()) {
+ ExplainPrinter dummy;
+ printers.push_back(std::move(dummy));
+ }
+
+ _parent.fieldName("collectionAvailability").print(printers);
+ }
+
+ void operator()(const properties::LogicalProperty&,
+ const properties::DistributionAvailability& prop) {
+ struct Comparator {
+ bool operator()(const properties::DistributionRequirement& d1,
+ const properties::DistributionRequirement& d2) const {
+ const auto& distr1 = d1.getDistributionAndProjections();
+ const auto& distr2 = d2.getDistributionAndProjections();
+
+ if (distr1._type < distr2._type) {
+ return true;
+ }
+ if (distr1._type > distr2._type) {
+ return false;
+ }
+ return distr1._projectionNames < distr2._projectionNames;
+ }
+ };
+
+ std::set<properties::DistributionRequirement, Comparator> ordered;
+ for (const auto& distributionProp : prop.getDistributionSet()) {
+ ordered.insert(distributionProp);
+ }
+
+ std::vector<ExplainPrinter> printers;
+ for (const auto& distributionProp : ordered) {
+ ExplainPrinter local;
+ printDistributionProperty(local, distributionProp, true /*directToParent*/);
+ printers.push_back(std::move(local));
+ }
+ _parent.fieldName("distributionAvailability").print(printers);
+ }
+
+ private:
+ // We don't own this.
+ ExplainPrinter& _parent;
+ };
+
+ struct PhysPropPrintVisitor {
+ PhysPropPrintVisitor(ExplainPrinter& parent) : _parent(parent){};
+
+ void operator()(const properties::PhysProperty&,
+ const properties::CollationRequirement& prop) {
+ printCollationProperty(_parent, prop, true /*directToParent*/);
+ }
+
+ void operator()(const properties::PhysProperty&,
+ const properties::LimitSkipRequirement& prop) {
+ printLimitSkipProperty(_parent, prop, true /*directToParent*/);
+ }
+
+ void operator()(const properties::PhysProperty&,
+ const properties::ProjectionRequirement& prop) {
+ printProjectionRequirementProperty(_parent, prop, true /*directToParent*/);
+ }
+
+ void operator()(const properties::PhysProperty&,
+ const properties::DistributionRequirement& prop) {
+ printDistributionProperty(_parent, prop, true /*directToParent*/);
+ }
+
+ void operator()(const properties::PhysProperty&,
+ const properties::IndexingRequirement& prop) {
+ ExplainPrinter printer;
+
+ printer.fieldName("target", ExplainVersion::V3)
+ .print(IndexReqTargetEnum::toString[static_cast<int>(prop.getIndexReqTarget())]);
+ printBooleanFlag(printer, "needsRID", prop.getNeedsRID());
+ printBooleanFlag(printer, "dedupRID", prop.getDedupRID());
+
+ // TODO: consider printing satisfied partial indexes.
+ _parent.fieldName("indexingRequirement").print(printer);
+ }
+
+ void operator()(const properties::PhysProperty&,
+ const properties::RepetitionEstimate& prop) {
+ _parent.fieldName("repetitionEstimate").print(prop.getEstimate());
+ }
+
+ void operator()(const properties::PhysProperty&, const properties::LimitEstimate& prop) {
+ _parent.fieldName("limitEstimate").print(prop.getEstimate());
+ }
+
+ private:
+ // We don't own this.
+ ExplainPrinter& _parent;
+ };
+
+ template <class P, class V, class C>
+ static ExplainPrinter printProps(const std::string& description, const C& props) {
+ ExplainPrinter printer;
+ if (version < ExplainVersion::V3) {
+ printer.print(description).print(":");
+ }
+
+ std::map<typename P::key_type, P> ordered;
+ for (const auto& entry : props) {
+ ordered.insert(entry);
+ }
+
+ ExplainPrinter local;
+ V visitor(local);
+ for (const auto& entry : ordered) {
+ entry.second.visit(visitor);
+ }
+ printer.print(local);
+
+ return printer;
+ }
+
+ static ExplainPrinter printLogicalProps(const std::string& description,
+ const properties::LogicalProps& props) {
+ return printProps<properties::LogicalProperty, LogicalPropPrintVisitor>(description, props);
+ }
+
+ static ExplainPrinter printPhysProps(const std::string& description,
+ const properties::PhysProps& props) {
+ return printProps<properties::PhysProperty, PhysPropPrintVisitor>(description, props);
+ }
+
+ ExplainPrinter transport(const RootNode& node,
+ ExplainPrinter childResult,
+ ExplainPrinter refsResult) {
+ ExplainPrinter printer("Root");
+ maybePrintProps(printer, node);
+
+ printer.separator(" []").setChildCount(3);
+ printProjectionRequirementProperty(printer, node.getProperty(), false /*directToParent*/);
+ printer.fieldName("references", ExplainVersion::V3)
+ .print(refsResult)
+ .fieldName("child", ExplainVersion::V3)
+ .print(childResult);
+
+ return printer;
+ }
+
+ /**
+ * Expressions
+ */
+ ExplainPrinter transport(const Blackhole& expr) {
+ ExplainPrinter printer("Blackhole");
+ printer.separator(" []");
+ return printer;
+ }
+
+ ExplainPrinter transport(const Constant& expr) {
+ ExplainPrinter printer("Const");
+ printer.separator(" [")
+ .fieldName("value", ExplainVersion::V3)
+ .print(expr.get())
+ .separator("]");
+ return printer;
+ }
+
+ ExplainPrinter transport(const Variable& expr) {
+ ExplainPrinter printer("Variable");
+ printer.separator(" [")
+ .fieldName("name", ExplainVersion::V3)
+ .print(expr.name())
+ .separator("]");
+ return printer;
+ }
+
+ ExplainPrinter transport(const UnaryOp& expr, ExplainPrinter inResult) {
+ ExplainPrinter printer("UnaryOp");
+ printer.separator(" [")
+ .fieldName("op", ExplainVersion::V3)
+ .print(OperationsEnum::toString[static_cast<int>(expr.op())])
+ .separator("]")
+ .setChildCount(1)
+ .fieldName("input", ExplainVersion::V3)
+ .print(inResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const BinaryOp& expr,
+ ExplainPrinter leftResult,
+ ExplainPrinter rightResult) {
+ ExplainPrinter printer("BinaryOp");
+ printer.separator(" [")
+ .fieldName("op", ExplainVersion::V3)
+ .print(OperationsEnum::toString[static_cast<int>(expr.op())])
+ .separator("]")
+ .setChildCount(2)
+ .maybeReverse()
+ .fieldName("left", ExplainVersion::V3)
+ .print(leftResult)
+ .fieldName("right", ExplainVersion::V3)
+ .print(rightResult);
+ return printer;
+ }
+
+
+ ExplainPrinter transport(const If& expr,
+ ExplainPrinter condResult,
+ ExplainPrinter thenResult,
+ ExplainPrinter elseResult) {
+ ExplainPrinter printer("If");
+ printer.separator(" []")
+ .setChildCount(3)
+ .maybeReverse()
+ .fieldName("condition", ExplainVersion::V3)
+ .print(condResult)
+ .fieldName("then", ExplainVersion::V3)
+ .print(thenResult)
+ .fieldName("else", ExplainVersion::V3)
+ .print(elseResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const Let& expr,
+ ExplainPrinter bindResult,
+ ExplainPrinter exprResult) {
+ ExplainPrinter printer("Let");
+ printer.separator(" [")
+ .fieldName("variable", ExplainVersion::V3)
+ .print(expr.varName())
+ .separator("]")
+ .setChildCount(2)
+ .maybeReverse()
+ .fieldName("bind", ExplainVersion::V3)
+ .print(bindResult)
+ .fieldName("expression", ExplainVersion::V3)
+ .print(exprResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const LambdaAbstraction& expr, ExplainPrinter inResult) {
+ ExplainPrinter printer("LambdaAbstraction");
+ printer.separator(" [")
+ .fieldName("variable", ExplainVersion::V3)
+ .print(expr.varName())
+ .separator("]")
+ .setChildCount(1)
+ .fieldName("input", ExplainVersion::V3)
+ .print(inResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const LambdaApplication& expr,
+ ExplainPrinter lambdaResult,
+ ExplainPrinter argumentResult) {
+ ExplainPrinter printer("LambdaApplication");
+ printer.separator(" []")
+ .setChildCount(2)
+ .maybeReverse()
+ .fieldName("lambda", ExplainVersion::V3)
+ .print(lambdaResult)
+ .fieldName("argument", ExplainVersion::V3)
+ .print(argumentResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const FunctionCall& expr, std::vector<ExplainPrinter> argResults) {
+ ExplainPrinter printer("FunctionCall");
+ printer.separator(" [")
+ .fieldName("name", ExplainVersion::V3)
+ .print(expr.name())
+ .separator("]");
+ if (!argResults.empty()) {
+ printer.setChildCount(argResults.size())
+ .maybeReverse()
+ .fieldName("arguments", ExplainVersion::V3)
+ .print(argResults);
+ }
+ return printer;
+ }
+
+ ExplainPrinter transport(const EvalPath& expr,
+ ExplainPrinter pathResult,
+ ExplainPrinter inputResult) {
+ ExplainPrinter printer("EvalPath");
+ printer.separator(" []")
+ .setChildCount(2)
+ .maybeReverse()
+ .fieldName("path", ExplainVersion::V3)
+ .print(pathResult)
+ .fieldName("input", ExplainVersion::V3)
+ .print(inputResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const EvalFilter& expr,
+ ExplainPrinter pathResult,
+ ExplainPrinter inputResult) {
+ ExplainPrinter printer("EvalFilter");
+ printer.separator(" []")
+ .setChildCount(2)
+ .maybeReverse()
+ .fieldName("path", ExplainVersion::V3)
+ .print(pathResult)
+ .fieldName("input", ExplainVersion::V3)
+ .print(inputResult);
+ return printer;
+ }
+
+ /**
+ * Paths
+ */
+ ExplainPrinter transport(const PathConstant& path, ExplainPrinter inResult) {
+ ExplainPrinter printer("PathConstant");
+ printer.separator(" []")
+ .setChildCount(1)
+ .fieldName("input", ExplainVersion::V3)
+ .print(inResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const PathLambda& path, ExplainPrinter inResult) {
+ ExplainPrinter printer("PathLambda");
+ printer.separator(" []")
+ .setChildCount(1)
+ .fieldName("input", ExplainVersion::V3)
+ .print(inResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const PathIdentity& path) {
+ ExplainPrinter printer("PathIdentity");
+ printer.separator(" []");
+ return printer;
+ }
+
+ ExplainPrinter transport(const PathDefault& path, ExplainPrinter inResult) {
+ ExplainPrinter printer("PathDefault");
+ printer.separator(" []")
+ .setChildCount(1)
+ .fieldName("input", ExplainVersion::V3)
+ .print(inResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const PathCompare& path, ExplainPrinter valueResult) {
+ ExplainPrinter printer("PathCompare");
+ printer.separator(" [")
+ .fieldName("op", ExplainVersion::V3)
+ .print(OperationsEnum::toString[static_cast<int>(path.op())])
+ .separator("]")
+ .setChildCount(1)
+ .fieldName("value", ExplainVersion::V3)
+ .print(valueResult);
+ return printer;
+ }
+
+ static void printPathProjections(ExplainPrinter& printer,
+ const opt::unordered_set<std::string>& names) {
+ std::set<std::string> ordered;
+ for (const std::string& s : names) {
+ ordered.insert(s);
+ }
+
+ if constexpr (version < ExplainVersion::V3) {
+ bool first = true;
+ for (const std::string& s : ordered) {
+ if (first) {
+ first = false;
+ } else {
+ printer.print(", ");
+ }
+ printer.print(s);
+ }
+ } else if constexpr (version == ExplainVersion::V3) {
+ std::vector<ExplainPrinter> printers;
+ for (const std::string& s : ordered) {
+ ExplainPrinter local;
+ local.print(s);
+ printers.push_back(std::move(local));
+ }
+ printer.fieldName("projections").print(printers);
+ } else {
+ static_assert("Unknown version");
+ }
+ }
+
+ ExplainPrinter transport(const PathDrop& path) {
+ ExplainPrinter printer("PathDrop");
+ printer.separator(" [");
+ printPathProjections(printer, path.getNames());
+ printer.separator("]");
+ return printer;
+ }
+
+ ExplainPrinter transport(const PathKeep& path) {
+ ExplainPrinter printer("PathKeep");
+ printer.separator(" [");
+ printPathProjections(printer, path.getNames());
+ printer.separator("]");
+ return printer;
+ }
+
+ ExplainPrinter transport(const PathObj& path) {
+ ExplainPrinter printer("PathObj");
+ printer.separator(" []");
+ return printer;
+ }
+
+ ExplainPrinter transport(const PathArr& path) {
+ ExplainPrinter printer("PathArr");
+ printer.separator(" []");
+ return printer;
+ }
+
+ ExplainPrinter transport(const PathTraverse& path, ExplainPrinter inResult) {
+ ExplainPrinter printer("PathTraverse");
+ printer.separator(" []")
+ .setChildCount(1)
+ .fieldName("input", ExplainVersion::V3)
+ .print(inResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const PathField& path, ExplainPrinter inResult) {
+ ExplainPrinter printer("PathField");
+ printer.separator(" [")
+ .fieldName("path", ExplainVersion::V3)
+ .print(path.name())
+ .separator("]")
+ .setChildCount(1)
+ .fieldName("input", ExplainVersion::V3)
+ .print(inResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const PathGet& path, ExplainPrinter inResult) {
+ ExplainPrinter printer("PathGet");
+ printer.separator(" [")
+ .fieldName("path", ExplainVersion::V3)
+ .print(path.name())
+ .separator("]")
+ .setChildCount(1)
+ .fieldName("input", ExplainVersion::V3)
+ .print(inResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const PathComposeM& path,
+ ExplainPrinter leftResult,
+ ExplainPrinter rightResult) {
+ ExplainPrinter printer("PathComposeM");
+ printer.separator(" []")
+ .setChildCount(2)
+ .maybeReverse()
+ .fieldName("leftInput", ExplainVersion::V3)
+ .print(leftResult)
+ .fieldName("rightInput", ExplainVersion::V3)
+ .print(rightResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const PathComposeA& path,
+ ExplainPrinter leftResult,
+ ExplainPrinter rightResult) {
+ ExplainPrinter printer("PathComposeA");
+ printer.separator(" []")
+ .setChildCount(2)
+ .maybeReverse()
+ .fieldName("leftInput", ExplainVersion::V3)
+ .print(leftResult)
+ .fieldName("rightInput", ExplainVersion::V3)
+ .print(rightResult);
+ return printer;
+ }
+
+ ExplainPrinter transport(const Source& expr) {
+ ExplainPrinter printer("Source");
+ printer.separator(" []");
+ return printer;
+ }
+
+ ExplainPrinter generate(const ABT& node) {
+ return algebra::transport<false>(node, *this);
+ }
+
+ void printPhysNodeInfo(ExplainPrinter& printer, const cascades::PhysNodeInfo& nodeInfo) {
+ printer.fieldName("cost");
+ if (nodeInfo._cost.isInfinite()) {
+ printer.print(nodeInfo._cost.toString());
+ } else {
+ printer.print(nodeInfo._cost.getCost());
+ }
+ printer.separator(", ")
+ .fieldName("localCost")
+ .print(nodeInfo._localCost.getCost())
+ .separator(", ")
+ .fieldName("adjustedCE")
+ .print(nodeInfo._adjustedCE);
+
+ ExplainPrinter nodePrinter = generate(nodeInfo._node);
+ printer.separator(", ").fieldName("node").print(nodePrinter);
+ }
+
+ ExplainPrinter printMemo() {
+ std::vector<ExplainPrinter> groupPrinters;
+ for (size_t groupId = 0; groupId < _memo->getGroupCount(); groupId++) {
+ const cascades::Group& group = _memo->getGroup(groupId);
+
+ ExplainPrinter groupPrinter;
+ groupPrinter.fieldName("groupId").print(groupId).setChildCount(3);
+ {
+ ExplainPrinter logicalPropPrinter =
+ printLogicalProps("Logical properties", group._logicalProperties);
+ groupPrinter.fieldName("logicalProperties", ExplainVersion::V3)
+ .print(logicalPropPrinter);
+ }
+
+ {
+ std::vector<ExplainPrinter> logicalNodePrinters;
+ const ABTVector& logicalNodes = group._logicalNodes.getVector();
+ for (size_t i = 0; i < logicalNodes.size(); i++) {
+ ExplainPrinter local;
+ local.fieldName("logicalNodeId").print(i);
+ ExplainPrinter nodePrinter = generate(logicalNodes.at(i));
+ local.fieldName("node", ExplainVersion::V3).print(nodePrinter);
+
+ logicalNodePrinters.push_back(std::move(local));
+ }
+ ExplainPrinter logicalNodePrinter;
+ logicalNodePrinter.print(logicalNodePrinters);
+
+ groupPrinter.fieldName("logicalNodes").print(logicalNodePrinter);
+ }
+
+ {
+ std::vector<ExplainPrinter> physicalNodePrinters;
+ for (const auto& physOptResult : group._physicalNodes.getNodes()) {
+ ExplainPrinter local;
+ local.fieldName("physicalNodeId")
+ .print(physOptResult->_index)
+ .separator(", ")
+ .fieldName("costLimit");
+
+ if (physOptResult->_costLimit.isInfinite()) {
+ local.print(physOptResult->_costLimit.toString());
+ } else {
+ local.print(physOptResult->_costLimit.getCost());
+ }
+
+ ExplainPrinter propPrinter =
+ printPhysProps("Physical properties", physOptResult->_physProps);
+ local.fieldName("physicalProperties", ExplainVersion::V3).print(propPrinter);
+
+ if (physOptResult->_nodeInfo) {
+ ExplainPrinter local1;
+ printPhysNodeInfo(local1, *physOptResult->_nodeInfo);
+
+ if (!physOptResult->_rejectedNodeInfo.empty()) {
+ std::vector<ExplainPrinter> rejectedPrinters;
+ for (const auto& rejectedPlan : physOptResult->_rejectedNodeInfo) {
+ ExplainPrinter local2;
+ printPhysNodeInfo(local2, rejectedPlan);
+ rejectedPrinters.emplace_back(std::move(local2));
+ }
+ local1.fieldName("rejectedPlans").print(rejectedPrinters);
+ }
+
+ local.fieldName("nodeInfo", ExplainVersion::V3).print(local1);
+ } else {
+ local.separator(" (failed to optimize)");
+ }
+
+ physicalNodePrinters.push_back(std::move(local));
+ }
+ ExplainPrinter physNodePrinter;
+ physNodePrinter.print(physicalNodePrinters);
+
+ groupPrinter.fieldName("physicalNodes").print(physNodePrinter);
+ }
+
+ groupPrinters.push_back(std::move(groupPrinter));
+ }
+
+ ExplainPrinter printer;
+ printer.fieldName("Memo").print(groupPrinters);
+ return printer;
+ }
+
+private:
+ const bool _displayProperties;
+
+ // We don't own this.
+ const cascades::Memo* _memo;
+ const NodeToGroupPropsMap& _nodeMap;
+};
+
+std::string ExplainGenerator::explain(const ABT& node,
+ const bool displayProperties,
+ const cascades::Memo* memo,
+ const NodeToGroupPropsMap& nodeMap) {
+ ExplainGeneratorTransporter gen(displayProperties, memo, nodeMap);
+ return gen.generate(node).str();
+}
+
+std::string ExplainGenerator::explainV2(const ABT& node,
+ const bool displayProperties,
+ const cascades::Memo* memo,
+ const NodeToGroupPropsMap& nodeMap) {
+ ExplainGeneratorTransporter<ExplainVersion::V2> gen(displayProperties, memo, nodeMap);
+ return gen.generate(node).str();
+}
+
+std::pair<sbe::value::TypeTags, sbe::value::Value> ExplainGenerator::explainBSON(
+ const ABT& node,
+ const bool displayProperties,
+ const cascades::Memo* memo,
+ const NodeToGroupPropsMap& nodeMap) {
+ ExplainGeneratorTransporter<ExplainVersion::V3> gen(displayProperties, memo, nodeMap);
+ return gen.generate(node).moveValue();
+}
+
+template <class PrinterType>
+static void printBSONstr(PrinterType& printer,
+ const sbe::value::TypeTags tag,
+ const sbe::value::Value val) {
+ switch (tag) {
+ case sbe::value::TypeTags::Array: {
+ const auto* array = sbe::value::getArrayView(val);
+
+ PrinterType local;
+ for (size_t index = 0; index < array->size(); index++) {
+ if (index > 0) {
+ local.print(", ");
+ local.newLine();
+ }
+ const auto [tag1, val1] = array->getAt(index);
+ printBSONstr(local, tag1, val1);
+ }
+ printer.print("[").print(local).print("]");
+
+ break;
+ }
+
+ case sbe::value::TypeTags::Object: {
+ const auto* obj = sbe::value::getObjectView(val);
+
+ PrinterType local;
+ for (size_t index = 0; index < obj->size(); index++) {
+ if (index > 0) {
+ local.print(", ");
+ local.newLine();
+ }
+ local.fieldName(obj->field(index));
+ const auto [tag1, val1] = obj->getAt(index);
+ printBSONstr(local, tag1, val1);
+ }
+ printer.print("{").print(local).print("}");
+
+ break;
+ }
+
+ default: {
+ std::ostringstream os;
+ os << std::make_pair(tag, val);
+ printer.print(os.str());
+ }
+ }
+}
+
+std::string ExplainGenerator::printBSON(const sbe::value::TypeTags tag,
+ const sbe::value::Value val) {
+ ExplainPrinterImpl<ExplainVersion::V2> printer;
+ printBSONstr(printer, tag, val);
+ return printer.str();
+}
+
+std::string ExplainGenerator::explainLogicalProps(const std::string& description,
+ const properties::LogicalProps& props) {
+ return ExplainGeneratorTransporter<ExplainVersion::V2>::printLogicalProps(description, props)
+ .str();
+}
+
+std::string ExplainGenerator::explainPhysProps(const std::string& description,
+ const properties::PhysProps& props) {
+ return ExplainGeneratorTransporter<ExplainVersion::V2>::printPhysProps(description, props)
+ .str();
+}
+
+std::string ExplainGenerator::explainMemo(const cascades::Memo& memo) {
+ ExplainGeneratorTransporter<ExplainVersion::V2> gen(false /*displayProperties*/, &memo);
+ return gen.printMemo().str();
+}
+
+std::pair<sbe::value::TypeTags, sbe::value::Value> ExplainGenerator::explainMemoBSON(
+ const cascades::Memo& memo) {
+ ExplainGeneratorTransporter<ExplainVersion::V3> gen(false /*displayProperties*/, &memo);
+ return gen.printMemo().moveValue();
+}
+
+std::string ExplainGenerator::explainPartialSchemaReqMap(const PartialSchemaRequirements& reqMap) {
+ ExplainGeneratorTransporter<ExplainVersion::V2> gen;
+ ExplainGeneratorTransporter<ExplainVersion::V2>::ExplainPrinter result;
+ gen.printPartialSchemaReqMap(result, reqMap);
+ return result.str();
+}
+
+std::string ExplainGenerator::explainInterval(const IntervalRequirement& interval) {
+ ExplainGeneratorTransporter<ExplainVersion::V2> gen;
+ return gen.printInterval(interval);
+}
+
+std::string ExplainGenerator::explainIntervalExpr(const IntervalReqExpr::Node& intervalExpr) {
+ ExplainGeneratorTransporter<ExplainVersion::V2> gen;
+ return gen.printIntervalExpr(intervalExpr).str();
+}
+
+std::string _printNode(const ABT& node) {
+ if (node.empty()) {
+ return "Empty\n";
+ }
+ return ExplainGenerator::explainV2(node);
+}
+
+std::string _printInterval(const IntervalRequirement& interval) {
+ return ExplainGenerator::explainInterval(interval);
+}
+
+std::string _printLogicalProps(const properties::LogicalProps& props) {
+ return ExplainGenerator::explainLogicalProps("Logical Properties", props);
+}
+
+std::string _printPhysProps(const properties::PhysProps& props) {
+ return ExplainGenerator::explainPhysProps("Physical Properties", props);
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/explain.h b/src/mongo/db/query/optimizer/explain.h
new file mode 100644
index 00000000000..970d912cdbc
--- /dev/null
+++ b/src/mongo/db/query/optimizer/explain.h
@@ -0,0 +1,113 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/exec/sbe/values/value.h"
+#include "mongo/db/query/optimizer/explain_interface.h"
+#include "mongo/db/query/optimizer/metadata.h"
+#include "mongo/db/query/optimizer/node_defs.h"
+#include "mongo/db/query/optimizer/props.h"
+#include "mongo/db/query/optimizer/syntax/syntax.h"
+
+namespace mongo::optimizer {
+
+namespace cascades {
+class Memo;
+}
+
+/**
+ * This structure holds any data that is required by the BSON version of explain. It is
+ * self-sufficient and separate because it must outlive the other optimizer state as it is used by
+ * the runtime plan executor.
+ */
+class ABTPrinter : public AbstractABTPrinter {
+public:
+ ABTPrinter(ABT abtTree, NodeToGroupPropsMap nodeToPropsMap)
+ : _abtTree(std::move(abtTree)), _nodeToPropsMap(std::move(nodeToPropsMap)) {}
+
+ BSONObj explainBSON() const override final;
+
+private:
+ ABT _abtTree;
+ NodeToGroupPropsMap _nodeToPropsMap;
+};
+
+class ExplainGenerator {
+public:
+ // Optionally display logical and physical properties using the memo.
+ // whenever memo delegators are printed.
+ static std::string explain(const ABT& node,
+ bool displayProperties = false,
+ const cascades::Memo* memo = nullptr,
+ const NodeToGroupPropsMap& nodeMap = {});
+
+ // Optionally display logical and physical properties using the memo.
+ // whenever memo delegators are printed.
+ static std::string explainV2(const ABT& node,
+ bool displayProperties = false,
+ const cascades::Memo* memo = nullptr,
+ const NodeToGroupPropsMap& nodeMap = {});
+
+ static std::pair<sbe::value::TypeTags, sbe::value::Value> explainBSON(
+ const ABT& node,
+ bool displayProperties = false,
+ const cascades::Memo* memo = nullptr,
+ const NodeToGroupPropsMap& nodeMap = {});
+
+ static std::string printBSON(sbe::value::TypeTags tag, sbe::value::Value val);
+
+ static std::string explainLogicalProps(const std::string& description,
+ const properties::LogicalProps& props);
+ static std::string explainPhysProps(const std::string& description,
+ const properties::PhysProps& props);
+
+ static std::string explainMemo(const cascades::Memo& memo);
+
+ static std::pair<sbe::value::TypeTags, sbe::value::Value> explainMemoBSON(
+ const cascades::Memo& memo);
+
+ static std::string explainPartialSchemaReqMap(const PartialSchemaRequirements& reqMap);
+
+ static std::string explainInterval(const IntervalRequirement& interval);
+
+ static std::string explainIntervalExpr(const IntervalReqExpr::Node& intervalExpr);
+};
+
+// Functions used by GDB for pretty-printing. For whatever reason GDB cannot find the static
+// methods of ExplainGenerator, while it can find the functions below.
+
+std::string _printNode(const ABT& node);
+
+std::string _printInterval(const IntervalRequirement& interval);
+
+std::string _printLogicalProps(const properties::LogicalProps& props);
+std::string _printPhysProps(const properties::PhysProps& props);
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/explain_interface.h b/src/mongo/db/query/optimizer/explain_interface.h
new file mode 100644
index 00000000000..a007d1c2b49
--- /dev/null
+++ b/src/mongo/db/query/optimizer/explain_interface.h
@@ -0,0 +1,48 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/bson/bsonobj.h"
+
+
+namespace mongo::optimizer {
+
+/**
+ * Used to abstract away the explaining of the optimizer.
+ * Note: we should not depend on anything here (certainly not the rest of the optimizer).
+ */
+class AbstractABTPrinter {
+public:
+ virtual ~AbstractABTPrinter() = default;
+
+ virtual BSONObj explainBSON() const = 0;
+};
+
+}; // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/index_bounds.cpp b/src/mongo/db/query/optimizer/index_bounds.cpp
new file mode 100644
index 00000000000..5c2ff9fd66b
--- /dev/null
+++ b/src/mongo/db/query/optimizer/index_bounds.cpp
@@ -0,0 +1,246 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/index_bounds.h"
+
+#include "mongo/db/query/optimizer/node.h"
+
+
+namespace mongo::optimizer {
+
+BoundRequirement::BoundRequirement() : _inclusive(false), _bound() {}
+
+BoundRequirement::BoundRequirement(bool inclusive, boost::optional<ABT> bound)
+ : _inclusive(inclusive), _bound(std::move(bound)) {
+ uassert(6624077, "Infinite bound cannot be inclusive", !inclusive || !isInfinite());
+}
+
+bool BoundRequirement::operator==(const BoundRequirement& other) const {
+ return _inclusive == other._inclusive && _bound == other._bound;
+}
+
+bool BoundRequirement::isInclusive() const {
+ return _inclusive;
+}
+
+void BoundRequirement::setInclusive(bool value) {
+ _inclusive = value;
+}
+
+bool BoundRequirement::isInfinite() const {
+ return !_bound.has_value();
+}
+
+const ABT& BoundRequirement::getBound() const {
+ uassert(6624078, "Cannot retrieve infinite bound", !isInfinite());
+ return _bound.get();
+}
+
+IntervalRequirement::IntervalRequirement(BoundRequirement lowBound, BoundRequirement highBound)
+ : _lowBound(std::move(lowBound)), _highBound(std::move(highBound)) {}
+
+bool IntervalRequirement::operator==(const IntervalRequirement& other) const {
+ return _lowBound == other._lowBound && _highBound == other._highBound;
+}
+
+bool IntervalRequirement::isFullyOpen() const {
+ return _lowBound.isInfinite() && _highBound.isInfinite();
+}
+
+bool IntervalRequirement::isEquality() const {
+ return _lowBound.isInclusive() && _highBound.isInclusive() && _lowBound == _highBound;
+}
+
+const BoundRequirement& IntervalRequirement::getLowBound() const {
+ return _lowBound;
+}
+
+BoundRequirement& IntervalRequirement::getLowBound() {
+ return _lowBound;
+}
+
+const BoundRequirement& IntervalRequirement::getHighBound() const {
+ return _highBound;
+}
+
+BoundRequirement& IntervalRequirement::getHighBound() {
+ return _highBound;
+}
+
+PartialSchemaKey::PartialSchemaKey() : PartialSchemaKey({}, make<PathIdentity>()) {}
+
+PartialSchemaKey::PartialSchemaKey(ProjectionName projectionName)
+ : PartialSchemaKey(std::move(projectionName), make<PathIdentity>()) {}
+
+PartialSchemaKey::PartialSchemaKey(ProjectionName projectionName, ABT path)
+ : _projectionName(std::move(projectionName)), _path(std::move(path)) {
+ assertPathSort(_path);
+}
+
+bool PartialSchemaKey::operator==(const PartialSchemaKey& other) const {
+ return _projectionName == other._projectionName && _path == other._path;
+}
+
+bool PartialSchemaKey::emptyPath() const {
+ return _path.is<PathIdentity>();
+}
+
+bool isIntervalReqFullyOpenDNF(const IntervalReqExpr::Node& n) {
+ if (auto singular = IntervalReqExpr::getSingularDNF(n); singular && singular->isFullyOpen()) {
+ return true;
+ }
+ return false;
+}
+
+PartialSchemaRequirement::PartialSchemaRequirement()
+ : _intervals(IntervalReqExpr::makeSingularDNF()) {}
+
+PartialSchemaRequirement::PartialSchemaRequirement(ProjectionName boundProjectionName,
+ IntervalReqExpr::Node intervals)
+ : _boundProjectionName(std::move(boundProjectionName)), _intervals(std::move(intervals)) {}
+
+bool PartialSchemaRequirement::operator==(const PartialSchemaRequirement& other) const {
+ return _boundProjectionName == other._boundProjectionName && _intervals == other._intervals;
+}
+
+bool PartialSchemaRequirement::hasBoundProjectionName() const {
+ return !_boundProjectionName.empty();
+}
+
+const ProjectionName& PartialSchemaRequirement::getBoundProjectionName() const {
+ return _boundProjectionName;
+}
+
+void PartialSchemaRequirement::setBoundProjectionName(ProjectionName boundProjectionName) {
+ _boundProjectionName = std::move(boundProjectionName);
+}
+
+const IntervalReqExpr::Node& PartialSchemaRequirement::getIntervals() const {
+ return _intervals;
+}
+
+IntervalReqExpr::Node& PartialSchemaRequirement::getIntervals() {
+ return _intervals;
+}
+
+/**
+ * Helper class used to compare PartialSchemaKey objects.
+ */
+class Path3WCompare {
+public:
+ Path3WCompare() {}
+
+ int compareTags(const ABT& n, const ABT& other) {
+ const auto t1 = n.tagOf();
+ const auto t2 = other.tagOf();
+ return (t1 == t2) ? 0 : ((t1 < t2) ? -1 : 1);
+ }
+
+ int operator()(const ABT& n, const PathGet& node, const ABT& other) {
+ if (auto otherGet = other.cast<PathGet>(); otherGet != nullptr) {
+ const int varCmp = node.name().compare(otherGet->name());
+ return (varCmp == 0) ? node.getPath().visit(*this, otherGet->getPath()) : varCmp;
+ }
+ return compareTags(n, other);
+ }
+
+ int operator()(const ABT& n, const PathTraverse& node, const ABT& other) {
+ if (auto otherTraverse = other.cast<PathTraverse>(); otherTraverse != nullptr) {
+ return node.getPath().visit(*this, otherTraverse->getPath());
+ }
+ return compareTags(n, other);
+ }
+
+ int operator()(const ABT& n, const PathIdentity& node, const ABT& other) {
+ return compareTags(n, other);
+ }
+
+ template <typename T, typename... Ts>
+ int operator()(const ABT& /*n*/, const T& /*node*/, Ts&&...) {
+ uasserted(6624079, "Unexpected node type");
+ return 0;
+ }
+
+ static int compare(const ABT& node, const ABT& other) {
+ Path3WCompare instance;
+ return node.visit(instance, other);
+ }
+};
+
+bool PartialSchemaKeyLessComparator::operator()(const PartialSchemaKey& k1,
+ const PartialSchemaKey& k2) const {
+ const int projCmp = k1._projectionName.compare(k2._projectionName);
+ if (projCmp != 0) {
+ return projCmp < 0;
+ }
+ return Path3WCompare::compare(k1._path, k2._path) < 0;
+}
+
+ResidualRequirement::ResidualRequirement(PartialSchemaKey key,
+ PartialSchemaRequirement req,
+ CEType ce)
+ : _key(std::move(key)), _req(std::move(req)), _ce(ce) {}
+
+bool CandidateIndexEntry::operator==(const CandidateIndexEntry& other) const {
+ return _fieldProjectionMap == other._fieldProjectionMap && _intervals == other._intervals &&
+ _residualRequirements == other._residualRequirements &&
+ _fieldsToCollate == other._fieldsToCollate;
+}
+
+IndexSpecification::IndexSpecification(std::string scanDefName,
+ std::string indexDefName,
+ MultiKeyIntervalRequirement interval,
+ bool reverseOrder)
+ : _scanDefName(std::move(scanDefName)),
+ _indexDefName(std::move(indexDefName)),
+ _interval(std::move(interval)),
+ _reverseOrder(reverseOrder) {}
+
+bool IndexSpecification::operator==(const IndexSpecification& other) const {
+ return _scanDefName == other._scanDefName && _indexDefName == other._indexDefName &&
+ _interval == other._interval && _reverseOrder == other._reverseOrder;
+}
+
+const std::string& IndexSpecification::getScanDefName() const {
+ return _scanDefName;
+}
+
+const std::string& IndexSpecification::getIndexDefName() const {
+ return _indexDefName;
+}
+
+const MultiKeyIntervalRequirement& IndexSpecification::getInterval() const {
+ return _interval;
+}
+
+bool IndexSpecification::isReverseOrder() const {
+ return _reverseOrder;
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/index_bounds.h b/src/mongo/db/query/optimizer/index_bounds.h
new file mode 100644
index 00000000000..f172c361292
--- /dev/null
+++ b/src/mongo/db/query/optimizer/index_bounds.h
@@ -0,0 +1,206 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/bool_expression.h"
+#include "mongo/db/query/optimizer/defs.h"
+#include "mongo/db/query/optimizer/syntax/syntax.h"
+
+
+namespace mongo::optimizer {
+
+class BoundRequirement {
+public:
+ static BoundRequirement makeInfinite() {
+ return BoundRequirement(false, boost::none);
+ }
+
+ BoundRequirement();
+ BoundRequirement(bool inclusive, boost::optional<ABT> bound);
+
+ bool operator==(const BoundRequirement& other) const;
+
+ bool isInclusive() const;
+ void setInclusive(bool value);
+
+ bool isInfinite() const;
+ const ABT& getBound() const;
+
+private:
+ bool _inclusive;
+
+ // If we do not have a bound ABT, the bound is considered infinite.
+ boost::optional<ABT> _bound;
+};
+
+class IntervalRequirement {
+public:
+ IntervalRequirement() = default;
+ IntervalRequirement(BoundRequirement lowBound, BoundRequirement highBound);
+
+ bool operator==(const IntervalRequirement& other) const;
+
+ bool isFullyOpen() const;
+ bool isEquality() const;
+
+ const BoundRequirement& getLowBound() const;
+ BoundRequirement& getLowBound();
+ const BoundRequirement& getHighBound() const;
+ BoundRequirement& getHighBound();
+
+private:
+ BoundRequirement _lowBound;
+ BoundRequirement _highBound;
+};
+
+struct PartialSchemaKey {
+ PartialSchemaKey();
+ PartialSchemaKey(ProjectionName projectionName);
+ PartialSchemaKey(ProjectionName projectionName, ABT path);
+
+ bool operator==(const PartialSchemaKey& other) const;
+
+ bool emptyPath() const;
+
+ // Referred, or input projection name.
+ ProjectionName _projectionName;
+
+ // (Partially determined) path.
+ ABT _path;
+};
+
+using IntervalReqExpr = BoolExpr<IntervalRequirement>;
+bool isIntervalReqFullyOpenDNF(const IntervalReqExpr::Node& n);
+
+class PartialSchemaRequirement {
+public:
+ PartialSchemaRequirement();
+ PartialSchemaRequirement(ProjectionName boundProjectionName, IntervalReqExpr::Node intervals);
+
+ bool operator==(const PartialSchemaRequirement& other) const;
+
+ bool hasBoundProjectionName() const;
+ const ProjectionName& getBoundProjectionName() const;
+ void setBoundProjectionName(ProjectionName boundProjectionName);
+
+ const IntervalReqExpr::Node& getIntervals() const;
+ IntervalReqExpr::Node& getIntervals();
+
+private:
+ // Bound, or output projection name.
+ ProjectionName _boundProjectionName;
+
+ IntervalReqExpr::Node _intervals;
+};
+
+struct PartialSchemaKeyLessComparator {
+ bool operator()(const PartialSchemaKey& k1, const PartialSchemaKey& k2) const;
+};
+
+// Map from referred (or input) projection name to list of requirements for that projection.
+using PartialSchemaRequirements =
+ std::map<PartialSchemaKey, PartialSchemaRequirement, PartialSchemaKeyLessComparator>;
+
+using PartialSchemaKeyCE = std::map<PartialSchemaKey, CEType, PartialSchemaKeyLessComparator>;
+using ResidualKeyMap = std::map<PartialSchemaKey, PartialSchemaKey, PartialSchemaKeyLessComparator>;
+
+using PartialSchemaKeySet = std::set<PartialSchemaKey, PartialSchemaKeyLessComparator>;
+
+// Requirements which are not satisfied directly by an IndexScan, PhysicalScan or Seek (e.g. using
+// an index field, or scan field). They are intended to be sorted in their containing vector from
+// most to least selective.
+struct ResidualRequirement {
+ ResidualRequirement(PartialSchemaKey key, PartialSchemaRequirement req, CEType ce);
+
+ PartialSchemaKey _key;
+ PartialSchemaRequirement _req;
+ CEType _ce;
+};
+using ResidualRequirements = std::vector<ResidualRequirement>;
+
+// A sequence of intervals corresponding, one for each index key.
+using MultiKeyIntervalRequirement = std::vector<IntervalRequirement>;
+
+// Multi-key intervals represent unions and conjunctions of individual multi-key intervals.
+using MultiKeyIntervalReqExpr = BoolExpr<MultiKeyIntervalRequirement>;
+
+// Used to pre-compute candidate indexes for SargableNodes.
+struct CandidateIndexEntry {
+ bool operator==(const CandidateIndexEntry& other) const;
+
+ FieldProjectionMap _fieldProjectionMap;
+ MultiKeyIntervalReqExpr::Node _intervals;
+
+ PartialSchemaRequirements _residualRequirements;
+ // Projections needed to evaluate residual requirements.
+ ProjectionNameSet _residualRequirementsTempProjections;
+
+ // Used for CE. Mapping for residual requirement key to query key.
+ ResidualKeyMap _residualKeyMap;
+
+ // We have equalities on those index fields, and thus do not consider for collation
+ // requirements.
+ // TODO: consider a bitset.
+ opt::unordered_set<size_t> _fieldsToCollate;
+};
+
+using CandidateIndexMap = opt::unordered_map<std::string /*index name*/, CandidateIndexEntry>;
+
+class IndexSpecification {
+public:
+ IndexSpecification(std::string scanDefName,
+ std::string indexDefName,
+ MultiKeyIntervalRequirement interval,
+ bool reverseOrder);
+
+ bool operator==(const IndexSpecification& other) const;
+
+ const std::string& getScanDefName() const;
+ const std::string& getIndexDefName() const;
+
+ const MultiKeyIntervalRequirement& getInterval() const;
+
+ bool isReverseOrder() const;
+
+private:
+ // Name of the collection.
+ const std::string _scanDefName;
+
+ // The name of the index.
+ const std::string _indexDefName;
+
+ // The index interval.
+ MultiKeyIntervalRequirement _interval;
+
+ // Do we reverse the index order.
+ const bool _reverseOrder;
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/interval_intersection_test.cpp b/src/mongo/db/query/optimizer/interval_intersection_test.cpp
new file mode 100644
index 00000000000..7593e199846
--- /dev/null
+++ b/src/mongo/db/query/optimizer/interval_intersection_test.cpp
@@ -0,0 +1,667 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <string>
+#include <vector>
+
+#include "mongo/db/query/optimizer/explain.h"
+#include "mongo/db/query/optimizer/metadata.h"
+#include "mongo/db/query/optimizer/opt_phase_manager.h"
+#include "mongo/db/query/optimizer/utils/interval_utils.h"
+#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
+#include "mongo/platform/atomic_word.h"
+#include "mongo/unittest/unittest.h"
+#include "mongo/util/processinfo.h"
+
+namespace mongo::optimizer {
+namespace {
+
+struct QueryTest {
+ std::string query;
+ std::string expectedPlan;
+};
+
+std::string optimizedQueryPlan(const std::string& query,
+ const opt::unordered_map<std::string, IndexDefinition>& indexes) {
+ PrefixId prefixId;
+ std::string scanDefName = "coll";
+ Metadata metadata = {{{scanDefName, ScanDefinition{{}, indexes}}}};
+ ABT translated =
+ translatePipeline(metadata, "[{$match: " + query + "}]", scanDefName, prefixId);
+
+ OptPhaseManager phaseManager(
+ OptPhaseManager::getAllRewritesSet(), prefixId, metadata, DebugInfo::kDefaultForTests);
+
+ ABT optimized = translated;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ return ExplainGenerator::explainV2(optimized);
+}
+
+TEST(IntervalIntersection, SingleFieldIntersection) {
+ opt::unordered_map<std::string, IndexDefinition> testIndex = {
+ {"index1", makeIndexDefinition("a0", CollationOp::Ascending, /*Not multikey*/ false)}};
+
+ std::vector<QueryTest> queryTests = {
+ {"{a0: {$gt:14, $lt:21}}",
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': scan_0}, coll]\n"
+ "| | BindBlock:\n"
+ "| | [scan_0]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: coll, indexDefName: index1, interval: "
+ "{(Const [14], Const [21])}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n"},
+ {"{$and: [{a0: {$gt:14}}, {a0: {$lt: 21}}]}",
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': scan_0}, coll]\n"
+ "| | BindBlock:\n"
+ "| | [scan_0]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: coll, indexDefName: index1, interval: "
+ "{(Const [14], Const [21])}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n"},
+ {"{$or: [{$and: [{a0: {$gt:9}}, {a0: {$lt: 12}}]}, {$and: [{a0: {$gt:40}}, {a0: {$lt: "
+ "44}}]}]}",
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': scan_0}, coll]\n"
+ "| | BindBlock:\n"
+ "| | [scan_0]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [rid_0]\n"
+ "| aggregations: \n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [rid_0]\n"
+ "| | Source []\n"
+ "| IndexScan [{'<rid>': rid_0}, scanDefName: coll, indexDefName: index1, "
+ "interval: {(Const [40], Const [44])}]\n"
+ "| BindBlock:\n"
+ "| [rid_0]\n"
+ "| Source []\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: coll, indexDefName: index1, interval: "
+ "{(Const [9], Const [12])}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n"},
+ // Contradictions
+ // Empty interval
+ {"{$and: [{a0: {$gt:20}}, {a0: {$lt: 20}}]}",
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [scan_0]\n"
+ "| Const [Nothing]\n"
+ "LimitSkip []\n"
+ "| limitSkip:\n"
+ "| limit: 0\n"
+ "| skip: 0\n"
+ "CoScan []\n"},
+ // One conjunct non-empty, one conjunct empty
+ {"{$or: [{$and: [{a0: {$gt:9}}, {a0: {$lt: 12}}]}, {$and: [{a0: {$gt:44}}, {a0: {$lt: "
+ "40}}]}]}",
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': scan_0}, coll]\n"
+ "| | BindBlock:\n"
+ "| | [scan_0]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: coll, indexDefName: index1, interval: "
+ "{(Const [9], Const [12])}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n"},
+ // Both conjuncts empty, whole disjunct empty
+ {"{$or: [{$and: [{a0: {$gt:15}}, {a0: {$lt: 10}}]}, {$and: [{a0: {$gt:44}}, {a0: {$lt: "
+ "40}}]}]}",
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [scan_0]\n"
+ "| Const [Nothing]\n"
+ "LimitSkip []\n"
+ "| limitSkip:\n"
+ "| limit: 0\n"
+ "| skip: 0\n"
+ "CoScan []\n"},
+ {"{$or: [{$and: [{a0: {$gt:12}}, {a0: {$lt: 12}}]}, {$and: [{a0: {$gte:42}}, {a0: {$lt: "
+ "42}}]}]}",
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [scan_0]\n"
+ "| Const [Nothing]\n"
+ "LimitSkip []\n"
+ "| limitSkip:\n"
+ "| limit: 0\n"
+ "| skip: 0\n"
+ "CoScan []\n"},
+ };
+
+ /*
+ std::cout << "\nExpected query plans:\n\n";
+ for (auto& qt : queryTests) {
+ std::cout << optimizedQueryPlan(qt.query, testIndex) << std::endl << std::endl;
+ }
+ */
+ ASSERT_EQ(queryTests[0].expectedPlan, optimizedQueryPlan(queryTests[0].query, testIndex));
+ ASSERT_EQ(queryTests[1].expectedPlan, optimizedQueryPlan(queryTests[1].query, testIndex));
+ ASSERT_EQ(queryTests[2].expectedPlan, optimizedQueryPlan(queryTests[2].query, testIndex));
+ ASSERT_EQ(queryTests[3].expectedPlan, optimizedQueryPlan(queryTests[3].query, testIndex));
+ ASSERT_EQ(queryTests[4].expectedPlan, optimizedQueryPlan(queryTests[4].query, testIndex));
+ ASSERT_EQ(queryTests[5].expectedPlan, optimizedQueryPlan(queryTests[5].query, testIndex));
+ ASSERT_EQ(queryTests[6].expectedPlan, optimizedQueryPlan(queryTests[6].query, testIndex));
+}
+
+TEST(IntervalIntersection, MultiFieldIntersection) {
+ std::vector<TestIndexField> indexFields{{"a0", CollationOp::Ascending, false},
+ {"b0", CollationOp::Ascending, false}};
+
+ opt::unordered_map<std::string, IndexDefinition> testIndex = {
+ {"index1", makeCompositeIndexDefinition(indexFields, false /*isMultiKey*/)}};
+
+ std::vector<QueryTest> queryTests = {
+ {"{$and: [{a0: {$gt:11}}, {a0: {$lt:14}}, {b0: {$gt: 21}}, {b0: {$lt: 12}}]}",
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [scan_0]\n"
+ "| Const [Nothing]\n"
+ "LimitSkip []\n"
+ "| limitSkip:\n"
+ "| limit: 0\n"
+ "| skip: 0\n"
+ "CoScan []\n"},
+ {"{$and: [{a0: {$gt:14}}, {a0: {$lt:11}}, {b0: {$gt: 12}}, {b0: {$lt: 21}}]}",
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Filter []\n"
+ "| FunctionCall [traverseF]\n"
+ "| | | Const [false]\n"
+ "| | LambdaAbstraction [valCmp4]\n"
+ "| | BinaryOp [Lt]\n"
+ "| | | Const [0]\n"
+ "| | BinaryOp [Cmp3w]\n"
+ "| | | Const [21]\n"
+ "| | Variable [valCmp4]\n"
+ "| FunctionCall [getField]\n"
+ "| | Const [\"b0\"]\n"
+ "| Const [Nothing]\n"
+ "Filter []\n"
+ "| FunctionCall [traverseF]\n"
+ "| | | Const [false]\n"
+ "| | LambdaAbstraction [valCmp1]\n"
+ "| | BinaryOp [Gt]\n"
+ "| | | Const [0]\n"
+ "| | BinaryOp [Cmp3w]\n"
+ "| | | Const [12]\n"
+ "| | Variable [valCmp1]\n"
+ "| FunctionCall [getField]\n"
+ "| | Const [\"b0\"]\n"
+ "| Const [Nothing]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [scan_0]\n"
+ "| Const [Nothing]\n"
+ "LimitSkip []\n"
+ "| limitSkip:\n"
+ "| limit: 0\n"
+ "| skip: 0\n"
+ "CoScan []\n"},
+ {"{$and: [{a0: {$gt:14}}, {a0: {$lt:11}}, {b0: {$gt: 21}}, {b0: {$lt: 12}}]}",
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Filter []\n"
+ "| FunctionCall [traverseF]\n"
+ "| | | Const [false]\n"
+ "| | LambdaAbstraction [valCmp10]\n"
+ "| | BinaryOp [Lt]\n"
+ "| | | Const [0]\n"
+ "| | BinaryOp [Cmp3w]\n"
+ "| | | Const [12]\n"
+ "| | Variable [valCmp10]\n"
+ "| FunctionCall [getField]\n"
+ "| | Const [\"b0\"]\n"
+ "| Const [Nothing]\n"
+ "Filter []\n"
+ "| FunctionCall [traverseF]\n"
+ "| | | Const [false]\n"
+ "| | LambdaAbstraction [valCmp7]\n"
+ "| | BinaryOp [Gt]\n"
+ "| | | Const [0]\n"
+ "| | BinaryOp [Cmp3w]\n"
+ "| | | Const [21]\n"
+ "| | Variable [valCmp7]\n"
+ "| FunctionCall [getField]\n"
+ "| | Const [\"b0\"]\n"
+ "| Const [Nothing]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [scan_0]\n"
+ "| Const [Nothing]\n"
+ "LimitSkip []\n"
+ "| limitSkip:\n"
+ "| limit: 0\n"
+ "| skip: 0\n"
+ "CoScan []\n"},
+ {"{$and: [{a0: 42}, {b0: {$gt: 21}}, {b0: {$lt: 12}}]}",
+ "Root []\n"
+ "| | projections: \n"
+ "| | scan_0\n"
+ "| RefBlock: \n"
+ "| Variable [scan_0]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [scan_0]\n"
+ "| Const [Nothing]\n"
+ "LimitSkip []\n"
+ "| limitSkip:\n"
+ "| limit: 0\n"
+ "| skip: 0\n"
+ "CoScan []\n"},
+ };
+
+ /*std::cout << "\nExpected query plans:\n\n";
+ for (auto& qt : queryTests) {
+ std::cout << optimizedQueryPlan(qt.query, testIndex) << std::endl << std::endl;
+ }*/
+
+ ASSERT_EQ(queryTests[0].expectedPlan, optimizedQueryPlan(queryTests[0].query, testIndex));
+ // TODO: these tests contain an escaped string literal in the explain output, which causes
+ // failure in other tests by not escaping string literals as in the expected explain.
+ // Most likely there is some shared state in the SBE printer that gets messed up.
+ // ASSERT_EQ(queryTests[1].expectedPlan, optimizedQueryPlan(queryTests[1].query, testIndex));
+ // ASSERT_EQ(queryTests[2].expectedPlan, optimizedQueryPlan(queryTests[2].query, testIndex));
+ // ASSERT_EQ(queryTests[3].expectedPlan, optimizedQueryPlan(queryTests[3].query, testIndex));
+}
+
+TEST(IntervalIntersection, VariableIntervals) {
+ {
+ auto interval =
+ IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{
+ IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{
+ IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
+ BoundRequirement(true /*inclusive*/, make<Variable>("v1")),
+ BoundRequirement::makeInfinite()}),
+ IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
+ BoundRequirement(false /*inclusive*/, make<Variable>("v2")),
+ BoundRequirement::makeInfinite()})})});
+
+ auto result = intersectDNFIntervals(interval);
+ ASSERT_TRUE(result);
+
+ // (max(v1, v2), +inf) U [v2 >= v1 ? MaxKey : v1, max(v1, v2)]
+ ASSERT_EQ(
+ "{\n"
+ " {\n"
+ " {[If [] BinaryOp [Gte] Variable [v2] Variable [v1] Const [maxKey] Variable "
+ "[v1], If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable [v2]]}\n"
+ " }\n"
+ " U \n"
+ " {\n"
+ " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable "
+ "[v2], +inf)}\n"
+ " }\n"
+ "}\n",
+ ExplainGenerator::explainIntervalExpr(*result));
+
+ // Make sure repeated intersection does not change the result.
+ auto result1 = intersectDNFIntervals(*result);
+ ASSERT_TRUE(result1);
+ ASSERT_TRUE(*result == *result1);
+ }
+
+ {
+ auto interval =
+ IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{
+ IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{
+ IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
+ BoundRequirement(true /*inclusive*/, make<Variable>("v1")),
+ BoundRequirement(true /*inclusive*/, make<Variable>("v3"))}),
+ IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
+ BoundRequirement(true /*inclusive*/, make<Variable>("v2")),
+ BoundRequirement(true /*inclusive*/, make<Variable>("v4"))})})});
+
+ auto result = intersectDNFIntervals(interval);
+ ASSERT_TRUE(result);
+
+ // [v1, v3] ^ [v2, v4] -> [max(v1, v2), min(v3, v4)]
+ ASSERT_EQ(
+ "{\n"
+ " {\n"
+ " {[If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable "
+ "[v2], If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n"
+ " }\n"
+ "}\n",
+ ExplainGenerator::explainIntervalExpr(*result));
+
+ // Make sure repeated intersection does not change the result.
+ auto result1 = intersectDNFIntervals(*result);
+ ASSERT_TRUE(result1);
+ ASSERT_TRUE(*result == *result1);
+ }
+
+ {
+ auto interval =
+ IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{
+ IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{
+ IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
+ BoundRequirement(false /*inclusive*/, make<Variable>("v1")),
+ BoundRequirement(true /*inclusive*/, make<Variable>("v3"))}),
+ IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
+ BoundRequirement(true /*inclusive*/, make<Variable>("v2")),
+ BoundRequirement(true /*inclusive*/, make<Variable>("v4"))})})});
+
+ auto result = intersectDNFIntervals(interval);
+ ASSERT_TRUE(result);
+
+ ASSERT_EQ(
+ "{\n"
+ " {\n"
+ " {[If [] BinaryOp [Gte] Variable [v1] Variable [v2] Const [maxKey] Variable "
+ "[v2], If [] BinaryOp [Lte] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable "
+ "[v1] Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] "
+ "Variable [v4] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable "
+ "[v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n"
+ " }\n"
+ " U \n"
+ " {\n"
+ " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable "
+ "[v2], If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n"
+ " }\n"
+ "}\n",
+ ExplainGenerator::explainIntervalExpr(*result));
+
+ // Make sure repeated intersection does not change the result.
+ auto result1 = intersectDNFIntervals(*result);
+ ASSERT_TRUE(result1);
+ ASSERT_TRUE(*result == *result1);
+ }
+
+ {
+ auto interval =
+ IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{
+ IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{
+ IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
+ BoundRequirement(false /*inclusive*/, make<Variable>("v1")),
+ BoundRequirement(true /*inclusive*/, make<Variable>("v3"))}),
+ IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
+ BoundRequirement(true /*inclusive*/, make<Variable>("v2")),
+ BoundRequirement(false /*inclusive*/, make<Variable>("v4"))})})});
+
+ auto result = intersectDNFIntervals(interval);
+ ASSERT_TRUE(result);
+
+ ASSERT_EQ(
+ "{\n"
+ " {\n"
+ " {[If [] BinaryOp [Gte] Variable [v1] Variable [v2] Const [maxKey] Variable "
+ "[v2], If [] BinaryOp [Lte] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable "
+ "[v1] Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] "
+ "Variable [v4] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable "
+ "[v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n"
+ " }\n"
+ " U \n"
+ " {\n"
+ " {[If [] BinaryOp [Gte] If [] BinaryOp [Gte] Variable [v1] Variable [v2] "
+ "Variable [v1] Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable "
+ "[v3] Variable [v4] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] "
+ "Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable "
+ "[v4], If [] BinaryOp [Lte] Variable [v4] Variable [v3] Const [minKey] Variable "
+ "[v3]]}\n"
+ " }\n"
+ " U \n"
+ " {\n"
+ " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable "
+ "[v2], If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4])}\n"
+ " }\n"
+ "}\n",
+ ExplainGenerator::explainIntervalExpr(*result));
+
+ // Make sure repeated intersection does not change the result.
+ auto result1 = intersectDNFIntervals(*result);
+ ASSERT_TRUE(result1);
+ ASSERT_TRUE(*result == *result1);
+ }
+}
+
+template <int N>
+void updateResults(const bool lowInc,
+ const int low,
+ const bool highInc,
+ const int high,
+ std::bitset<N>& inclusion) {
+ for (int v = 0; v < low + (lowInc ? 0 : 1); v++) {
+ inclusion.set(v, false);
+ }
+ for (int v = high + (highInc ? 1 : 0); v < N; v++) {
+ inclusion.set(v, false);
+ }
+}
+
+template <int N>
+class IntervalInclusionTransport {
+public:
+ using ResultType = std::bitset<N>;
+
+ ResultType transport(const IntervalReqExpr::Atom& node) {
+ const auto& expr = node.getExpr();
+ const auto& lb = expr.getLowBound();
+ const auto& hb = expr.getHighBound();
+
+ std::bitset<N> result;
+ result.flip();
+ updateResults<N>(lb.isInclusive(),
+ lb.getBound().cast<Constant>()->getValueInt32(),
+ hb.isInclusive(),
+ hb.getBound().cast<Constant>()->getValueInt32(),
+ result);
+ return result;
+ }
+
+ ResultType transport(const IntervalReqExpr::Conjunction& node,
+ std::vector<ResultType> childResults) {
+ for (size_t index = 1; index < childResults.size(); index++) {
+ childResults.front() &= childResults.at(index);
+ }
+ return childResults.front();
+ }
+
+ ResultType transport(const IntervalReqExpr::Disjunction& node,
+ std::vector<ResultType> childResults) {
+ for (size_t index = 1; index < childResults.size(); index++) {
+ childResults.front() |= childResults.at(index);
+ }
+ return childResults.front();
+ }
+
+ ResultType computeInclusion(const IntervalReqExpr::Node& intervals) {
+ return algebra::transport<false>(intervals, *this);
+ }
+};
+
+template <int V>
+int decode(int& permutation) {
+ const int result = permutation % V;
+ permutation /= V;
+ return result;
+}
+
+template <int N>
+void testInterval(int permutation) {
+ const bool low1Inc = decode<2>(permutation);
+ const int low1 = decode<N>(permutation);
+ const bool high1Inc = decode<2>(permutation);
+ const int high1 = decode<N>(permutation);
+ const bool low2Inc = decode<2>(permutation);
+ const int low2 = decode<N>(permutation);
+ const bool high2Inc = decode<2>(permutation);
+ const int high2 = decode<N>(permutation);
+
+ auto interval = IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{
+ IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{
+ IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
+ {low1Inc, Constant::int32(low1)}, {high1Inc, Constant::int32(high1)}}),
+ IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
+ {low2Inc, Constant::int32(low2)}, {high2Inc, Constant::int32(high2)}})})});
+
+ auto result = intersectDNFIntervals(interval);
+ std::bitset<N> inclusionActual;
+ if (result) {
+ // Since we are testing with constants, we should have at most one interval.
+ ASSERT_TRUE(IntervalReqExpr::getSingularDNF(*result));
+
+ IntervalInclusionTransport<N> transport;
+ // Compute actual inclusion bitset based on the interval intersection code.
+ inclusionActual = transport.computeInclusion(*result);
+ }
+
+ std::bitset<N> inclusionExpected;
+ inclusionExpected.flip();
+
+ // Compute ground truth.
+ updateResults<N>(low1Inc, low1, high1Inc, high1, inclusionExpected);
+ updateResults<N>(low2Inc, low2, high2Inc, high2, inclusionExpected);
+
+ ASSERT_EQ(inclusionExpected, inclusionActual);
+}
+
+TEST(IntervalIntersection, IntervalPermutations) {
+ static constexpr int N = 10;
+ static constexpr int numPermutations = N * N * N * N * 2 * 2 * 2 * 2;
+
+ /**
+ * Test for interval intersection. Generate intervals with constants in the
+ * range of [0, N), with random inclusion/exclusion of the endpoints. Intersect the intervals
+ * and verify against ground truth.
+ */
+ const size_t numThreads = ProcessInfo::getNumCores();
+ std::cout << "Testing " << numPermutations << " interval permutations using " << numThreads
+ << " cores...\n";
+ auto timeBegin = Date_t::now();
+
+ AtomicWord<int> permutation(0);
+ std::vector<stdx::thread> threads;
+ for (size_t i = 0; i < numThreads; i++) {
+ threads.emplace_back([&permutation]() {
+ for (;;) {
+ const int nextP = permutation.fetchAndAdd(1);
+ if (nextP >= numPermutations) {
+ break;
+ }
+ testInterval<N>(nextP);
+ }
+ });
+ }
+ for (auto& thread : threads) {
+ thread.join();
+ }
+
+ const auto elapsed =
+ (Date_t::now().toMillisSinceEpoch() - timeBegin.toMillisSinceEpoch()) / 1000.0;
+ std::cout << "...done. Took: " << elapsed << " s.\n";
+}
+
+} // namespace
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp b/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp
new file mode 100644
index 00000000000..4c912046d3a
--- /dev/null
+++ b/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp
@@ -0,0 +1,1495 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
+#include "mongo/db/query/optimizer/cascades/logical_props_derivation.h"
+#include "mongo/db/query/optimizer/explain.h"
+#include "mongo/db/query/optimizer/node.h"
+#include "mongo/db/query/optimizer/opt_phase_manager.h"
+#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo::optimizer {
+namespace {
+
+TEST(LogicalRewriter, RootNodeMerge) {
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("a", "test");
+ ABT limitSkipNode1 =
+ make<LimitSkipNode>(properties::LimitSkipRequirement(-1, 10), std::move(scanNode));
+ ABT limitSkipNode2 =
+ make<LimitSkipNode>(properties::LimitSkipRequirement(5, 0), std::move(limitSkipNode1));
+
+ ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a"}},
+ std::move(limitSkipNode2));
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " a\n"
+ " RefBlock: \n"
+ " Variable [a]\n"
+ " LimitSkip []\n"
+ " limitSkip:\n"
+ " limit: 5\n"
+ " skip: 0\n"
+ " LimitSkip []\n"
+ " limitSkip:\n"
+ " limit: (none)\n"
+ " skip: 10\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [a]\n"
+ " Source []\n",
+ rootNode);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", {{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+ ABT rewritten = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(rewritten));
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " a\n"
+ " RefBlock: \n"
+ " Variable [a]\n"
+ " LimitSkip []\n"
+ " limitSkip:\n"
+ " limit: 5\n"
+ " skip: 10\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [a]\n"
+ " Source []\n",
+ rewritten);
+}
+
+TEST(LogicalRewriter, Memo) {
+ using namespace cascades;
+ using namespace properties;
+
+ Metadata metadata{{{"test", {}}}};
+ Memo memo(DebugInfo::kDefaultForTests,
+ metadata,
+ std::make_unique<DefaultLogicalPropsDerivation>(),
+ std::make_unique<HeuristicCE>());
+
+ ABT scanNode = make<ScanNode>("ptest", "test");
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(make<PathConstant>(make<UnaryOp>(Operations::Neg, Constant::int64(1))),
+ make<Variable>("ptest")),
+ std::move(scanNode));
+ ABT evalNode = make<EvaluationNode>(
+ "P1",
+ make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("ptest")),
+ std::move(filterNode));
+
+ NodeIdSet insertedNodeIds;
+ const GroupIdType rootGroupId = memo.integrate(evalNode, {}, insertedNodeIds);
+ ASSERT_EQ(2, rootGroupId);
+ ASSERT_EQ(3, memo.getGroupCount());
+
+ NodeIdSet expectedInsertedNodeIds = {{0, 0}, {1, 0}, {2, 0}};
+ ASSERT_TRUE(insertedNodeIds == expectedInsertedNodeIds);
+
+ ASSERT_EXPLAIN_MEMO(
+ "Memo: \n"
+ " groupId: 0\n"
+ " | | Logical properties:\n"
+ " | | cardinalityEstimate: \n"
+ " | | ce: 1000\n"
+ " | | projections: \n"
+ " | | ptest\n"
+ " | | indexingAvailability: \n"
+ " | | [groupId: 0, scanProjection: ptest, scanDefName: test, "
+ "possiblyEqPredsOnly]\n"
+ " | | collectionAvailability: \n"
+ " | | test\n"
+ " | | distributionAvailability: \n"
+ " | | distribution: \n"
+ " | | type: Centralized\n"
+ " | logicalNodes: \n"
+ " | logicalNodeId: 0\n"
+ " | Scan [test]\n"
+ " | BindBlock:\n"
+ " | [ptest]\n"
+ " | Source []\n"
+ " physicalNodes: \n"
+ " groupId: 1\n"
+ " | | Logical properties:\n"
+ " | | cardinalityEstimate: \n"
+ " | | ce: 100\n"
+ " | | projections: \n"
+ " | | ptest\n"
+ " | | indexingAvailability: \n"
+ " | | [groupId: 0, scanProjection: ptest, scanDefName: test]\n"
+ " | | collectionAvailability: \n"
+ " | | test\n"
+ " | | distributionAvailability: \n"
+ " | | distribution: \n"
+ " | | type: Centralized\n"
+ " | logicalNodes: \n"
+ " | logicalNodeId: 0\n"
+ " | Filter []\n"
+ " | | EvalFilter []\n"
+ " | | | Variable [ptest]\n"
+ " | | PathConstant []\n"
+ " | | UnaryOp [Neg]\n"
+ " | | Const [1]\n"
+ " | MemoLogicalDelegator [groupId: 0]\n"
+ " physicalNodes: \n"
+ " groupId: 2\n"
+ " | | Logical properties:\n"
+ " | | cardinalityEstimate: \n"
+ " | | ce: 100\n"
+ " | | projections: \n"
+ " | | P1\n"
+ " | | ptest\n"
+ " | | indexingAvailability: \n"
+ " | | [groupId: 0, scanProjection: ptest, scanDefName: test]\n"
+ " | | collectionAvailability: \n"
+ " | | test\n"
+ " | | distributionAvailability: \n"
+ " | | distribution: \n"
+ " | | type: Centralized\n"
+ " | logicalNodes: \n"
+ " | logicalNodeId: 0\n"
+ " | Evaluation []\n"
+ " | | BindBlock:\n"
+ " | | [P1]\n"
+ " | | EvalPath []\n"
+ " | | | Variable [ptest]\n"
+ " | | PathConstant []\n"
+ " | | Const [2]\n"
+ " | MemoLogicalDelegator [groupId: 1]\n"
+ " physicalNodes: \n",
+ memo);
+
+ {
+ // Try to insert into the memo again.
+ NodeIdSet insertedNodeIds;
+ const GroupIdType group = memo.integrate(evalNode, {}, insertedNodeIds);
+ ASSERT_EQ(2, group);
+ ASSERT_EQ(3, memo.getGroupCount());
+
+ // Nothing was inserted.
+ ASSERT_EQ(1, memo.getGroup(0)._logicalNodes.size());
+ ASSERT_EQ(1, memo.getGroup(1)._logicalNodes.size());
+ ASSERT_EQ(1, memo.getGroup(2)._logicalNodes.size());
+ }
+
+ // Insert a different tree, this time only scan and project.
+ ABT scanNode1 = make<ScanNode>("ptest", "test");
+ ABT evalNode1 = make<EvaluationNode>(
+ "P1",
+ make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("ptest")),
+ std::move(scanNode1));
+
+ {
+ NodeIdSet insertedNodeIds1;
+ const GroupIdType rootGroupId1 = memo.integrate(evalNode1, {}, insertedNodeIds1);
+ ASSERT_EQ(3, rootGroupId1);
+ ASSERT_EQ(4, memo.getGroupCount());
+
+ // Nothing was inserted in first 3 groups.
+ ASSERT_EQ(1, memo.getGroup(0)._logicalNodes.size());
+ ASSERT_EQ(1, memo.getGroup(1)._logicalNodes.size());
+ ASSERT_EQ(1, memo.getGroup(2)._logicalNodes.size());
+ }
+
+ {
+ const Group& group = memo.getGroup(3);
+ ASSERT_EQ(1, group._logicalNodes.size());
+
+ ASSERT_EXPLAIN(
+ "Evaluation []\n"
+ " BindBlock:\n"
+ " [P1]\n"
+ " EvalPath []\n"
+ " PathConstant []\n"
+ " Const [2]\n"
+ " Variable [ptest]\n"
+ " MemoLogicalDelegator [groupId: 0]\n",
+ group._logicalNodes.at(0));
+ }
+}
+
+TEST(LogicalRewriter, FilterProjectRewrite) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("ptest", "test");
+ ABT collationNode = make<CollationNode>(
+ CollationRequirement({{"ptest", CollationOp::Ascending}}), std::move(scanNode));
+ ABT evalNode =
+ make<EvaluationNode>("P1",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")),
+ std::move(collationNode));
+ ABT filterNode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("P1")),
+ std::move(evalNode));
+
+ ABT rootNode = make<RootNode>(properties::ProjectionRequirement{{}}, std::move(filterNode));
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " RefBlock: \n"
+ " Filter []\n"
+ " EvalFilter []\n"
+ " PathIdentity []\n"
+ " Variable [P1]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [P1]\n"
+ " EvalPath []\n"
+ " PathIdentity []\n"
+ " Variable [ptest]\n"
+ " Collation []\n"
+ " collation: \n"
+ " ptest: Ascending\n"
+ " RefBlock: \n"
+ " Variable [ptest]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ rootNode);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", {{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+ ABT latest = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(latest));
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " RefBlock: \n"
+ " Collation []\n"
+ " collation: \n"
+ " ptest: Ascending\n"
+ " RefBlock: \n"
+ " Variable [ptest]\n"
+ " Filter []\n"
+ " EvalFilter []\n"
+ " PathIdentity []\n"
+ " Variable [P1]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [P1]\n"
+ " EvalPath []\n"
+ " PathIdentity []\n"
+ " Variable [ptest]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ latest);
+}
+
+TEST(LogicalRewriter, FilterProjectComplexRewrite) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("ptest", "test");
+
+ ABT projection2Node = make<EvaluationNode>(
+ "p2", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode));
+
+ ABT projection3Node =
+ make<EvaluationNode>("p3",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")),
+ std::move(projection2Node));
+
+ ABT collationNode = make<CollationNode>(
+ CollationRequirement({{"ptest", CollationOp::Ascending}}), std::move(projection3Node));
+
+ ABT projection1Node =
+ make<EvaluationNode>("p1",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")),
+ std::move(collationNode));
+
+ ABT filter1Node = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("p1")),
+ std::move(projection1Node));
+
+ ABT filterScanNode = make<FilterNode>(
+ make<EvalFilter>(make<PathIdentity>(), make<Variable>("ptest")), std::move(filter1Node));
+
+ ABT filter2Node = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("p2")),
+ std::move(filterScanNode));
+
+ ABT rootNode = make<RootNode>(properties::ProjectionRequirement{{}}, std::move(filter2Node));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| RefBlock: \n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [p2]\n"
+ "| PathIdentity []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [p1]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [p1]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | ptest: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [ptest]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [p3]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [p2]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ rootNode);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", {{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+ ABT latest = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(latest));
+
+ // Note: this assert depends on the order on which we consider rewrites.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| RefBlock: \n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | ptest: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [ptest]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [p2]\n"
+ "| PathIdentity []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [p1]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [p1]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [p3]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [p2]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ latest);
+}
+
+TEST(LogicalRewriter, FilterProjectGroupRewrite) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("ptest", "test");
+
+ ABT projectionANode = make<EvaluationNode>(
+ "a", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode));
+ ABT projectionBNode =
+ make<EvaluationNode>("b",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")),
+ std::move(projectionANode));
+
+ ABT groupByNode = make<GroupByNode>(ProjectionNameVector{"a"},
+ ProjectionNameVector{"c"},
+ makeSeq(make<Variable>("b")),
+ std::move(projectionBNode));
+
+ ABT filterANode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("a")),
+ std::move(groupByNode));
+
+ ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"c"}},
+ std::move(filterANode));
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", {{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+ ABT latest = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(latest));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | c\n"
+ "| RefBlock: \n"
+ "| Variable [c]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [a]\n"
+ "| aggregations: \n"
+ "| [c]\n"
+ "| Variable [b]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [b]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [a]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [a]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ latest);
+}
+
+TEST(LogicalRewriter, FilterProjectUnwindRewrite) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("ptest", "test");
+
+ ABT projectionANode = make<EvaluationNode>(
+ "a", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode));
+ ABT projectionBNode =
+ make<EvaluationNode>("b",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")),
+ std::move(projectionANode));
+
+ ABT unwindNode =
+ make<UnwindNode>("a", "a_pid", false /*retainNonArrays*/, std::move(projectionBNode));
+
+ // This filter should stay above the unwind.
+ ABT filterANode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("a")),
+ std::move(unwindNode));
+
+ // This filter should be pushed down below the unwind.
+ ABT filterBNode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("b")),
+ std::move(filterANode));
+
+ ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a", "b"}},
+ std::move(filterBNode));
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", {{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+ ABT latest = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(latest));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | a\n"
+ "| | b\n"
+ "| RefBlock: \n"
+ "| Variable [a]\n"
+ "| Variable [b]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [b]\n"
+ "| PathIdentity []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [a]\n"
+ "| PathIdentity []\n"
+ "Unwind []\n"
+ "| BindBlock:\n"
+ "| [a]\n"
+ "| Source []\n"
+ "| [a_pid]\n"
+ "| Source []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [b]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [a]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ latest);
+}
+
+TEST(LogicalRewriter, FilterProjectExchangeRewrite) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("ptest", "test");
+
+ ABT projectionANode = make<EvaluationNode>(
+ "a", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode));
+ ABT projectionBNode =
+ make<EvaluationNode>("b",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")),
+ std::move(projectionANode));
+
+ ABT exchangeNode = make<ExchangeNode>(
+ properties::DistributionRequirement({DistributionType::HashPartitioning, {"a"}}),
+ std::move(projectionBNode));
+
+ ABT filterANode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("a")),
+ std::move(exchangeNode));
+
+ ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a", "b"}},
+ std::move(filterANode));
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", {{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+ ABT latest = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(latest));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | a\n"
+ "| | b\n"
+ "| RefBlock: \n"
+ "| Variable [a]\n"
+ "| Variable [b]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [b]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: HashPartitioning\n"
+ "| | projections: \n"
+ "| | a\n"
+ "| RefBlock: \n"
+ "| Variable [a]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [a]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [a]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ latest);
+}
+
+TEST(LogicalRewriter, UnwindCollationRewrite) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("ptest", "test");
+
+ ABT projectionANode = make<EvaluationNode>(
+ "a", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode));
+ ABT projectionBNode =
+ make<EvaluationNode>("b",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")),
+ std::move(projectionANode));
+
+ // This collation node should stay below the unwind.
+ ABT collationANode = make<CollationNode>(CollationRequirement({{"a", CollationOp::Ascending}}),
+ std::move(projectionBNode));
+
+ // This collation node should go above the unwind.
+ ABT collationBNode = make<CollationNode>(CollationRequirement({{"b", CollationOp::Ascending}}),
+ std::move(collationANode));
+
+ ABT unwindNode =
+ make<UnwindNode>("a", "a_pid", false /*retainNonArrays*/, std::move(collationBNode));
+
+ ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a", "b"}},
+ std::move(unwindNode));
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", {{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+ ABT latest = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(latest));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | a\n"
+ "| | b\n"
+ "| RefBlock: \n"
+ "| Variable [a]\n"
+ "| Variable [b]\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | b: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [b]\n"
+ "Unwind []\n"
+ "| BindBlock:\n"
+ "| [a]\n"
+ "| Source []\n"
+ "| [a_pid]\n"
+ "| Source []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [b]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [a]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ latest);
+}
+
+TEST(LogicalRewriter, FilterUnionReorderSingleProjection) {
+ PrefixId prefixId;
+ ABT scanNode1 = make<ScanNode>("ptest1", "test1");
+ ABT scanNode2 = make<ScanNode>("ptest2", "test2");
+ // Create two eval nodes such that the two branches of the union share a projection.
+ ABT evalNode1 =
+ make<EvaluationNode>("pUnion",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest1")),
+ std::move(scanNode1));
+ ABT evalNode2 =
+ make<EvaluationNode>("pUnion",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest2")),
+ std::move(scanNode2));
+
+ ABT unionNode = make<UnionNode>(ProjectionNameVector{"pUnion"}, makeSeq(evalNode1, evalNode2));
+
+ ABT filter = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))),
+ make<Variable>("pUnion")),
+ std::move(unionNode));
+ ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"pUnion"}},
+ std::move(filter));
+
+ ABT latest = std::move(rootNode);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pUnion\n"
+ "| RefBlock: \n"
+ "| Variable [pUnion]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [pUnion]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [pUnion]\n"
+ "| | Source []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [pUnion]\n"
+ "| | EvalPath []\n"
+ "| | | Variable [ptest2]\n"
+ "| | PathIdentity []\n"
+ "| Scan [test2]\n"
+ "| BindBlock:\n"
+ "| [ptest2]\n"
+ "| Source []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [pUnion]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest1]\n"
+ "| PathIdentity []\n"
+ "Scan [test1]\n"
+ " BindBlock:\n"
+ " [ptest1]\n"
+ " Source []\n",
+ latest);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase},
+ prefixId,
+ {{{"test1", {{}, {}}}, {"test2", {{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+ ASSERT_TRUE(phaseManager.optimize(latest));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pUnion\n"
+ "| RefBlock: \n"
+ "| Variable [pUnion]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [pUnion]\n"
+ "| | Source []\n"
+ "| Filter []\n"
+ "| | EvalFilter []\n"
+ "| | | Variable [pUnion]\n"
+ "| | PathGet [a]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Eq]\n"
+ "| | Const [1]\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [pUnion]\n"
+ "| | EvalPath []\n"
+ "| | | Variable [ptest2]\n"
+ "| | PathIdentity []\n"
+ "| Scan [test2]\n"
+ "| BindBlock:\n"
+ "| [ptest2]\n"
+ "| Source []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [pUnion]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [pUnion]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest1]\n"
+ "| PathIdentity []\n"
+ "Scan [test1]\n"
+ " BindBlock:\n"
+ " [ptest1]\n"
+ " Source []\n",
+ latest);
+}
+
+TEST(LogicalRewriter, MultipleFilterUnionReorder) {
+ PrefixId prefixId;
+ ABT scanNode1 = make<ScanNode>("ptest1", "test1");
+ ABT scanNode2 = make<ScanNode>("ptest2", "test2");
+
+ // Create multiple shared projections for each child.
+ ABT pUnion11 =
+ make<EvaluationNode>("pUnion1",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest1")),
+ std::move(scanNode1));
+ ABT pUnion12 =
+ make<EvaluationNode>("pUnion2",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest1")),
+ std::move(pUnion11));
+
+ ABT pUnion21 =
+ make<EvaluationNode>("pUnion1",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest2")),
+ std::move(scanNode2));
+ ABT pUnion22 =
+ make<EvaluationNode>("pUnion2",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest2")),
+ std::move(pUnion21));
+
+ ABT unionNode =
+ make<UnionNode>(ProjectionNameVector{"pUnion1", "pUnion2"}, makeSeq(pUnion12, pUnion22));
+
+ // Create two filters, one for each of the two common projections.
+ ABT filterUnion1 = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))),
+ make<Variable>("pUnion1")),
+ std::move(unionNode));
+ ABT filterUnion2 = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))),
+ make<Variable>("pUnion2")),
+ std::move(filterUnion1));
+ ABT rootNode = make<RootNode>(
+ properties::ProjectionRequirement{ProjectionNameVector{"pUnion1", "pUnion2"}},
+ std::move(filterUnion2));
+
+ ABT latest = std::move(rootNode);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pUnion1\n"
+ "| | pUnion2\n"
+ "| RefBlock: \n"
+ "| Variable [pUnion1]\n"
+ "| Variable [pUnion2]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [pUnion2]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [pUnion1]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [pUnion1]\n"
+ "| | Source []\n"
+ "| | [pUnion2]\n"
+ "| | Source []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [pUnion2]\n"
+ "| | EvalPath []\n"
+ "| | | Variable [ptest2]\n"
+ "| | PathIdentity []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [pUnion1]\n"
+ "| | EvalPath []\n"
+ "| | | Variable [ptest2]\n"
+ "| | PathIdentity []\n"
+ "| Scan [test2]\n"
+ "| BindBlock:\n"
+ "| [ptest2]\n"
+ "| Source []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [pUnion2]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest1]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [pUnion1]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest1]\n"
+ "| PathIdentity []\n"
+ "Scan [test1]\n"
+ " BindBlock:\n"
+ " [ptest1]\n"
+ " Source []\n",
+ latest);
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase},
+ prefixId,
+ {{{"test1", {{}, {}}}, {"test2", {{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+ ASSERT_TRUE(phaseManager.optimize(latest));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pUnion1\n"
+ "| | pUnion2\n"
+ "| RefBlock: \n"
+ "| Variable [pUnion1]\n"
+ "| Variable [pUnion2]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [pUnion1]\n"
+ "| | Source []\n"
+ "| | [pUnion2]\n"
+ "| | Source []\n"
+ "| Filter []\n"
+ "| | EvalFilter []\n"
+ "| | | Variable [pUnion2]\n"
+ "| | PathGet [a]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Eq]\n"
+ "| | Const [1]\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [pUnion2]\n"
+ "| | EvalPath []\n"
+ "| | | Variable [ptest2]\n"
+ "| | PathIdentity []\n"
+ "| Filter []\n"
+ "| | EvalFilter []\n"
+ "| | | Variable [pUnion1]\n"
+ "| | PathGet [a]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Eq]\n"
+ "| | Const [1]\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [pUnion1]\n"
+ "| | EvalPath []\n"
+ "| | | Variable [ptest2]\n"
+ "| | PathIdentity []\n"
+ "| Scan [test2]\n"
+ "| BindBlock:\n"
+ "| [ptest2]\n"
+ "| Source []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [pUnion2]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [pUnion2]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest1]\n"
+ "| PathIdentity []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [pUnion1]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [pUnion1]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest1]\n"
+ "| PathIdentity []\n"
+ "Scan [test1]\n"
+ " BindBlock:\n"
+ " [ptest1]\n"
+ " Source []\n",
+ latest);
+}
+
+TEST(LogicalRewriter, FilterUnionUnionPushdown) {
+ PrefixId prefixId;
+ ABT scanNode1 = make<ScanNode>("ptest", "test1");
+ ABT scanNode2 = make<ScanNode>("ptest", "test2");
+ ABT unionNode = make<UnionNode>(ProjectionNameVector{"ptest"}, makeSeq(scanNode1, scanNode2));
+
+ ABT scanNode3 = make<ScanNode>("ptest", "test3");
+ ABT parentUnionNode =
+ make<UnionNode>(ProjectionNameVector{"ptest"}, makeSeq(unionNode, scanNode3));
+
+ ABT filter = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))),
+ make<Variable>("ptest")),
+ std::move(parentUnionNode));
+ ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"ptest"}},
+ std::move(filter));
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test1", {{}, {}}}, {"test2", {{}, {}}}, {"test3", {{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+ ABT latest = std::move(rootNode);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | ptest\n"
+ "| RefBlock: \n"
+ "| Variable [ptest]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [ptest]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [ptest]\n"
+ "| | Source []\n"
+ "| Scan [test3]\n"
+ "| BindBlock:\n"
+ "| [ptest]\n"
+ "| Source []\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [ptest]\n"
+ "| | Source []\n"
+ "| Scan [test2]\n"
+ "| BindBlock:\n"
+ "| [ptest]\n"
+ "| Source []\n"
+ "Scan [test1]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ latest);
+
+ ASSERT_TRUE(phaseManager.optimize(latest));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | ptest\n"
+ "| RefBlock: \n"
+ "| Variable [ptest]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [ptest]\n"
+ "| | Source []\n"
+ "| Sargable [Complete]\n"
+ "| | | | | requirementsMap: \n"
+ "| | | | | refProjection: ptest, path: 'PathGet [a] PathTraverse [] "
+ "PathIdentity []', intervals: {{{[Const [1], Const [1]]}}}\n"
+ "| | | | candidateIndexes: \n"
+ "| | | BindBlock:\n"
+ "| | RefBlock: \n"
+ "| | Variable [ptest]\n"
+ "| Scan [test3]\n"
+ "| BindBlock:\n"
+ "| [ptest]\n"
+ "| Source []\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [ptest]\n"
+ "| | Source []\n"
+ "| Sargable [Complete]\n"
+ "| | | | | requirementsMap: \n"
+ "| | | | | refProjection: ptest, path: 'PathGet [a] PathTraverse [] "
+ "PathIdentity []', intervals: {{{[Const [1], Const [1]]}}}\n"
+ "| | | | candidateIndexes: \n"
+ "| | | BindBlock:\n"
+ "| | RefBlock: \n"
+ "| | Variable [ptest]\n"
+ "| Scan [test2]\n"
+ "| BindBlock:\n"
+ "| [ptest]\n"
+ "| Source []\n"
+ "Sargable [Complete]\n"
+ "| | | | requirementsMap: \n"
+ "| | | | refProjection: ptest, path: 'PathGet [a] PathTraverse [] PathIdentity "
+ "[]', intervals: {{{[Const [1], Const [1]]}}}\n"
+ "| | | candidateIndexes: \n"
+ "| | BindBlock:\n"
+ "| RefBlock: \n"
+ "| Variable [ptest]\n"
+ "Scan [test1]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ latest);
+}
+
+TEST(LogicalRewriter, UnionPreservesCommonLogicalProps) {
+ ABT scanNode1 = make<ScanNode>("ptest1", "test1");
+ ABT scanNode2 = make<ScanNode>("ptest2", "test2");
+ ABT evalNode1 = make<EvaluationNode>(
+ "a",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("ptest1")),
+ std::move(scanNode1));
+
+ ABT evalNode2 = make<EvaluationNode>(
+ "a",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("ptest2")),
+ std::move(scanNode2));
+ ABT unionNode = make<UnionNode>(ProjectionNameVector{"a"}, makeSeq(evalNode1, evalNode2));
+
+ ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a"}},
+ std::move(unionNode));
+
+ Metadata metadata{{{"test1",
+ ScanDefinition{{},
+ {},
+ {DistributionType::HashPartitioning,
+ makeSeq(make<PathGet>("a", make<PathIdentity>()))}}},
+ {"test2",
+ ScanDefinition{{},
+ {},
+ {DistributionType::HashPartitioning,
+ makeSeq(make<PathGet>("a", make<PathIdentity>()))}}}},
+ 2};
+
+ // Run the reordering rewrite such that the scan produces a hash partition.
+ PrefixId prefixId;
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_MEMO(
+ "Memo: \n"
+ " groupId: 0\n"
+ " | | Logical properties:\n"
+ " | | cardinalityEstimate: \n"
+ " | | ce: 1000\n"
+ " | | projections: \n"
+ " | | ptest1\n"
+ " | | indexingAvailability: \n"
+ " | | [groupId: 0, scanProjection: ptest1, scanDefName: test1, "
+ "possiblyEqPredsOnly]\n"
+ " | | collectionAvailability: \n"
+ " | | test1\n"
+ " | | distributionAvailability: \n"
+ " | | distribution: \n"
+ " | | type: UnknownPartitioning\n"
+ " | logicalNodes: \n"
+ " | logicalNodeId: 0\n"
+ " | Scan [test1]\n"
+ " | BindBlock:\n"
+ " | [ptest1]\n"
+ " | Source []\n"
+ " physicalNodes: \n"
+ " groupId: 1\n"
+ " | | Logical properties:\n"
+ " | | cardinalityEstimate: \n"
+ " | | ce: 1000\n"
+ " | | requirementCEs: \n"
+ " | | refProjection: ptest1, path: 'PathGet [a] PathIdentity []', ce: "
+ "1000\n"
+ " | | projections: \n"
+ " | | a\n"
+ " | | ptest1\n"
+ " | | indexingAvailability: \n"
+ " | | [groupId: 0, scanProjection: ptest1, scanDefName: test1]\n"
+ " | | collectionAvailability: \n"
+ " | | test1\n"
+ " | | distributionAvailability: \n"
+ " | | distribution: \n"
+ " | | type: Centralized\n"
+ " | | distribution: \n"
+ " | | type: RoundRobin\n"
+ " | | distribution: \n"
+ " | | type: HashPartitioning\n"
+ " | | projections: \n"
+ " | | a\n"
+ " | | distribution: \n"
+ " | | type: UnknownPartitioning\n"
+ " | logicalNodes: \n"
+ " | logicalNodeId: 0\n"
+ " | Sargable [Complete]\n"
+ " | | | | | requirementsMap: \n"
+ " | | | | | refProjection: ptest1, path: 'PathGet [a] "
+ "PathIdentity []', boundProjection: a, intervals: {{{(-inf, +inf)}}}\n"
+ " | | | | candidateIndexes: \n"
+ " | | | BindBlock:\n"
+ " | | | [a]\n"
+ " | | | Source []\n"
+ " | | RefBlock: \n"
+ " | | Variable [ptest1]\n"
+ " | MemoLogicalDelegator [groupId: 0]\n"
+ " physicalNodes: \n"
+ " groupId: 2\n"
+ " | | Logical properties:\n"
+ " | | cardinalityEstimate: \n"
+ " | | ce: 1000\n"
+ " | | projections: \n"
+ " | | ptest2\n"
+ " | | indexingAvailability: \n"
+ " | | [groupId: 2, scanProjection: ptest2, scanDefName: test2, "
+ "possiblyEqPredsOnly]\n"
+ " | | collectionAvailability: \n"
+ " | | test2\n"
+ " | | distributionAvailability: \n"
+ " | | distribution: \n"
+ " | | type: UnknownPartitioning\n"
+ " | logicalNodes: \n"
+ " | logicalNodeId: 0\n"
+ " | Scan [test2]\n"
+ " | BindBlock:\n"
+ " | [ptest2]\n"
+ " | Source []\n"
+ " physicalNodes: \n"
+ " groupId: 3\n"
+ " | | Logical properties:\n"
+ " | | cardinalityEstimate: \n"
+ " | | ce: 1000\n"
+ " | | requirementCEs: \n"
+ " | | refProjection: ptest2, path: 'PathGet [a] PathIdentity []', ce: "
+ "1000\n"
+ " | | projections: \n"
+ " | | a\n"
+ " | | ptest2\n"
+ " | | indexingAvailability: \n"
+ " | | [groupId: 2, scanProjection: ptest2, scanDefName: test2]\n"
+ " | | collectionAvailability: \n"
+ " | | test2\n"
+ " | | distributionAvailability: \n"
+ " | | distribution: \n"
+ " | | type: Centralized\n"
+ " | | distribution: \n"
+ " | | type: RoundRobin\n"
+ " | | distribution: \n"
+ " | | type: HashPartitioning\n"
+ " | | projections: \n"
+ " | | a\n"
+ " | | distribution: \n"
+ " | | type: UnknownPartitioning\n"
+ " | logicalNodes: \n"
+ " | logicalNodeId: 0\n"
+ " | Sargable [Complete]\n"
+ " | | | | | requirementsMap: \n"
+ " | | | | | refProjection: ptest2, path: 'PathGet [a] "
+ "PathIdentity []', boundProjection: a, intervals: {{{(-inf, +inf)}}}\n"
+ " | | | | candidateIndexes: \n"
+ " | | | BindBlock:\n"
+ " | | | [a]\n"
+ " | | | Source []\n"
+ " | | RefBlock: \n"
+ " | | Variable [ptest2]\n"
+ " | MemoLogicalDelegator [groupId: 2]\n"
+ " physicalNodes: \n"
+ " groupId: 4\n"
+ " | | Logical properties:\n"
+ " | | cardinalityEstimate: \n"
+ " | | ce: 2000\n"
+ " | | projections: \n"
+ " | | a\n"
+ " | | collectionAvailability: \n"
+ " | | test1\n"
+ " | | test2\n"
+ " | | distributionAvailability: \n"
+ " | | distribution: \n"
+ " | | type: Centralized\n"
+ " | | distribution: \n"
+ " | | type: RoundRobin\n"
+ " | | distribution: \n"
+ " | | type: HashPartitioning\n"
+ " | | projections: \n"
+ " | | a\n"
+ " | | distribution: \n"
+ " | | type: UnknownPartitioning\n"
+ " | logicalNodes: \n"
+ " | logicalNodeId: 0\n"
+ " | Union []\n"
+ " | | | BindBlock:\n"
+ " | | | [a]\n"
+ " | | | Source []\n"
+ " | | MemoLogicalDelegator [groupId: 3]\n"
+ " | MemoLogicalDelegator [groupId: 1]\n"
+ " physicalNodes: \n"
+ " groupId: 5\n"
+ " | | Logical properties:\n"
+ " | | cardinalityEstimate: \n"
+ " | | ce: 2000\n"
+ " | | projections: \n"
+ " | | a\n"
+ " | | collectionAvailability: \n"
+ " | | test1\n"
+ " | | test2\n"
+ " | | distributionAvailability: \n"
+ " | | distribution: \n"
+ " | | type: Centralized\n"
+ " | | distribution: \n"
+ " | | type: RoundRobin\n"
+ " | | distribution: \n"
+ " | | type: HashPartitioning\n"
+ " | | projections: \n"
+ " | | a\n"
+ " | | distribution: \n"
+ " | | type: UnknownPartitioning\n"
+ " | logicalNodes: \n"
+ " | logicalNodeId: 0\n"
+ " | Root []\n"
+ " | | | projections: \n"
+ " | | | a\n"
+ " | | RefBlock: \n"
+ " | | Variable [a]\n"
+ " | MemoLogicalDelegator [groupId: 4]\n"
+ " physicalNodes: \n",
+ phaseManager.getMemo());
+}
+
+TEST(LogicalRewriter, SargableCE) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("ptest", "test");
+
+ ABT filterANode = make<FilterNode>(
+ make<EvalFilter>(make<PathGet>("a", make<PathCompare>(Operations::Eq, Constant::int64(1))),
+ make<Variable>("ptest")),
+ std::move(scanNode));
+ ABT filterBNode = make<FilterNode>(
+ make<EvalFilter>(make<PathGet>("b", make<PathCompare>(Operations::Eq, Constant::int64(2))),
+ make<Variable>("ptest")),
+ std::move(filterANode));
+
+ ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"ptest"}},
+ std::move(filterBNode));
+
+ OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase},
+ prefixId,
+ {{{"test", {{}, {}}}}},
+ DebugInfo::kDefaultForTests);
+ ABT latest = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(latest));
+
+ // Displays SargableNode-specific per-key estimates.
+ ASSERT_EXPLAIN_MEMO(
+ "Memo: \n"
+ " groupId: 0\n"
+ " | | Logical properties:\n"
+ " | | cardinalityEstimate: \n"
+ " | | ce: 1000\n"
+ " | | projections: \n"
+ " | | ptest\n"
+ " | | indexingAvailability: \n"
+ " | | [groupId: 0, scanProjection: ptest, scanDefName: test, "
+ "possiblyEqPredsOnly]\n"
+ " | | collectionAvailability: \n"
+ " | | test\n"
+ " | | distributionAvailability: \n"
+ " | | distribution: \n"
+ " | | type: Centralized\n"
+ " | logicalNodes: \n"
+ " | logicalNodeId: 0\n"
+ " | Scan [test]\n"
+ " | BindBlock:\n"
+ " | [ptest]\n"
+ " | Source []\n"
+ " physicalNodes: \n"
+ " groupId: 1\n"
+ " | | Logical properties:\n"
+ " | | cardinalityEstimate: \n"
+ " | | ce: 10\n"
+ " | | requirementCEs: \n"
+ " | | refProjection: ptest, path: 'PathGet [a] PathIdentity []', ce: "
+ "100\n"
+ " | | refProjection: ptest, path: 'PathGet [b] PathIdentity []', ce: "
+ "100\n"
+ " | | projections: \n"
+ " | | ptest\n"
+ " | | indexingAvailability: \n"
+ " | | [groupId: 0, scanProjection: ptest, scanDefName: test, "
+ "possiblyEqPredsOnly]\n"
+ " | | collectionAvailability: \n"
+ " | | test\n"
+ " | | distributionAvailability: \n"
+ " | | distribution: \n"
+ " | | type: Centralized\n"
+ " | logicalNodes: \n"
+ " | logicalNodeId: 0\n"
+ " | Sargable [Complete]\n"
+ " | | | | | requirementsMap: \n"
+ " | | | | | refProjection: ptest, path: 'PathGet [a] PathIdentity "
+ "[]', intervals: {{{[Const [1], Const [1]]}}}\n"
+ " | | | | | refProjection: ptest, path: 'PathGet [b] PathIdentity "
+ "[]', intervals: {{{[Const [2], Const [2]]}}}\n"
+ " | | | | candidateIndexes: \n"
+ " | | | BindBlock:\n"
+ " | | RefBlock: \n"
+ " | | Variable [ptest]\n"
+ " | MemoLogicalDelegator [groupId: 0]\n"
+ " physicalNodes: \n"
+ " groupId: 2\n"
+ " | | Logical properties:\n"
+ " | | cardinalityEstimate: \n"
+ " | | ce: 10\n"
+ " | | projections: \n"
+ " | | ptest\n"
+ " | | indexingAvailability: \n"
+ " | | [groupId: 0, scanProjection: ptest, scanDefName: test, "
+ "possiblyEqPredsOnly]\n"
+ " | | collectionAvailability: \n"
+ " | | test\n"
+ " | | distributionAvailability: \n"
+ " | | distribution: \n"
+ " | | type: Centralized\n"
+ " | logicalNodes: \n"
+ " | logicalNodeId: 0\n"
+ " | Root []\n"
+ " | | | projections: \n"
+ " | | | ptest\n"
+ " | | RefBlock: \n"
+ " | | Variable [ptest]\n"
+ " | MemoLogicalDelegator [groupId: 1]\n"
+ " physicalNodes: \n",
+ phaseManager.getMemo());
+}
+
+} // namespace
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/metadata.cpp b/src/mongo/db/query/optimizer/metadata.cpp
new file mode 100644
index 00000000000..d81cec8b64c
--- /dev/null
+++ b/src/mongo/db/query/optimizer/metadata.cpp
@@ -0,0 +1,161 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/metadata.h"
+
+#include "mongo/db/query/optimizer/node.h"
+
+
+namespace mongo::optimizer {
+
+DistributionAndPaths::DistributionAndPaths(DistributionType type)
+ : DistributionAndPaths(type, {}) {}
+
+DistributionAndPaths::DistributionAndPaths(DistributionType type, ABTVector paths)
+ : _type(type), _paths(std::move(paths)) {
+ uassert(6624080,
+ "Invalid distribution type",
+ _paths.empty() || _type == DistributionType::HashPartitioning ||
+ _type == DistributionType::RangePartitioning);
+}
+
+bool IndexCollationEntry::operator==(const IndexCollationEntry& other) const {
+ return _path == other._path && _op == other._op;
+}
+
+IndexCollationEntry::IndexCollationEntry(ABT path, CollationOp op)
+ : _path(std::move(path)), _op(op) {}
+
+IndexDefinition::IndexDefinition(IndexCollationSpec collationSpec, bool isMultiKey)
+ : IndexDefinition(std::move(collationSpec), isMultiKey, {DistributionType::Centralized}, {}) {}
+
+IndexDefinition::IndexDefinition(IndexCollationSpec collationSpec,
+ bool isMultiKey,
+ DistributionAndPaths distributionAndPaths,
+ PartialSchemaRequirements partialReqMap)
+ : IndexDefinition(std::move(collationSpec),
+ 2 /*version*/,
+ 0 /*orderingBits*/,
+ isMultiKey,
+ std::move(distributionAndPaths),
+ std::move(partialReqMap)) {}
+
+IndexDefinition::IndexDefinition(IndexCollationSpec collationSpec,
+ int64_t version,
+ uint32_t orderingBits,
+ bool isMultiKey,
+ DistributionAndPaths distributionAndPaths,
+ PartialSchemaRequirements partialReqMap)
+ : _collationSpec(std::move(collationSpec)),
+ _version(version),
+ _orderingBits(orderingBits),
+ _isMultiKey(isMultiKey),
+ _distributionAndPaths(distributionAndPaths),
+ _partialReqMap(std::move(partialReqMap)) {}
+
+const IndexCollationSpec& IndexDefinition::getCollationSpec() const {
+ return _collationSpec;
+}
+
+int64_t IndexDefinition::getVersion() const {
+ return _version;
+}
+
+uint32_t IndexDefinition::getOrdering() const {
+ return _orderingBits;
+}
+
+bool IndexDefinition::isMultiKey() const {
+ return _isMultiKey;
+}
+
+const DistributionAndPaths& IndexDefinition::getDistributionAndPaths() const {
+ return _distributionAndPaths;
+}
+
+const PartialSchemaRequirements& IndexDefinition::getPartialReqMap() const {
+ return _partialReqMap;
+}
+
+ScanDefinition::ScanDefinition() : ScanDefinition(OptionsMapType{}, {}) {}
+
+ScanDefinition::ScanDefinition(OptionsMapType options,
+ opt::unordered_map<std::string, IndexDefinition> indexDefs)
+ : ScanDefinition(std::move(options),
+ std::move(indexDefs),
+ {DistributionType::Centralized},
+ true /*exists*/) {}
+
+ScanDefinition::ScanDefinition(OptionsMapType options,
+ opt::unordered_map<std::string, IndexDefinition> indexDefs,
+ DistributionAndPaths distributionAndPaths,
+ const bool exists,
+ const CEType ce)
+ : _options(std::move(options)),
+ _distributionAndPaths(std::move(distributionAndPaths)),
+ _indexDefs(std::move(indexDefs)),
+ _exists(exists),
+ _ce(ce) {}
+
+const ScanDefinition::OptionsMapType& ScanDefinition::getOptionsMap() const {
+ return _options;
+}
+
+const DistributionAndPaths& ScanDefinition::getDistributionAndPaths() const {
+ return _distributionAndPaths;
+}
+
+const opt::unordered_map<std::string, IndexDefinition>& ScanDefinition::getIndexDefs() const {
+ return _indexDefs;
+}
+
+opt::unordered_map<std::string, IndexDefinition>& ScanDefinition::getIndexDefs() {
+ return _indexDefs;
+}
+
+bool ScanDefinition::exists() const {
+ return _exists;
+}
+
+CEType ScanDefinition::getCE() const {
+ return _ce;
+}
+
+Metadata::Metadata(opt::unordered_map<std::string, ScanDefinition> scanDefs)
+ : Metadata(std::move(scanDefs), 1 /*numberOfPartitions*/) {}
+
+Metadata::Metadata(opt::unordered_map<std::string, ScanDefinition> scanDefs,
+ size_t numberOfPartitions)
+ : _scanDefs(std::move(scanDefs)), _numberOfPartitions(numberOfPartitions) {}
+
+bool Metadata::isParallelExecution() const {
+ return _numberOfPartitions > 1;
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/metadata.h b/src/mongo/db/query/optimizer/metadata.h
new file mode 100644
index 00000000000..62f18327d40
--- /dev/null
+++ b/src/mongo/db/query/optimizer/metadata.h
@@ -0,0 +1,160 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <map>
+
+#include "mongo/db/query/optimizer/index_bounds.h"
+
+
+namespace mongo::optimizer {
+
+struct DistributionAndPaths {
+ DistributionAndPaths(DistributionType type);
+ DistributionAndPaths(DistributionType type, ABTVector paths);
+
+ DistributionType _type;
+ ABTVector _paths;
+};
+
+
+struct IndexCollationEntry {
+ IndexCollationEntry(ABT path, CollationOp op);
+
+ bool operator==(const IndexCollationEntry& other) const;
+
+ ABT _path;
+ CollationOp _op;
+};
+
+using IndexCollationSpec = std::vector<IndexCollationEntry>;
+
+/**
+ * Defines an available system index.
+ */
+class IndexDefinition {
+public:
+ // For testing.
+ IndexDefinition(IndexCollationSpec collationSpec, bool isMultiKey);
+
+ IndexDefinition(IndexCollationSpec collationSpec,
+ bool isMultiKey,
+ DistributionAndPaths distributionAndPaths,
+ PartialSchemaRequirements partialReqMap);
+
+ IndexDefinition(IndexCollationSpec collationSpec,
+ int64_t version,
+ uint32_t orderingBits,
+ bool isMultiKey,
+ DistributionAndPaths distributionAndPaths,
+ PartialSchemaRequirements partialReqMap);
+
+ const IndexCollationSpec& getCollationSpec() const;
+
+ int64_t getVersion() const;
+ uint32_t getOrdering() const;
+ bool isMultiKey() const;
+
+ const DistributionAndPaths& getDistributionAndPaths() const;
+
+ const PartialSchemaRequirements& getPartialReqMap() const;
+
+private:
+ const IndexCollationSpec _collationSpec;
+
+ const int64_t _version;
+ const uint32_t _orderingBits;
+ const bool _isMultiKey;
+
+ const DistributionAndPaths _distributionAndPaths;
+
+ // Requirements map for partial filter expression.
+ const PartialSchemaRequirements _partialReqMap;
+};
+
+// Used to specify parameters to scan node, such as collection name, or file where collection is
+// read from.
+class ScanDefinition {
+public:
+ using OptionsMapType = opt::unordered_map<std::string, std::string>;
+
+ ScanDefinition();
+ ScanDefinition(OptionsMapType options,
+ opt::unordered_map<std::string, IndexDefinition> indexDefs);
+ ScanDefinition(OptionsMapType options,
+ opt::unordered_map<std::string, IndexDefinition> indexDefs,
+ DistributionAndPaths distributionAndPaths,
+ bool exists = true,
+ CEType ce = -1.0);
+
+ const OptionsMapType& getOptionsMap() const;
+
+ const DistributionAndPaths& getDistributionAndPaths() const;
+
+ const opt::unordered_map<std::string, IndexDefinition>& getIndexDefs() const;
+ opt::unordered_map<std::string, IndexDefinition>& getIndexDefs();
+
+ bool exists() const;
+
+ CEType getCE() const;
+
+private:
+ OptionsMapType _options;
+ DistributionAndPaths _distributionAndPaths;
+
+ /**
+ * Indexes associated with this collection.
+ */
+ opt::unordered_map<std::string, IndexDefinition> _indexDefs;
+
+ /**
+ * True if the collection exists.
+ */
+ bool _exists;
+
+ // If positive, estimated number of docs in the collection.
+ CEType _ce;
+};
+
+struct Metadata {
+ Metadata(opt::unordered_map<std::string, ScanDefinition> scanDefs);
+ Metadata(opt::unordered_map<std::string, ScanDefinition> scanDefs, size_t numberOfPartitions);
+
+ opt::unordered_map<std::string, ScanDefinition> _scanDefs;
+
+ // Degree of parallelism.
+ size_t _numberOfPartitions;
+
+ bool isParallelExecution() const;
+
+ // TODO: generalize cluster spec.
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/node.cpp b/src/mongo/db/query/optimizer/node.cpp
new file mode 100644
index 00000000000..b3d656930c3
--- /dev/null
+++ b/src/mongo/db/query/optimizer/node.cpp
@@ -0,0 +1,775 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <functional>
+#include <stack>
+
+#include "mongo/db/query/optimizer/node.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer {
+
+/**
+ * A simple helper that creates a vector of Sources and binds names.
+ */
+static ABT buildSimpleBinder(const ProjectionNameVector& names) {
+ ABTVector sources;
+ for (size_t idx = 0; idx < names.size(); ++idx) {
+ sources.emplace_back(make<Source>());
+ }
+
+ return make<ExpressionBinder>(names, std::move(sources));
+}
+
+static ABT buildReferences(const ProjectionNameSet& projections) {
+ ABTVector variables;
+ ProjectionNameOrderedSet ordered = convertToOrderedSet(projections);
+ for (const ProjectionName& projection : ordered) {
+ variables.emplace_back(make<Variable>(projection));
+ }
+ return make<References>(std::move(variables));
+}
+
+ScanNode::ScanNode(ProjectionName projectionName, std::string scanDefName)
+ : Base(buildSimpleBinder({std::move(projectionName)})), _scanDefName(std::move(scanDefName)) {}
+
+const ProjectionName& ScanNode::getProjectionName() const {
+ return binder().names()[0];
+}
+
+const ProjectionType& ScanNode::getProjection() const {
+ return binder().exprs()[0];
+}
+
+const std::string& ScanNode::getScanDefName() const {
+ return _scanDefName;
+}
+
+bool ScanNode::operator==(const ScanNode& other) const {
+ return getProjectionName() == other.getProjectionName() && _scanDefName == other._scanDefName;
+}
+
+static ProjectionNameVector extractProjectionNamesForScan(
+ const FieldProjectionMap& fieldProjectionMap) {
+ ProjectionNameVector result;
+
+ if (!fieldProjectionMap._ridProjection.empty()) {
+ result.push_back(fieldProjectionMap._ridProjection);
+ }
+ if (!fieldProjectionMap._rootProjection.empty()) {
+ result.push_back(fieldProjectionMap._rootProjection);
+ }
+ for (const auto& entry : fieldProjectionMap._fieldProjections) {
+ result.push_back(entry.second);
+ }
+
+ return result;
+}
+
+PhysicalScanNode::PhysicalScanNode(FieldProjectionMap fieldProjectionMap,
+ std::string scanDefName,
+ bool useParallelScan)
+ : Base(buildSimpleBinder(extractProjectionNamesForScan(fieldProjectionMap))),
+ _fieldProjectionMap(std::move(fieldProjectionMap)),
+ _scanDefName(std::move(scanDefName)),
+ _useParallelScan(useParallelScan) {}
+
+bool PhysicalScanNode::operator==(const PhysicalScanNode& other) const {
+ return _fieldProjectionMap == other._fieldProjectionMap && _scanDefName == other._scanDefName &&
+ _useParallelScan == other._useParallelScan;
+}
+
+const FieldProjectionMap& PhysicalScanNode::getFieldProjectionMap() const {
+ return _fieldProjectionMap;
+}
+
+const std::string& PhysicalScanNode::getScanDefName() const {
+ return _scanDefName;
+}
+
+bool PhysicalScanNode::useParallelScan() const {
+ return _useParallelScan;
+}
+
+ValueScanNode::ValueScanNode(ProjectionNameVector projections)
+ : ValueScanNode(std::move(projections), Constant::emptyArray()) {}
+
+ValueScanNode::ValueScanNode(ProjectionNameVector projections, ABT valueArray)
+ : Base(buildSimpleBinder(std::move(projections))), _valueArray(std::move(valueArray)) {
+ const auto constPtr = _valueArray.cast<Constant>();
+ uassert(6624081, "Expected a constant", constPtr != nullptr);
+
+ const auto [tag, val] = constPtr->get();
+ uassert(6624082, "Expected an array constant.", tag == sbe::value::TypeTags::Array);
+
+ const auto arr = sbe::value::getArrayView(val);
+ _arraySize = arr->size();
+ const size_t projectionCount = binder().names().size();
+ for (size_t i = 0; i < _arraySize; i++) {
+ const auto [tag1, val1] = arr->getAt(i);
+ uassert(6624083, "Expected an array element.", tag1 == sbe::value::TypeTags::Array);
+ const size_t innerSize = sbe::value::getArrayView(val1)->size();
+ uassert(6624084, "Invalid array size.", innerSize == projectionCount);
+ }
+}
+
+bool ValueScanNode::operator==(const ValueScanNode& other) const {
+ return binder() == other.binder() && _arraySize == other._arraySize &&
+ _valueArray == other._valueArray;
+}
+
+const ABT& ValueScanNode::getValueArray() const {
+ return _valueArray;
+}
+
+size_t ValueScanNode::getArraySize() const {
+ return _arraySize;
+}
+
+CoScanNode::CoScanNode() : Base() {}
+
+bool CoScanNode::operator==(const CoScanNode& other) const {
+ return true;
+}
+
+IndexScanNode::IndexScanNode(FieldProjectionMap fieldProjectionMap, IndexSpecification indexSpec)
+ : Base(buildSimpleBinder(extractProjectionNamesForScan(fieldProjectionMap))),
+ _fieldProjectionMap(std::move(fieldProjectionMap)),
+ _indexSpec(std::move(indexSpec)) {}
+
+bool IndexScanNode::operator==(const IndexScanNode& other) const {
+ // Scan spec does not participate, the indexSpec by itself should determine equality.
+ return _fieldProjectionMap == other._fieldProjectionMap && _indexSpec == other._indexSpec;
+}
+
+const FieldProjectionMap& IndexScanNode::getFieldProjectionMap() const {
+ return _fieldProjectionMap;
+}
+
+const IndexSpecification& IndexScanNode::getIndexSpecification() const {
+ return _indexSpec;
+}
+
+SeekNode::SeekNode(ProjectionName ridProjectionName,
+ FieldProjectionMap fieldProjectionMap,
+ std::string scanDefName)
+ : Base(buildSimpleBinder(extractProjectionNamesForScan(fieldProjectionMap)),
+ make<References>(ProjectionNameVector{ridProjectionName})),
+ _ridProjectionName(std::move(ridProjectionName)),
+ _fieldProjectionMap(std::move(fieldProjectionMap)),
+ _scanDefName(std::move(scanDefName)) {}
+
+bool SeekNode::operator==(const SeekNode& other) const {
+ return _ridProjectionName == other._ridProjectionName &&
+ _fieldProjectionMap == other._fieldProjectionMap && _scanDefName == other._scanDefName;
+}
+
+const FieldProjectionMap& SeekNode::getFieldProjectionMap() const {
+ return _fieldProjectionMap;
+}
+
+const std::string& SeekNode::getScanDefName() const {
+ return _scanDefName;
+}
+
+const ProjectionName& SeekNode::getRIDProjectionName() const {
+ return _ridProjectionName;
+}
+
+MemoLogicalDelegatorNode::MemoLogicalDelegatorNode(const GroupIdType groupId)
+ : Base(), _groupId(groupId) {}
+
+GroupIdType MemoLogicalDelegatorNode::getGroupId() const {
+ return _groupId;
+}
+
+bool MemoLogicalDelegatorNode::operator==(const MemoLogicalDelegatorNode& other) const {
+ return _groupId == other._groupId;
+}
+
+MemoPhysicalDelegatorNode::MemoPhysicalDelegatorNode(const MemoPhysicalNodeId nodeId)
+ : Base(), _nodeId(nodeId) {}
+
+bool MemoPhysicalDelegatorNode::operator==(const MemoPhysicalDelegatorNode& other) const {
+ return _nodeId == other._nodeId;
+}
+
+MemoPhysicalNodeId MemoPhysicalDelegatorNode::getNodeId() const {
+ return _nodeId;
+}
+
+FilterNode::FilterNode(FilterType filter, ABT child) : Base(std::move(child), std::move(filter)) {
+ assertExprSort(getFilter());
+ assertNodeSort(getChild());
+}
+
+bool FilterNode::operator==(const FilterNode& other) const {
+ return getFilter() == other.getFilter() && getChild() == other.getChild();
+}
+
+const FilterType& FilterNode::getFilter() const {
+ return get<1>();
+}
+
+FilterType& FilterNode::getFilter() {
+ return get<1>();
+}
+
+const ABT& FilterNode::getChild() const {
+ return get<0>();
+}
+
+ABT& FilterNode::getChild() {
+ return get<0>();
+}
+
+EvaluationNode::EvaluationNode(ProjectionName projectionName, ProjectionType projection, ABT child)
+ : Base(std::move(child),
+ make<ExpressionBinder>(std::move(projectionName), std::move(projection))) {
+ assertNodeSort(getChild());
+}
+
+bool EvaluationNode::operator==(const EvaluationNode& other) const {
+ return binder() == other.binder() && getProjection() == other.getProjection() &&
+ getChild() == other.getChild();
+}
+
+RIDIntersectNode::RIDIntersectNode(ProjectionName scanProjectionName,
+ const bool hasLeftIntervals,
+ const bool hasRightIntervals,
+ ABT leftChild,
+ ABT rightChild)
+ : Base(std::move(leftChild), std::move(rightChild)),
+ _scanProjectionName(std::move(scanProjectionName)),
+ _hasLeftIntervals(hasLeftIntervals),
+ _hasRightIntervals(hasRightIntervals) {
+ assertNodeSort(getLeftChild());
+ assertNodeSort(getRightChild());
+}
+
+const ABT& RIDIntersectNode::getLeftChild() const {
+ return get<0>();
+}
+
+ABT& RIDIntersectNode::getLeftChild() {
+ return get<0>();
+}
+
+const ABT& RIDIntersectNode::getRightChild() const {
+ return get<1>();
+}
+
+ABT& RIDIntersectNode::getRightChild() {
+ return get<1>();
+}
+
+bool RIDIntersectNode::operator==(const RIDIntersectNode& other) const {
+ return _scanProjectionName == other._scanProjectionName &&
+ _hasLeftIntervals == other._hasLeftIntervals &&
+ _hasRightIntervals == other._hasRightIntervals && getLeftChild() == other.getLeftChild() &&
+ getRightChild() == other.getRightChild();
+}
+
+const ProjectionName& RIDIntersectNode::getScanProjectionName() const {
+ return _scanProjectionName;
+}
+
+bool RIDIntersectNode::hasLeftIntervals() const {
+ return _hasLeftIntervals;
+}
+
+bool RIDIntersectNode::hasRightIntervals() const {
+ return _hasRightIntervals;
+}
+
+static ProjectionNameVector createSargableBindings(const PartialSchemaRequirements& reqMap) {
+ ProjectionNameVector result;
+ for (const auto& entry : reqMap) {
+ if (entry.second.hasBoundProjectionName()) {
+ result.push_back(entry.second.getBoundProjectionName());
+ }
+ }
+ return result;
+}
+
+static ProjectionNameVector createSargableReferences(const PartialSchemaRequirements& reqMap) {
+ ProjectionNameOrderPreservingSet result;
+ for (const auto& entry : reqMap) {
+ result.emplace_back(entry.first._projectionName);
+ }
+ return result.getVector();
+}
+
+SargableNode::SargableNode(PartialSchemaRequirements reqMap,
+ CandidateIndexMap candidateIndexMap,
+ const IndexReqTarget target,
+ ABT child)
+ : Base(std::move(child),
+ buildSimpleBinder(createSargableBindings(reqMap)),
+ make<References>(createSargableReferences(reqMap))),
+ _reqMap(std::move(reqMap)),
+ _candidateIndexMap(std::move(candidateIndexMap)),
+ _target(target) {
+ assertNodeSort(getChild());
+ uassert(6624085, "Empty requirements map", !_reqMap.empty());
+ // We currently use a 64-bit mask when splitting into left and right requirements.
+ uassert(6624086, "Requirements map too large", _reqMap.size() < 64);
+
+ // Assert merged map does not contain duplicate bound projections.
+ ProjectionNameSet boundsProjectionNameSet;
+ for (const auto& entry : _reqMap) {
+ if (entry.second.hasBoundProjectionName() &&
+ !boundsProjectionNameSet.insert(entry.second.getBoundProjectionName()).second) {
+ uasserted(6624087, "Duplicate bound projection");
+ }
+ }
+
+ // Assert there are no references to internally bound projections.
+ for (const auto& entry : _reqMap) {
+ if (boundsProjectionNameSet.find(entry.first._projectionName) !=
+ boundsProjectionNameSet.cend()) {
+ uasserted(6624088, "We are binding to an internal projection");
+ }
+ }
+}
+
+bool SargableNode::operator==(const SargableNode& other) const {
+ return _reqMap == other._reqMap && _candidateIndexMap == other._candidateIndexMap &&
+ _target == other._target && getChild() == other.getChild();
+}
+
+const PartialSchemaRequirements& SargableNode::getReqMap() const {
+ return _reqMap;
+}
+
+const CandidateIndexMap& SargableNode::getCandidateIndexMap() const {
+ return _candidateIndexMap;
+}
+
+IndexReqTarget SargableNode::getTarget() const {
+ return _target;
+}
+
+BinaryJoinNode::BinaryJoinNode(JoinType joinType,
+ ProjectionNameSet correlatedProjectionNames,
+ FilterType filter,
+ ABT leftChild,
+ ABT rightChild)
+ : Base(std::move(leftChild), std::move(rightChild), std::move(filter)),
+ _joinType(joinType),
+ _correlatedProjectionNames(std::move(correlatedProjectionNames)) {
+ assertExprSort(getFilter());
+ assertNodeSort(getLeftChild());
+ assertNodeSort(getRightChild());
+}
+
+JoinType BinaryJoinNode::getJoinType() const {
+ return _joinType;
+}
+
+const ProjectionNameSet& BinaryJoinNode::getCorrelatedProjectionNames() const {
+ return _correlatedProjectionNames;
+}
+
+bool BinaryJoinNode::operator==(const BinaryJoinNode& other) const {
+ return _joinType == other._joinType &&
+ _correlatedProjectionNames == other._correlatedProjectionNames &&
+ getLeftChild() == other.getLeftChild() && getRightChild() == other.getRightChild();
+}
+
+const ABT& BinaryJoinNode::getLeftChild() const {
+ return get<0>();
+}
+
+ABT& BinaryJoinNode::getLeftChild() {
+ return get<0>();
+}
+
+const ABT& BinaryJoinNode::getRightChild() const {
+ return get<1>();
+}
+
+ABT& BinaryJoinNode::getRightChild() {
+ return get<1>();
+}
+
+const ABT& BinaryJoinNode::getFilter() const {
+ return get<2>();
+}
+
+static ABT buildHashJoinReferences(const ProjectionNameVector& leftKeys,
+ const ProjectionNameVector& rightKeys) {
+ ABTVector variables;
+ for (const ProjectionName& projection : leftKeys) {
+ variables.emplace_back(make<Variable>(projection));
+ }
+ for (const ProjectionName& projection : rightKeys) {
+ variables.emplace_back(make<Variable>(projection));
+ }
+
+ return make<References>(std::move(variables));
+}
+
+HashJoinNode::HashJoinNode(JoinType joinType,
+ ProjectionNameVector leftKeys,
+ ProjectionNameVector rightKeys,
+ ABT leftChild,
+ ABT rightChild)
+ : Base(std::move(leftChild),
+ std::move(rightChild),
+ buildHashJoinReferences(leftKeys, rightKeys)),
+ _joinType(joinType),
+ _leftKeys(std::move(leftKeys)),
+ _rightKeys(std::move(rightKeys)) {
+ uassert(
+ 6624089, "Invalid key sizes", !_leftKeys.empty() && _leftKeys.size() == _rightKeys.size());
+ assertNodeSort(getLeftChild());
+ assertNodeSort(getRightChild());
+}
+
+bool HashJoinNode::operator==(const HashJoinNode& other) const {
+ return _joinType == other._joinType && _leftKeys == other._leftKeys &&
+ _rightKeys == other._rightKeys && getLeftChild() == other.getLeftChild() &&
+ getRightChild() == other.getRightChild();
+}
+
+JoinType HashJoinNode::getJoinType() const {
+ return _joinType;
+}
+
+const ProjectionNameVector& HashJoinNode::getLeftKeys() const {
+ return _leftKeys;
+}
+
+const ProjectionNameVector& HashJoinNode::getRightKeys() const {
+ return _rightKeys;
+}
+
+const ABT& HashJoinNode::getLeftChild() const {
+ return get<0>();
+}
+
+ABT& HashJoinNode::getLeftChild() {
+ return get<0>();
+}
+
+const ABT& HashJoinNode::getRightChild() const {
+ return get<1>();
+}
+
+ABT& HashJoinNode::getRightChild() {
+ return get<1>();
+}
+
+MergeJoinNode::MergeJoinNode(ProjectionNameVector leftKeys,
+ ProjectionNameVector rightKeys,
+ std::vector<CollationOp> collation,
+ ABT leftChild,
+ ABT rightChild)
+ : Base(std::move(leftChild),
+ std::move(rightChild),
+ buildHashJoinReferences(leftKeys, rightKeys)),
+ _collation(std::move(collation)),
+ _leftKeys(std::move(leftKeys)),
+ _rightKeys(std::move(rightKeys)) {
+ uassert(
+ 6624090, "Invalid key sizes", !_leftKeys.empty() && _leftKeys.size() == _rightKeys.size());
+ uassert(6624091, "Invalid collation size", _collation.size() == _leftKeys.size());
+ assertNodeSort(getLeftChild());
+ assertNodeSort(getRightChild());
+}
+
+bool MergeJoinNode::operator==(const MergeJoinNode& other) const {
+ return _leftKeys == other._leftKeys && _rightKeys == other._rightKeys &&
+ _collation == other._collation && getLeftChild() == other.getLeftChild() &&
+ getRightChild() == other.getRightChild();
+}
+
+const ProjectionNameVector& MergeJoinNode::getLeftKeys() const {
+ return _leftKeys;
+}
+
+const ProjectionNameVector& MergeJoinNode::getRightKeys() const {
+ return _rightKeys;
+}
+
+const std::vector<CollationOp>& MergeJoinNode::getCollation() const {
+ return _collation;
+}
+
+const ABT& MergeJoinNode::getLeftChild() const {
+ return get<0>();
+}
+
+ABT& MergeJoinNode::getLeftChild() {
+ return get<0>();
+}
+
+const ABT& MergeJoinNode::getRightChild() const {
+ return get<1>();
+}
+
+ABT& MergeJoinNode::getRightChild() {
+ return get<1>();
+}
+
+/**
+ * A helper that builds References object of UnionNode for reference tracking purposes.
+ *
+ * Example: union outputs 3 projections: A,B,C and it has 4 children. Then the References object is
+ * a vector of variables A,B,C,A,B,C,A,B,C,A,B,C. One group of variables per child.
+ */
+static ABT buildUnionReferences(const ProjectionNameVector& names, const size_t numOfChildren) {
+ ABTVector variables;
+ for (size_t outerIdx = 0; outerIdx < numOfChildren; ++outerIdx) {
+ for (size_t idx = 0; idx < names.size(); ++idx) {
+ variables.emplace_back(make<Variable>(names[idx]));
+ }
+ }
+
+ return make<References>(std::move(variables));
+}
+
+UnionNode::UnionNode(ProjectionNameVector unionProjectionNames, ABTVector children)
+ : Base(std::move(children),
+ buildSimpleBinder(unionProjectionNames),
+ buildUnionReferences(unionProjectionNames, children.size())) {
+ uassert(6624007, "Empty union", !unionProjectionNames.empty());
+
+ for (auto& n : nodes()) {
+ assertNodeSort(n);
+ }
+}
+
+bool UnionNode::operator==(const UnionNode& other) const {
+ return binder() == other.binder() && nodes() == other.nodes();
+}
+
+GroupByNode::GroupByNode(ProjectionNameVector groupByProjectionNames,
+ ProjectionNameVector aggregationProjectionNames,
+ ABTVector aggregationExpressions,
+ ABT child)
+ : GroupByNode(std::move(groupByProjectionNames),
+ std::move(aggregationProjectionNames),
+ std::move(aggregationExpressions),
+ GroupNodeType::Complete,
+ std::move(child)) {}
+
+GroupByNode::GroupByNode(ProjectionNameVector groupByProjectionNames,
+ ProjectionNameVector aggregationProjectionNames,
+ ABTVector aggregationExpressions,
+ GroupNodeType type,
+ ABT child)
+ : Base(std::move(child),
+ buildSimpleBinder(aggregationProjectionNames),
+ make<References>(std::move(aggregationExpressions)),
+ buildSimpleBinder(groupByProjectionNames),
+ make<References>(groupByProjectionNames)),
+ _type(type) {
+ assertNodeSort(getChild());
+ uassert(6624300,
+ "Mismatched number of agg expressions and names",
+ getAggregationExpressions().size() == getAggregationProjectionNames().size());
+}
+
+bool GroupByNode::operator==(const GroupByNode& other) const {
+ return getAggregationProjectionNames() == other.getAggregationProjectionNames() &&
+ getAggregationProjections() == other.getAggregationProjections() &&
+ getGroupByProjectionNames() == other.getGroupByProjectionNames() && _type == other._type &&
+ getChild() == other.getChild();
+}
+
+const ABTVector& GroupByNode::getAggregationExpressions() const {
+ return get<2>().cast<References>()->nodes();
+}
+
+const ABT& GroupByNode::getChild() const {
+ return get<0>();
+}
+
+ABT& GroupByNode::getChild() {
+ return get<0>();
+}
+
+GroupNodeType GroupByNode::getType() const {
+ return _type;
+}
+
+UnwindNode::UnwindNode(ProjectionName projectionName,
+ ProjectionName pidProjectionName,
+ const bool retainNonArrays,
+ ABT child)
+ : Base(std::move(child),
+ buildSimpleBinder(ProjectionNameVector{projectionName, std::move(pidProjectionName)}),
+ make<References>(ProjectionNameVector{projectionName})),
+ _retainNonArrays(retainNonArrays) {
+ assertNodeSort(getChild());
+}
+
+bool UnwindNode::getRetainNonArrays() const {
+ return _retainNonArrays;
+}
+
+const ABT& UnwindNode::getChild() const {
+ return get<0>();
+}
+
+ABT& UnwindNode::getChild() {
+ return get<0>();
+}
+
+bool UnwindNode::operator==(const UnwindNode& other) const {
+ return binder() == other.binder() && _retainNonArrays == other._retainNonArrays &&
+ getChild() == other.getChild();
+}
+
+UniqueNode::UniqueNode(ProjectionNameVector projections, ABT child)
+ : Base(std::move(child), make<References>(ProjectionNameVector{projections})),
+ _projections(std::move(projections)) {
+ assertNodeSort(getChild());
+ uassert(6624092, "Empty projections", !_projections.empty());
+}
+
+bool UniqueNode::operator==(const UniqueNode& other) const {
+ return _projections == other._projections;
+}
+
+const ProjectionNameVector& UniqueNode::getProjections() const {
+ return _projections;
+}
+
+const ABT& UniqueNode::getChild() const {
+ return get<0>();
+}
+
+CollationNode::CollationNode(properties::CollationRequirement property, ABT child)
+ : Base(std::move(child),
+ buildReferences(extractReferencedColumns(properties::makePhysProps(property)))),
+ _property(std::move(property)) {
+ assertNodeSort(getChild());
+}
+
+const properties::CollationRequirement& CollationNode::getProperty() const {
+ return _property;
+}
+
+properties::CollationRequirement& CollationNode::getProperty() {
+ return _property;
+}
+
+bool CollationNode::operator==(const CollationNode& other) const {
+ return _property == other._property && getChild() == other.getChild();
+}
+
+const ABT& CollationNode::getChild() const {
+ return get<0>();
+}
+
+ABT& CollationNode::getChild() {
+ return get<0>();
+}
+
+LimitSkipNode::LimitSkipNode(properties::LimitSkipRequirement property, ABT child)
+ : Base(std::move(child)), _property(std::move(property)) {
+ assertNodeSort(getChild());
+}
+
+const properties::LimitSkipRequirement& LimitSkipNode::getProperty() const {
+ return _property;
+}
+
+properties::LimitSkipRequirement& LimitSkipNode::getProperty() {
+ return _property;
+}
+
+bool LimitSkipNode::operator==(const LimitSkipNode& other) const {
+ return _property == other._property && getChild() == other.getChild();
+}
+
+const ABT& LimitSkipNode::getChild() const {
+ return get<0>();
+}
+
+ABT& LimitSkipNode::getChild() {
+ return get<0>();
+}
+
+ExchangeNode::ExchangeNode(const properties::DistributionRequirement distribution, ABT child)
+ : Base(std::move(child), buildReferences(distribution.getAffectedProjectionNames())),
+ _distribution(std::move(distribution)) {
+ assertNodeSort(getChild());
+ uassert(6624008,
+ "Cannot exchange towards an unknown distribution",
+ distribution.getDistributionAndProjections()._type !=
+ DistributionType::UnknownPartitioning);
+}
+
+bool ExchangeNode::operator==(const ExchangeNode& other) const {
+ return _distribution == other._distribution && getChild() == other.getChild();
+}
+
+const ABT& ExchangeNode::getChild() const {
+ return get<0>();
+}
+
+ABT& ExchangeNode::getChild() {
+ return get<0>();
+}
+
+const properties::DistributionRequirement& ExchangeNode::getProperty() const {
+ return _distribution;
+}
+
+properties::DistributionRequirement& ExchangeNode::getProperty() {
+ return _distribution;
+}
+
+RootNode::RootNode(properties::ProjectionRequirement property, ABT child)
+ : Base(std::move(child), buildReferences(property.getAffectedProjectionNames())),
+ _property(std::move(property)) {
+ assertNodeSort(getChild());
+}
+
+bool RootNode::operator==(const RootNode& other) const {
+ return getChild() == other.getChild() && _property == other._property;
+}
+
+const properties::ProjectionRequirement& RootNode::getProperty() const {
+ return _property;
+}
+
+const ABT& RootNode::getChild() const {
+ return get<0>();
+}
+
+ABT& RootNode::getChild() {
+ return get<0>();
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/node.h b/src/mongo/db/query/optimizer/node.h
new file mode 100644
index 00000000000..1a202140f6e
--- /dev/null
+++ b/src/mongo/db/query/optimizer/node.h
@@ -0,0 +1,862 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <memory>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "mongo/db/query/optimizer/algebra/operator.h"
+#include "mongo/db/query/optimizer/defs.h"
+#include "mongo/db/query/optimizer/metadata.h"
+#include "mongo/db/query/optimizer/props.h"
+#include "mongo/db/query/optimizer/syntax/expr.h"
+#include "mongo/db/query/optimizer/syntax/path.h"
+
+
+namespace mongo::optimizer {
+
+using FilterType = ABT;
+using ProjectionType = ABT;
+
+/**
+ * Marker for node class (both logical and physical sub-classes).
+ * A node not marked with either LogicalNode or PhysicalNode is considered to be both a logical and
+ * a physical node (e.g. a filter node). It is invalid to mark a node with both tags in the same
+ * time.
+ */
+class Node {};
+
+/**
+ * Marker for exclusively logical nodes.
+ */
+class LogicalNode {};
+
+/**
+ * Marker for exclusively physical nodes.
+ */
+class PhysicalNode {};
+
+inline void assertNodeSort(const ABT& e) {
+ if (!e.is<Node>()) {
+ uasserted(6624009, "Node syntax sort expected");
+ }
+}
+
+template <class T>
+inline constexpr bool canBeLogicalNode() {
+ // Node which is not exclusively physical.
+ return std::is_base_of_v<Node, T> && !std::is_base_of_v<PhysicalNode, T>;
+}
+
+template <class T>
+inline constexpr bool canBePhysicalNode() {
+ // Node which is not exclusively logical.
+ return std::is_base_of_v<Node, T> && !std::is_base_of_v<LogicalNode, T>;
+}
+
+/**
+ * Logical Scan node.
+ * It defines scanning a collection with an optional projection name that contains the documents.
+ * The collection is specified via the scanDefName entry in the metadata.
+ */
+class ScanNode final : public Operator<ScanNode, 1>, public Node, public LogicalNode {
+ using Base = Operator<ScanNode, 1>;
+
+public:
+ static constexpr const char* kDefaultCollectionNameSpec = "collectionName";
+
+ ScanNode(ProjectionName projectionName, std::string scanDefName);
+
+ bool operator==(const ScanNode& other) const;
+
+ const ExpressionBinder& binder() const {
+ const ABT& result = get<0>();
+ uassert(6624010, "Invalid binder type", result.is<ExpressionBinder>());
+ return *result.cast<ExpressionBinder>();
+ }
+
+ const ProjectionName& getProjectionName() const;
+ const ProjectionType& getProjection() const;
+
+ const std::string& getScanDefName() const;
+
+private:
+ const std::string _scanDefName;
+};
+
+/**
+ * Physical Scan node.
+ * It defines scanning a collection with an optional projection name that contains the documents.
+ * The collection is specified via the scanDefName entry in the metadata.
+ *
+ * Optionally set of fields is specified to retrieve from the underlying collection, and expose as
+ * projections.
+ */
+class PhysicalScanNode final : public Operator<PhysicalScanNode, 1>,
+ public Node,
+ public PhysicalNode {
+ using Base = Operator<PhysicalScanNode, 1>;
+
+public:
+ PhysicalScanNode(FieldProjectionMap fieldProjectionMap,
+ std::string scanDefName,
+ bool useParallelScan);
+
+ bool operator==(const PhysicalScanNode& other) const;
+
+ const ExpressionBinder& binder() const {
+ const ABT& result = get<0>();
+ uassert(6624011, "Invalid binder type", result.is<ExpressionBinder>());
+ return *result.cast<ExpressionBinder>();
+ }
+
+ const FieldProjectionMap& getFieldProjectionMap() const;
+
+ const std::string& getScanDefName() const;
+
+ bool useParallelScan() const;
+
+private:
+ const FieldProjectionMap _fieldProjectionMap;
+ const std::string _scanDefName;
+ const bool _useParallelScan;
+};
+
+/**
+ * Logical ValueScanNode.
+ *
+ * It originates a set of projections each with a fixed
+ * sequence of values, which is encoded as an array.
+ */
+class ValueScanNode final : public Operator<ValueScanNode, 1>, public Node, public LogicalNode {
+ using Base = Operator<ValueScanNode, 1>;
+
+public:
+ ValueScanNode(ProjectionNameVector projections);
+ ValueScanNode(ProjectionNameVector projections, ABT valueArray);
+
+ bool operator==(const ValueScanNode& other) const;
+
+ const ExpressionBinder& binder() const {
+ const ABT& result = get<0>();
+ uassert(6624012, "Invalid binder type", result.is<ExpressionBinder>());
+ return *result.cast<ExpressionBinder>();
+ }
+
+ const ABT& getValueArray() const;
+ size_t getArraySize() const;
+
+private:
+ const ABT _valueArray;
+ size_t _arraySize;
+};
+
+/**
+ * Physical CoScanNode.
+ *
+ * Conceptually it originates an infinite stream of Nothing.
+ * A typical use case is to limit it to one document, and attach projections with a following
+ * EvaluationNode(s).
+ */
+class CoScanNode final : public Operator<CoScanNode, 0>, public Node, public PhysicalNode {
+ using Base = Operator<CoScanNode, 0>;
+
+public:
+ CoScanNode();
+
+ bool operator==(const CoScanNode& other) const;
+};
+
+/**
+ * Index scan node.
+ * Retrieve data using an index. Return recordIds or values (if the index is covering).
+ * This is a physical node.
+ *
+ * The collection is specified by scanDef, and the index by the indexDef.
+ */
+class IndexScanNode final : public Operator<IndexScanNode, 1>, public Node, public PhysicalNode {
+ using Base = Operator<IndexScanNode, 1>;
+
+public:
+ IndexScanNode(FieldProjectionMap fieldProjectionMap, IndexSpecification indexSpec);
+
+ bool operator==(const IndexScanNode& other) const;
+
+ const ExpressionBinder& binder() const {
+ const ABT& result = get<0>();
+ uassert(6624013, "Invalid binder type", result.is<ExpressionBinder>());
+ return *result.cast<ExpressionBinder>();
+ }
+
+ const FieldProjectionMap& getFieldProjectionMap() const;
+ const IndexSpecification& getIndexSpecification() const;
+
+private:
+ const FieldProjectionMap _fieldProjectionMap;
+ const IndexSpecification _indexSpec;
+};
+
+/**
+ * SeekNode.
+ * Retrieve values using rowIds (typically previously retrieved using an index scan).
+ * This is a physical node.
+ *
+ * 'ridProjectionName' parameter designates the incoming rid which is the starting point of the
+ * seek. 'fieldProjectionMap' may choose to include an outgoing rid which will contain the
+ * successive (if we do not have a following limit) document ids.
+ *
+ * TODO: Can we let it advance with a limit based on upper rid limit in case of primary index?
+ */
+class SeekNode final : public Operator<SeekNode, 2>, public Node, public PhysicalNode {
+ using Base = Operator<SeekNode, 2>;
+
+public:
+ SeekNode(ProjectionName ridProjectionName,
+ FieldProjectionMap fieldProjectionMap,
+ std::string scanDefName);
+
+ bool operator==(const SeekNode& other) const;
+
+ const ExpressionBinder& binder() const {
+ const ABT& result = get<0>();
+ uassert(6624014, "Invalid binder type", result.is<ExpressionBinder>());
+ return *result.cast<ExpressionBinder>();
+ }
+
+ const ProjectionName& getRIDProjectionName() const;
+
+ const FieldProjectionMap& getFieldProjectionMap() const;
+
+ const std::string& getScanDefName() const;
+
+private:
+ const ProjectionName _ridProjectionName;
+ const FieldProjectionMap _fieldProjectionMap;
+ const std::string _scanDefName;
+};
+
+
+/**
+ * Logical group delegator node: scan from a given group.
+ * Used in conjunction with memo.
+ */
+class MemoLogicalDelegatorNode final : public Operator<MemoLogicalDelegatorNode, 0>,
+ public Node,
+ public LogicalNode {
+ using Base = Operator<MemoLogicalDelegatorNode, 0>;
+
+public:
+ MemoLogicalDelegatorNode(GroupIdType groupId);
+
+ bool operator==(const MemoLogicalDelegatorNode& other) const;
+
+ GroupIdType getGroupId() const;
+
+private:
+ const GroupIdType _groupId;
+};
+
+/**
+ * Physical group delegator node: refer to a physical node in a memo group.
+ * Used in conjunction with memo.
+ */
+class MemoPhysicalDelegatorNode final : public Operator<MemoPhysicalDelegatorNode, 0>,
+ public Node,
+ public PhysicalNode {
+ using Base = Operator<MemoPhysicalDelegatorNode, 0>;
+
+public:
+ MemoPhysicalDelegatorNode(MemoPhysicalNodeId nodeId);
+
+ bool operator==(const MemoPhysicalDelegatorNode& other) const;
+
+ MemoPhysicalNodeId getNodeId() const;
+
+private:
+ const MemoPhysicalNodeId _nodeId;
+};
+
+/**
+ * Filter node.
+ * It applies a filter over its input.
+ *
+ * This node is both logical and physical.
+ */
+class FilterNode final : public Operator<FilterNode, 2>, public Node {
+ using Base = Operator<FilterNode, 2>;
+
+public:
+ FilterNode(FilterType filter, ABT child);
+
+ bool operator==(const FilterNode& other) const;
+
+ const FilterType& getFilter() const;
+ FilterType& getFilter();
+
+ const ABT& getChild() const;
+ ABT& getChild();
+};
+
+/**
+ * Evaluation node.
+ * Adds a new projection to its input.
+ *
+ * This node is both logical and physical.
+ */
+class EvaluationNode final : public Operator<EvaluationNode, 2>, public Node {
+ using Base = Operator<EvaluationNode, 2>;
+
+public:
+ EvaluationNode(ProjectionName projectionName, ProjectionType projection, ABT child);
+
+ bool operator==(const EvaluationNode& other) const;
+
+ const ExpressionBinder& binder() const {
+ const ABT& result = get<1>();
+ uassert(6624015, "Invalid binder type", result.is<ExpressionBinder>());
+ return *result.cast<ExpressionBinder>();
+ }
+
+ const ProjectionName& getProjectionName() const {
+ return binder().names()[0];
+ }
+
+ const ProjectionType& getProjection() const {
+ return binder().exprs()[0];
+ }
+
+ const ABT& getChild() const {
+ return get<0>();
+ }
+
+ ABT& getChild() {
+ return get<0>();
+ }
+};
+
+/**
+ * RID intersection node.
+ * This is a logical node representing either index-index intersection or index-collection scan
+ * (seek) fetch.
+ *
+ * It is equivalent to a join node with the difference that RID projections do not exist on logical
+ * level, and thus projection names are not determined until physical optimization. We want to also
+ * restrict the type of operations on RIDs (in this case only set intersection) as opposed to say
+ * filter on rid = 5.
+ */
+class RIDIntersectNode final : public Operator<RIDIntersectNode, 2>,
+ public Node,
+ public LogicalNode {
+ using Base = Operator<RIDIntersectNode, 2>;
+
+public:
+ RIDIntersectNode(ProjectionName scanProjectionName,
+ bool hasLeftIntervals,
+ bool hasRightIntervals,
+ ABT leftChild,
+ ABT rightChild);
+
+ bool operator==(const RIDIntersectNode& other) const;
+
+ const ABT& getLeftChild() const;
+ ABT& getLeftChild();
+
+ const ABT& getRightChild() const;
+ ABT& getRightChild();
+
+ const ProjectionName& getScanProjectionName() const;
+
+ bool hasLeftIntervals() const;
+ bool hasRightIntervals() const;
+
+private:
+ const ProjectionName _scanProjectionName;
+
+ // If true left and right children have at least one proper interval (not fully open).
+ const bool _hasLeftIntervals;
+ const bool _hasRightIntervals;
+};
+
+/**
+ * Sargable node.
+ * This is a logical node which represents special kinds of (simple) evaluations and filters which
+ * are amenable to being used in indexing or covered scans.
+ *
+ * It collects a conjunction of predicates in the following form:
+ * <path, inputProjection> -> <interval, outputProjection>
+
+ * For example to encode a conjunction which encodes filtering with array traversal on "a"
+ ($match(a: {$gt, 1}} combined with a retrieval of the field "b" (without restrictions on its
+ value).
+ * PathGet "a" Traverse Id | scan_0 -> [1, +inf], <none>
+ * PathGet "b" Id | scan_0 -> (-inf, +inf), "pb"
+ */
+class SargableNode final : public Operator<SargableNode, 3>, public Node, public LogicalNode {
+ using Base = Operator<SargableNode, 3>;
+
+public:
+ SargableNode(PartialSchemaRequirements reqMap,
+ CandidateIndexMap candidateIndexMap,
+ IndexReqTarget target,
+ ABT child);
+
+ bool operator==(const SargableNode& other) const;
+
+ const ExpressionBinder& binder() const {
+ const ABT& result = get<1>();
+ uassert(6624016, "Invalid binder type", result.is<ExpressionBinder>());
+ return *result.cast<ExpressionBinder>();
+ }
+
+ const ABT& getChild() const {
+ return get<0>();
+ }
+ ABT& getChild() {
+ return get<0>();
+ }
+
+ const PartialSchemaRequirements& getReqMap() const;
+ const CandidateIndexMap& getCandidateIndexMap() const;
+
+ IndexReqTarget getTarget() const;
+
+private:
+ const PartialSchemaRequirements _reqMap;
+
+ CandidateIndexMap _candidateIndexMap;
+
+ // Performance optimization to limit number of groups.
+ // Under what indexing requirements can this node be implemented.
+ const IndexReqTarget _target;
+};
+
+#define JOIN_TYPE(F) \
+ F(Inner) \
+ F(Left) \
+ F(Right) \
+ F(Full)
+
+MAKE_PRINTABLE_ENUM(JoinType, JOIN_TYPE);
+MAKE_PRINTABLE_ENUM_STRING_ARRAY(JoinTypeEnum, JoinType, JOIN_TYPE);
+#undef JOIN_TYPE
+
+/**
+ * Logical binary join.
+ * Join of two logical nodes. Can express inner and outer joins, with an associated join predicate.
+ *
+ * This node is logical, with a default physical implementation corresponding to a Nested Loops Join
+ * (NLJ).
+ * Variables used in the inner (right) side are automatically bound with variables from the left
+ * (outer) side.
+ */
+class BinaryJoinNode final : public Operator<BinaryJoinNode, 3>, public Node {
+ using Base = Operator<BinaryJoinNode, 3>;
+
+public:
+ BinaryJoinNode(JoinType joinType,
+ ProjectionNameSet correlatedProjectionNames,
+ FilterType filter,
+ ABT leftChild,
+ ABT rightChild);
+
+ bool operator==(const BinaryJoinNode& other) const;
+
+ JoinType getJoinType() const;
+
+ const ProjectionNameSet& getCorrelatedProjectionNames() const;
+
+ const ABT& getLeftChild() const;
+ ABT& getLeftChild();
+
+ const ABT& getRightChild() const;
+ ABT& getRightChild();
+
+ const ABT& getFilter() const;
+
+private:
+ const JoinType _joinType;
+
+ // Those projections must exist on the outer side and are used to bind free variables on the
+ // inner side.
+ const ProjectionNameSet _correlatedProjectionNames;
+};
+
+/**
+ * Physical hash join node.
+ * Join condition is a conjunction of pairwise equalities between corresponding left and right keys.
+ *
+ * TODO: support all join types (not just Inner).
+ */
+class HashJoinNode final : public Operator<HashJoinNode, 3>, public Node, public PhysicalNode {
+ using Base = Operator<HashJoinNode, 3>;
+
+public:
+ HashJoinNode(JoinType joinType,
+ ProjectionNameVector leftKeys,
+ ProjectionNameVector rightKeys,
+ ABT leftChild,
+ ABT rightChild);
+
+ bool operator==(const HashJoinNode& other) const;
+
+ JoinType getJoinType() const;
+ const ProjectionNameVector& getLeftKeys() const;
+ const ProjectionNameVector& getRightKeys() const;
+
+ const ABT& getLeftChild() const;
+ ABT& getLeftChild();
+
+ const ABT& getRightChild() const;
+ ABT& getRightChild();
+
+private:
+ const JoinType _joinType;
+
+ // Join condition is a conjunction of _leftKeys.at(i) == _rightKeys.at(i).
+ const ProjectionNameVector _leftKeys;
+ const ProjectionNameVector _rightKeys;
+};
+
+/**
+ * Merge Join node.
+ * This is a physical node representing joining of two sorted inputs.
+ */
+class MergeJoinNode final : public Operator<MergeJoinNode, 3>, public Node, public PhysicalNode {
+ using Base = Operator<MergeJoinNode, 3>;
+
+public:
+ MergeJoinNode(ProjectionNameVector leftKeys,
+ ProjectionNameVector rightKeys,
+ std::vector<CollationOp> collation,
+ ABT leftChild,
+ ABT rightChild);
+
+ bool operator==(const MergeJoinNode& other) const;
+
+ const ProjectionNameVector& getLeftKeys() const;
+ const ProjectionNameVector& getRightKeys() const;
+
+ const std::vector<CollationOp>& getCollation() const;
+
+ const ABT& getLeftChild() const;
+ ABT& getLeftChild();
+
+ const ABT& getRightChild() const;
+ ABT& getRightChild();
+
+private:
+ // Describes how to merge the sorted streams.
+ std::vector<CollationOp> _collation;
+
+ // Join condition is a conjunction of _leftKeys.at(i) == _rightKeys.at(i).
+ const ProjectionNameVector _leftKeys;
+ const ProjectionNameVector _rightKeys;
+};
+
+/**
+ * Union of several logical nodes. Projections in common to all nodes are logically union-ed in the
+ * output. It can be used with a single child just to restrict projections.
+ *
+ * This node is both logical and physical.
+ */
+class UnionNode final : public OperatorDynamic<UnionNode, 2>, public Node {
+ using Base = OperatorDynamic<UnionNode, 2>;
+
+public:
+ UnionNode(ProjectionNameVector unionProjectionNames, ABTVector children);
+
+ bool operator==(const UnionNode& other) const;
+
+ const ExpressionBinder& binder() const {
+ const ABT& result = get<0>();
+ uassert(6624017, "Invalid binder type", result.is<ExpressionBinder>());
+ return *result.cast<ExpressionBinder>();
+ }
+};
+
+#define GROUPNODETYPE_OPNAMES(F) \
+ F(Complete) \
+ F(Local) \
+ F(Global)
+
+MAKE_PRINTABLE_ENUM(GroupNodeType, GROUPNODETYPE_OPNAMES);
+MAKE_PRINTABLE_ENUM_STRING_ARRAY(GroupNodeTypeEnum, GroupNodeType, GROUPNODETYPE_OPNAMES);
+#undef PATHSYNTAX_OPNAMES
+
+/**
+ * Group-by node.
+ * This node is logical with a default physical implementation corresponding to a hash group-by.
+ * Projects the group-by column from its child, and adds aggregation expressions.
+ *
+ * TODO: other physical implementations: stream group-by.
+ */
+class GroupByNode : public Operator<GroupByNode, 5>, public Node {
+ using Base = Operator<GroupByNode, 5>;
+
+public:
+ GroupByNode(ProjectionNameVector groupByProjectionNames,
+ ProjectionNameVector aggregationProjectionNames,
+ ABTVector aggregationExpressions,
+ ABT child);
+
+ GroupByNode(ProjectionNameVector groupByProjectionNames,
+ ProjectionNameVector aggregationProjectionNames,
+ ABTVector aggregationExpressions,
+ GroupNodeType type,
+ ABT child);
+
+ bool operator==(const GroupByNode& other) const;
+
+ const ExpressionBinder& binderAgg() const {
+ const ABT& result = get<1>();
+ uassert(6624018, "Invalid binder type", result.is<ExpressionBinder>());
+ return *result.cast<ExpressionBinder>();
+ }
+
+ const ExpressionBinder& binderGb() const {
+ const ABT& result = get<3>();
+ uassert(6624019, "Invalid binder type", result.is<ExpressionBinder>());
+ return *result.cast<ExpressionBinder>();
+ }
+
+ const ProjectionNameVector& getGroupByProjectionNames() const {
+ return binderGb().names();
+ }
+
+ const ProjectionNameVector& getAggregationProjectionNames() const {
+ return binderAgg().names();
+ }
+
+ const auto& getAggregationProjections() const {
+ return binderAgg().exprs();
+ }
+
+ const auto& getGroupByProjections() const {
+ return binderGb().exprs();
+ }
+
+ const ABTVector& getAggregationExpressions() const;
+
+ const ABT& getChild() const;
+ ABT& getChild();
+
+ GroupNodeType getType() const;
+
+private:
+ // Used for local-global rewrite.
+ GroupNodeType _type;
+};
+
+/**
+ * Unwind node.
+ * Unwinds an embedded relation inside an array. Generates unwinding positions in the CID
+ * projection.
+ *
+ * This node is both logical and physical.
+ */
+class UnwindNode final : public Operator<UnwindNode, 3>, public Node {
+ using Base = Operator<UnwindNode, 3>;
+
+public:
+ UnwindNode(ProjectionName projectionName,
+ ProjectionName pidProjectionName,
+ bool retainNonArrays,
+ ABT child);
+
+ bool operator==(const UnwindNode& other) const;
+
+ const ExpressionBinder& binder() const {
+ const ABT& result = get<1>();
+ uassert(6624020, "Invalid binder type", result.is<ExpressionBinder>());
+ return *result.cast<ExpressionBinder>();
+ }
+
+ const ProjectionName& getProjectionName() const {
+ return binder().names()[0];
+ }
+
+ const ProjectionName& getPIDProjectionName() const {
+ return binder().names()[1];
+ }
+
+ const ProjectionType& getProjection() const {
+ return binder().exprs()[0];
+ }
+
+ const ProjectionType& getPIDProjection() const {
+ return binder().exprs()[1];
+ }
+
+ const ABT& getChild() const;
+
+ ABT& getChild();
+
+ bool getRetainNonArrays() const;
+
+private:
+ const bool _retainNonArrays;
+};
+
+/**
+ * Unique node.
+ *
+ * This is a physical node. It encodes an operation which will duplicate the child input using a
+ * sequence of given projection names. It is similar to GroupBy using the given projections as a
+ * compound grouping key.
+ */
+class UniqueNode final : public Operator<UniqueNode, 2>, public Node, public PhysicalNode {
+ using Base = Operator<UniqueNode, 2>;
+
+public:
+ UniqueNode(ProjectionNameVector projections, ABT child);
+
+ bool operator==(const UniqueNode& other) const;
+
+ const ProjectionNameVector& getProjections() const;
+
+ const ABT& getChild() const;
+
+private:
+ ProjectionNameVector _projections;
+};
+
+/**
+ * Collation node.
+ * This node is both logical and physical.
+ *
+ * It represents an operator to collate (sort, or cluster) the input.
+ */
+class CollationNode final : public Operator<CollationNode, 2>, public Node {
+ using Base = Operator<CollationNode, 2>;
+
+public:
+ CollationNode(properties::CollationRequirement property, ABT child);
+
+ bool operator==(const CollationNode& other) const;
+
+ const properties::CollationRequirement& getProperty() const;
+ properties::CollationRequirement& getProperty();
+
+ const ABT& getChild() const;
+
+ ABT& getChild();
+
+private:
+ properties::CollationRequirement _property;
+};
+
+/**
+ * Limit and skip node.
+ * This node is both logical and physical.
+ *
+ * It limits the size of the input by a fixed amount.
+ */
+class LimitSkipNode final : public Operator<LimitSkipNode, 1>, public Node {
+ using Base = Operator<LimitSkipNode, 1>;
+
+public:
+ LimitSkipNode(properties::LimitSkipRequirement property, ABT child);
+
+ bool operator==(const LimitSkipNode& other) const;
+
+ const properties::LimitSkipRequirement& getProperty() const;
+ properties::LimitSkipRequirement& getProperty();
+
+ const ABT& getChild() const;
+
+ ABT& getChild();
+
+private:
+ properties::LimitSkipRequirement _property;
+};
+
+/**
+ * Exchange node.
+ * It specifies how the relation is spread across machines in the execution environment.
+ * Currently only single-node, and hash-based partitioning are supported.
+ * TODO: range-based partitioning, replication, and round-robin.
+ *
+ * This node is both logical and physical.
+ */
+class ExchangeNode final : public Operator<ExchangeNode, 2>, public Node {
+ using Base = Operator<ExchangeNode, 2>;
+
+public:
+ ExchangeNode(properties::DistributionRequirement distribution, ABT child);
+
+ bool operator==(const ExchangeNode& other) const;
+
+ const properties::DistributionRequirement& getProperty() const;
+ properties::DistributionRequirement& getProperty();
+
+ const ABT& getChild() const;
+
+ ABT& getChild();
+
+private:
+ properties::DistributionRequirement _distribution;
+
+ /**
+ * Defined for hash and range-based partitioning.
+ * TODO: other exchange-specific params (e.g. chunk boundaries?)
+ */
+ const ProjectionName _projectionName;
+};
+
+/**
+ * Root of the tree that holds references to the output of the query. In the mql case the query
+ * outputs a single "column" (aka document) but in a general case (SQL) we can output arbitrary many
+ * "columns". We need the internal references for the output projections in order to keep them live,
+ * otherwise they would be dropped from the tree by DCE.
+ *
+ * This node is only logical.
+ */
+class RootNode final : public Operator<RootNode, 2>, public Node {
+ using Base = Operator<RootNode, 2>;
+
+public:
+ RootNode(properties::ProjectionRequirement property, ABT child);
+
+ bool operator==(const RootNode& other) const;
+
+ const properties::ProjectionRequirement& getProperty() const;
+
+ const ABT& getChild() const;
+ ABT& getChild();
+
+private:
+ const properties::ProjectionRequirement _property;
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/node_defs.h b/src/mongo/db/query/optimizer/node_defs.h
new file mode 100644
index 00000000000..9a097f24591
--- /dev/null
+++ b/src/mongo/db/query/optimizer/node_defs.h
@@ -0,0 +1,66 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/node.h"
+#include "mongo/db/query/optimizer/props.h"
+
+namespace mongo::optimizer {
+
+// Used for physical rewrites. For each child we optimize, we specify physical properties.
+using ChildPropsType = std::vector<std::pair<ABT*, properties::PhysProps>>;
+
+// Use for physical rewrites. For each physical node we implement, we set CE to use when costing.
+using NodeCEMap = opt::unordered_map<const Node*, CEType>;
+
+struct NodeProps {
+ // Used to tie to a corresponding SBE stage.
+ int32_t _planNodeId;
+
+ // Which is the corresponding memo group, and its properties.
+ MemoPhysicalNodeId _groupId;
+ properties::LogicalProps _logicalProps;
+ properties::PhysProps _physicalProps;
+
+ // Total cost of the best plan (includes the subtree).
+ CostType _cost;
+ // Local cost (excludes subtree).
+ CostType _localCost;
+
+ // For display purposes, adjusted cardinality based on physical properties (e.g. Repetition and
+ // Limit-Skip).
+ CEType _adjustedCE;
+};
+
+// Map from node to various properties, including logical and physical. Used to determine for
+// example which of the available projections are used for exchanges.
+using NodeToGroupPropsMap = opt::unordered_map<const Node*, NodeProps>;
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/opt_phase_manager.cpp b/src/mongo/db/query/optimizer/opt_phase_manager.cpp
new file mode 100644
index 00000000000..7c282ea2fc5
--- /dev/null
+++ b/src/mongo/db/query/optimizer/opt_phase_manager.cpp
@@ -0,0 +1,332 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/opt_phase_manager.h"
+
+#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
+#include "mongo/db/query/optimizer/cascades/cost_derivation.h"
+#include "mongo/db/query/optimizer/cascades/logical_props_derivation.h"
+#include "mongo/db/query/optimizer/rewrites/const_eval.h"
+#include "mongo/db/query/optimizer/rewrites/path.h"
+#include "mongo/db/query/optimizer/rewrites/path_lower.h"
+#include "mongo/db/query/optimizer/utils/memo_utils.h"
+
+namespace mongo::optimizer {
+
+OptPhaseManager::PhaseSet OptPhaseManager::_allRewrites = {OptPhase::ConstEvalPre,
+ OptPhase::PathFuse,
+ OptPhase::MemoSubstitutionPhase,
+ OptPhase::MemoExplorationPhase,
+ OptPhase::MemoImplementationPhase,
+ OptPhase::PathLower,
+ OptPhase::ConstEvalPost};
+
+OptPhaseManager::OptPhaseManager(OptPhaseManager::PhaseSet phaseSet,
+ PrefixId& prefixId,
+ Metadata metadata,
+ DebugInfo debugInfo)
+ : OptPhaseManager(std::move(phaseSet),
+ prefixId,
+ false /*requireRID*/,
+ std::move(metadata),
+ std::make_unique<HeuristicCE>(),
+ std::make_unique<DefaultCosting>(),
+ std::move(debugInfo)) {}
+
+OptPhaseManager::OptPhaseManager(OptPhaseManager::PhaseSet phaseSet,
+ PrefixId& prefixId,
+ const bool requireRID,
+ Metadata metadata,
+ std::unique_ptr<CEInterface> ceDerivation,
+ std::unique_ptr<CostingInterface> costDerivation,
+ DebugInfo debugInfo)
+ : _phaseSet(std::move(phaseSet)),
+ _debugInfo(std::move(debugInfo)),
+ _hints(),
+ _metadata(std::move(metadata)),
+ _memo(_debugInfo,
+ _metadata,
+ std::make_unique<DefaultLogicalPropsDerivation>(),
+ std::move(ceDerivation)),
+ _costDerivation(std::move(costDerivation)),
+ _physicalNodeId(),
+ _requireRID(requireRID),
+ _ridProjections(),
+ _prefixId(prefixId) {
+ uassert(6624093, "Empty Cost derivation", _costDerivation.get());
+
+ for (const auto& entry : _metadata._scanDefs) {
+ _ridProjections.emplace(entry.first, _prefixId.getNextId("rid"));
+ }
+}
+
+template <OptPhaseManager::OptPhase phase, class C>
+bool OptPhaseManager::runStructuralPhase(C instance, VariableEnvironment& env, ABT& input) {
+ if (!hasPhase(phase)) {
+ return true;
+ }
+
+ for (int iterationCount = 0; instance.optimize(input); iterationCount++) {
+ if (_debugInfo.exceedsIterationLimit(iterationCount)) {
+ // Iteration limit exceeded.
+ return false;
+ }
+ }
+
+ return !env.hasFreeVariables();
+}
+
+template <OptPhaseManager::OptPhase phase1, OptPhaseManager::OptPhase phase2, class C1, class C2>
+bool OptPhaseManager::runStructuralPhases(C1 instance1,
+ C2 instance2,
+ VariableEnvironment& env,
+ ABT& input) {
+ const bool hasPhase1 = hasPhase(phase1);
+ const bool hasPhase2 = hasPhase(phase2);
+ if (!hasPhase1 && !hasPhase2) {
+ return true;
+ }
+
+ bool changed = true;
+ for (int iterationCount = 0; changed; iterationCount++) {
+ if (_debugInfo.exceedsIterationLimit(iterationCount)) {
+ // Iteration limit exceeded.
+ return false;
+ }
+
+ changed = false;
+ if (hasPhase1) {
+ changed |= instance1.optimize(input);
+ }
+ if (hasPhase2) {
+ changed |= instance2.optimize(input);
+ }
+ }
+
+ return !env.hasFreeVariables();
+}
+
+bool OptPhaseManager::runMemoLogicalRewrite(const OptPhase phase,
+ VariableEnvironment& env,
+ const LogicalRewriter::RewriteSet& rewriteSet,
+ GroupIdType& rootGroupId,
+ const bool runStandalone,
+ std::unique_ptr<LogicalRewriter>& logicalRewriter,
+ ABT& input) {
+ if (!hasPhase(phase)) {
+ return true;
+ }
+
+ _memo.clear();
+ logicalRewriter = std::make_unique<LogicalRewriter>(_memo, _prefixId, rewriteSet);
+ rootGroupId = logicalRewriter->addRootNode(input);
+
+ if (runStandalone) {
+ if (!logicalRewriter->rewriteToFixPoint()) {
+ return false;
+ }
+
+ input = extractLatestPlan(_memo, rootGroupId);
+ env.rebuild(input);
+ }
+
+ return !env.hasFreeVariables();
+}
+
+bool OptPhaseManager::runMemoPhysicalRewrite(const OptPhase phase,
+ VariableEnvironment& env,
+ const GroupIdType rootGroupId,
+ std::unique_ptr<LogicalRewriter>& logicalRewriter,
+ ABT& input) {
+ using namespace properties;
+
+ if (!hasPhase(phase)) {
+ return true;
+ }
+ if (rootGroupId < 0) {
+ // Nothing inserted in the memo. Logical rewrites did not run?
+ return false;
+ }
+
+ // By default we require centralized result.
+ // Also by default we do not require projections: the Root node will add those.
+ PhysProps physProps = makePhysProps(DistributionRequirement(DistributionType::Centralized));
+ if (_requireRID) {
+ const auto& rootLogicalProps = _memo.getGroup(rootGroupId)._logicalProperties;
+ if (!hasProperty<IndexingAvailability>(rootLogicalProps)) {
+ // We cannot obtain rid for this query.
+ return false;
+ }
+
+ setProperty(
+ physProps,
+ IndexingRequirement(
+ IndexReqTarget::Complete, true /*needRID*/, true /*dedupRID*/, rootGroupId));
+ }
+
+ PhysicalRewriter rewriter(_memo, _hints, _ridProjections, *_costDerivation, logicalRewriter);
+
+ auto optGroupResult =
+ rewriter.optimizeGroup(rootGroupId, std::move(physProps), _prefixId, CostType::kInfinity);
+ if (!optGroupResult._success) {
+ return false;
+ }
+
+ _physicalNodeId = {rootGroupId, optGroupResult._index};
+ std::tie(input, _nodeToGroupPropsMap) = extractPhysicalPlan(_physicalNodeId, _metadata, _memo);
+
+ env.rebuild(input);
+ return !env.hasFreeVariables();
+}
+
+bool OptPhaseManager::runMemoRewritePhases(VariableEnvironment& env, ABT& input) {
+ GroupIdType rootGroupId = -1;
+ std::unique_ptr<LogicalRewriter> logicalRewriter;
+
+ if (!runMemoLogicalRewrite(OptPhase::MemoSubstitutionPhase,
+ env,
+ LogicalRewriter::getSubstitutionSet(),
+ rootGroupId,
+ true /*runStandalone*/,
+ logicalRewriter,
+ input)) {
+ return false;
+ }
+
+ if (!runMemoLogicalRewrite(OptPhase::MemoExplorationPhase,
+ env,
+ LogicalRewriter::getExplorationSet(),
+ rootGroupId,
+ !hasPhase(OptPhase::MemoImplementationPhase),
+ logicalRewriter,
+ input)) {
+ return false;
+ }
+
+ if (!runMemoPhysicalRewrite(
+ OptPhase::MemoImplementationPhase, env, rootGroupId, logicalRewriter, input)) {
+ return false;
+ }
+
+ return true;
+}
+
+bool OptPhaseManager::optimize(ABT& input) {
+ VariableEnvironment env = VariableEnvironment::build(input);
+ if (env.hasFreeVariables()) {
+ return false;
+ }
+
+ if (!runStructuralPhases<OptPhase::ConstEvalPre, OptPhase::PathFuse, ConstEval, PathFusion>(
+ ConstEval{env, true /*disableSargableInlining*/}, PathFusion{env}, env, input)) {
+ return false;
+ }
+
+ if (!runMemoRewritePhases(env, input)) {
+ return false;
+ }
+
+ if (!runStructuralPhase<OptPhase::PathLower, PathLowering>(
+ PathLowering{_prefixId, env}, env, input)) {
+ return false;
+ }
+
+ ProjectionNameSet erasedProjNames;
+ if (!runStructuralPhase<OptPhase::ConstEvalPost, ConstEval>(
+ ConstEval{env, false /*disableSargableInlining*/, &erasedProjNames}, env, input)) {
+ return false;
+ }
+ if (!erasedProjNames.empty()) {
+ // If we have erased some eval nodes, make sure to delete the corresponding projection names
+ // from the node property map.
+ for (auto& [nodePtr, props] : _nodeToGroupPropsMap) {
+ if (properties::hasProperty<properties::ProjectionRequirement>(props._physicalProps)) {
+ auto& requiredProjNames =
+ properties::getProperty<properties::ProjectionRequirement>(props._physicalProps)
+ .getProjections();
+ for (const ProjectionName& projName : erasedProjNames) {
+ requiredProjNames.erase(projName);
+ }
+ }
+ }
+ }
+
+ env.rebuild(input);
+ if (env.hasFreeVariables()) {
+ return false;
+ }
+
+ return true;
+}
+
+bool OptPhaseManager::hasPhase(const OptPhase phase) const {
+ return _phaseSet.find(phase) != _phaseSet.cend();
+}
+
+const OptPhaseManager::PhaseSet& OptPhaseManager::getAllRewritesSet() {
+ return _allRewrites;
+}
+
+MemoPhysicalNodeId OptPhaseManager::getPhysicalNodeId() const {
+ return _physicalNodeId;
+}
+
+const QueryHints& OptPhaseManager::getHints() const {
+ return _hints;
+}
+
+QueryHints& OptPhaseManager::getHints() {
+ return _hints;
+}
+
+const Memo& OptPhaseManager::getMemo() const {
+ return _memo;
+}
+
+const Metadata& OptPhaseManager::getMetadata() const {
+ return _metadata;
+}
+
+PrefixId& OptPhaseManager::getPrefixId() const {
+ return _prefixId;
+}
+
+const NodeToGroupPropsMap& OptPhaseManager::getNodeToGroupPropsMap() const {
+ return _nodeToGroupPropsMap;
+}
+
+NodeToGroupPropsMap& OptPhaseManager::getNodeToGroupPropsMap() {
+ return _nodeToGroupPropsMap;
+}
+
+const opt::unordered_map<std::string, ProjectionName>& OptPhaseManager::getRIDProjections() const {
+ return _ridProjections;
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/opt_phase_manager.h b/src/mongo/db/query/optimizer/opt_phase_manager.h
new file mode 100644
index 00000000000..c079c4f08aa
--- /dev/null
+++ b/src/mongo/db/query/optimizer/opt_phase_manager.h
@@ -0,0 +1,181 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <unordered_set>
+
+#include "mongo/db/query/optimizer/cascades/interfaces.h"
+#include "mongo/db/query/optimizer/cascades/logical_rewriter.h"
+#include "mongo/db/query/optimizer/cascades/physical_rewriter.h"
+
+namespace mongo::optimizer {
+
+using namespace cascades;
+
+/**
+ * This class wraps together different optimization phases.
+ * First the transport rewrites are applied such as constant folding and redundant expression
+ * elimination. Second the logical and physical reordering rewrites are applied using the memo.
+ * Third the final transport rewritesd are applied.
+ */
+class OptPhaseManager {
+public:
+ enum class OptPhase {
+ // ConstEval performs the following rewrites: constant folding, inlining, and dead code
+ // elimination.
+ ConstEvalPre,
+ PathFuse,
+
+ // Memo phases below perform Cascades-style optimization.
+ // Reorder and transform nodes. Convert Filter and Eval nodes to SargableNodes, and possibly
+ // merge them.
+ MemoSubstitutionPhase,
+ // Performs Local-global and rewrites to enable index intersection.
+ // If there is an implementation phase, it runs integrated with the top-down optimization.
+ // If there is no implementation phase, it runs standalone.
+ MemoExplorationPhase,
+ // Implementation and enforcement rules.
+ MemoImplementationPhase,
+
+ PathLower,
+ ConstEvalPost
+ };
+
+ using PhaseSet = opt::unordered_set<OptPhase>;
+
+ OptPhaseManager(PhaseSet phaseSet, PrefixId& prefixId, Metadata metadata, DebugInfo debugInfo);
+ OptPhaseManager(PhaseSet phaseSet,
+ PrefixId& prefixId,
+ bool requireRID,
+ Metadata metadata,
+ std::unique_ptr<CEInterface> ceDerivation,
+ std::unique_ptr<CostingInterface> costDerivation,
+ DebugInfo debugInfo);
+
+ /**
+ * Optimization modifies the input argument.
+ * Return result is true for successful optimization and false for failure.
+ */
+ bool optimize(ABT& input);
+
+ static const PhaseSet& getAllRewritesSet();
+
+ MemoPhysicalNodeId getPhysicalNodeId() const;
+
+ const QueryHints& getHints() const;
+ QueryHints& getHints();
+
+ const Memo& getMemo() const;
+
+ const Metadata& getMetadata() const;
+
+ PrefixId& getPrefixId() const;
+
+ const NodeToGroupPropsMap& getNodeToGroupPropsMap() const;
+ NodeToGroupPropsMap& getNodeToGroupPropsMap();
+
+ const opt::unordered_map<std::string, ProjectionName>& getRIDProjections() const;
+
+private:
+ bool hasPhase(OptPhase phase) const;
+
+ template <OptPhase phase, class C>
+ bool runStructuralPhase(C instance, VariableEnvironment& env, ABT& input);
+
+ /**
+ * Run two structural phases until mutual fixpoint.
+ * We assume we can construct from the types by initializing with env.
+ */
+ template <const OptPhase phase1, const OptPhase phase2, class C1, class C2>
+ bool runStructuralPhases(C1 instance1, C2 instance2, VariableEnvironment& env, ABT& input);
+
+ bool runMemoLogicalRewrite(OptPhase phase,
+ VariableEnvironment& env,
+ const LogicalRewriter::RewriteSet& rewriteSet,
+ GroupIdType& rootGroupId,
+ bool runStandalone,
+ std::unique_ptr<LogicalRewriter>& logicalRewriter,
+ ABT& input);
+
+ bool runMemoPhysicalRewrite(OptPhase phase,
+ VariableEnvironment& env,
+ GroupIdType rootGroupId,
+ std::unique_ptr<LogicalRewriter>& logicalRewriter,
+ ABT& input);
+
+ bool runMemoRewritePhases(VariableEnvironment& env, ABT& input);
+
+
+ static PhaseSet _allRewrites;
+
+ const PhaseSet _phaseSet;
+
+ const DebugInfo _debugInfo;
+
+ QueryHints _hints;
+
+ Metadata _metadata;
+
+ /**
+ * Final state of the memo after physical rewrites are complete.
+ */
+ Memo _memo;
+
+ /**
+ * Cost derivation function.
+ */
+ std::unique_ptr<CostingInterface> _costDerivation;
+
+ /**
+ * Root physical node if we have performed physical rewrites.
+ */
+ MemoPhysicalNodeId _physicalNodeId;
+
+ /**
+ * Map from node to logical and physical properties.
+ */
+ NodeToGroupPropsMap _nodeToGroupPropsMap;
+
+ /**
+ * Used to optimize update and delete statements. If set will include indexing requirement with
+ * seed physical properties.
+ */
+ const bool _requireRID;
+
+ /**
+ * RID projection names we have generated for each scanDef. Used for physical rewriting.
+ */
+ opt::unordered_map<std::string, ProjectionName> _ridProjections;
+
+ // We don't own this.
+ PrefixId& _prefixId;
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/optimizer_test.cpp b/src/mongo/db/query/optimizer/optimizer_test.cpp
new file mode 100644
index 00000000000..856ad5e3d07
--- /dev/null
+++ b/src/mongo/db/query/optimizer/optimizer_test.cpp
@@ -0,0 +1,639 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/explain.h"
+#include "mongo/db/query/optimizer/node.h"
+#include "mongo/db/query/optimizer/reference_tracker.h"
+#include "mongo/db/query/optimizer/rewrites/const_eval.h"
+#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo::optimizer {
+namespace {
+
+TEST(Optimizer, ConstEval) {
+ // 1 + 2
+ auto tree = make<BinaryOp>(Operations::Add, Constant::int64(1), Constant::int64(2));
+
+ // Run the evaluator.
+ auto env = VariableEnvironment::build(tree);
+ ConstEval evaluator{env};
+ evaluator.optimize(tree);
+
+ // The result must be Constant.
+ auto result = tree.cast<Constant>();
+ ASSERT(result != nullptr);
+
+ // And the value must be 3 (i.e. 1+2).
+ ASSERT_EQ(result->getValueInt64(), 3);
+
+ ASSERT_NE(ABT::tagOf<Constant>(), ABT::tagOf<BinaryOp>());
+ ASSERT_EQ(tree.tagOf(), ABT::tagOf<Constant>());
+}
+
+TEST(Optimizer, ConstEvalCompose) {
+ // (1 + 2) + 3
+ auto tree =
+ make<BinaryOp>(Operations::Add,
+ make<BinaryOp>(Operations::Add, Constant::int64(1), Constant::int64(2)),
+ Constant::int64(3));
+
+ // Run the evaluator.
+ auto env = VariableEnvironment::build(tree);
+ ConstEval evaluator{env};
+ evaluator.optimize(tree);
+
+ // The result must be Constant.
+ auto result = tree.cast<Constant>();
+ ASSERT(result != nullptr);
+
+ // And the value must be 6 (i.e. 1+2+3).
+ ASSERT_EQ(result->getValueInt64(), 6);
+}
+
+TEST(Optimizer, Tracker1) {
+ ABT scanNode = make<ScanNode>("ptest", "test");
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(make<PathConstant>(make<UnaryOp>(Operations::Neg, Constant::int64(1))),
+ make<Variable>("ptest")),
+ std::move(scanNode));
+ ABT evalNode = make<EvaluationNode>(
+ "P1",
+ make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("ptest")),
+ std::move(filterNode));
+
+ ABT rootNode =
+ make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"P1", "ptest"}},
+ std::move(evalNode));
+
+ auto env = VariableEnvironment::build(rootNode);
+ ASSERT(!env.hasFreeVariables());
+}
+
+TEST(Optimizer, Tracker2) {
+ ABT expr = make<Let>("x",
+ Constant::int64(4),
+ make<BinaryOp>(Operations::Add, make<Variable>("x"), make<Variable>("x")));
+
+ auto env = VariableEnvironment::build(expr);
+ ConstEval evaluator{env};
+ evaluator.optimize(expr);
+
+ // The result must be Constant.
+ auto result = expr.cast<Constant>();
+ ASSERT(result != nullptr);
+
+ // And the value must be 8 (i.e. x+x = 4+4 = 8).
+ ASSERT_EQ(result->getValueInt64(), 8);
+}
+
+TEST(Optimizer, Tracker3) {
+ ABT scanNode = make<ScanNode>("ptest", "test");
+ ABT filterNode = make<FilterNode>(make<Variable>("free"), std::move(scanNode));
+ ABT evalNode1 = make<EvaluationNode>("free", Constant::int64(5), std::move(filterNode));
+
+ auto env = VariableEnvironment::build(evalNode1);
+ // "free" must still be a free variable.
+ ASSERT(env.hasFreeVariables());
+ ASSERT_EQ(env.freeOccurences("free"), 1);
+
+ // Projecting "unrelated" must not resolve "free".
+ ABT evalNode2 = make<EvaluationNode>("unrelated", Constant::int64(5), std::move(evalNode1));
+
+ env.rebuild(evalNode2);
+ ASSERT(env.hasFreeVariables());
+ ASSERT_EQ(env.freeOccurences("free"), 1);
+
+ // Another expression referencing "free" will resolve. But the original "free" reference is
+ // unaffected (i.e. it is still a free variable).
+ ABT filterNode2 = make<FilterNode>(make<Variable>("free"), std::move(evalNode2));
+
+ env.rebuild(filterNode2);
+ ASSERT(env.hasFreeVariables());
+ ASSERT_EQ(env.freeOccurences("free"), 1);
+}
+
+TEST(Optimizer, Tracker4) {
+ ABT scanNode = make<ScanNode>("ptest", "test");
+ auto scanNodeRef = scanNode.ref();
+ ABT evalNode = make<EvaluationNode>("unrelated", Constant::int64(5), std::move(scanNode));
+ ABT filterNode = make<FilterNode>(make<Variable>("ptest"), std::move(evalNode));
+
+ auto env = VariableEnvironment::build(filterNode);
+ ASSERT(!env.hasFreeVariables());
+
+ // Get all variables from the expression
+ auto vars = VariableEnvironment::getVariables(filterNode.cast<FilterNode>()->getFilter());
+ ASSERT(vars._variables.size() == 1);
+ // Get all definitions from the scan and below (even though there is nothing below the scan).
+ auto defs = env.getDefinitions(scanNodeRef);
+ // Make sure that variables are defined by the scan (and not by Eval).
+ for (auto v : vars._variables) {
+ auto it = defs.find(v->name());
+ ASSERT(it != defs.end());
+ ASSERT(it->second.definedBy == env.getDefinition(v).definedBy);
+ }
+}
+
+TEST(Optimizer, RefExplain) {
+ ABT scanNode = make<ScanNode>("ptest", "test");
+ ASSERT_EXPLAIN(
+ "Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ scanNode);
+
+ // Now repeat for the reference type.
+ auto ref = scanNode.ref();
+ ASSERT_EXPLAIN(
+ "Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ ref);
+
+ ASSERT_EQ(scanNode.tagOf(), ref.tagOf());
+}
+
+TEST(Optimizer, CoScan) {
+ ABT coScanNode = make<CoScanNode>();
+ ABT limitNode =
+ make<LimitSkipNode>(properties::LimitSkipRequirement(1, 0), std::move(coScanNode));
+
+ VariableEnvironment venv = VariableEnvironment::build(limitNode);
+ ASSERT_TRUE(!venv.hasFreeVariables());
+
+ ASSERT_EXPLAIN(
+ "LimitSkip []\n"
+ " limitSkip:\n"
+ " limit: 1\n"
+ " skip: 0\n"
+ " CoScan []\n",
+ limitNode);
+}
+
+TEST(Optimizer, Basic) {
+ ABT scanNode = make<ScanNode>("ptest", "test");
+ ASSERT_EXPLAIN(
+ "Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ scanNode);
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(make<PathConstant>(make<UnaryOp>(Operations::Neg, Constant::int64(1))),
+ make<Variable>("ptest")),
+ std::move(scanNode));
+ ABT evalNode = make<EvaluationNode>(
+ "P1",
+ make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("ptest")),
+ std::move(filterNode));
+
+ ABT rootNode =
+ make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"P1", "ptest"}},
+ std::move(evalNode));
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " P1\n"
+ " ptest\n"
+ " RefBlock: \n"
+ " Variable [P1]\n"
+ " Variable [ptest]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [P1]\n"
+ " EvalPath []\n"
+ " PathConstant []\n"
+ " Const [2]\n"
+ " Variable [ptest]\n"
+ " Filter []\n"
+ " EvalFilter []\n"
+ " PathConstant []\n"
+ " UnaryOp [Neg]\n"
+ " Const [1]\n"
+ " Variable [ptest]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ rootNode);
+
+
+ ABT clonedNode = rootNode;
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " P1\n"
+ " ptest\n"
+ " RefBlock: \n"
+ " Variable [P1]\n"
+ " Variable [ptest]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [P1]\n"
+ " EvalPath []\n"
+ " PathConstant []\n"
+ " Const [2]\n"
+ " Variable [ptest]\n"
+ " Filter []\n"
+ " EvalFilter []\n"
+ " PathConstant []\n"
+ " UnaryOp [Neg]\n"
+ " Const [1]\n"
+ " Variable [ptest]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ clonedNode);
+
+ auto env = VariableEnvironment::build(rootNode);
+ ProjectionNameSet set = env.topLevelProjections();
+ ProjectionNameSet expSet = {"P1", "ptest"};
+ ASSERT(expSet == set);
+ ASSERT(!env.hasFreeVariables());
+}
+
+TEST(Optimizer, GroupBy) {
+ ABT scanNode = make<ScanNode>("ptest", "test");
+ ABT evalNode1 = make<EvaluationNode>("p1", Constant::int64(1), std::move(scanNode));
+ ABT evalNode2 = make<EvaluationNode>("p2", Constant::int64(2), std::move(evalNode1));
+ ABT evalNode3 = make<EvaluationNode>("p3", Constant::int64(3), std::move(evalNode2));
+
+ {
+ auto env = VariableEnvironment::build(evalNode3);
+ ProjectionNameSet projSet = env.topLevelProjections();
+ ProjectionNameSet expSet = {"p1", "p2", "p3", "ptest"};
+ ASSERT(expSet == projSet);
+ ASSERT(!env.hasFreeVariables());
+ }
+
+ ABT agg1 = Constant::int64(10);
+ ABT agg2 = Constant::int64(11);
+ ABT groupByNode = make<GroupByNode>(ProjectionNameVector{"p1", "p2"},
+ ProjectionNameVector{"a1", "a2"},
+ makeSeq(std::move(agg1), std::move(agg2)),
+ std::move(evalNode3));
+
+ ABT rootNode = make<RootNode>(
+ properties::ProjectionRequirement{ProjectionNameVector{"p1", "p2", "a1", "a2"}},
+ std::move(groupByNode));
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " p1\n"
+ " p2\n"
+ " a1\n"
+ " a2\n"
+ " RefBlock: \n"
+ " Variable [a1]\n"
+ " Variable [a2]\n"
+ " Variable [p1]\n"
+ " Variable [p2]\n"
+ " GroupBy []\n"
+ " groupings: \n"
+ " RefBlock: \n"
+ " Variable [p1]\n"
+ " Variable [p2]\n"
+ " aggregations: \n"
+ " [a1]\n"
+ " Const [10]\n"
+ " [a2]\n"
+ " Const [11]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [p3]\n"
+ " Const [3]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [p2]\n"
+ " Const [2]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [p1]\n"
+ " Const [1]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ rootNode);
+
+ {
+ auto env = VariableEnvironment::build(rootNode);
+ ProjectionNameSet projSet = env.topLevelProjections();
+ ProjectionNameSet expSet = {"p1", "p2", "a1", "a2"};
+ ASSERT(expSet == projSet);
+ ASSERT(!env.hasFreeVariables());
+ }
+}
+
+TEST(Optimizer, Union) {
+ ABT scanNode1 = make<ScanNode>("ptest", "test");
+ ABT projNode1 = make<EvaluationNode>("B", Constant::int64(3), std::move(scanNode1));
+ ABT scanNode2 = make<ScanNode>("ptest", "test");
+ ABT projNode2 = make<EvaluationNode>("B", Constant::int64(4), std::move(scanNode2));
+ ABT scanNode3 = make<ScanNode>("ptest1", "test");
+ ABT evalNode = make<EvaluationNode>(
+ "ptest",
+ make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("ptest1")),
+ std::move(scanNode3));
+ ABT projNode3 = make<EvaluationNode>("B", Constant::int64(5), std::move(evalNode));
+
+
+ ABT unionNode = make<UnionNode>(ProjectionNameVector{"ptest", "B"},
+ makeSeq(projNode1, projNode2, projNode3));
+
+ ABT rootNode =
+ make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"ptest", "B"}},
+ std::move(unionNode));
+
+ {
+ auto env = VariableEnvironment::build(rootNode);
+ ProjectionNameSet projSet = env.topLevelProjections();
+ ProjectionNameSet expSet = {"ptest", "B"};
+ ASSERT(expSet == projSet);
+ ASSERT(!env.hasFreeVariables());
+ }
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " ptest\n"
+ " B\n"
+ " RefBlock: \n"
+ " Variable [B]\n"
+ " Variable [ptest]\n"
+ " Union []\n"
+ " BindBlock:\n"
+ " [B]\n"
+ " Source []\n"
+ " [ptest]\n"
+ " Source []\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [B]\n"
+ " Const [3]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [B]\n"
+ " Const [4]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [B]\n"
+ " Const [5]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " EvalPath []\n"
+ " PathConstant []\n"
+ " Const [2]\n"
+ " Variable [ptest1]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [ptest1]\n"
+ " Source []\n",
+ rootNode);
+}
+
+TEST(Optimizer, Unwind) {
+ ABT scanNode = make<ScanNode>("p1", "test");
+ ABT evalNode = make<EvaluationNode>(
+ "p2",
+ make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("p1")),
+ std::move(scanNode));
+ ABT unwindNode = make<UnwindNode>("p2", "p2pid", true /*retainNonArrays*/, std::move(evalNode));
+
+ // Make a copy of unwindNode as it will be used later again in the wind test.
+ ABT rootNode = make<RootNode>(
+ properties::ProjectionRequirement{ProjectionNameVector{"p1", "p2", "p2pid"}}, unwindNode);
+
+ {
+ auto env = VariableEnvironment::build(rootNode);
+ ProjectionNameSet projSet = env.topLevelProjections();
+ ProjectionNameSet expSet = {"p1", "p2", "p2pid"};
+ ASSERT(expSet == projSet);
+ ASSERT(!env.hasFreeVariables());
+ }
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " p1\n"
+ " p2\n"
+ " p2pid\n"
+ " RefBlock: \n"
+ " Variable [p1]\n"
+ " Variable [p2]\n"
+ " Variable [p2pid]\n"
+ " Unwind [retainNonArrays]\n"
+ " BindBlock:\n"
+ " [p2]\n"
+ " Source []\n"
+ " [p2pid]\n"
+ " Source []\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [p2]\n"
+ " EvalPath []\n"
+ " PathConstant []\n"
+ " Const [2]\n"
+ " Variable [p1]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [p1]\n"
+ " Source []\n",
+ rootNode);
+}
+
+TEST(Optimizer, Collation) {
+ ABT scanNode = make<ScanNode>("a", "test");
+ ABT evalNode = make<EvaluationNode>(
+ "b",
+ make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("a")),
+ std::move(scanNode));
+
+ ABT collationNode =
+ make<CollationNode>(properties::CollationRequirement(
+ {{"a", CollationOp::Ascending}, {"b", CollationOp::Clustered}}),
+ std::move(evalNode));
+ {
+ auto env = VariableEnvironment::build(collationNode);
+ ProjectionNameSet projSet = env.topLevelProjections();
+ ProjectionNameSet expSet = {"a", "b"};
+ ASSERT(expSet == projSet);
+ ASSERT(!env.hasFreeVariables());
+ }
+
+ ASSERT_EXPLAIN(
+ "Collation []\n"
+ " collation: \n"
+ " a: Ascending\n"
+ " b: Clustered\n"
+ " RefBlock: \n"
+ " Variable [a]\n"
+ " Variable [b]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [b]\n"
+ " EvalPath []\n"
+ " PathConstant []\n"
+ " Const [2]\n"
+ " Variable [a]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [a]\n"
+ " Source []\n",
+ collationNode);
+}
+
+TEST(Optimizer, LimitSkip) {
+ ABT scanNode = make<ScanNode>("a", "test");
+ ABT evalNode = make<EvaluationNode>(
+ "b",
+ make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("a")),
+ std::move(scanNode));
+
+ ABT limitSkipNode =
+ make<LimitSkipNode>(properties::LimitSkipRequirement(10, 20), std::move(evalNode));
+ {
+ auto env = VariableEnvironment::build(limitSkipNode);
+ ProjectionNameSet projSet = env.topLevelProjections();
+ ProjectionNameSet expSet = {"a", "b"};
+ ASSERT(expSet == projSet);
+ ASSERT(!env.hasFreeVariables());
+ }
+
+ ASSERT_EXPLAIN(
+ "LimitSkip []\n"
+ " limitSkip:\n"
+ " limit: 10\n"
+ " skip: 20\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [b]\n"
+ " EvalPath []\n"
+ " PathConstant []\n"
+ " Const [2]\n"
+ " Variable [a]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [a]\n"
+ " Source []\n",
+ limitSkipNode);
+}
+
+TEST(Optimizer, Distribution) {
+ ABT scanNode = make<ScanNode>("a", "test");
+ ABT evalNode = make<EvaluationNode>(
+ "b",
+ make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("a")),
+ std::move(scanNode));
+
+ ABT exchangeNode = make<ExchangeNode>(
+ properties::DistributionRequirement({DistributionType::HashPartitioning, {"b"}}),
+ std::move(evalNode));
+
+ ASSERT_EXPLAIN(
+ "Exchange []\n"
+ " distribution: \n"
+ " type: HashPartitioning\n"
+ " projections: \n"
+ " b\n"
+ " RefBlock: \n"
+ " Variable [b]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [b]\n"
+ " EvalPath []\n"
+ " PathConstant []\n"
+ " Const [2]\n"
+ " Variable [a]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [a]\n"
+ " Source []\n",
+ exchangeNode);
+}
+
+TEST(Properties, Basic) {
+ using namespace properties;
+
+ CollationRequirement collation1(
+ {{"p1", CollationOp::Ascending}, {"p2", CollationOp::Descending}});
+ CollationRequirement collation2(
+ {{"p1", CollationOp::Ascending}, {"p2", CollationOp::Clustered}});
+ ASSERT_TRUE(collationsCompatible(collation1.getCollationSpec(), collation2.getCollationSpec()));
+ ASSERT_FALSE(
+ collationsCompatible(collation2.getCollationSpec(), collation1.getCollationSpec()));
+
+ PhysProps props;
+ ASSERT_FALSE(hasProperty<CollationRequirement>(props));
+ ASSERT_TRUE(setProperty(props, collation1));
+ ASSERT_TRUE(hasProperty<CollationRequirement>(props));
+ ASSERT_FALSE(setProperty(props, collation2));
+ ASSERT_TRUE(collation1 == getProperty<CollationRequirement>(props));
+
+ LimitSkipRequirement ls(10, 20);
+ ASSERT_FALSE(hasProperty<LimitSkipRequirement>(props));
+ ASSERT_TRUE(setProperty(props, ls));
+ ASSERT_TRUE(hasProperty<LimitSkipRequirement>(props));
+ ASSERT_TRUE(ls == getProperty<LimitSkipRequirement>(props));
+
+ LimitSkipRequirement ls1(-1, 10);
+ LimitSkipRequirement ls2(5, 0);
+ {
+ LimitSkipRequirement ls3 = ls2;
+ combineLimitSkipProperties(ls3, ls1);
+ ASSERT_EQ(5, ls3.getLimit());
+ ASSERT_EQ(10, ls3.getSkip());
+ }
+ {
+ LimitSkipRequirement ls3 = ls1;
+ combineLimitSkipProperties(ls3, ls2);
+ ASSERT_EQ(0, ls3.getLimit());
+ ASSERT_EQ(0, ls3.getSkip());
+ }
+}
+
+} // namespace
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp b/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp
new file mode 100644
index 00000000000..80ee0f07ebb
--- /dev/null
+++ b/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp
@@ -0,0 +1,4659 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
+#include "mongo/db/query/optimizer/cascades/ce_hinted.h"
+#include "mongo/db/query/optimizer/cascades/cost_derivation.h"
+#include "mongo/db/query/optimizer/explain.h"
+#include "mongo/db/query/optimizer/metadata.h"
+#include "mongo/db/query/optimizer/node.h"
+#include "mongo/db/query/optimizer/opt_phase_manager.h"
+#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo::optimizer {
+namespace {
+
+TEST(PhysRewriter, PhysicalRewriterBasic) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("p1", "test");
+
+ ABT projectionNode1 = make<EvaluationNode>(
+ "p2", make<EvalPath>(make<PathIdentity>(), make<Variable>("p1")), std::move(scanNode));
+
+ ABT filter1Node = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("p1")),
+ std::move(projectionNode1));
+
+ ABT filter2Node = make<FilterNode>(
+ make<EvalFilter>(make<PathGet>("a", make<PathCompare>(Operations::Eq, Constant::int64(1))),
+ make<Variable>("p2")),
+ std::move(filter1Node));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"p2"}}, std::move(filter2Node));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"test", {{}, {}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ {
+ auto env = VariableEnvironment::build(optimized);
+ ProjectionNameSet expSet = {"p1", "p2"};
+ ASSERT_TRUE(expSet == env.topLevelProjections());
+ }
+ ASSERT_EQ(5, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | p2\n"
+ "| RefBlock: \n"
+ "| Variable [p2]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [p2]\n"
+ "| PathGet [a]\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [p2]\n"
+ "| EvalPath []\n"
+ "| | Variable [p1]\n"
+ "| PathIdentity []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [p1]\n"
+ "| PathIdentity []\n"
+ "PhysicalScan [{'<root>': p1}, test]\n"
+ " BindBlock:\n"
+ " [p1]\n"
+ " Source []\n",
+ optimized);
+
+ // Plan output with properties.
+ ASSERT_EXPLAIN_PROPS_V2(
+ "Properties [cost: 1.02, localCost: 0, adjustedCE: 10]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 10\n"
+ "| | projections: \n"
+ "| | p1\n"
+ "| | p2\n"
+ "| | indexingAvailability: \n"
+ "| | [groupId: 0, scanProjection: p1, scanDefName: test]\n"
+ "| | collectionAvailability: \n"
+ "| | test\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| Physical:\n"
+ "| distribution: \n"
+ "| type: Centralized\n"
+ "| indexingRequirement: \n"
+ "| Complete, dedupRID\n"
+ "Root []\n"
+ "| | projections: \n"
+ "| | p2\n"
+ "| RefBlock: \n"
+ "| Variable [p2]\n"
+ "Properties [cost: 1.02, localCost: 0.020001, adjustedCE: 10]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 10\n"
+ "| | requirementCEs: \n"
+ "| | refProjection: p2, path: 'PathGet [a] PathIdentity []', ce: 10\n"
+ "| | projections: \n"
+ "| | p1\n"
+ "| | p2\n"
+ "| | indexingAvailability: \n"
+ "| | [groupId: 0, scanProjection: p1, scanDefName: test]\n"
+ "| | collectionAvailability: \n"
+ "| | test\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| Physical:\n"
+ "| projections: \n"
+ "| p2\n"
+ "| distribution: \n"
+ "| type: Centralized\n"
+ "| indexingRequirement: \n"
+ "| Complete, dedupRID\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [p2]\n"
+ "| PathGet [a]\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Properties [cost: 1, localCost: 0.200001, adjustedCE: 100]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 100\n"
+ "| | projections: \n"
+ "| | p1\n"
+ "| | p2\n"
+ "| | indexingAvailability: \n"
+ "| | [groupId: 0, scanProjection: p1, scanDefName: test]\n"
+ "| | collectionAvailability: \n"
+ "| | test\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| Physical:\n"
+ "| projections: \n"
+ "| p2\n"
+ "| distribution: \n"
+ "| type: Centralized, disableExchanges\n"
+ "| indexingRequirement: \n"
+ "| Complete, dedupRID\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [p2]\n"
+ "| EvalPath []\n"
+ "| | Variable [p1]\n"
+ "| PathIdentity []\n"
+ "Properties [cost: 0.800002, localCost: 0.200001, adjustedCE: 100]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 100\n"
+ "| | projections: \n"
+ "| | p1\n"
+ "| | indexingAvailability: \n"
+ "| | [groupId: 0, scanProjection: p1, scanDefName: test]\n"
+ "| | collectionAvailability: \n"
+ "| | test\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| Physical:\n"
+ "| projections: \n"
+ "| p1\n"
+ "| distribution: \n"
+ "| type: Centralized, disableExchanges\n"
+ "| indexingRequirement: \n"
+ "| Complete, dedupRID\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [p1]\n"
+ "| PathIdentity []\n"
+ "Properties [cost: 0.600001, localCost: 0.600001, adjustedCE: 1000]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 1000\n"
+ "| | projections: \n"
+ "| | p1\n"
+ "| | indexingAvailability: \n"
+ "| | [groupId: 0, scanProjection: p1, scanDefName: test, possiblyEqPredsOnly]\n"
+ "| | collectionAvailability: \n"
+ "| | test\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| Physical:\n"
+ "| projections: \n"
+ "| p1\n"
+ "| distribution: \n"
+ "| type: Centralized, disableExchanges\n"
+ "| indexingRequirement: \n"
+ "| Complete, dedupRID\n"
+ "PhysicalScan [{'<root>': p1}, test]\n"
+ " BindBlock:\n"
+ " [p1]\n"
+ " Source []\n",
+ phaseManager);
+}
+
+TEST(PhysRewriter, GroupBy) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("ptest", "test");
+
+ ABT projectionANode = make<EvaluationNode>(
+ "a", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode));
+ ABT projectionBNode =
+ make<EvaluationNode>("b",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")),
+ std::move(projectionANode));
+
+ ABT groupByNode = make<GroupByNode>(ProjectionNameVector{"a"},
+ ProjectionNameVector{"c"},
+ makeSeq(make<Variable>("b")),
+ std::move(projectionBNode));
+
+ ABT filterCNode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("c")),
+ std::move(groupByNode));
+
+ ABT filterANode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("a")),
+ std::move(filterCNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"c"}}, std::move(filterANode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"test", {{}, {}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(7, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | c\n"
+ "| RefBlock: \n"
+ "| Variable [c]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [a]\n"
+ "| PathIdentity []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [c]\n"
+ "| PathIdentity []\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [a]\n"
+ "| aggregations: \n"
+ "| [c]\n"
+ "| Variable [b]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [b]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [a]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "PhysicalScan [{'<root>': ptest}, test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, GroupBy1) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("ptest", "test");
+
+ ABT projectionANode = make<EvaluationNode>("pa", Constant::null(), std::move(scanNode));
+ ABT projectionA1Node =
+ make<EvaluationNode>("pa1", Constant::null(), std::move(projectionANode));
+
+ ABT groupByNode = make<GroupByNode>(ProjectionNameVector{},
+ ProjectionNameVector{"pb", "pb1"},
+ makeSeq(make<Variable>("pa"), make<Variable>("pa1")),
+ std::move(projectionA1Node));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pb"}}, std::move(groupByNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"test", {{}, {}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(5, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Projection "pb1" is unused and we do not generate an aggregation expression for it.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pb\n"
+ "| RefBlock: \n"
+ "| Variable [pb]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| aggregations: \n"
+ "| [pb]\n"
+ "| Variable [pa]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [pa]\n"
+ "| Const [null]\n"
+ "PhysicalScan [{}, test]\n"
+ " BindBlock:\n",
+ optimized);
+}
+
+TEST(PhysRewriter, Unwind) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("ptest", "test");
+
+ ABT projectionANode = make<EvaluationNode>(
+ "a", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode));
+ ABT projectionBNode =
+ make<EvaluationNode>("b",
+ make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")),
+ std::move(projectionANode));
+
+ ABT unwindNode =
+ make<UnwindNode>("a", "a_pid", false /*retainNonArrays*/, std::move(projectionBNode));
+
+ // This filter should stay above the unwind.
+ ABT filterANode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("a")),
+ std::move(unwindNode));
+
+ // This filter should be pushed down below the unwind.
+ ABT filterBNode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("b")),
+ std::move(filterANode));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"a", "b"}},
+ std::move(filterBNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"test", {{}, {}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(7, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | a\n"
+ "| | b\n"
+ "| RefBlock: \n"
+ "| Variable [a]\n"
+ "| Variable [b]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [b]\n"
+ "| PathIdentity []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [a]\n"
+ "| PathIdentity []\n"
+ "Unwind []\n"
+ "| BindBlock:\n"
+ "| [a]\n"
+ "| Source []\n"
+ "| [a_pid]\n"
+ "| Source []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [b]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [a]\n"
+ "| EvalPath []\n"
+ "| | Variable [ptest]\n"
+ "| PathIdentity []\n"
+ "PhysicalScan [{'<root>': ptest}, test]\n"
+ " BindBlock:\n"
+ " [ptest]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, DuplicateFilter) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterNode1 = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(0)))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterNode2 = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(0)))),
+ make<Variable>("root")),
+ std::move(filterNode1));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1", {{}, {}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(2, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Only one copy of the filter.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_0]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [0]\n"
+ "PhysicalScan [{'<root>': root, 'a': evalTemp_0}, c1]\n"
+ " BindBlock:\n"
+ " [evalTemp_0]\n"
+ " Source []\n"
+ " [root]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, FilterCollation) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalNode = make<EvaluationNode>(
+ "pb",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))),
+ make<Variable>("root")),
+ std::move(evalNode));
+
+ ABT collationNode = make<CollationNode>(CollationRequirement({{"pb", CollationOp::Ascending}}),
+ std::move(filterNode));
+
+ ABT limitSkipNode = make<LimitSkipNode>(LimitSkipRequirement{10, 0}, std::move(collationNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pb"}}, std::move(limitSkipNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1", {{}, {}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(9, 11, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Limit-skip is attached to the collation node by virtue of physical props.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pb\n"
+ "| RefBlock: \n"
+ "| Variable [pb]\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | pb: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [pb]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_0]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "PhysicalScan [{'a': evalTemp_0, 'b': pb}, c1]\n"
+ " BindBlock:\n"
+ " [evalTemp_0]\n"
+ " Source []\n"
+ " [pb]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, EvalCollation) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalNode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}),
+ std::move(evalNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(collationNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1", {{}, {}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | pa: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "PhysicalScan [{'a': pa}, c1]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, FilterEvalCollation) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(10)))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT evalNode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(filterNode));
+
+ ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}),
+ std::move(evalNode));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
+ std::move(collationNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1", {{}, {}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | pa: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [pa]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [10]\n"
+ "PhysicalScan [{'<root>': root, 'a': pa}, c1]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n"
+ " [root]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, FilterIndexing) {
+ using namespace properties;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
+
+ {
+ PrefixId prefixId;
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ // Demonstrate sargable node is rewritten from filter node.
+ // Note: SargableNodes cannot be lowered and by default are not created unless we have
+ // indexes.
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "RIDIntersect [root, hasLeftIntervals]\n"
+ "| Scan [c1]\n"
+ "| BindBlock:\n"
+ "| [root]\n"
+ "| Source []\n"
+ "Sargable [Index]\n"
+ "| | | | requirementsMap: \n"
+ "| | | | refProjection: root, path: 'PathGet [a] PathTraverse [] "
+ "PathIdentity []', intervals: {{{[Const [1], Const [1]]}}}\n"
+ "| | | candidateIndexes: \n"
+ "| | | candidateId: 1, index1, {}, {}, {{{[Const [1], Const [1]]}}}\n"
+ "| | BindBlock:\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Scan [c1]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ optimized);
+ }
+
+ {
+ PrefixId prefixId;
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Test sargable filter is satisfied with an index scan.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: "
+ "{[Const [1], Const [1]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+ }
+
+ {
+ PrefixId prefixId;
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1", {{}, {}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(2, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Test we can optimize sargable filter nodes even without an index.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_0]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "PhysicalScan [{'<root>': root, 'a': evalTemp_0}, c1]\n"
+ " BindBlock:\n"
+ " [evalTemp_0]\n"
+ " Source []\n"
+ " [root]\n"
+ " Source []\n",
+ optimized);
+ }
+}
+
+TEST(PhysRewriter, FilterIndexing1) {
+ using namespace properties;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ // This node will not be converted to Sargable.
+ ABT evalNode = make<EvaluationNode>(
+ "p1",
+ make<EvalPath>(
+ make<PathGet>(
+ "b",
+ make<PathLambda>(make<LambdaAbstraction>(
+ "t",
+ make<BinaryOp>(Operations::Add, make<Variable>("t"), Constant::int64(1))))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))),
+ make<Variable>("p1")),
+ std::move(evalNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"p1"}}, std::move(filterNode));
+
+ PrefixId prefixId;
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(7, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | p1\n"
+ "| RefBlock: \n"
+ "| Variable [p1]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [p1]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [p1]\n"
+ "| EvalPath []\n"
+ "| | Variable [root]\n"
+ "| PathGet [b]\n"
+ "| PathLambda []\n"
+ "| LambdaAbstraction [t]\n"
+ "| BinaryOp [Add]\n"
+ "| | Const [1]\n"
+ "| Variable [t]\n"
+ "PhysicalScan [{'<root>': root}, c1]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, FilterIndexing2) {
+ using namespace properties;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(make<PathGet>("a",
+ make<PathTraverse>(make<PathGet>(
+ "b",
+ make<PathTraverse>(make<PathCompare>(
+ Operations::Eq, Constant::int64(1)))))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
+
+ PrefixId prefixId;
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {{"index1",
+ {{{make<PathGet>("a", make<PathGet>("b", make<PathIdentity>())),
+ CollationOp::Ascending}},
+ false /*isMultiKey*/}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const "
+ "[1], Const [1]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, FilterIndexing2NonSarg) {
+ using namespace properties;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalNode1 = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterNode1 = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1))),
+ make<Variable>("pa")),
+ std::move(evalNode1));
+
+ // Dependent eval node.
+ ABT evalNode2 = make<EvaluationNode>(
+ "pb",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("pa")),
+ std::move(filterNode1));
+
+ // Non-sargable filter.
+ ABT filterNode2 = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathTraverse>(make<PathLambda>(make<LambdaAbstraction>(
+ "var", make<FunctionCall>("someFunction", makeSeq(make<Variable>("var")))))),
+ make<Variable>("pb")),
+ std::move(evalNode2));
+
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2));
+
+ PrefixId prefixId;
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(15, 20, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Demonstrate non-sargable evaluation and filter are moved under the NLJ+seek,
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [pb]\n"
+ "| PathTraverse []\n"
+ "| PathLambda []\n"
+ "| LambdaAbstraction [var]\n"
+ "| FunctionCall [someFunction]\n"
+ "| Variable [var]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [pb]\n"
+ "| EvalPath []\n"
+ "| | Variable [pa]\n"
+ "| PathGet [b]\n"
+ "| PathIdentity []\n"
+ "IndexScan [{'<indexKey> 0': pa, '<rid>': rid_0}, scanDefName: c1, indexDefName: index1, "
+ "interval: {[Const [1], Const [1]]}]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, FilterIndexing3) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalNode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1))),
+ make<Variable>("pa")),
+ std::move(evalNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("b"), CollationOp::Ascending}},
+ false /*isMultiKey*/,
+ {DistributionType::Centralized},
+ {}}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(5, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // We dont need a Seek if we dont have multi-key paths.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "IndexScan [{'<indexKey> 0': pa}, scanDefName: c1, indexDefName: index1, interval: {[Const "
+ "[1], Const [1]], (-inf, +inf)}]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, FilterIndexing3MultiKey) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalNode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1))),
+ make<Variable>("pa")),
+ std::move(evalNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {{"index1",
+ IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending},
+ {makeIndexPath("b"), CollationOp::Ascending}},
+ true /*isMultiKey*/,
+ {DistributionType::Centralized},
+ {}}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(7, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // We need a Seek to obtain value for "a".
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'a': pa}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [pa]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "Unique []\n"
+ "| projections: \n"
+ "| rid_0\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const "
+ "[1], Const [1]], (-inf, +inf)}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, FilterIndexing4) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalNode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterANode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Lt, Constant::int64(1))),
+ make<Variable>("pa")),
+ std::move(evalNode));
+
+ ABT filterBNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "b", make<PathTraverse>(make<PathCompare>(Operations::Lt, Constant::int64(1)))),
+ make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT filterCNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "c", make<PathTraverse>(make<PathCompare>(Operations::Lt, Constant::int64(1)))),
+ make<Variable>("root")),
+ std::move(filterBNode));
+
+ ABT filterDNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "d", make<PathTraverse>(make<PathCompare>(Operations::Lt, Constant::int64(1)))),
+ make<Variable>("root")),
+ std::move(filterCNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterDNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("b"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("c"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("d"), CollationOp::Ascending}},
+ false /*isMultiKey*/,
+ {DistributionType::Centralized},
+ {}}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+
+ // For now leave only GroupBy+Union RIDIntersect.
+ phaseManager.getHints()._disableHashJoinRIDIntersect = true;
+
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(65, 80, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_8]\n"
+ "| PathCompare [Lt]\n"
+ "| Const [1]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_7]\n"
+ "| PathCompare [Lt]\n"
+ "| Const [1]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_6]\n"
+ "| PathCompare [Lt]\n"
+ "| Const [1]\n"
+ "IndexScan [{'<indexKey> 0': pa, '<indexKey> 1': evalTemp_6, '<indexKey> 2': evalTemp_7, "
+ "'<indexKey> 3': evalTemp_8}, scanDefName: c1, indexDefName: index1, interval: {(-inf, "
+ "Const [1]), (-inf, +inf), (-inf, +inf), (-inf, +inf)}]\n"
+ " BindBlock:\n"
+ " [evalTemp_6]\n"
+ " Source []\n"
+ " [evalTemp_7]\n"
+ " Source []\n"
+ " [evalTemp_8]\n"
+ " Source []\n"
+ " [pa]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, FilterIndexing5) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalANode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterANode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Gt, Constant::int64(0))),
+ make<Variable>("pa")),
+ std::move(evalANode));
+
+ ABT evalBNode = make<EvaluationNode>(
+ "pb",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT filterBNode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Gt, Constant::int64(0))),
+ make<Variable>("pb")),
+ std::move(evalBNode));
+
+ ABT collationNode = make<CollationNode>(CollationRequirement({{"pb", CollationOp::Ascending}}),
+ std::move(filterBNode));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa", "pb"}},
+ std::move(collationNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("b"), CollationOp::Ascending}},
+ false /*isMultiKey*/,
+ {DistributionType::Centralized},
+ {}}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(25, 55, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // We can cover both fields with the index, and need separate sort on "b".
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| | pb\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "| Variable [pb]\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | pb: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [pb]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [pb]\n"
+ "| PathCompare [Gt]\n"
+ "| Const [0]\n"
+ "IndexScan [{'<indexKey> 0': pa, '<indexKey> 1': pb}, scanDefName: c1, indexDefName: "
+ "index1, interval: {(Const [0], +inf), (-inf, +inf)}]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n"
+ " [pb]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, FilterIndexing6) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalANode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterANode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(0))),
+ make<Variable>("pa")),
+ std::move(evalANode));
+
+ ABT evalBNode = make<EvaluationNode>(
+ "pb",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT filterBNode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Gt, Constant::int64(0))),
+ make<Variable>("pb")),
+ std::move(evalBNode));
+
+ ABT collationNode = make<CollationNode>(CollationRequirement({{"pb", CollationOp::Ascending}}),
+ std::move(filterBNode));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa", "pb"}},
+ std::move(collationNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("b"), CollationOp::Ascending}},
+ false /*isMultiKey*/,
+ {DistributionType::Centralized},
+ {}}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(9, 15, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // We can cover both fields with the index, and do not need a separate sort on "b".
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| | pb\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "| Variable [pb]\n"
+ "IndexScan [{'<indexKey> 0': pa, '<indexKey> 1': pb}, scanDefName: c1, indexDefName: "
+ "index1, interval: {[Const [0], Const [0]], (Const [0], +inf)}]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n"
+ " [pb]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, FilterIndexingStress) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT result = make<ScanNode>("root", "c1");
+
+ static constexpr size_t kFilterCount = 15;
+ // A query with a large number of filters on different fields.
+ for (size_t index = 0; index < kFilterCount; index++) {
+ std::ostringstream os;
+ os << "field" << index;
+
+ result = make<FilterNode>(
+ make<EvalFilter>(make<PathGet>(os.str(),
+ make<PathTraverse>(make<PathCompare>(
+ Operations::Eq, Constant::int64(0)))),
+ make<Variable>("root")),
+ std::move(result));
+ }
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(result));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{makeNonMultikeyIndexPath("field0"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("field1"), CollationOp::Ascending}},
+ false /*isMultiKey*/,
+ {DistributionType::Centralized},
+ {}}},
+ {"index2",
+ IndexDefinition{{{makeNonMultikeyIndexPath("field2"), CollationOp::Ascending}},
+ false /*isMultiKey*/,
+ {DistributionType::Centralized},
+ {}}},
+ {"index3",
+ IndexDefinition{{{makeNonMultikeyIndexPath("field3"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("field4"), CollationOp::Ascending}},
+ false /*isMultiKey*/,
+ {DistributionType::Centralized},
+ {}}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+
+ // Without the changes to restrict SargableNode split to which this test is tied, we would
+ // be exploring 2^kFilterCount plans, one for each created group.
+ ASSERT_EQ(51, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [root]\n"
+ "| PathGet [field14]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [root]\n"
+ "| PathGet [field13]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [root]\n"
+ "| PathGet [field12]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [root]\n"
+ "| PathGet [field11]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [root]\n"
+ "| PathGet [field10]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [0]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| Filter []\n"
+ "| | EvalFilter []\n"
+ "| | | Variable [evalTemp_17]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Eq]\n"
+ "| | Const [0]\n"
+ "| Filter []\n"
+ "| | EvalFilter []\n"
+ "| | | Variable [evalTemp_16]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Eq]\n"
+ "| | Const [0]\n"
+ "| Filter []\n"
+ "| | EvalFilter []\n"
+ "| | | Variable [evalTemp_15]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Eq]\n"
+ "| | Const [0]\n"
+ "| Filter []\n"
+ "| | EvalFilter []\n"
+ "| | | Variable [evalTemp_14]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Eq]\n"
+ "| | Const [0]\n"
+ "| Filter []\n"
+ "| | EvalFilter []\n"
+ "| | | Variable [evalTemp_13]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Eq]\n"
+ "| | Const [0]\n"
+ "| Filter []\n"
+ "| | EvalFilter []\n"
+ "| | | Variable [evalTemp_12]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Eq]\n"
+ "| | Const [0]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root, 'field2': evalTemp_12, 'field5': "
+ "evalTemp_13, 'field6': evalTemp_14, 'field7': evalTemp_15, 'field8': evalTemp_16, "
+ "'field9': evalTemp_17}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [evalTemp_12]\n"
+ "| | Source []\n"
+ "| | [evalTemp_13]\n"
+ "| | Source []\n"
+ "| | [evalTemp_14]\n"
+ "| | Source []\n"
+ "| | [evalTemp_15]\n"
+ "| | Source []\n"
+ "| | [evalTemp_16]\n"
+ "| | Source []\n"
+ "| | [evalTemp_17]\n"
+ "| | Source []\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "MergeJoin []\n"
+ "| | | Condition\n"
+ "| | | rid_0 = rid_1\n"
+ "| | Collation\n"
+ "| | Ascending\n"
+ "| Union []\n"
+ "| | BindBlock:\n"
+ "| | [rid_1]\n"
+ "| | Source []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [rid_1]\n"
+ "| | Variable [rid_0]\n"
+ "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index3, interval: {[Const "
+ "[0], Const [0]], [Const [0], Const [0]]}]\n"
+ "| BindBlock:\n"
+ "| [rid_0]\n"
+ "| Source []\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const "
+ "[0], Const [0]], [Const [0], Const [0]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, FilterIndexingVariable) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ // In the absence of full implementation of query parameterization, here we pretend we have a
+ // function "getQueryParam" which will return a query parameter by index.
+ const auto getQueryParamFn = [](const size_t index) {
+ return make<FunctionCall>("getQueryParam", makeSeq(Constant::int32(index)));
+ };
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ // Encode a condition using two query parameters (expressed as functions):
+ // "a" > param_0 AND "a" >= param_1 (observe param_1 comparison is inclusive).
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>("a",
+ make<PathTraverse>(make<PathComposeM>(
+ make<PathCompare>(Operations::Gt, getQueryParamFn(0)),
+ make<PathCompare>(Operations::Gte, getQueryParamFn(1))))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Observe unioning of two index scans with complex expressions for bounds. This encodes:
+ // (max(param_0, param_1), +inf) U [param_0 > param_1 ? MaxKey : param_1, max(param_0, param_1)]
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [rid_0]\n"
+ "| aggregations: \n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [rid_0]\n"
+ "| | Source []\n"
+ "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {(If [] "
+ "BinaryOp [Gte] FunctionCall [getQueryParam] Const [0] FunctionCall [getQueryParam] Const "
+ "[1] FunctionCall [getQueryParam] Const [0] FunctionCall [getQueryParam] Const [1], "
+ "+inf)}]\n"
+ "| BindBlock:\n"
+ "| [rid_0]\n"
+ "| Source []\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[If [] "
+ "BinaryOp [Gte] FunctionCall [getQueryParam] Const [0] FunctionCall [getQueryParam] Const "
+ "[1] Const [maxKey] FunctionCall [getQueryParam] Const [1], If [] BinaryOp [Gte] "
+ "FunctionCall [getQueryParam] Const [0] FunctionCall [getQueryParam] Const [1] "
+ "FunctionCall [getQueryParam] Const [0] FunctionCall [getQueryParam] Const [1]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, FilterReorder) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT result = make<ScanNode>("root", "c1");
+
+ PartialSchemaSelHints hints;
+ static constexpr size_t kFilterCount = 5;
+ for (size_t i = 0; i < kFilterCount; i++) {
+ ProjectionName projName = prefixId.getNextId("field");
+ hints.emplace(
+ PartialSchemaKey{"root",
+ make<PathGet>(projName, make<PathTraverse>(make<PathIdentity>()))},
+ 0.1 * (kFilterCount - i));
+ result = make<FilterNode>(
+ make<EvalFilter>(make<PathGet>(std::move(projName),
+ make<PathTraverse>(make<PathCompare>(
+ Operations::Eq, Constant::int64(i)))),
+ make<Variable>("root")),
+ std::move(result));
+ }
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(result));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ false /*requireRID*/,
+ {{{"c1", ScanDefinition{{}, {}}}}},
+ std::make_unique<HintedCE>(std::move(hints)),
+ std::make_unique<DefaultCosting>(),
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(2, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Observe filters are ordered from most selective (lowest sel) to least selective (highest
+ // sel).
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_0]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_1]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_2]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_3]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [3]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_4]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [4]\n"
+ "PhysicalScan [{'<root>': root, 'field_0': evalTemp_0, 'field_1': evalTemp_1, "
+ "'field_2': "
+ "evalTemp_2, 'field_3': evalTemp_3, 'field_4': evalTemp_4}, c1]\n"
+ " BindBlock:\n"
+ " [evalTemp_0]\n"
+ " Source []\n"
+ " [evalTemp_1]\n"
+ " Source []\n"
+ " [evalTemp_2]\n"
+ " Source []\n"
+ " [evalTemp_3]\n"
+ " Source []\n"
+ " [evalTemp_4]\n"
+ " Source []\n"
+ " [root]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, CoveredScan) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalNode1 = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(evalNode1));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "IndexScan [{'<indexKey> 0': pa}, scanDefName: c1, indexDefName: index1, interval: "
+ "{(-inf, "
+ "+inf)}]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, EvalIndexing) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalNode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterNode =
+ make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(1)),
+ make<Variable>("pa")),
+ std::move(evalNode));
+
+ ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}),
+ std::move(filterNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(collationNode));
+
+ {
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(5, 10, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Should not need a collation node.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "IndexScan [{'<indexKey> 0': pa}, scanDefName: c1, indexDefName: index1, "
+ "interval: {(Const [1], +inf)}]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n",
+ optimized);
+ }
+
+ {
+ // Index and collation node have incompatible ops.
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ makeIndexDefinition("a", CollationOp::Clustered, false /*isMultiKey*/)}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(10, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Index does not have the right collation and now we need a collation node.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | pa: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "IndexScan [{'<indexKey> 0': pa}, scanDefName: c1, indexDefName: index1, "
+ "interval: {(Const [1], +inf)}]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n",
+ optimized);
+ }
+}
+
+TEST(PhysRewriter, EvalIndexing1) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalNode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterNode =
+ make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Eq, Constant::int64(1)),
+ make<Variable>("pa")),
+ std::move(evalNode));
+
+ ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}),
+ std::move(filterNode));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
+ std::move(collationNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(8, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const "
+ "[1], Const [1]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, EvalIndexing2) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalNode1 = make<EvaluationNode>(
+ "pa1",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT evalNode2 = make<EvaluationNode>(
+ "pa2",
+ make<EvalPath>(make<PathField>("a", make<PathConstant>(make<Variable>("pa1"))),
+ Constant::int32(0)),
+ std::move(evalNode1));
+
+ ABT evalNode3 = make<EvaluationNode>(
+ "pa3",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("pa2")),
+ std::move(evalNode2));
+
+ ABT collationNode = make<CollationNode>(CollationRequirement({{"pa3", CollationOp::Ascending}}),
+ std::move(evalNode3));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa2"}},
+ std::move(collationNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::ConstEvalPre,
+ OptPhaseManager::OptPhase::PathFuse,
+ OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(10, 20, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pa2\n"
+ "| RefBlock: \n"
+ "| Variable [pa2]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [pa3]\n"
+ "| Variable [pa1]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [pa2]\n"
+ "| EvalPath []\n"
+ "| | Const [0]\n"
+ "| PathField [a]\n"
+ "| PathConstant []\n"
+ "| Variable [pa1]\n"
+ "IndexScan [{'<indexKey> 0': pa1}, scanDefName: c1, indexDefName: index1, interval: "
+ "{(-inf, +inf)}]\n"
+ " BindBlock:\n"
+ " [pa1]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, MultiKeyIndex) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalANode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterANode =
+ make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Eq, Constant::int64(1)),
+ make<Variable>("pa")),
+ std::move(evalANode));
+
+ ABT evalBNode = make<EvaluationNode>(
+ "pb",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT filterBNode =
+ make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(2)),
+ make<Variable>("pb")),
+ std::move(evalBNode));
+
+ ABT collationNode = make<CollationNode>(
+ CollationRequirement({{"pa", CollationOp::Ascending}, {"pb", CollationOp::Ascending}}),
+ std::move(filterBNode));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
+ std::move(collationNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1", makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)},
+ {"index2",
+ makeIndexDefinition("b", CollationOp::Descending, false /*isMultiKey*/)}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ {
+ ABT optimized = rootNode;
+
+ // Test RIDIntersect using only Group+Union.
+ phaseManager.getHints()._disableHashJoinRIDIntersect = true;
+
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(15, 25, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // GroupBy+Union cannot propagate collation requirement, and we need a separate
+ // CollationNode.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | pa: Ascending\n"
+ "| | pb: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "| Variable [pb]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | FunctionCall [getArraySize]\n"
+ "| | Variable [sides_0]\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [rid_0]\n"
+ "| aggregations: \n"
+ "| [pa]\n"
+ "| FunctionCall [$max]\n"
+ "| Variable [unionTemp_0]\n"
+ "| [pb]\n"
+ "| FunctionCall [$max]\n"
+ "| Variable [unionTemp_1]\n"
+ "| [sides_0]\n"
+ "| FunctionCall [$addToSet]\n"
+ "| Variable [sideId_0]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [rid_0]\n"
+ "| | Source []\n"
+ "| | [sideId_0]\n"
+ "| | Source []\n"
+ "| | [unionTemp_0]\n"
+ "| | Source []\n"
+ "| | [unionTemp_1]\n"
+ "| | Source []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [unionTemp_1]\n"
+ "| | Variable [pb]\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [unionTemp_0]\n"
+ "| | Const [Nothing]\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [sideId_0]\n"
+ "| | Const [1]\n"
+ "| IndexScan [{'<indexKey> 0': pb, '<rid>': rid_0}, scanDefName: c1, "
+ "indexDefName: "
+ "index2, interval: {(Const [2], +inf)}]\n"
+ "| BindBlock:\n"
+ "| [pb]\n"
+ "| Source []\n"
+ "| [rid_0]\n"
+ "| Source []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [unionTemp_1]\n"
+ "| Const [Nothing]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [unionTemp_0]\n"
+ "| Variable [pa]\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [sideId_0]\n"
+ "| Const [0]\n"
+ "IndexScan [{'<indexKey> 0': pa, '<rid>': rid_0}, scanDefName: c1, indexDefName: "
+ "index1, interval: {[Const [1], Const [1]]}]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+ }
+
+ {
+ ABT optimized = rootNode;
+
+ phaseManager.getHints()._disableGroupByAndUnionRIDIntersect = false;
+ phaseManager.getHints()._disableHashJoinRIDIntersect = false;
+
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(15, 25, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Index2 will be used in reverse direction.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "HashJoin [joinType: Inner]\n"
+ "| | Condition\n"
+ "| | rid_0 = rid_1\n"
+ "| Union []\n"
+ "| | BindBlock:\n"
+ "| | [rid_1]\n"
+ "| | Source []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [rid_1]\n"
+ "| | Variable [rid_0]\n"
+ "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index2, interval: "
+ "{(Const [2], +inf)}, reversed]\n"
+ "| BindBlock:\n"
+ "| [rid_0]\n"
+ "| Source []\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: "
+ "{[Const [1], Const [1]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+ }
+}
+
+TEST(PhysRewriter, CompoundIndex1) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalANode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterANode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1))),
+ make<Variable>("pa")),
+ std::move(evalANode));
+
+ ABT evalBNode = make<EvaluationNode>(
+ "pb",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT filterBNode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2))),
+ make<Variable>("pb")),
+ std::move(evalBNode));
+
+ ABT filterCNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "c", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))),
+ make<Variable>("root")),
+ std::move(filterBNode));
+
+ ABT filterDNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "d", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(4)))),
+ make<Variable>("root")),
+ std::move(filterCNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterDNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("c"), CollationOp::Descending}},
+ false /*isMultiKey*/}},
+ {"index2",
+ IndexDefinition{{{makeNonMultikeyIndexPath("b"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("d"), CollationOp::Ascending}},
+ false /*isMultiKey*/}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(60, 110, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "MergeJoin []\n"
+ "| | | Condition\n"
+ "| | | rid_0 = rid_1\n"
+ "| | Collation\n"
+ "| | Ascending\n"
+ "| Union []\n"
+ "| | BindBlock:\n"
+ "| | [rid_1]\n"
+ "| | Source []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [rid_1]\n"
+ "| | Variable [rid_0]\n"
+ "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index2, interval: {[Const "
+ "[2], Const [2]], [Const [4], Const [4]]}]\n"
+ "| BindBlock:\n"
+ "| [rid_0]\n"
+ "| Source []\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const "
+ "[1], Const [1]], [Const [3], Const [3]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, CompoundIndex2) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalANode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterANode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1))),
+ make<Variable>("pa")),
+ std::move(evalANode));
+
+ ABT evalBNode = make<EvaluationNode>(
+ "pb",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT filterBNode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2))),
+ make<Variable>("pb")),
+ std::move(evalBNode));
+
+ ABT filterCNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "c", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))),
+ make<Variable>("root")),
+ std::move(filterBNode));
+
+ ABT filterDNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "d", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(4)))),
+ make<Variable>("root")),
+ std::move(filterCNode));
+
+ ABT collationNode = make<CollationNode>(
+ CollationRequirement({{"pa", CollationOp::Ascending}, {"pb", CollationOp::Ascending}}),
+ std::move(filterDNode));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
+ std::move(collationNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {
+ {"index1",
+ IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("c"), CollationOp::Descending}},
+ false /*isMultiKey*/}},
+ {"index2",
+ IndexDefinition{{{makeNonMultikeyIndexPath("b"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("d"), CollationOp::Ascending}},
+ false /*isMultiKey*/}},
+ }}}}},
+ {true /*debugMode*/, 3 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(100, 170, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "MergeJoin []\n"
+ "| | | Condition\n"
+ "| | | rid_0 = rid_1\n"
+ "| | Collation\n"
+ "| | Ascending\n"
+ "| Union []\n"
+ "| | BindBlock:\n"
+ "| | [rid_1]\n"
+ "| | Source []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [rid_1]\n"
+ "| | Variable [rid_0]\n"
+ "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index2, interval: "
+ "{[Const "
+ "[2], Const [2]], [Const [4], Const [4]]}]\n"
+ "| BindBlock:\n"
+ "| [rid_0]\n"
+ "| Source []\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const "
+ "[1], Const [1]], [Const [3], Const [3]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, CompoundIndex3) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalANode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterANode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1))),
+ make<Variable>("pa")),
+ std::move(evalANode));
+
+ ABT evalBNode = make<EvaluationNode>(
+ "pb",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT filterBNode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2))),
+ make<Variable>("pb")),
+ std::move(evalBNode));
+
+ ABT filterCNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "c", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))),
+ make<Variable>("root")),
+ std::move(filterBNode));
+
+ ABT filterDNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "d", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(4)))),
+ make<Variable>("root")),
+ std::move(filterCNode));
+
+ ABT collationNode = make<CollationNode>(
+ CollationRequirement({{"pa", CollationOp::Ascending}, {"pb", CollationOp::Ascending}}),
+ std::move(filterDNode));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
+ std::move(collationNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {{"index1",
+ IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending},
+ {makeIndexPath("c"), CollationOp::Descending}},
+ true /*isMultiKey*/}},
+ {"index2",
+ IndexDefinition{{{makeIndexPath("b"), CollationOp::Ascending},
+ {makeIndexPath("d"), CollationOp::Ascending}},
+ true /*isMultiKey*/}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(70, 110, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | pa: Ascending\n"
+ "| | pb: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "| Variable [pb]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root, 'a': pa, 'b': pb}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [pa]\n"
+ "| | Source []\n"
+ "| | [pb]\n"
+ "| | Source []\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "MergeJoin []\n"
+ "| | | Condition\n"
+ "| | | rid_0 = rid_1\n"
+ "| | Collation\n"
+ "| | Ascending\n"
+ "| Union []\n"
+ "| | BindBlock:\n"
+ "| | [rid_1]\n"
+ "| | Source []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [rid_1]\n"
+ "| | Variable [rid_0]\n"
+ "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index2, interval: {[Const "
+ "[2], Const [2]], [Const [4], Const [4]]}]\n"
+ "| BindBlock:\n"
+ "| [rid_0]\n"
+ "| Source []\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const "
+ "[1], Const [1]], [Const [3], Const [3]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, IndexBoundsIntersect) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterNode1 = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "b", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterNode2 = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathComposeA>(
+ make<PathComposeM>(make<PathGet>("a",
+ make<PathTraverse>(make<PathCompare>(
+ Operations::Gt, Constant::int64(70)))),
+ make<PathGet>("a",
+ make<PathTraverse>(make<PathCompare>(
+ Operations::Lt, Constant::int64(90))))),
+ make<PathGet>(
+ "a",
+ make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(100))))),
+ make<Variable>("root")),
+ std::move(filterNode1));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{}}, std::move(filterNode2));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {{"index1",
+ IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending},
+ {makeIndexPath("b"), CollationOp::Ascending}},
+ true /*isMultiKey*/}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(10, 15, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| RefBlock: \n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_1]\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [rid_0]\n"
+ "| aggregations: \n"
+ "| [evalTemp_1]\n"
+ "| FunctionCall [$first]\n"
+ "| Variable [disjunction_0]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [disjunction_0]\n"
+ "| | Source []\n"
+ "| | [rid_0]\n"
+ "| | Source []\n"
+ "| IndexScan [{'<indexKey> 1': disjunction_0, '<rid>': rid_0}, scanDefName: c1, "
+ "indexDefName: index1, interval: {[Const [100], Const [100]], (-inf, +inf)}]\n"
+ "| BindBlock:\n"
+ "| [disjunction_0]\n"
+ "| Source []\n"
+ "| [rid_0]\n"
+ "| Source []\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | FunctionCall [getArraySize]\n"
+ "| | Variable [sides_0]\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [rid_0]\n"
+ "| aggregations: \n"
+ "| [disjunction_0]\n"
+ "| FunctionCall [$first]\n"
+ "| Variable [conjunction_0]\n"
+ "| [sides_0]\n"
+ "| FunctionCall [$addToSet]\n"
+ "| Variable [sideId_0]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [conjunction_0]\n"
+ "| | Source []\n"
+ "| | [rid_0]\n"
+ "| | Source []\n"
+ "| | [sideId_0]\n"
+ "| | Source []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [sideId_0]\n"
+ "| | Const [1]\n"
+ "| IndexScan [{'<indexKey> 1': conjunction_0, '<rid>': rid_0}, scanDefName: c1, "
+ "indexDefName: index1, interval: {(-inf, Const [90]), (-inf, +inf)}]\n"
+ "| BindBlock:\n"
+ "| [conjunction_0]\n"
+ "| Source []\n"
+ "| [rid_0]\n"
+ "| Source []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [sideId_0]\n"
+ "| Const [0]\n"
+ "IndexScan [{'<indexKey> 1': conjunction_0, '<rid>': rid_0}, scanDefName: c1, "
+ "indexDefName: index1, interval: {(Const [70], +inf), (-inf, +inf)}]\n"
+ " BindBlock:\n"
+ " [conjunction_0]\n"
+ " Source []\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, IndexBoundsIntersect1) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalNode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathComposeM>(
+ make<PathTraverse>(make<PathCompare>(Operations::Gt, Constant::int64(70))),
+ make<PathTraverse>(make<PathCompare>(Operations::Lt, Constant::int64(90)))),
+ make<Variable>("pa")),
+ std::move(evalNode));
+
+ ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}),
+ std::move(filterNode));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
+ std::move(collationNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}},
+ false /*isMultiKey*/}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(15, 20, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {(Const "
+ "[70], Const [90])}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, IndexBoundsIntersect2) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalNode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(make<PathTraverse>(make<PathComposeM>(
+ make<PathCompare>(Operations::Gt, Constant::int64(70)),
+ make<PathCompare>(Operations::Lt, Constant::int64(90)))),
+ make<Variable>("pa")),
+ std::move(evalNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {{"index1",
+ IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}},
+ true /*isMultiKey*/}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(6, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Demonstrate we can intersect the bounds here because composition does not contain
+ // traverse.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "Unique []\n"
+ "| projections: \n"
+ "| rid_0\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {(Const "
+ "[70], Const [90])}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, IndexBoundsIntersect3) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>("a",
+ make<PathTraverse>(make<PathComposeM>(
+ make<PathGet>("b",
+ make<PathTraverse>(make<PathCompare>(
+ Operations::Gt, Constant::int64(70)))),
+ make<PathGet>("b",
+ make<PathTraverse>(make<PathCompare>(
+ Operations::Lt, Constant::int64(90))))))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{makeIndexPath(FieldPathType{"a", "b"}, true /*isMultiKey*/),
+ CollationOp::Ascending}},
+ true /*isMultiKey*/}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // We intersect indexes because the outer composition is over the same field ("b").
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | FunctionCall [getArraySize]\n"
+ "| | Variable [sides_0]\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [rid_0]\n"
+ "| aggregations: \n"
+ "| [sides_0]\n"
+ "| FunctionCall [$addToSet]\n"
+ "| Variable [sideId_0]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [rid_0]\n"
+ "| | Source []\n"
+ "| | [sideId_0]\n"
+ "| | Source []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [sideId_0]\n"
+ "| | Const [1]\n"
+ "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: "
+ "{(-inf, "
+ "Const [90])}]\n"
+ "| BindBlock:\n"
+ "| [rid_0]\n"
+ "| Source []\n"
+ "Evaluation []\n"
+ "| BindBlock:\n"
+ "| [sideId_0]\n"
+ "| Const [0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {(Const "
+ "[70], +inf)}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, IndexBoundsIntersect4) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>("a",
+ make<PathTraverse>(make<PathComposeM>(
+ make<PathGet>("b",
+ make<PathTraverse>(make<PathCompare>(
+ Operations::Gt, Constant::int64(70)))),
+ make<PathGet>("c",
+ make<PathTraverse>(make<PathCompare>(
+ Operations::Lt, Constant::int64(90))))))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{makeIndexPath(FieldPathType{"a", "b"}, true /*isMultiKey*/),
+ CollationOp::Ascending}},
+ true /*isMultiKey*/}},
+ {"index2",
+ IndexDefinition{{{makeIndexPath(FieldPathType{"a", "c"}, true /*isMultiKey*/),
+ CollationOp::Ascending}},
+ true /*isMultiKey*/}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(3, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // We do not intersect indexes because the outer composition is over the different fields.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [root]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathComposeM []\n"
+ "| | PathGet [c]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Lt]\n"
+ "| | Const [90]\n"
+ "| PathGet [b]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Gt]\n"
+ "| Const [70]\n"
+ "PhysicalScan [{'<root>': root}, c1]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, IndexResidualReq) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalANode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterANode =
+ make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(0)),
+ make<Variable>("pa")),
+ std::move(evalANode));
+
+ ABT filterBNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "b", make<PathGet>("c", make<PathCompare>(Operations::Gt, Constant::int64(0)))),
+ make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}),
+ std::move(filterBNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(collationNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending},
+ {makeNonMultikeyIndexPath("b"), CollationOp::Ascending}},
+ false /*isMultiKey*/}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(10, 26, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Make sure we can use the index to cover "b" while testing "b.c" with a separate filter.
+ ASSERT_EXPLAIN_PROPS_V2(
+ "Properties [cost: 0.070002, localCost: 0, adjustedCE: 10]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 10\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| | root\n"
+ "| | indexingAvailability: \n"
+ "| | [groupId: 0, scanProjection: root, scanDefName: c1]\n"
+ "| | collectionAvailability: \n"
+ "| | c1\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| Physical:\n"
+ "| distribution: \n"
+ "| type: Centralized\n"
+ "| indexingRequirement: \n"
+ "| Complete, dedupRID\n"
+ "Root []\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "Properties [cost: 0.070002, localCost: 0.070002, adjustedCE: 10]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 10\n"
+ "| | requirementCEs: \n"
+ "| | refProjection: root, path: 'PathGet [a] PathIdentity []', ce: 100\n"
+ "| | refProjection: root, path: 'PathGet [b] PathGet [c] PathIdentity []', "
+ "ce: 100\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| | root\n"
+ "| | indexingAvailability: \n"
+ "| | [groupId: 0, scanProjection: root, scanDefName: c1]\n"
+ "| | collectionAvailability: \n"
+ "| | c1\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| Physical:\n"
+ "| collation: \n"
+ "| pa: Ascending\n"
+ "| projections: \n"
+ "| pa\n"
+ "| distribution: \n"
+ "| type: Centralized\n"
+ "| indexingRequirement: \n"
+ "| Index, dedupRID\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_1]\n"
+ "| PathGet [c]\n"
+ "| PathCompare [Gt]\n"
+ "| Const [0]\n"
+ "IndexScan [{'<indexKey> 0': pa, '<indexKey> 1': evalTemp_1}, scanDefName: c1, "
+ "indexDefName: index1, interval: {(Const [0], +inf), (-inf, +inf)}]\n"
+ " BindBlock:\n"
+ " [evalTemp_1]\n"
+ " Source []\n"
+ " [pa]\n"
+ " Source []\n",
+ phaseManager);
+}
+
+TEST(PhysRewriter, IndexResidualReq1) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterANode = make<FilterNode>(
+ make<EvalFilter>(make<PathGet>("a", make<PathCompare>(Operations::Eq, Constant::int64(0))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterBNode = make<FilterNode>(
+ make<EvalFilter>(make<PathGet>("b", make<PathCompare>(Operations::Eq, Constant::int64(0))),
+ make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT filterCNode = make<FilterNode>(
+ make<EvalFilter>(make<PathGet>("c", make<PathCompare>(Operations::Eq, Constant::int64(0))),
+ make<Variable>("root")),
+ std::move(filterBNode));
+
+ ABT evalDNode = make<EvaluationNode>(
+ "pd",
+ make<EvalPath>(make<PathGet>("d", make<PathIdentity>()), make<Variable>("root")),
+ std::move(filterCNode));
+
+ ABT collationNode = make<CollationNode>(CollationRequirement({{"pd", CollationOp::Ascending}}),
+ std::move(evalDNode));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
+ std::move(collationNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ makeCompositeIndexDefinition({{"a", CollationOp::Ascending, false /*isMultiKey*/},
+ {"b", CollationOp::Ascending, false /*isMultiKey*/},
+ {"c", CollationOp::Ascending, false /*isMultiKey*/},
+ {"d", CollationOp::Ascending, false /*isMultiKey*/}},
+ false /*isMultiKey*/)},
+ {"index2",
+ makeCompositeIndexDefinition({{"a", CollationOp::Ascending, false /*isMultiKey*/},
+ {"b", CollationOp::Ascending, false /*isMultiKey*/},
+ {"d", CollationOp::Ascending, false /*isMultiKey*/}},
+ false /*isMultiKey*/)},
+ {"index3",
+ makeCompositeIndexDefinition({{"a", CollationOp::Ascending, false /*isMultiKey*/},
+ {"d", CollationOp::Ascending, false /*isMultiKey*/}},
+ false /*isMultiKey*/)}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(25, 30, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Prefer index1 over index2 and index3 in order to cover all fields.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const "
+ "[0], Const [0]], [Const [0], Const [0]], [Const [0], Const [0]], (-inf, +inf)}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, IndexResidualReq2) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterANode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(0)))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterBNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "b", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(0)))),
+ make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterBNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {{"index1",
+ makeCompositeIndexDefinition(
+ {{"a", CollationOp::Ascending, true /*isMultiKey*/},
+ {"c", CollationOp::Ascending, true /*isMultiKey*/},
+ {"b", CollationOp::Ascending, true /*isMultiKey*/}})}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(7, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // We can cover "b" with the index and filter before we Seek.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "Unique []\n"
+ "| projections: \n"
+ "| rid_0\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_4]\n"
+ "| PathCompare [Eq]\n"
+ "| Const [0]\n"
+ "IndexScan [{'<indexKey> 2': evalTemp_4, '<rid>': rid_0}, scanDefName: c1, "
+ "indexDefName: "
+ "index1, interval: {[Const [0], Const [0]], (-inf, +inf), (-inf, +inf)}]\n"
+ " BindBlock:\n"
+ " [evalTemp_4]\n"
+ " Source []\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, ElemMatchIndex) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ // This encodes an elemMatch with a conjunction >70 and <90.
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a",
+ make<PathComposeM>(make<PathArr>(),
+ make<PathTraverse>(make<PathComposeM>(
+ make<PathCompare>(Operations::Gt, Constant::int64(70)),
+ make<PathCompare>(Operations::Lt, Constant::int64(90)))))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(5, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [root]\n"
+ "| PathGet [a]\n"
+ "| PathArr []\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "Unique []\n"
+ "| projections: \n"
+ "| rid_0\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {(Const "
+ "[70], Const [90])}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, ElemMatchIndex1) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterNode1 = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "b", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ // This encodes an elemMatch with a conjunction >70 and <90.
+ ABT filterNode2 = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a",
+ make<PathComposeM>(make<PathArr>(),
+ make<PathTraverse>(make<PathComposeM>(
+ make<PathCompare>(Operations::Gt, Constant::int64(70)),
+ make<PathCompare>(Operations::Lt, Constant::int64(90)))))),
+ make<Variable>("root")),
+ std::move(filterNode1));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {{"index1",
+ makeCompositeIndexDefinition(
+ {{"b", CollationOp::Ascending, true /*isMultiKey*/},
+ {"a", CollationOp::Ascending, true /*isMultiKey*/}})}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(5, 10, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Demonstrate we can cover both the filter and the extracted elemMatch predicate with the
+ // index.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [root]\n"
+ "| PathGet [a]\n"
+ "| PathArr []\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "Unique []\n"
+ "| projections: \n"
+ "| rid_0\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const "
+ "[1], Const [1]], (Const [70], Const [90])}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, ObjectElemMatch) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a",
+ make<PathComposeM>(make<PathArr>(),
+ make<PathTraverse>(make<PathComposeM>(
+ make<PathGet>("b",
+ make<PathTraverse>(make<PathCompare>(
+ Operations::Eq, Constant::int64(1)))),
+ make<PathGet>("c",
+ make<PathTraverse>(make<PathCompare>(
+ Operations::Eq, Constant::int64(2)))))))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {{"index1",
+ makeCompositeIndexDefinition(
+ {{"b", CollationOp::Ascending, true /*isMultiKey*/},
+ {"a", CollationOp::Ascending, true /*isMultiKey*/}})}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // We currently cannot use indexes with ObjectElemMatch.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [root]\n"
+ "| PathGet [a]\n"
+ "| PathTraverse []\n"
+ "| PathComposeM []\n"
+ "| | PathGet [c]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Eq]\n"
+ "| | Const [2]\n"
+ "| PathGet [b]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [root]\n"
+ "| PathGet [a]\n"
+ "| PathArr []\n"
+ "PhysicalScan [{'<root>': root}, c1]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, ParallelScan) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1", ScanDefinition{{}, {}, {DistributionType::UnknownPartitioning}}}},
+ 5 /*numberOfPartitions*/},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| RefBlock: \n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_0]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [1]\n"
+ "PhysicalScan [{'<root>': root, 'a': evalTemp_0}, c1, parallel]\n"
+ " BindBlock:\n"
+ " [evalTemp_0]\n"
+ " Source []\n"
+ " [root]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, HashPartitioning) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT projectionANode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+ ABT projectionBNode = make<EvaluationNode>(
+ "pb",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")),
+ std::move(projectionANode));
+
+ ABT groupByNode = make<GroupByNode>(ProjectionNameVector{"pa"},
+ ProjectionNameVector{"pc"},
+ makeSeq(make<Variable>("pb")),
+ std::move(projectionBNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {},
+ {DistributionType::HashPartitioning,
+ makeSeq(make<PathGet>("a", make<PathIdentity>()))}}}},
+ 5 /*numberOfPartitions*/},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(5, 10, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pc\n"
+ "| RefBlock: \n"
+ "| Variable [pc]\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| RefBlock: \n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [pa]\n"
+ "| aggregations: \n"
+ "| [pc]\n"
+ "| Variable [pb]\n"
+ "PhysicalScan [{'a': pa, 'b': pb}, c1]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n"
+ " [pb]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, IndexPartitioning) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT projectionANode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterANode =
+ make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(0)),
+ make<Variable>("pa")),
+ std::move(projectionANode));
+
+ ABT projectionBNode = make<EvaluationNode>(
+ "pb",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT filterBNode =
+ make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(1)),
+ make<Variable>("pb")),
+ std::move(projectionBNode));
+
+ ABT groupByNode = make<GroupByNode>(ProjectionNameVector{"pa"},
+ ProjectionNameVector{"pc"},
+ makeSeq(make<Variable>("pb")),
+ std::move(filterBNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{
+ {{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}},
+ false /*isMultiKey*/,
+ {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("a"))},
+ {}}}},
+ {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("b"))}}}},
+ 5 /*numberOfPartitions*/},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(75, 150, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pc\n"
+ "| RefBlock: \n"
+ "| Variable [pc]\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| RefBlock: \n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [pa]\n"
+ "| aggregations: \n"
+ "| [pc]\n"
+ "| Variable [pb]\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: HashPartitioning\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| Filter []\n"
+ "| | EvalFilter []\n"
+ "| | | Variable [pb]\n"
+ "| | PathCompare [Gt]\n"
+ "| | Const [1]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'b': pb}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [pb]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: RoundRobin\n"
+ "| RefBlock: \n"
+ "IndexScan [{'<indexKey> 0': pa, '<rid>': rid_0}, scanDefName: c1, indexDefName: index1, "
+ "interval: {(Const [0], +inf)}]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, IndexPartitioning1) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT projectionANode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT filterANode =
+ make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(0)),
+ make<Variable>("pa")),
+ std::move(projectionANode));
+
+ ABT projectionBNode = make<EvaluationNode>(
+ "pb",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT filterBNode =
+ make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(1)),
+ make<Variable>("pb")),
+ std::move(projectionBNode));
+
+ ABT groupByNode = make<GroupByNode>(ProjectionNameVector{"pa"},
+ ProjectionNameVector{"pc"},
+ makeSeq(make<Variable>("pb")),
+ std::move(filterBNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{
+ {},
+ {{"index1",
+ IndexDefinition{
+ {{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}},
+ false /*isMultiKey*/,
+ {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("a"))},
+ {}}},
+ {"index2",
+ IndexDefinition{
+ {{makeNonMultikeyIndexPath("b"), CollationOp::Ascending}},
+ false /*isMultiKey*/,
+ {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("b"))},
+ {}}}},
+ {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("c"))}}}},
+ 5 /*numberOfPartitions*/},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(150, 300, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pc\n"
+ "| RefBlock: \n"
+ "| Variable [pc]\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| RefBlock: \n"
+ "GroupBy []\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [pa]\n"
+ "| aggregations: \n"
+ "| [pc]\n"
+ "| Variable [pb]\n"
+ "HashJoin [joinType: Inner]\n"
+ "| | Condition\n"
+ "| | rid_0 = rid_3\n"
+ "| Union []\n"
+ "| | BindBlock:\n"
+ "| | [pa]\n"
+ "| | Source []\n"
+ "| | [rid_3]\n"
+ "| | Source []\n"
+ "| Evaluation []\n"
+ "| | BindBlock:\n"
+ "| | [rid_3]\n"
+ "| | Variable [rid_0]\n"
+ "| IndexScan [{'<indexKey> 0': pa, '<rid>': rid_0}, scanDefName: c1, indexDefName: "
+ "index1, interval: {(Const [0], +inf)}]\n"
+ "| BindBlock:\n"
+ "| [pa]\n"
+ "| Source []\n"
+ "| [rid_0]\n"
+ "| Source []\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: HashPartitioning\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "IndexScan [{'<indexKey> 0': pb, '<rid>': rid_0}, scanDefName: c1, indexDefName: index2, "
+ "interval: {(Const [1], +inf)}]\n"
+ " BindBlock:\n"
+ " [pb]\n"
+ " Source []\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, LocalGlobalAgg) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalANode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+ ABT evalBNode = make<EvaluationNode>(
+ "pb",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")),
+ std::move(evalANode));
+
+ ABT groupByNode =
+ make<GroupByNode>(ProjectionNameVector{"pa"},
+ ProjectionNameVector{"pc"},
+ makeSeq(make<FunctionCall>("$sum", makeSeq(make<Variable>("pb")))),
+ std::move(evalBNode));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa", "pc"}},
+ std::move(groupByNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1", ScanDefinition{{}, {}, {DistributionType::UnknownPartitioning}}}},
+ 5 /*numberOfPartitions*/},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(15, 25, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| | pc\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "| Variable [pc]\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| RefBlock: \n"
+ "GroupBy [Global]\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [pa]\n"
+ "| aggregations: \n"
+ "| [pc]\n"
+ "| FunctionCall [$sum]\n"
+ "| Variable [preagg_0]\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: HashPartitioning\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "GroupBy [Local]\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| | Variable [pa]\n"
+ "| aggregations: \n"
+ "| [preagg_0]\n"
+ "| FunctionCall [$sum]\n"
+ "| Variable [pb]\n"
+ "PhysicalScan [{'a': pa, 'b': pb}, c1, parallel]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n"
+ " [pb]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, LocalGlobalAgg1) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalBNode = make<EvaluationNode>(
+ "pb",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT groupByNode =
+ make<GroupByNode>(ProjectionNameVector{},
+ ProjectionNameVector{"pc"},
+ makeSeq(make<FunctionCall>("$sum", makeSeq(make<Variable>("pb")))),
+ std::move(evalBNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1", ScanDefinition{{}, {}, {DistributionType::UnknownPartitioning}}}},
+ 5 /*numberOfPartitions*/},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(5, 15, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pc\n"
+ "| RefBlock: \n"
+ "| Variable [pc]\n"
+ "GroupBy [Global]\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| aggregations: \n"
+ "| [pc]\n"
+ "| FunctionCall [$sum]\n"
+ "| Variable [preagg_0]\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| RefBlock: \n"
+ "GroupBy [Local]\n"
+ "| | groupings: \n"
+ "| | RefBlock: \n"
+ "| aggregations: \n"
+ "| [preagg_0]\n"
+ "| FunctionCall [$sum]\n"
+ "| Variable [pb]\n"
+ "PhysicalScan [{'b': pb}, c1, parallel]\n"
+ " BindBlock:\n"
+ " [pb]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, LocalLimitSkip) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT limitSkipNode = make<LimitSkipNode>(LimitSkipRequirement{20, 10}, std::move(scanNode));
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
+ std::move(limitSkipNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1", ScanDefinition{{}, {}, {DistributionType::UnknownPartitioning}}}},
+ 5 /*numberOfPartitions*/},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(5, 15, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_PROPS_V2(
+ "Properties [cost: 0.0066022, localCost: 0, adjustedCE: 20]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 20\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| | collectionAvailability: \n"
+ "| | c1\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| | distribution: \n"
+ "| | type: UnknownPartitioning\n"
+ "| Physical:\n"
+ "| distribution: \n"
+ "| type: Centralized\n"
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Properties [cost: 0.0066022, localCost: 1e-06, adjustedCE: 30]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 1000\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| | indexingAvailability: \n"
+ "| | [groupId: 0, scanProjection: root, scanDefName: c1, possiblyEqPredsOnly]\n"
+ "| | collectionAvailability: \n"
+ "| | c1\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: UnknownPartitioning\n"
+ "| Physical:\n"
+ "| limitSkip:\n"
+ "| limit: 20\n"
+ "| skip: 10\n"
+ "| projections: \n"
+ "| root\n"
+ "| distribution: \n"
+ "| type: Centralized\n"
+ "| indexingRequirement: \n"
+ "| Complete, dedupRID\n"
+ "LimitSkip []\n"
+ "| limitSkip:\n"
+ "| limit: 20\n"
+ "| skip: 10\n"
+ "Properties [cost: 0.0066012, localCost: 0.003001, adjustedCE: 30]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 1000\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| | indexingAvailability: \n"
+ "| | [groupId: 0, scanProjection: root, scanDefName: c1, possiblyEqPredsOnly]\n"
+ "| | collectionAvailability: \n"
+ "| | c1\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: UnknownPartitioning\n"
+ "| Physical:\n"
+ "| projections: \n"
+ "| root\n"
+ "| distribution: \n"
+ "| type: Centralized\n"
+ "| indexingRequirement: \n"
+ "| Complete, dedupRID\n"
+ "| limitEstimate: 30\n"
+ "Exchange []\n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| RefBlock: \n"
+ "Properties [cost: 0.0036002, localCost: 0.0036002, adjustedCE: 30]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 1000\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| | indexingAvailability: \n"
+ "| | [groupId: 0, scanProjection: root, scanDefName: c1, possiblyEqPredsOnly]\n"
+ "| | collectionAvailability: \n"
+ "| | c1\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: UnknownPartitioning\n"
+ "| Physical:\n"
+ "| projections: \n"
+ "| root\n"
+ "| distribution: \n"
+ "| type: UnknownPartitioning, disableExchanges\n"
+ "| indexingRequirement: \n"
+ "| Complete, dedupRID\n"
+ "| limitEstimate: 30\n"
+ "PhysicalScan [{'<root>': root}, c1, parallel]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ phaseManager);
+}
+
+TEST(PhysRewriter, CollationLimit) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT evalNode = make<EvaluationNode>(
+ "pa",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}),
+ std::move(evalNode));
+ ABT limitSkipNode = make<LimitSkipNode>(LimitSkipRequirement{20, 0}, std::move(collationNode));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
+ std::move(limitSkipNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1", {{}, {}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_BETWEEN(9, 11, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // We have a collation node with limit-skip physical properties. It will be lowered to a
+ // sort node with limit.
+ ASSERT_EXPLAIN_PROPS_V2(
+ "Properties [cost: 4.92193, localCost: 0, adjustedCE: 20]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 20\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| | root\n"
+ "| | collectionAvailability: \n"
+ "| | c1\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| Physical:\n"
+ "| distribution: \n"
+ "| type: Centralized\n"
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Properties [cost: 4.92193, localCost: 4.32193, adjustedCE: 20]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 1000\n"
+ "| | requirementCEs: \n"
+ "| | refProjection: root, path: 'PathGet [a] PathIdentity []', ce: "
+ "1000\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| | root\n"
+ "| | indexingAvailability: \n"
+ "| | [groupId: 0, scanProjection: root, scanDefName: c1]\n"
+ "| | collectionAvailability: \n"
+ "| | c1\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| Physical:\n"
+ "| collation: \n"
+ "| pa: Ascending\n"
+ "| limitSkip:\n"
+ "| limit: 20\n"
+ "| skip: 0\n"
+ "| projections: \n"
+ "| root\n"
+ "| distribution: \n"
+ "| type: Centralized\n"
+ "| indexingRequirement: \n"
+ "| Complete, dedupRID\n"
+ "Collation []\n"
+ "| | collation: \n"
+ "| | pa: Ascending\n"
+ "| RefBlock: \n"
+ "| Variable [pa]\n"
+ "Properties [cost: 0.600001, localCost: 0.600001, adjustedCE: 1000]\n"
+ "| | Logical:\n"
+ "| | cardinalityEstimate: \n"
+ "| | ce: 1000\n"
+ "| | requirementCEs: \n"
+ "| | refProjection: root, path: 'PathGet [a] PathIdentity []', ce: "
+ "1000\n"
+ "| | projections: \n"
+ "| | pa\n"
+ "| | root\n"
+ "| | indexingAvailability: \n"
+ "| | [groupId: 0, scanProjection: root, scanDefName: c1]\n"
+ "| | collectionAvailability: \n"
+ "| | c1\n"
+ "| | distributionAvailability: \n"
+ "| | distribution: \n"
+ "| | type: Centralized\n"
+ "| Physical:\n"
+ "| projections: \n"
+ "| root\n"
+ "| pa\n"
+ "| distribution: \n"
+ "| type: Centralized\n"
+ "| indexingRequirement: \n"
+ "| Complete, dedupRID\n"
+ "PhysicalScan [{'<root>': root, 'a': pa}, c1]\n"
+ " BindBlock:\n"
+ " [pa]\n"
+ " Source []\n"
+ " [root]\n"
+ " Source []\n",
+ phaseManager);
+}
+
+TEST(PhysRewriter, PartialIndex1) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterANode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))),
+ make<Variable>("root")),
+ std::move(scanNode));
+ ABT filterBNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "b", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2)))),
+ make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterBNode));
+
+ // TODO: Test cases where partial filter bound is a range which subsumes the query
+ // requirement
+ // TODO: (e.g. half open interval)
+ auto conversionResult = convertExprToPartialSchemaReq(make<EvalFilter>(
+ make<PathGet>("b",
+ make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2)))),
+ make<Variable>("root")));
+ ASSERT_TRUE(conversionResult._success);
+ ASSERT_FALSE(conversionResult._hasEmptyInterval);
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {{"index1",
+ IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}},
+ true /*isMultiKey*/,
+ {DistributionType::Centralized},
+ std::move(conversionResult._reqMap)}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Partial schema requirement is not on an index field. We get a seek on this field.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| Filter []\n"
+ "| | EvalFilter []\n"
+ "| | | Variable [evalTemp_2]\n"
+ "| | PathTraverse []\n"
+ "| | PathCompare [Eq]\n"
+ "| | Const [2]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root, 'b': evalTemp_2}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [evalTemp_2]\n"
+ "| | Source []\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const "
+ "[3], Const [3]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, PartialIndex2) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterANode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterANode));
+
+ auto conversionResult = convertExprToPartialSchemaReq(make<EvalFilter>(
+ make<PathGet>("a",
+ make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))),
+ make<Variable>("root")));
+ ASSERT_TRUE(conversionResult._success);
+ ASSERT_FALSE(conversionResult._hasEmptyInterval);
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {{"index1",
+ IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}},
+ true /*isMultiKey*/,
+ {DistributionType::Centralized},
+ std::move(conversionResult._reqMap)}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Partial schema requirement on an index field.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "BinaryJoin [joinType: Inner, {rid_0}]\n"
+ "| | Const [true]\n"
+ "| LimitSkip []\n"
+ "| | limitSkip:\n"
+ "| | limit: 1\n"
+ "| | skip: 0\n"
+ "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n"
+ "| | BindBlock:\n"
+ "| | [root]\n"
+ "| | Source []\n"
+ "| RefBlock: \n"
+ "| Variable [rid_0]\n"
+ "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const "
+ "[3], Const [3]]}]\n"
+ " BindBlock:\n"
+ " [rid_0]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, PartialIndexReject) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterANode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))),
+ make<Variable>("root")),
+ std::move(scanNode));
+ ABT filterBNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "b", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2)))),
+ make<Variable>("root")),
+ std::move(filterANode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterBNode));
+
+ auto conversionResult = convertExprToPartialSchemaReq(make<EvalFilter>(
+ make<PathGet>("b",
+ make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(4)))),
+ make<Variable>("root")));
+ ASSERT_TRUE(conversionResult._success);
+ ASSERT_FALSE(conversionResult._hasEmptyInterval);
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"c1",
+ ScanDefinition{{},
+ {{"index1",
+ IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}},
+ true /*isMultiKey*/,
+ {DistributionType::Centralized},
+ std::move(conversionResult._reqMap)}}}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(3, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Incompatible partial filter. Use scan.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_1]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [2]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_0]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [3]\n"
+ "PhysicalScan [{'<root>': root, 'a': evalTemp_0, 'b': evalTemp_1}, c1]\n"
+ " BindBlock:\n"
+ " [evalTemp_0]\n"
+ " Source []\n"
+ " [evalTemp_1]\n"
+ " Source []\n"
+ " [root]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, RequireRID) {
+ using namespace properties;
+ PrefixId prefixId;
+
+ ABT scanNode = make<ScanNode>("root", "c1");
+
+ ABT filterNode = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))),
+ make<Variable>("root")),
+ std::move(scanNode));
+
+ ABT rootNode =
+ make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
+
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ true /*requireRID*/,
+ {{{"c1", ScanDefinition{{}, {}}}}},
+ std::make_unique<HeuristicCE>(),
+ std::make_unique<DefaultCosting>(),
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = rootNode;
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(2, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ // Make sure the Scan node returns rid.
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | root\n"
+ "| RefBlock: \n"
+ "| Variable [root]\n"
+ "Filter []\n"
+ "| EvalFilter []\n"
+ "| | Variable [evalTemp_0]\n"
+ "| PathTraverse []\n"
+ "| PathCompare [Eq]\n"
+ "| Const [3]\n"
+ "PhysicalScan [{'<rid>': rid_0, '<root>': root, 'a': evalTemp_0}, c1]\n"
+ " BindBlock:\n"
+ " [evalTemp_0]\n"
+ " Source []\n"
+ " [rid_0]\n"
+ " Source []\n"
+ " [root]\n"
+ " Source []\n",
+ optimized);
+}
+
+TEST(PhysRewriter, UnionRewrite) {
+ using namespace properties;
+
+ ABT scanNode1 = make<ScanNode>("ptest1", "test1");
+ ABT scanNode2 = make<ScanNode>("ptest2", "test2");
+
+ // Each branch produces two projections, pUnion1 and pUnion2.
+ ABT evalNode1 = make<EvaluationNode>(
+ "pUnion1",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("ptest1")),
+ std::move(scanNode1));
+ ABT evalNode2 = make<EvaluationNode>(
+ "pUnion2",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("ptest1")),
+ std::move(evalNode1));
+
+ ABT evalNode3 = make<EvaluationNode>(
+ "pUnion1",
+ make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("ptest2")),
+ std::move(scanNode2));
+ ABT evalNode4 = make<EvaluationNode>(
+ "pUnion2",
+ make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("ptest2")),
+ std::move(evalNode3));
+
+ ABT unionNode =
+ make<UnionNode>(ProjectionNameVector{"pUnion1", "pUnion2"}, makeSeq(evalNode2, evalNode4));
+
+ ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pUnion1"}},
+ std::move(unionNode));
+
+ PrefixId prefixId;
+ OptPhaseManager phaseManager(
+ {OptPhaseManager::OptPhase::MemoSubstitutionPhase,
+ OptPhaseManager::OptPhase::MemoExplorationPhase,
+ OptPhaseManager::OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"test1", {{}, {}}}, {"test2", {{}, {}}}}},
+ {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+
+ ABT optimized = std::move(rootNode);
+ ASSERT_TRUE(phaseManager.optimize(optimized));
+ ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+
+ ASSERT_EXPLAIN_V2(
+ "Root []\n"
+ "| | projections: \n"
+ "| | pUnion1\n"
+ "| RefBlock: \n"
+ "| Variable [pUnion1]\n"
+ "Union []\n"
+ "| | BindBlock:\n"
+ "| | [pUnion1]\n"
+ "| | Source []\n"
+ "| PhysicalScan [{'a': pUnion1}, test2]\n"
+ "| BindBlock:\n"
+ "| [pUnion1]\n"
+ "| Source []\n"
+ "PhysicalScan [{'a': pUnion1}, test1]\n"
+ " BindBlock:\n"
+ " [pUnion1]\n"
+ " Source []\n",
+ optimized);
+}
+
+} // namespace
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/props.cpp b/src/mongo/db/query/optimizer/props.cpp
new file mode 100644
index 00000000000..c03825c4101
--- /dev/null
+++ b/src/mongo/db/query/optimizer/props.cpp
@@ -0,0 +1,373 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/props.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+#include "mongo/util/assert_util.h"
+
+namespace mongo::optimizer::properties {
+
+CollationRequirement::CollationRequirement(ProjectionCollationSpec spec) : _spec(std::move(spec)) {
+ ProjectionNameSet projections;
+ for (const auto& entry : _spec) {
+ uassert(6624021, "Repeated projection name", projections.insert(entry.first).second);
+ }
+}
+
+CollationRequirement CollationRequirement::Empty = CollationRequirement();
+
+bool CollationRequirement::operator==(const CollationRequirement& other) const {
+ return _spec == other._spec;
+}
+
+const ProjectionCollationSpec& CollationRequirement::getCollationSpec() const {
+ return _spec;
+}
+
+ProjectionCollationSpec& CollationRequirement::getCollationSpec() {
+ return _spec;
+}
+
+bool CollationRequirement::hasClusteredOp() const {
+ for (const auto& [projName, op] : _spec) {
+ if (op == CollationOp::Clustered) {
+ return true;
+ }
+ }
+ return false;
+}
+
+ProjectionNameSet CollationRequirement::getAffectedProjectionNames() const {
+ ProjectionNameSet result;
+ for (const auto& entry : _spec) {
+ result.insert(entry.first);
+ }
+ return result;
+}
+
+LimitSkipRequirement::LimitSkipRequirement(const int64_t limit, const int64_t skip)
+ : _limit((limit < 0) ? kMaxVal : limit), _skip(skip) {}
+
+bool LimitSkipRequirement::operator==(const LimitSkipRequirement& other) const {
+ return _skip == other._skip && _limit == other._limit;
+}
+
+int64_t LimitSkipRequirement::getLimit() const {
+ return _limit;
+}
+
+int64_t LimitSkipRequirement::getSkip() const {
+ return _skip;
+}
+
+int64_t LimitSkipRequirement::getAbsoluteLimit() const {
+ return hasLimit() ? (_skip + _limit) : kMaxVal;
+}
+
+ProjectionNameSet LimitSkipRequirement::getAffectedProjectionNames() const {
+ return {};
+}
+
+bool LimitSkipRequirement::hasLimit() const {
+ return _limit != kMaxVal;
+}
+
+ProjectionRequirement::ProjectionRequirement(ProjectionNameOrderPreservingSet projections)
+ : _projections(std::move(projections)) {}
+
+bool ProjectionRequirement::operator==(const ProjectionRequirement& other) const {
+ return _projections.isEqualIgnoreOrder(other.getProjections());
+}
+
+ProjectionNameSet ProjectionRequirement::getAffectedProjectionNames() const {
+ ProjectionNameSet result;
+ for (const ProjectionName& projection : _projections.getVector()) {
+ result.insert(projection);
+ }
+ return result;
+}
+
+const ProjectionNameOrderPreservingSet& ProjectionRequirement::getProjections() const {
+ return _projections;
+}
+
+ProjectionNameOrderPreservingSet& ProjectionRequirement::getProjections() {
+ return _projections;
+}
+
+DistributionAndProjections::DistributionAndProjections(DistributionType type)
+ : DistributionAndProjections(type, {}) {}
+
+DistributionAndProjections::DistributionAndProjections(DistributionType type,
+ ProjectionNameVector projectionNames)
+ : _type(type), _projectionNames(std::move(projectionNames)) {
+ uassert(6624096,
+ "Must have projection names when distributed under hash or range partitioning",
+ (_type != DistributionType::HashPartitioning &&
+ _type != DistributionType::RangePartitioning) ||
+ !_projectionNames.empty());
+}
+
+bool DistributionAndProjections::operator==(const DistributionAndProjections& other) const {
+ return _type == other._type && _projectionNames == other._projectionNames;
+}
+
+DistributionRequirement::DistributionRequirement(
+ DistributionAndProjections distributionAndProjections)
+ : _distributionAndProjections(std::move(distributionAndProjections)),
+ _disableExchanges(false) {}
+
+bool DistributionRequirement::operator==(const DistributionRequirement& other) const {
+ return _distributionAndProjections == other._distributionAndProjections &&
+ _disableExchanges == other._disableExchanges;
+}
+
+const DistributionAndProjections& DistributionRequirement::getDistributionAndProjections() const {
+ return _distributionAndProjections;
+}
+
+DistributionAndProjections& DistributionRequirement::getDistributionAndProjections() {
+ return _distributionAndProjections;
+}
+
+ProjectionNameSet DistributionRequirement::getAffectedProjectionNames() const {
+ ProjectionNameSet result;
+ for (const ProjectionName& projectionName : _distributionAndProjections._projectionNames) {
+ result.insert(projectionName);
+ }
+ return result;
+}
+
+bool DistributionRequirement::getDisableExchanges() const {
+ return _disableExchanges;
+}
+
+void DistributionRequirement::setDisableExchanges(const bool disableExchanges) {
+ _disableExchanges = disableExchanges;
+}
+
+IndexingRequirement::IndexingRequirement()
+ : IndexingRequirement(IndexReqTarget::Complete, "", true /*dedupRID*/, {}) {}
+
+IndexingRequirement::IndexingRequirement(IndexReqTarget indexReqTarget,
+ bool needsRID,
+ bool dedupRID,
+ GroupIdType satisfiedPartialIndexesGroupId)
+ : _indexReqTarget(indexReqTarget),
+ _needsRID(needsRID),
+ _dedupRID(dedupRID),
+ _satisfiedPartialIndexesGroupId(std::move(satisfiedPartialIndexesGroupId)) {
+ uassert(6624097,
+ "Avoiding dedup is only allowed for Index target",
+ _dedupRID || _indexReqTarget == IndexReqTarget::Index);
+}
+
+bool IndexingRequirement::operator==(const IndexingRequirement& other) const {
+ return _indexReqTarget == other._indexReqTarget && _needsRID == other._needsRID &&
+ _dedupRID == other._dedupRID &&
+ _satisfiedPartialIndexesGroupId == other._satisfiedPartialIndexesGroupId;
+}
+
+ProjectionNameSet IndexingRequirement::getAffectedProjectionNames() const {
+ // Specifically not returning ridProjectionName (even if present).
+ return {};
+}
+
+IndexReqTarget IndexingRequirement::getIndexReqTarget() const {
+ return _indexReqTarget;
+}
+
+bool IndexingRequirement::getNeedsRID() const {
+ return _needsRID;
+}
+
+bool IndexingRequirement::getDedupRID() const {
+ return _dedupRID;
+}
+
+void IndexingRequirement::setDedupRID(const bool value) {
+ _dedupRID = value;
+}
+
+const GroupIdType IndexingRequirement::getSatisfiedPartialIndexesGroupId() const {
+ return _satisfiedPartialIndexesGroupId;
+}
+
+RepetitionEstimate::RepetitionEstimate(const CEType estimate) : _estimate(estimate) {}
+
+bool RepetitionEstimate::operator==(const RepetitionEstimate& other) const {
+ return _estimate == other._estimate;
+}
+
+ProjectionNameSet RepetitionEstimate::getAffectedProjectionNames() const {
+ return {};
+}
+
+CEType RepetitionEstimate::getEstimate() const {
+ return _estimate;
+}
+
+LimitEstimate::LimitEstimate(const CEType estimate) : _estimate(estimate) {}
+
+bool LimitEstimate::operator==(const LimitEstimate& other) const {
+ return _estimate == other._estimate;
+}
+
+ProjectionNameSet LimitEstimate::getAffectedProjectionNames() const {
+ return {};
+}
+
+bool LimitEstimate::hasLimit() const {
+ return _estimate >= 0.0;
+}
+
+CEType LimitEstimate::getEstimate() const {
+ return _estimate;
+}
+
+ProjectionAvailability::ProjectionAvailability(ProjectionNameSet projections)
+ : _projections(std::move(projections)) {}
+
+bool ProjectionAvailability::operator==(const ProjectionAvailability& other) const {
+ return _projections == other._projections;
+}
+
+const ProjectionNameSet& ProjectionAvailability::getProjections() const {
+ return _projections;
+}
+
+CardinalityEstimate::CardinalityEstimate(const CEType estimate)
+ : _estimate(estimate), _partialSchemaKeyCEMap() {}
+
+bool CardinalityEstimate::operator==(const CardinalityEstimate& other) const {
+ return _estimate == other._estimate && _partialSchemaKeyCEMap == other._partialSchemaKeyCEMap;
+}
+
+CEType CardinalityEstimate::getEstimate() const {
+ return _estimate;
+}
+
+CEType& CardinalityEstimate::getEstimate() {
+ return _estimate;
+}
+
+const PartialSchemaKeyCE& CardinalityEstimate::getPartialSchemaKeyCEMap() const {
+ return _partialSchemaKeyCEMap;
+}
+
+PartialSchemaKeyCE& CardinalityEstimate::getPartialSchemaKeyCEMap() {
+ return _partialSchemaKeyCEMap;
+}
+
+IndexingAvailability::IndexingAvailability(GroupIdType scanGroupId,
+ ProjectionName scanProjection,
+ std::string scanDefName,
+ const bool possiblyEqPredsOnly,
+ opt::unordered_set<std::string> satisfiedPartialIndexes)
+ : _scanGroupId(scanGroupId),
+ _scanProjection(std::move(scanProjection)),
+ _scanDefName(std::move(scanDefName)),
+ _possiblyEqPredsOnly(possiblyEqPredsOnly),
+ _satisfiedPartialIndexes(std::move(satisfiedPartialIndexes)) {}
+
+bool IndexingAvailability::operator==(const IndexingAvailability& other) const {
+ return _scanGroupId == other._scanGroupId && _scanProjection == other._scanProjection &&
+ _scanDefName == other._scanDefName && _possiblyEqPredsOnly == other._possiblyEqPredsOnly &&
+ _satisfiedPartialIndexes == other._satisfiedPartialIndexes;
+}
+
+GroupIdType IndexingAvailability::getScanGroupId() const {
+ return _scanGroupId;
+}
+
+const ProjectionName& IndexingAvailability::getScanProjection() const {
+ return _scanProjection;
+}
+
+const std::string& IndexingAvailability::getScanDefName() const {
+ return _scanDefName;
+}
+
+const opt::unordered_set<std::string>& IndexingAvailability::getSatisfiedPartialIndexes() const {
+ return _satisfiedPartialIndexes;
+}
+
+opt::unordered_set<std::string>& IndexingAvailability::getSatisfiedPartialIndexes() {
+ return _satisfiedPartialIndexes;
+}
+
+bool IndexingAvailability::getPossiblyEqPredsOnly() const {
+ return _possiblyEqPredsOnly;
+}
+
+void IndexingAvailability::setPossiblyEqPredsOnly(const bool value) {
+ _possiblyEqPredsOnly = value;
+}
+
+CollectionAvailability::CollectionAvailability(opt::unordered_set<std::string> scanDefSet)
+ : _scanDefSet(std::move(scanDefSet)) {}
+
+bool CollectionAvailability::operator==(const CollectionAvailability& other) const {
+ return _scanDefSet == other._scanDefSet;
+}
+
+const opt::unordered_set<std::string>& CollectionAvailability::getScanDefSet() const {
+ return _scanDefSet;
+}
+
+opt::unordered_set<std::string>& CollectionAvailability::getScanDefSet() {
+ return _scanDefSet;
+}
+
+size_t DistributionHash::operator()(
+ const DistributionAndProjections& distributionAndProjections) const {
+ size_t result = 0;
+ updateHash(result, std::hash<DistributionType>()(distributionAndProjections._type));
+ for (const ProjectionName& projectionName : distributionAndProjections._projectionNames) {
+ updateHash(result, std::hash<ProjectionName>()(projectionName));
+ }
+ return result;
+}
+
+DistributionAvailability::DistributionAvailability(DistributionSet distributionSet)
+ : _distributionSet(std::move(distributionSet)) {}
+
+bool DistributionAvailability::operator==(const DistributionAvailability& other) const {
+ return _distributionSet == other._distributionSet;
+}
+
+const DistributionSet& DistributionAvailability::getDistributionSet() const {
+ return _distributionSet;
+}
+
+DistributionSet& DistributionAvailability::getDistributionSet() {
+ return _distributionSet;
+}
+
+} // namespace mongo::optimizer::properties
diff --git a/src/mongo/db/query/optimizer/props.h b/src/mongo/db/query/optimizer/props.h
new file mode 100644
index 00000000000..63c49f51774
--- /dev/null
+++ b/src/mongo/db/query/optimizer/props.h
@@ -0,0 +1,481 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "mongo/db/query/optimizer/algebra/operator.h"
+#include "mongo/db/query/optimizer/algebra/polyvalue.h"
+#include "mongo/db/query/optimizer/defs.h"
+#include "mongo/db/query/optimizer/metadata.h"
+#include "mongo/util/assert_util.h"
+
+namespace mongo::optimizer::properties {
+
+/**
+ * Tag for logical property types.
+ */
+class LogicalPropertyTag {};
+
+/**
+ * Tag for physical property types.
+ */
+class PhysPropertyTag {};
+
+/**
+ * Logical properties.
+ */
+class CardinalityEstimate;
+
+class ProjectionAvailability;
+class IndexingAvailability;
+class CollectionAvailability;
+class DistributionAvailability;
+
+/**
+ * Physical properties.
+ */
+class CollationRequirement;
+class LimitSkipRequirement;
+class ProjectionRequirement;
+class DistributionRequirement;
+class IndexingRequirement;
+class RepetitionEstimate;
+class LimitEstimate;
+
+using LogicalProperty = algebra::PolyValue<CardinalityEstimate,
+ ProjectionAvailability,
+ IndexingAvailability,
+ CollectionAvailability,
+ DistributionAvailability>;
+
+using PhysProperty = algebra::PolyValue<CollationRequirement,
+ LimitSkipRequirement,
+ ProjectionRequirement,
+ DistributionRequirement,
+ IndexingRequirement,
+ RepetitionEstimate,
+ LimitEstimate>;
+
+using LogicalProps = opt::unordered_map<LogicalProperty::key_type, LogicalProperty>;
+using PhysProps = opt::unordered_map<PhysProperty::key_type, PhysProperty>;
+
+template <typename T, typename... Args>
+inline auto makeProperty(Args&&... args) {
+ if constexpr (std::is_base_of_v<LogicalPropertyTag, T>) {
+ return LogicalProperty::make<T>(std::forward<Args>(args)...);
+ } else if constexpr (std::is_base_of_v<PhysPropertyTag, T>) {
+ return PhysProperty::make<T>(std::forward<Args>(args)...);
+ } else {
+ static_assert("Unknown property type");
+ }
+}
+
+template <class P>
+static constexpr auto getPropertyKey() {
+ if constexpr (std::is_base_of_v<LogicalPropertyTag, P>) {
+ return LogicalProperty::template tagOf<P>();
+ } else if constexpr (std::is_base_of_v<PhysPropertyTag, P>) {
+ return PhysProperty::template tagOf<P>();
+ } else {
+ static_assert("Unknown property type");
+ }
+}
+
+template <class P, class C>
+bool hasProperty(const C& props) {
+ return props.find(getPropertyKey<P>()) != props.cend();
+}
+
+template <class P, class C>
+P& getProperty(C& props) {
+ if (!hasProperty<P>(props)) {
+ uasserted(6624022, "Property type does not exist.");
+ }
+ return *props.at(getPropertyKey<P>()).template cast<P>();
+}
+
+template <class P, class C>
+const P& getPropertyConst(const C& props) {
+ if (!hasProperty<P>(props)) {
+ uasserted(6624023, "Property type does not exist.");
+ }
+ return *props.at(getPropertyKey<P>()).template cast<P>();
+}
+
+template <class P, class C>
+void removeProperty(C& props) {
+ props.erase(getPropertyKey<P>());
+}
+
+template <class P, class C>
+bool setProperty(C& props, P property) {
+ return props.emplace(getPropertyKey<P>(), makeProperty<P>(std::move(property))).second;
+}
+
+template <class P, class C>
+void setPropertyOverwrite(C& props, P property) {
+ props.insert_or_assign(getPropertyKey<P>(), makeProperty<P>(std::move(property)));
+}
+
+template <class C, typename... Args>
+inline auto makeProps(Args&&... args) {
+ C props;
+ (setProperty(props, args), ...);
+ return props;
+}
+
+template <typename... Args>
+inline auto makeLogicalProps(Args&&... args) {
+ return makeProps<LogicalProps>(std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+inline auto makePhysProps(Args&&... args) {
+ return makeProps<PhysProps>(std::forward<Args>(args)...);
+}
+
+/**
+ * A physical property which specifies how the collection (or intermediate result) is required to be
+ * collated (sorted).
+ */
+class CollationRequirement final : public PhysPropertyTag {
+public:
+ static CollationRequirement Empty;
+
+ CollationRequirement() = default;
+ CollationRequirement(ProjectionCollationSpec spec);
+
+ bool operator==(const CollationRequirement& other) const;
+
+ const ProjectionCollationSpec& getCollationSpec() const;
+ ProjectionCollationSpec& getCollationSpec();
+
+ bool hasClusteredOp() const;
+
+ ProjectionNameSet getAffectedProjectionNames() const;
+
+private:
+ ProjectionCollationSpec _spec;
+};
+
+/**
+ * A physical property which specifies what portion of the result in terms of window defined by the
+ * limit and skip is to be returned.
+ */
+class LimitSkipRequirement final : public PhysPropertyTag {
+public:
+ static constexpr int64_t kMaxVal = std::numeric_limits<int64_t>::max();
+
+ LimitSkipRequirement(int64_t limit, int64_t skip);
+
+ bool operator==(const LimitSkipRequirement& other) const;
+
+ bool hasLimit() const;
+
+ int64_t getLimit() const;
+ int64_t getSkip() const;
+ int64_t getAbsoluteLimit() const;
+
+ ProjectionNameSet getAffectedProjectionNames() const;
+
+private:
+ // Max number of documents to return. Maximum integer value means unlimited.
+ int64_t _limit;
+ // Documents to skip before start returning in result.
+ int64_t _skip;
+};
+
+/**
+ * A physical property which specifies required projections to be returned as part of the result.
+ */
+class ProjectionRequirement final : public PhysPropertyTag {
+public:
+ ProjectionRequirement(ProjectionNameOrderPreservingSet projections);
+
+ bool operator==(const ProjectionRequirement& other) const;
+
+ const ProjectionNameOrderPreservingSet& getProjections() const;
+ ProjectionNameOrderPreservingSet& getProjections();
+
+ ProjectionNameSet getAffectedProjectionNames() const;
+
+private:
+ ProjectionNameOrderPreservingSet _projections;
+};
+
+struct DistributionAndProjections {
+ DistributionAndProjections(DistributionType type);
+ DistributionAndProjections(DistributionType type, ProjectionNameVector projectionNames);
+
+ bool operator==(const DistributionAndProjections& other) const;
+
+ const DistributionType _type;
+
+ /**
+ * Defined for hash and range-based partitioning.
+ */
+ ProjectionNameVector _projectionNames;
+};
+
+/**
+ * A physical property which specifies how the result is to be distributed (or partitioned) amongst
+ * the computing partitions/nodes.
+ */
+class DistributionRequirement final : public PhysPropertyTag {
+public:
+ DistributionRequirement(DistributionAndProjections distributionAndProjections);
+
+ bool operator==(const DistributionRequirement& other) const;
+
+ const DistributionAndProjections& getDistributionAndProjections() const;
+ DistributionAndProjections& getDistributionAndProjections();
+
+ ProjectionNameSet getAffectedProjectionNames() const;
+
+ bool getDisableExchanges() const;
+ void setDisableExchanges(bool disableExchanges);
+
+private:
+ DistributionAndProjections _distributionAndProjections;
+
+ // Heuristic used to disable exchanges right after Filter, Eval, and local GroupBy nodes.
+ bool _disableExchanges;
+};
+
+/**
+ * A physical property which describes if we intend to satisfy sargable predicates using an index.
+ * With indexing requirement "Complete", we are requiring a regular physical
+ * scan (both rid and row). With "Seek" (where we must have a non-empty RID projection name), we are
+ * targeting a physical Seek. With "Index" (with or without RID projection name), we
+ * are targeting a physical IndexScan. If in this case we have set RID projection, then we have
+ * either gone for a Seek, or we have performed intersection. With empty RID we are targeting a
+ * covered index scan.
+ */
+class IndexingRequirement final : public PhysPropertyTag {
+public:
+ IndexingRequirement();
+ IndexingRequirement(IndexReqTarget indexReqTarget,
+ bool needsRID,
+ bool dedupRIDs,
+ GroupIdType satisfiedPartialIndexesGroupId);
+
+ bool operator==(const IndexingRequirement& other) const;
+
+ ProjectionNameSet getAffectedProjectionNames() const;
+
+ IndexReqTarget getIndexReqTarget() const;
+ bool getNeedsRID() const;
+
+ bool getDedupRID() const;
+ void setDedupRID(bool value);
+
+ const GroupIdType getSatisfiedPartialIndexesGroupId() const;
+
+private:
+ const IndexReqTarget _indexReqTarget;
+
+ // Do we need to return an RID projection.
+ const bool _needsRID;
+
+ // If target == Index, specifies if we need to dedup RIDs.
+ // Prior RID intersection removes the need to dedup.
+ bool _dedupRID;
+
+ // Set of indexes with partial indexes whose partial filters are satisfied considering the whole
+ // query. Points to a group where can interrogate IndexingAvailability to find the satisfied
+ // indexes.
+ const GroupIdType _satisfiedPartialIndexesGroupId;
+};
+
+/**
+ * A physical property that specifies how many times do we expect to execute the current subtree.
+ * Typically generated via a NLJ where it is set on the inner side to reflect the outer side's
+ * cardinality. This property affects costing of stateful physical operators such as sort and hash
+ * groupby.
+ */
+class RepetitionEstimate final : public PhysPropertyTag {
+public:
+ RepetitionEstimate(CEType estimate);
+
+ bool operator==(const RepetitionEstimate& other) const;
+
+ ProjectionNameSet getAffectedProjectionNames() const;
+
+ CEType getEstimate() const;
+
+private:
+ CEType _estimate;
+};
+
+/**
+ * A physical property that specifies that the we will consider only some approximate number of
+ * documents. Typically generated after enforcing a LimitSkipRequirement. This property affects
+ * costing of stateful physical operators such as sort and hash groupby.
+ */
+class LimitEstimate final : public PhysPropertyTag {
+public:
+ LimitEstimate(CEType estimate);
+
+ bool operator==(const LimitEstimate& other) const;
+
+ ProjectionNameSet getAffectedProjectionNames() const;
+
+ bool hasLimit() const;
+ CEType getEstimate() const;
+
+private:
+ CEType _estimate;
+};
+
+/**
+ * A logical property which specifies available projections for a given ABT tree.
+ */
+class ProjectionAvailability final : public LogicalPropertyTag {
+public:
+ ProjectionAvailability(ProjectionNameSet projections);
+
+ bool operator==(const ProjectionAvailability& other) const;
+
+ const ProjectionNameSet& getProjections() const;
+
+private:
+ ProjectionNameSet _projections;
+};
+
+/**
+ * A logical property which provides an estimated row count for a given ABT tree.
+ */
+class CardinalityEstimate final : public LogicalPropertyTag {
+public:
+ CardinalityEstimate(CEType estimate);
+
+ bool operator==(const CardinalityEstimate& other) const;
+
+ CEType getEstimate() const;
+ CEType& getEstimate();
+
+ const PartialSchemaKeyCE& getPartialSchemaKeyCEMap() const;
+ PartialSchemaKeyCE& getPartialSchemaKeyCEMap();
+
+private:
+ CEType _estimate;
+
+ // Used for SargableNodes. Provide additional per partial schema key CE.
+ PartialSchemaKeyCE _partialSchemaKeyCEMap;
+};
+
+/**
+ * A logical property which specifies availability to index predicates in the ABT subtree and
+ * contains the scan projection. The projection and definition name are here for convenience: it can
+ * be retrieved using the scan group from the memo.
+ */
+class IndexingAvailability final : public LogicalPropertyTag {
+public:
+ IndexingAvailability(GroupIdType scanGroupId,
+ ProjectionName scanProjection,
+ std::string scanDefName,
+ bool possiblyEqPredsOnly,
+ opt::unordered_set<std::string> satisfiedPartialIndexes);
+
+ bool operator==(const IndexingAvailability& other) const;
+
+ GroupIdType getScanGroupId() const;
+ const ProjectionName& getScanProjection() const;
+ const std::string& getScanDefName() const;
+
+ const opt::unordered_set<std::string>& getSatisfiedPartialIndexes() const;
+ opt::unordered_set<std::string>& getSatisfiedPartialIndexes();
+
+ bool getPossiblyEqPredsOnly() const;
+ void setPossiblyEqPredsOnly(bool value);
+
+private:
+ const GroupIdType _scanGroupId;
+ const ProjectionName _scanProjection;
+ const std::string _scanDefName;
+
+ // Specifies if all predicates in the current group and child group are "possibly" equalities.
+ // This is determined based on SargableNode exclusively containing equality intervals.
+ // The "possibly" part is due to 'Get "a" Id' being equivalent 'Get "a" Traverse Id' with a
+ // multi-key index.
+ bool _possiblyEqPredsOnly;
+
+ // Set of indexes with partial indexes whose partial filters are satisfied for the current
+ // group.
+ opt::unordered_set<std::string> _satisfiedPartialIndexes;
+};
+
+
+/**
+ * Logical property which specifies which collections (scanDefs) are available for a particular
+ * group. For example if the group contains a join of two tables, we would have (at least) two
+ * collections in the set.
+ */
+class CollectionAvailability final : public LogicalPropertyTag {
+public:
+ CollectionAvailability(opt::unordered_set<std::string> scanDefSet);
+
+ bool operator==(const CollectionAvailability& other) const;
+
+ const opt::unordered_set<std::string>& getScanDefSet() const;
+ opt::unordered_set<std::string>& getScanDefSet();
+
+private:
+ opt::unordered_set<std::string> _scanDefSet;
+};
+
+struct DistributionHash {
+ size_t operator()(const DistributionAndProjections& distributionAndProjections) const;
+};
+
+using DistributionSet = opt::unordered_set<DistributionAndProjections, DistributionHash>;
+
+/**
+ * Logical property which specifies promising projections and distributions to attempt to enforce
+ * during physical optimization. For example, a group containing a GroupByNode would add hash
+ * partitioning on the group-by projections.
+ */
+class DistributionAvailability final : public LogicalPropertyTag {
+public:
+ DistributionAvailability(DistributionSet distributionSet);
+
+ bool operator==(const DistributionAvailability& other) const;
+
+ const DistributionSet& getDistributionSet() const;
+ DistributionSet& getDistributionSet();
+
+private:
+ DistributionSet _distributionSet;
+};
+
+} // namespace mongo::optimizer::properties
diff --git a/src/mongo/db/query/optimizer/reference_tracker.cpp b/src/mongo/db/query/optimizer/reference_tracker.cpp
new file mode 100644
index 00000000000..c6afaba7024
--- /dev/null
+++ b/src/mongo/db/query/optimizer/reference_tracker.cpp
@@ -0,0 +1,689 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <iostream>
+
+#include "mongo/db/query/optimizer/cascades/memo.h"
+#include "mongo/db/query/optimizer/reference_tracker.h"
+
+namespace mongo::optimizer {
+
+struct CollectedInfo {
+ using VarRefsMap = opt::unordered_map<std::string, opt::unordered_map<const Variable*, bool>>;
+
+ /**
+ * All resolved variables so far
+ */
+ opt::unordered_map<const Variable*, Definition> useMap;
+
+ /**
+ * Definitions available for use in ancestor nodes (projections)
+ */
+ DefinitionsMap defs;
+
+ /**
+ * Free variables (i.e. so far not resolved)
+ */
+ opt::unordered_map<std::string, std::vector<const Variable*>> freeVars;
+
+ /**
+ * Per relational node projections.
+ */
+ opt::unordered_map<const Node*, DefinitionsMap> nodeDefs;
+
+ /**
+ * Support for tracking of the last local variable reference.
+ */
+ VarRefsMap varLastRefs;
+ opt::unordered_set<const Variable*> lastRefs;
+
+ /**
+ * This is a destructive merge, the 'other' will be siphoned out.
+ */
+ void merge(CollectedInfo&& other) {
+ // Incoming (other) info has some definitions. So let's try to resolve our free variables.
+ if (!other.defs.empty() && !freeVars.empty()) {
+ for (auto&& [name, def] : other.defs) {
+ resolveFreeVars(name, def);
+ }
+ }
+
+ // We have some definitions so let try to resolve other's free variables.
+ if (!defs.empty() && !other.freeVars.empty()) {
+ for (auto&& [name, def] : defs) {
+ other.resolveFreeVars(name, def);
+ }
+ }
+
+ useMap.merge(other.useMap);
+
+ // It should be impossible to have duplicate Variable pointer so everything should be
+ // copied.
+ uassert(6624024, "use map is not empty", other.useMap.empty());
+
+ defs.merge(other.defs);
+
+ // Projection names are globally unique so everything should be copied.
+ uassert(6624025, "duplicate projections", other.defs.empty());
+
+ for (auto&& [name, vars] : other.freeVars) {
+ auto& v = freeVars[name];
+ v.insert(v.end(), vars.begin(), vars.end());
+ }
+ other.freeVars.clear();
+
+ nodeDefs.merge(other.nodeDefs);
+
+ // It should be impossible to have duplicate Node pointer so everything should be
+ // copied.
+ uassert(6624026, "duplicate nodes", other.nodeDefs.empty());
+
+ // Merge last references.
+ mergeLastRefs(std::move(other.varLastRefs));
+ lastRefs.merge(other.lastRefs);
+ uassert(6624027, "duplicate lastRefs", other.lastRefs.empty());
+ }
+
+ /**
+ * Merges variable references from 'other' and adjust last references as needed.
+ */
+ void mergeLastRefs(VarRefsMap&& other) {
+ mergeLastRefsImpl(std::move(other), false, true);
+ }
+
+ /**
+ * Merges variable references from 'other' but keeps the last references from 'this'; i.e. it
+ * resets the 'other' side.
+ */
+ void mergeKeepLastRefs(VarRefsMap&& other) {
+ mergeLastRefsImpl(std::move(other), true, false);
+ }
+
+ /**
+ * Merges variable references from 'other' and keeps the last references from both sides.
+ */
+ void unionLastRefs(VarRefsMap&& other) {
+ mergeLastRefsImpl(std::move(other), false, false);
+ }
+
+ void mergeLastRefsImpl(VarRefsMap&& other, bool resetOther, bool resetBoth) {
+ for (auto otherIt = other.begin(), end = other.end(); otherIt != end;) {
+ if (auto localIt = varLastRefs.find(otherIt->first); localIt != varLastRefs.end()) {
+ // This variable is referenced in both sets and we resetOther when adjust it
+ // accordingly.
+ if (resetOther) {
+ for (auto& [k, isLastRef] : otherIt->second) {
+ isLastRef = false;
+ }
+ }
+
+ // Merge the maps.
+ localIt->second.merge(otherIt->second);
+
+ // This variable is referenced in both sets so it will be marked as NOT last if
+ // resetBoth is set.
+ if (resetBoth) {
+ for (auto& [k, isLastRef] : localIt->second) {
+ isLastRef = false;
+ }
+ }
+ other.erase(otherIt++);
+ } else {
+ ++otherIt;
+ }
+ }
+ varLastRefs.merge(other);
+ uassert(6624098, "varLastRefs must be empty", other.empty());
+ }
+
+ /**
+ * Records collected last variable references for a specific variable.
+ */
+ void finalizeLastRefs(const std::string& name) {
+ if (auto it = varLastRefs.find(name); it != varLastRefs.end()) {
+ for (auto& [var, isLastRef] : it->second) {
+ if (isLastRef) {
+ lastRefs.emplace(var);
+ }
+ }
+
+ // After the finalization the map is not needed anymore.
+ varLastRefs.erase(it);
+ }
+ }
+
+ /**
+ * This is a destructive merge, the 'others' will be siphoned out.
+ */
+ void merge(std::vector<CollectedInfo>&& others) {
+ for (auto& other : others) {
+ merge(std::move(other));
+ }
+ }
+
+ /**
+ * A special merge asserting that the 'other' has no defined projections. Expressions do not
+ * project anything, only Nodes do.
+ *
+ * We still have to track free variables though.
+ */
+ void mergeNoDefs(CollectedInfo&& other) {
+ other.assertEmptyDefs();
+ merge(std::move(other));
+ }
+
+ static ProjectionNameSet getProjections(const DefinitionsMap& defs) {
+ ProjectionNameSet result;
+
+ for (auto&& [k, v] : defs) {
+ result.emplace(k);
+ }
+ return result;
+ }
+
+ ProjectionNameSet getProjections() const {
+ return getProjections(defs);
+ }
+
+ void resolveFreeVars(const ProjectionName& name, const Definition& def) {
+ if (auto it = freeVars.find(name); it != freeVars.end()) {
+ for (const auto var : it->second) {
+ useMap.emplace(var, def);
+ }
+ freeVars.erase(it);
+ }
+ }
+
+ void assertEmptyDefs() {
+ uassert(6624028, "Definitions must be empty", defs.empty());
+ }
+};
+
+/**
+ * Collect all Variables into a set.
+ */
+class VariableCollector {
+public:
+ template <typename T, typename... Ts>
+ void transport(const T& /*op*/, Ts&&... /*ts*/) {}
+
+ void transport(const Variable& op) {
+ _result._variables.emplace(&op);
+ }
+
+ void transport(const LambdaAbstraction& op, const ABT& /*bind*/) {
+ _result._definedVars.insert(op.varName());
+ }
+
+ void transport(const Let& op, const ABT& /*bind*/, const ABT& /*expr*/) {
+ _result._definedVars.insert(op.varName());
+ }
+
+ static VariableCollectorResult collect(const ABT& n) {
+ VariableCollector collector;
+ collector.collectInternal(n);
+ return std::move(collector._result);
+ }
+
+private:
+ void collectInternal(const ABT& n) {
+ algebra::transport<false>(n, *this);
+ }
+
+ VariableCollectorResult _result;
+};
+
+struct Collector {
+ explicit Collector(const cascades::Memo* memo) : _memo(memo) {}
+
+ template <typename T, typename... Ts>
+ CollectedInfo transport(const ABT&, const T& op, Ts&&... ts) {
+ CollectedInfo result{};
+ (result.merge(std::forward<Ts>(ts)), ...);
+
+ if constexpr (std::is_base_of_v<Node, T>) {
+ result.nodeDefs[&op] = result.defs;
+ }
+
+ return result;
+ }
+
+ CollectedInfo transport(const ABT& n, const Variable& variable) {
+ CollectedInfo result{};
+
+ // Every variable starts as a free variable until it is resolved.
+ result.freeVars[variable.name()].push_back(&variable);
+
+ // Similarly, every variable starts as the last referencee until proven otherwise.
+ result.varLastRefs[variable.name()].emplace(&variable, true);
+
+ return result;
+ }
+
+ CollectedInfo transport(const ABT& n,
+ const Let& let,
+ CollectedInfo bindResult,
+ CollectedInfo inResult) {
+ CollectedInfo result{};
+
+ inResult.mergeKeepLastRefs(std::move(bindResult.varLastRefs));
+ inResult.finalizeLastRefs(let.varName());
+
+ result.merge(std::move(bindResult));
+
+ // Local variables are not part of projections (i.e. we do not track them in defs) so
+ // resolve any free variables manually.
+ inResult.resolveFreeVars(let.varName(), Definition{n.ref(), let.bind().ref()});
+ result.merge(std::move(inResult));
+
+ return result;
+ }
+
+ CollectedInfo transport(const ABT& n, const LambdaAbstraction& lam, CollectedInfo inResult) {
+ CollectedInfo result{};
+
+ inResult.finalizeLastRefs(lam.varName());
+ // Local variables are not part of projections (i.e. we do not track them in defs) so
+ // resolve any free variables manually.
+ inResult.resolveFreeVars(lam.varName(), Definition{n.ref(), ABT::reference_type{}});
+ result.merge(std::move(inResult));
+
+ return result;
+ }
+
+ CollectedInfo transport(const ABT& n,
+ const If&,
+ CollectedInfo condResult,
+ CollectedInfo thenResult,
+ CollectedInfo elseResult) {
+
+ CollectedInfo result{};
+
+
+ result.unionLastRefs(std::move(thenResult.varLastRefs));
+ result.unionLastRefs(std::move(elseResult.varLastRefs));
+ result.mergeKeepLastRefs(std::move(condResult.varLastRefs));
+
+ result.merge(std::move(condResult));
+ result.merge(std::move(thenResult));
+ result.merge(std::move(elseResult));
+
+ return result;
+ }
+
+ static CollectedInfo collectForScan(const ABT& n,
+ const Node& node,
+ const ExpressionBinder& binder,
+ CollectedInfo refs) {
+ CollectedInfo result{};
+
+ result.mergeNoDefs(std::move(refs));
+
+ for (size_t i = 0; i < binder.names().size(); i++) {
+ result.defs[binder.names()[i]] = Definition{n.ref(), binder.exprs()[i].ref()};
+ }
+ result.nodeDefs[&node] = result.defs;
+
+ return result;
+ }
+
+ CollectedInfo transport(const ABT& n, const ScanNode& node, CollectedInfo /*bindResult*/) {
+ return collectForScan(n, node, node.binder(), {});
+ }
+
+ CollectedInfo transport(const ABT& n, const ValueScanNode& node, CollectedInfo /*bindResult*/) {
+ return collectForScan(n, node, node.binder(), {});
+ }
+
+ CollectedInfo transport(const ABT& n,
+ const PhysicalScanNode& node,
+ CollectedInfo /*bindResult*/) {
+ return collectForScan(n, node, node.binder(), {});
+ }
+
+ CollectedInfo transport(const ABT& n, const IndexScanNode& node, CollectedInfo /*bindResult*/) {
+ return collectForScan(n, node, node.binder(), {});
+ }
+
+ CollectedInfo transport(const ABT& n,
+ const SeekNode& node,
+ CollectedInfo /*bindResult*/,
+ CollectedInfo refResult) {
+ return collectForScan(n, node, node.binder(), std::move(refResult));
+ }
+
+ CollectedInfo transport(const ABT& n,
+ const MemoLogicalDelegatorNode& memoLogicalDelegatorNode) {
+ CollectedInfo result{};
+
+ uassert(6624029, "Uninitialized memo", _memo);
+
+ auto& group = _memo->getGroup(memoLogicalDelegatorNode.getGroupId());
+
+ auto& projectionNames = group.binder().names();
+ auto& projections = group.binder().exprs();
+ for (size_t i = 0; i < projectionNames.size(); i++) {
+ result.defs[projectionNames.at(i)] = Definition{n.ref(), projections[i].ref()};
+ }
+
+ result.nodeDefs[&memoLogicalDelegatorNode] = result.defs;
+
+ return result;
+ }
+
+ CollectedInfo transport(const ABT& n,
+ const EvaluationNode& evaluationNode,
+ CollectedInfo childResult,
+ CollectedInfo exprResult) {
+ CollectedInfo result{};
+
+ // Make the definition available upstream.
+ uassert(6624030,
+ str::stream() << "Cannot overwrite project " << evaluationNode.getProjectionName(),
+ childResult.defs.count(evaluationNode.getProjectionName()) == 0);
+
+ result.merge(std::move(childResult));
+ result.mergeNoDefs(std::move(exprResult));
+
+ result.defs[evaluationNode.getProjectionName()] =
+ Definition{n.ref(), evaluationNode.getProjection().ref()};
+
+ result.nodeDefs[&evaluationNode] = result.defs;
+
+ return result;
+ }
+
+ CollectedInfo transport(const ABT& n,
+ const SargableNode& node,
+ CollectedInfo childResult,
+ CollectedInfo bindResult,
+ CollectedInfo /*refResult*/) {
+ CollectedInfo result{};
+
+ result.merge(std::move(childResult));
+ result.mergeNoDefs(std::move(bindResult));
+
+ const auto& projectionNames = node.binder().names();
+ const auto& projections = node.binder().exprs();
+ for (size_t i = 0; i < projectionNames.size(); i++) {
+ result.defs[projectionNames.at(i)] = Definition{n.ref(), projections[i].ref()};
+ }
+
+ result.nodeDefs[&node] = result.defs;
+
+ return result;
+ }
+
+ CollectedInfo transport(const ABT& n,
+ const RIDIntersectNode& node,
+ CollectedInfo leftChildResult,
+ CollectedInfo rightChildResult) {
+ CollectedInfo result{};
+
+ rightChildResult.defs.erase(node.getScanProjectionName());
+
+ result.merge(std::move(leftChildResult));
+ result.merge(std::move(rightChildResult));
+
+ result.nodeDefs[&node] = result.defs;
+
+ return result;
+ }
+
+ CollectedInfo transport(const ABT& n,
+ const BinaryJoinNode& binaryJoinNode,
+ CollectedInfo leftChildResult,
+ CollectedInfo rightChildResult,
+ CollectedInfo filterResult) {
+ CollectedInfo result{};
+
+ {
+ const ProjectionNameSet& leftProjections = leftChildResult.getProjections();
+ for (const ProjectionName& boundProjectionName :
+ binaryJoinNode.getCorrelatedProjectionNames()) {
+ uassert(6624099,
+ "Correlated projections must exist in left child.",
+ leftProjections.find(boundProjectionName) != leftProjections.cend());
+ }
+ }
+
+ result.merge(std::move(leftChildResult));
+ result.merge(std::move(rightChildResult));
+ result.mergeNoDefs(std::move(filterResult));
+
+ result.nodeDefs[&binaryJoinNode] = result.defs;
+
+ return result;
+ }
+
+ CollectedInfo transport(const ABT& n,
+ const UnionNode& unionNode,
+ std::vector<CollectedInfo> childResults,
+ CollectedInfo bindResult,
+ CollectedInfo refsResult) {
+ CollectedInfo result{};
+
+ const auto& names = unionNode.binder().names();
+
+ refsResult.assertEmptyDefs();
+
+ // Merge children but disregard any defined projections.
+ // Note that refsResult follows the structure as built by buildUnionReferences.
+ size_t counter = 0;
+ for (auto& u : childResults) {
+ // Manually copy and resolve references of specific child.
+ for (const auto& name : names) {
+ uassert(6624031, "Union projection does not exist", u.defs.count(name) != 0);
+ u.useMap.emplace(refsResult.freeVars[name][counter], u.defs[name]);
+ }
+ u.defs.clear();
+ result.merge(std::move(u));
+ ++counter;
+ }
+
+ result.mergeNoDefs(std::move(bindResult));
+
+ // Propagate union projections.
+ const auto& defs = unionNode.binder().exprs();
+ for (size_t idx = 0; idx < names.size(); ++idx) {
+ result.defs[names[idx]] = Definition{n.ref(), defs[idx].ref()};
+ }
+
+ result.nodeDefs[&unionNode] = result.defs;
+
+ return result;
+ }
+
+ CollectedInfo transport(const ABT& n,
+ const GroupByNode& groupNode,
+ CollectedInfo childResult,
+ CollectedInfo bindAggResult,
+ CollectedInfo refsAggResult,
+ CollectedInfo bindGbResult,
+ CollectedInfo refsGbResult) {
+ CollectedInfo result{};
+
+ // First resolve all variables from the inside point of view; i.e. agg expressions and group
+ // by expressions reference variables from the input child.
+ result.merge(std::move(refsAggResult));
+ result.merge(std::move(refsGbResult));
+ // Make a copy of 'childResult' as we need it later and 'merge' is destructive.
+ result.merge(CollectedInfo{childResult});
+
+ // GroupBy completely masks projected variables; i.e. outside expressions cannot reach
+ // inside the groupby. We will create a brand new set of projections from aggs and gbs here.
+ result.defs.clear();
+
+ const auto& aggs = groupNode.getAggregationProjectionNames();
+ const auto& gbs = groupNode.getGroupByProjectionNames();
+ for (size_t idx = 0; idx < aggs.size(); ++idx) {
+ uassert(6624032,
+ "Aggregation overwrites a child projection",
+ childResult.defs.count(aggs[idx]) == 0);
+ result.defs[aggs[idx]] =
+ Definition{n.ref(), groupNode.getAggregationProjections()[idx].ref()};
+ }
+
+ for (size_t idx = 0; idx < gbs.size(); ++idx) {
+ uassert(6624033,
+ "Group-by projection does not exist",
+ childResult.defs.count(gbs[idx]) != 0);
+ result.defs[gbs[idx]] =
+ Definition{n.ref(), groupNode.getGroupByProjections()[idx].ref()};
+ }
+
+ result.mergeNoDefs(std::move(bindAggResult));
+ result.mergeNoDefs(std::move(bindGbResult));
+
+ result.nodeDefs[&groupNode] = result.defs;
+
+ return result;
+ }
+
+ CollectedInfo transport(const ABT& n,
+ const UnwindNode& unwindNode,
+ CollectedInfo childResult,
+ CollectedInfo bindResult,
+ CollectedInfo refsResult) {
+ CollectedInfo result{};
+
+ // First resolve all variables from the inside point of view.
+ result.merge(std::move(refsResult));
+ result.merge(std::move(childResult));
+
+ const auto& name = unwindNode.getProjectionName();
+ uassert(6624034, "Unwind projection does not exist", result.defs.count(name) != 0);
+
+ // Redefine unwind projection.
+ result.defs[name] = Definition{n.ref(), unwindNode.getProjection().ref()};
+ // Define unwind PID.
+ result.defs[unwindNode.getPIDProjectionName()] =
+ Definition{n.ref(), unwindNode.getPIDProjection().ref()};
+
+ result.mergeNoDefs(std::move(bindResult));
+
+ result.nodeDefs[&unwindNode] = result.defs;
+
+ return result;
+ }
+
+ CollectedInfo collect(const ABT& n) {
+ return algebra::transport<true>(n, *this);
+ }
+
+private:
+ const cascades::Memo* _memo;
+};
+
+VariableEnvironment VariableEnvironment::build(const ABT& root, const cascades::Memo* memo) {
+
+ Collector c(memo);
+ auto info = std::make_unique<CollectedInfo>(c.collect(root));
+
+ // std::cout << "useMap size " << info.useMap.size() << "\n";
+ // std::cout << "defs size " << info.defs.size() << "\n";
+ // std::cout << "freeVars size " << info.freeVars.size() << "\n";
+
+ return VariableEnvironment{std::move(info), memo};
+}
+
+void VariableEnvironment::rebuild(const ABT& root) {
+ _info = std::make_unique<CollectedInfo>(Collector{_memo}.collect(root));
+}
+
+VariableEnvironment::VariableEnvironment(std::unique_ptr<CollectedInfo> info,
+ const cascades::Memo* memo)
+ : _info(std::move(info)), _memo(memo) {}
+
+VariableEnvironment::~VariableEnvironment() {}
+
+Definition VariableEnvironment::getDefinition(const Variable* var) const {
+ auto it = _info->useMap.find(var);
+ if (it == _info->useMap.end()) {
+ return Definition();
+ }
+
+ return it->second;
+}
+
+const DefinitionsMap& VariableEnvironment::getDefinitions(const Node* node) const {
+ auto it = _info->nodeDefs.find(node);
+ uassert(6624035, "node does not exist", it != _info->nodeDefs.end());
+
+ return it->second;
+}
+
+bool VariableEnvironment::hasDefinitions(const Node* node) const {
+ return _info->nodeDefs.find(node) != _info->nodeDefs.cend();
+}
+
+ProjectionNameSet VariableEnvironment::getProjections(const Node* node) const {
+ return CollectedInfo::getProjections(getDefinitions(node));
+}
+
+const DefinitionsMap& VariableEnvironment::getDefinitions(ABT::reference_type node) const {
+ uassert(6624036, "Invalid node type", node.is<Node>());
+ return getDefinitions(node.cast<Node>());
+}
+
+bool VariableEnvironment::hasDefinitions(ABT::reference_type node) const {
+ uassert(6624037, "Invalid node type", node.is<Node>());
+ return hasDefinitions(node.cast<Node>());
+}
+
+ProjectionNameSet VariableEnvironment::topLevelProjections() const {
+ return _info->getProjections();
+}
+
+bool VariableEnvironment::hasFreeVariables() const {
+ return !_info->freeVars.empty();
+}
+
+size_t VariableEnvironment::freeOccurences(const std::string& variable) const {
+ auto it = _info->freeVars.find(variable);
+ if (it == _info->freeVars.end()) {
+ return 0;
+ }
+
+ return it->second.size();
+}
+
+bool VariableEnvironment::isLastRef(const Variable* var) const {
+ if (_info->lastRefs.count(var)) {
+ return true;
+ }
+
+ return false;
+}
+
+VariableCollectorResult VariableEnvironment::getVariables(const ABT& n) {
+ return VariableCollector::collect(n);
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/reference_tracker.h b/src/mongo/db/query/optimizer/reference_tracker.h
new file mode 100644
index 00000000000..8d0e96e11aa
--- /dev/null
+++ b/src/mongo/db/query/optimizer/reference_tracker.h
@@ -0,0 +1,113 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/node.h"
+
+namespace mongo::optimizer {
+/**
+ * Every Variable ABT conceptually references to a point in the ABT tree. The pointed tree is the
+ * definition of the variable.
+ */
+struct Definition {
+ /**
+ * Pointer to ABT that defines the variable. It can be any Node (e.g. ScanNode, EvaluationNode,
+ * etc.) or Expr (e.g. let expression, lambda expression).
+ */
+ ABT::reference_type definedBy;
+
+ /**
+ * Pointer to actual definition of variable.
+ */
+ ABT::reference_type definition;
+};
+
+namespace cascades {
+class Memo;
+}
+
+struct CollectedInfo;
+using DefinitionsMap = opt::unordered_map<ProjectionName, Definition>;
+
+struct VariableCollectorResult {
+ opt::unordered_set<const Variable*> _variables;
+ // TODO: consider using a variable environment instance for this, but does not seem to be always
+ // viable, especially with rewrites.
+ opt::unordered_set<std::string> _definedVars;
+};
+
+class VariableEnvironment {
+ VariableEnvironment(std::unique_ptr<CollectedInfo> info, const cascades::Memo* memo);
+
+public:
+ /**
+ * Build the environment for the given ABT tree. The environment is valid as long as the tree
+ * does not change. More specifically, if a variable defining node is removed from the tree then
+ * the environment becomes stale and has to be rebuild.
+ */
+ static VariableEnvironment build(const ABT& root, const cascades::Memo* memo = nullptr);
+ void rebuild(const ABT& root);
+
+ ~VariableEnvironment();
+
+ /**
+ *
+ */
+ Definition getDefinition(const Variable* var) const;
+
+ /**
+ * We may revisit what we return from here.
+ */
+ ProjectionNameSet topLevelProjections() const;
+
+ const DefinitionsMap& getDefinitions(const Node* node) const;
+
+ /**
+ * Per node projection names
+ */
+ ProjectionNameSet getProjections(const Node* node) const;
+ bool hasDefinitions(const Node* node) const;
+
+ const DefinitionsMap& getDefinitions(ABT::reference_type node) const;
+ bool hasDefinitions(ABT::reference_type node) const;
+
+ bool hasFreeVariables() const;
+ size_t freeOccurences(const std::string& variable) const;
+
+ bool isLastRef(const Variable* var) const;
+
+ static VariableCollectorResult getVariables(const ABT& n);
+
+private:
+ std::unique_ptr<CollectedInfo> _info;
+ const cascades::Memo* _memo{nullptr};
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/rewrites/const_eval.cpp b/src/mongo/db/query/optimizer/rewrites/const_eval.cpp
new file mode 100644
index 00000000000..723bd29400d
--- /dev/null
+++ b/src/mongo/db/query/optimizer/rewrites/const_eval.cpp
@@ -0,0 +1,565 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/rewrites/const_eval.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer {
+bool ConstEval::optimize(ABT& n) {
+ invariant(_letRefs.empty());
+ invariant(_projectRefs.empty());
+ invariant(_singleRef.empty());
+ invariant(_noRefProj.empty());
+ invariant(!_inRefBlock);
+ invariant(_inCostlyCtx == 0);
+ invariant(_staleDefs.empty());
+ invariant(_staleABTs.empty());
+ invariant(_seenProjects.empty());
+ invariant(_inlinedDefs.empty());
+
+ _changed = false;
+
+ // We run the transport<true> that will pass the reference to ABT to specific transport
+ // functions. The reference serves as a conceptual 'this' pointer allowing the transport
+ // function to change the node itself.
+ algebra::transport<true>(n, *this);
+
+ // Test if there are any projections with no references. If so remove them from the tree
+ removeUnusedEvalNodes();
+
+ invariant(_letRefs.empty());
+ invariant(_projectRefs.empty());
+
+ while (_changed) {
+ _env.rebuild(n);
+
+ if (_singleRef.empty() && _noRefProj.empty()) {
+ break;
+ }
+ _changed = false;
+ algebra::transport<true>(n, *this);
+ removeUnusedEvalNodes();
+ }
+
+ // TODO: should we be clearing here?
+ _singleRef.clear();
+
+ _staleDefs.clear();
+ _staleABTs.clear();
+ return _changed;
+}
+
+void ConstEval::removeUnusedEvalNodes() {
+ for (auto&& [k, v] : _projectRefs) {
+ if (v.size() == 0) {
+ // Schedule node replacement as it has not references.
+ _noRefProj.emplace(k);
+ _changed = true;
+ } else if (v.size() == 1) {
+ // Do not inline nodes which can become Sargable.
+ // TODO: consider caching.
+ // TODO: consider deriving IndexingAvailability.
+ if (!_disableSargableInlining ||
+ !convertExprToPartialSchemaReq(k->getProjection())._success) {
+ // Schedule node inlining as there is exactly one reference.
+ _singleRef.emplace(v.front());
+ _changed = true;
+ }
+ }
+ }
+
+ _projectRefs.clear();
+ _seenProjects.clear();
+ _inlinedDefs.clear();
+}
+
+void ConstEval::transport(ABT& n, const Variable& var) {
+ auto def = _env.getDefinition(&var);
+
+ if (!def.definition.empty()) {
+ // See if we have already manipulated this definition and if so then use the newer version.
+ if (auto it = _staleDefs.find(def.definition); it != _staleDefs.end()) {
+ def.definition = it->second;
+ }
+ if (auto it = _staleDefs.find(def.definedBy); it != _staleDefs.end()) {
+ def.definedBy = it->second;
+ }
+
+ if (auto constant = def.definition.cast<Constant>(); constant && !_inRefBlock) {
+ // If we find the definition and it is a simple constant then substitute the variable.
+ swapAndUpdate(n, def.definition);
+ } else if (auto variable = def.definition.cast<Variable>(); variable && !_inRefBlock) {
+ // This is a indirection to another variable. So we can skip, but first remember that we
+ // inlined this variable so that we won't try to replace it with a common expression and
+ // revert the inlining.
+ _inlinedDefs.emplace(def.definition);
+ swapAndUpdate(n, def.definition);
+ } else if (_singleRef.erase(&var)) {
+ // If this is the only reference to some expression then substitute the variable, but
+ // first remember that we inlined this expression so that we won't try to replace it
+ // with a common expression and revert the inlining.
+ _inlinedDefs.emplace(def.definition);
+ swapAndUpdate(n, def.definition);
+ } else if (auto let = def.definedBy.cast<Let>(); let) {
+ invariant(_letRefs.count(let));
+ _letRefs[let].emplace_back(&var);
+ } else if (auto project = def.definedBy.cast<EvaluationNode>(); project) {
+ invariant(_projectRefs.count(project));
+ _projectRefs[project].emplace_back(&var);
+
+ // If we are in the ref block we do not want to inline even if there is only a single
+ // reference. Similarly, we do not want to inline any variable under traverse.
+ if (_inRefBlock || _inCostlyCtx > 0) {
+ _projectRefs[project].emplace_back(&var);
+ }
+ }
+ }
+}
+
+void ConstEval::prepare(ABT&, const Let& let) {
+ _letRefs[&let] = {};
+}
+
+void ConstEval::transport(ABT& n, const Let& let, ABT& bind, ABT& in) {
+ auto& letRefs = _letRefs[&let];
+ if (letRefs.size() == 0) {
+ // The bind expressions has not been referenced so it is dead code and the whole let
+ // expression can be removed; i.e. we implement a following rewrite:
+ //
+ // n == let var=<bind expr> in <in expr>
+ //
+ // v
+ //
+ // n == <in expr>
+
+ // We don't want to make a copy of 'in' as it may be arbitrarily large. Also, we cannot
+ // move it out as it is part of the Let object and we do not want to invalidate any
+ // assumptions the Let may have about its structure. Hence we swap it for the "special"
+ // Blackhole object. The Blackhole does nothing, it just plugs the hole left in the 'in'
+ // place.
+ auto result = std::exchange(in, make<Blackhole>());
+
+ // Swap the current node (n) for the result.
+ swapAndUpdate(n, std::move(result));
+ } else if (letRefs.size() == 1) {
+ // The bind expression has been referenced exactly once so schedule it for inlining.
+ _singleRef.emplace(letRefs.front());
+ _changed = true;
+ }
+ _letRefs.erase(&let);
+}
+
+void ConstEval::transport(ABT& n, const LambdaApplication& app, ABT& lam, ABT& arg) {
+ // If the 'lam' expression is LambdaAbstraction then we can do the inplace beta reduction.
+ // TODO - missing alpha conversion so for now assume globally unique names.
+ if (auto lambda = lam.cast<LambdaAbstraction>(); lambda) {
+ auto result = make<Let>(lambda->varName(),
+ std::exchange(arg, make<Blackhole>()),
+ std::exchange(lambda->getBody(), make<Blackhole>()));
+
+ swapAndUpdate(n, std::move(result));
+ }
+}
+
+namespace fold_helpers {
+using namespace sbe::value;
+
+template <class T>
+sbe::value::Value constFoldNumberHelper(const sbe::value::TypeTags lhsTag,
+ const sbe::value::Value lhsValue,
+ const TypeTags rhsTag,
+ const sbe::value::Value rhsValue) {
+ const auto result = numericCast<T>(lhsTag, lhsValue) + numericCast<T>(rhsTag, rhsValue);
+ return bitcastFrom<T>(result);
+}
+
+template <>
+sbe::value::Value constFoldNumberHelper<Decimal128>(const TypeTags lhsTag,
+ const sbe::value::Value lhsValue,
+ const TypeTags rhsTag,
+ const sbe::value::Value rhsValue) {
+ const auto result =
+ numericCast<Decimal128>(lhsTag, lhsValue).add(numericCast<Decimal128>(rhsTag, rhsValue));
+ return makeCopyDecimal(result).second;
+}
+
+} // namespace fold_helpers
+
+// Specific transport for binary operation
+// The const correctness is probably wrong (as const ABT& lhs, const ABT& rhs does not work for
+// some reason but we can fix it later).
+void ConstEval::transport(ABT& n, const BinaryOp& op, ABT& lhs, ABT& rhs) {
+ using namespace fold_helpers;
+
+ switch (op.op()) {
+ case Operations::Add: {
+ // Let say we want to recognize ConstLhs + ConstRhs and replace it with the result of
+ // addition.
+ auto lhsConst = lhs.cast<Constant>();
+ auto rhsConst = rhs.cast<Constant>();
+
+ if (lhsConst && rhsConst) {
+ auto [lhsTag, lhsValue] = lhsConst->get();
+ auto [rhsTag, rhsValue] = rhsConst->get();
+
+ if (isNumber(lhsTag) && isNumber(rhsTag)) {
+ // So this is the addition operation and both arguments are number constants,
+ // hence we can compute the result.
+
+ const TypeTags resultType = getWidestNumericalType(lhsTag, rhsTag);
+ sbe::value::Value resultValue;
+
+ switch (resultType) {
+ case TypeTags::NumberInt32: {
+ resultValue =
+ constFoldNumberHelper<int32_t>(lhsTag, lhsValue, rhsTag, rhsValue);
+ break;
+ }
+
+ case TypeTags::NumberInt64: {
+ resultValue =
+ constFoldNumberHelper<int64_t>(lhsTag, lhsValue, rhsTag, rhsValue);
+ break;
+ }
+
+ case TypeTags::NumberDouble: {
+ resultValue =
+ constFoldNumberHelper<double>(lhsTag, lhsValue, rhsTag, rhsValue);
+ break;
+ }
+
+ case TypeTags::NumberDecimal: {
+ resultValue = constFoldNumberHelper<Decimal128>(
+ lhsTag, lhsValue, rhsTag, rhsValue);
+ break;
+ }
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+
+ // And this is the crucial step - we swap the current node (n) for the result.
+ swapAndUpdate(n, make<Constant>(resultType, resultValue));
+ }
+ }
+ break;
+ }
+
+ case Operations::Or: {
+ // Nothing and short-circuiting semantics of the 'or' operation in SBE allow us to
+ // interrogate 'lhs' only.
+ if (auto lhsConst = lhs.cast<Constant>(); lhsConst) {
+ auto [lhsTag, lhsValue] = lhsConst->get();
+ if (lhsTag == sbe::value::TypeTags::Boolean &&
+ !sbe::value::bitcastTo<bool>(lhsValue)) {
+ // false || rhs -> rhs
+ swapAndUpdate(n, std::exchange(rhs, make<Blackhole>()));
+ } else if (lhsTag == sbe::value::TypeTags::Boolean &&
+ sbe::value::bitcastTo<bool>(lhsValue)) {
+ // true || rhs -> true
+ swapAndUpdate(n, Constant::boolean(true));
+ }
+ }
+ break;
+ }
+
+ case Operations::And: {
+ // Nothing and short-circuiting semantics of the 'and' operation in SBE allow us to
+ // interrogate 'lhs' only.
+ if (auto lhsConst = lhs.cast<Constant>(); lhsConst) {
+ auto [lhsTag, lhsValue] = lhsConst->get();
+ if (lhsTag == sbe::value::TypeTags::Boolean &&
+ !sbe::value::bitcastTo<bool>(lhsValue)) {
+ // false && rhs -> false
+ swapAndUpdate(n, Constant::boolean(false));
+ } else if (lhsTag == sbe::value::TypeTags::Boolean &&
+ sbe::value::bitcastTo<bool>(lhsValue)) {
+ // true && rhs -> rhs
+ swapAndUpdate(n, std::exchange(rhs, make<Blackhole>()));
+ }
+ }
+ break;
+ }
+
+ case Operations::Eq: {
+ if (lhs == rhs) {
+ // If the subtrees are equal, we can conclude that their result is equal because we
+ // have only pure functions.
+ swapAndUpdate(n, Constant::boolean(true));
+ } else if (lhs.is<Constant>() && rhs.is<Constant>()) {
+ // We have two constants which are not equal.
+ swapAndUpdate(n, Constant::boolean(false));
+ }
+ break;
+ }
+
+ case Operations::Lt:
+ case Operations::Lte:
+ case Operations::Gt:
+ case Operations::Gte:
+ case Operations::Cmp3w: {
+ const auto lhsConst = lhs.cast<Constant>();
+ const auto rhsConst = rhs.cast<Constant>();
+
+ if (lhsConst) {
+ const auto [lhsTag, lhsVal] = lhsConst->get();
+
+ if (rhsConst) {
+ const auto [rhsTag, rhsVal] = rhsConst->get();
+
+ const auto [compareTag, compareVal] =
+ sbe::value::compareValue(lhsTag, lhsVal, rhsTag, rhsVal);
+ const auto cmpVal = sbe::value::bitcastTo<int32_t>(compareVal);
+
+ switch (op.op()) {
+ case Operations::Lt:
+ swapAndUpdate(n, Constant::boolean(cmpVal < 0));
+ break;
+ case Operations::Lte:
+ swapAndUpdate(n, Constant::boolean(cmpVal <= 0));
+ break;
+ case Operations::Gt:
+ swapAndUpdate(n, Constant::boolean(cmpVal > 0));
+ break;
+ case Operations::Gte:
+ swapAndUpdate(n, Constant::boolean(cmpVal >= 0));
+ break;
+ case Operations::Cmp3w:
+ swapAndUpdate(n, Constant::int32(cmpVal));
+ break;
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+ } else {
+ if (lhsTag == sbe::value::TypeTags::MinKey) {
+ switch (op.op()) {
+ case Operations::Lte:
+ swapAndUpdate(n, Constant::boolean(true));
+ break;
+ case Operations::Gt:
+ swapAndUpdate(n, Constant::boolean(false));
+ break;
+
+ default:
+ break;
+ }
+ } else if (lhsTag == sbe::value::TypeTags::MaxKey) {
+ switch (op.op()) {
+ case Operations::Lt:
+ swapAndUpdate(n, Constant::boolean(false));
+ break;
+ case Operations::Gte:
+ swapAndUpdate(n, Constant::boolean(true));
+ break;
+
+ default:
+ break;
+ }
+ }
+ }
+ } else if (rhsConst) {
+ const auto [rhsTag, rhsVal] = rhsConst->get();
+
+ if (rhsTag == sbe::value::TypeTags::MinKey) {
+ switch (op.op()) {
+ case Operations::Lt:
+ swapAndUpdate(n, Constant::boolean(false));
+ break;
+
+ case Operations::Gte:
+ swapAndUpdate(n, Constant::boolean(true));
+ break;
+
+ default:
+ break;
+ }
+ } else if (rhsTag == sbe::value::TypeTags::MaxKey) {
+ switch (op.op()) {
+ case Operations::Lte:
+ swapAndUpdate(n, Constant::boolean(true));
+ break;
+
+ case Operations::Gt:
+ swapAndUpdate(n, Constant::boolean(false));
+ break;
+
+ default:
+ break;
+ }
+ }
+ }
+ }
+
+ default:
+ // Not implemented.
+ break;
+ }
+}
+
+void ConstEval::transport(ABT& n, const FunctionCall& op, std::vector<ABT>& args) {
+ // We can simplify exists(constant) to true if the said constant is not Nothing.
+ if (op.name() == "exists" && args.size() == 1 && args[0].is<Constant>()) {
+ auto [tag, val] = args[0].cast<Constant>()->get();
+ if (tag != sbe::value::TypeTags::Nothing) {
+ swapAndUpdate(n, Constant::boolean(true));
+ }
+ }
+
+ if (op.name() == "newArray") {
+ bool allConstants = true;
+ for (const ABT& arg : op.nodes()) {
+ if (!arg.is<Constant>()) {
+ allConstants = false;
+ break;
+ }
+ }
+
+ if (allConstants) {
+ // All arguments are constants. Replace with an array constant.
+
+ sbe::value::Array array;
+ for (const ABT& arg : op.nodes()) {
+ auto [tag, val] = arg.cast<Constant>()->get();
+ // Copy the value before inserting into the array.
+ auto [tagCopy, valCopy] = sbe::value::copyValue(tag, val);
+ array.push_back(tagCopy, valCopy);
+ }
+
+ auto [tag, val] = sbe::value::makeCopyArray(array);
+ swapAndUpdate(n, make<Constant>(tag, val));
+ }
+ }
+}
+
+void ConstEval::transport(ABT& n, const If& op, ABT& cond, ABT& thenBranch, ABT& elseBranch) {
+ // If the condition is a boolean constant we can simplify.
+ if (auto condConst = cond.cast<Constant>(); condConst) {
+ auto [condTag, condValue] = condConst->get();
+ if (condTag == sbe::value::TypeTags::Boolean && sbe::value::bitcastTo<bool>(condValue)) {
+ // if true -> thenBranch
+ swapAndUpdate(n, std::exchange(thenBranch, make<Blackhole>()));
+ } else if (condTag == sbe::value::TypeTags::Boolean &&
+ !sbe::value::bitcastTo<bool>(condValue)) {
+ // if false -> elseBranch
+ swapAndUpdate(n, std::exchange(elseBranch, make<Blackhole>()));
+ }
+ }
+}
+
+void ConstEval::prepare(ABT&, const PathTraverse&) {
+ ++_inCostlyCtx;
+}
+
+void ConstEval::transport(ABT&, const PathTraverse&, ABT&) {
+ --_inCostlyCtx;
+}
+
+void ConstEval::prepare(ABT&, const LambdaAbstraction&) {
+ ++_inCostlyCtx;
+}
+
+void ConstEval::transport(ABT&, const LambdaAbstraction&, ABT&) {
+ --_inCostlyCtx;
+}
+
+void ConstEval::transport(ABT& n, const EvaluationNode& op, ABT& child, ABT& expr) {
+ if (_noRefProj.erase(&op)) {
+ // The evaluation node is unused so replace it with its own child.
+ if (_erasedProjNames != nullptr) {
+ _erasedProjNames->insert(op.getProjectionName());
+ }
+
+ // First, pull out the child and put in a blackhole.
+ auto result = std::exchange(child, make<Blackhole>());
+
+ // Replace the evaluation node itself with the extracted child.
+ swapAndUpdate(n, std::move(result));
+ } else {
+ if (!_projectRefs.count(&op)) {
+ _projectRefs[&op] = {};
+ }
+
+ // Do not consider simple constants or variable references for elimination.
+ if (!op.getProjection().is<Constant>() && !op.getProjection().is<Variable>()) {
+ // Try to find a projection with the same expression as the current 'op' node and
+ // substitute it with a variable pointing to that source projection.
+ if (auto source = _seenProjects.find(&op); source != _seenProjects.end() &&
+ // Make sure that the matched projection is visible to the current 'op'.
+ _env.getProjections(&op).count((*source)->getProjectionName()) &&
+ // If we already inlined the matched projection, we don't want to use it as a source
+ // for common expression as it will negate the inlining.
+ !_inlinedDefs.count((*source)->getProjection().ref())) {
+ invariant(_projectRefs.count(*source));
+
+ auto var = make<Variable>((*source)->getProjectionName());
+ // Source now will have an extra reference from the newly constructed projection.
+ _projectRefs[*source].emplace_back(var.cast<Variable>());
+
+ auto newN = make<EvaluationNode>(op.getProjectionName(),
+ std::move(var),
+ std::exchange(child, make<Blackhole>()));
+ // The new projection node should inherit the references from the old node.
+ _projectRefs[newN.cast<EvaluationNode>()] = std::move(_projectRefs[&op]);
+ _projectRefs.erase(&op);
+
+ swapAndUpdate(n, std::move(newN));
+ } else {
+ _seenProjects.emplace(&op);
+ }
+ }
+ }
+}
+
+void ConstEval::prepare(ABT&, const References& refs) {
+ // It is structurally impossible to nest References nodes.
+ invariant(!_inRefBlock);
+ _inRefBlock = true;
+}
+void ConstEval::transport(ABT& n, const References& op, std::vector<ABT>&) {
+ invariant(_inRefBlock);
+ _inRefBlock = false;
+}
+
+void ConstEval::swapAndUpdate(ABT& n, ABT newN) {
+ // Record the mapping from the old to the new.
+ invariant(_staleDefs.count(n.ref()) == 0);
+ invariant(_staleDefs.count(newN.ref()) == 0);
+
+ _staleDefs[n.ref()] = newN.ref();
+
+ // Do the swap.
+ std::swap(n, newN);
+
+ // newN now contains the old ABT
+ _staleABTs.emplace_back(std::move(newN));
+
+ _changed = true;
+}
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/rewrites/const_eval.h b/src/mongo/db/query/optimizer/rewrites/const_eval.h
new file mode 100644
index 00000000000..42a1fdde77a
--- /dev/null
+++ b/src/mongo/db/query/optimizer/rewrites/const_eval.h
@@ -0,0 +1,121 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/reference_tracker.h"
+#include "mongo/db/query/optimizer/utils/abt_hash.h"
+
+namespace mongo::optimizer {
+/**
+ * This is an example rewriter that does constant evaluation in-place.
+ */
+class ConstEval {
+public:
+ ConstEval(VariableEnvironment& env,
+ const bool disableSargableInlining = false,
+ ProjectionNameSet* erasedProjNames = nullptr)
+ : _disableSargableInlining(disableSargableInlining),
+ _env(env),
+ _erasedProjNames(erasedProjNames) {}
+
+ // The default noop transport. Note the first ABT& parameter.
+ template <typename T, typename... Ts>
+ void transport(ABT&, const T&, Ts&&...) {}
+
+ void transport(ABT& n, const Variable& var);
+
+ void prepare(ABT&, const Let& let);
+ void transport(ABT& n, const Let& let, ABT&, ABT& in);
+ void transport(ABT& n, const LambdaApplication& app, ABT& lam, ABT& arg);
+ void prepare(ABT&, const LambdaAbstraction&);
+ void transport(ABT&, const LambdaAbstraction&, ABT&);
+
+ // Specific transport for binary operation
+ // The const correctness is probably wrong (as const ABT& lhs, const ABT& rhs does not work for
+ // some reason but we can fix it later).
+ void transport(ABT& n, const BinaryOp& op, ABT& lhs, ABT& rhs);
+ void transport(ABT& n, const FunctionCall& op, std::vector<ABT>& args);
+ void transport(ABT& n, const If& op, ABT& cond, ABT& thenBranch, ABT& elseBranch);
+ void transport(ABT& n, const EvaluationNode& op, ABT& child, ABT& expr);
+
+ void prepare(ABT&, const PathTraverse&);
+ void transport(ABT&, const PathTraverse&, ABT&);
+
+ void prepare(ABT&, const References& refs);
+ void transport(ABT& n, const References& op, std::vector<ABT>&);
+
+ // The tree is passed in as NON-const reference as we will be updating it.
+ bool optimize(ABT& n);
+
+private:
+ struct EvalNodeHash {
+ size_t operator()(const EvaluationNode* node) const {
+ return ABTHashGenerator::generate(node->getProjection());
+ }
+ };
+
+ struct EvalNodeCompare {
+ size_t operator()(const EvaluationNode* lhs, const EvaluationNode* rhs) const {
+ return lhs->getProjection() == rhs->getProjection();
+ }
+ };
+
+ struct RefHash {
+ size_t operator()(const ABT::reference_type& nodeRef) const {
+ return nodeRef.hash();
+ }
+ };
+
+ void swapAndUpdate(ABT& n, ABT newN);
+ void removeUnusedEvalNodes();
+
+ // Controls if we can inline certain EvaluationNodes.
+ const bool _disableSargableInlining;
+
+ VariableEnvironment& _env;
+ opt::unordered_set<const Variable*> _singleRef;
+ opt::unordered_set<const EvaluationNode*> _noRefProj;
+ opt::unordered_map<const Let*, std::vector<const Variable*>> _letRefs;
+ opt::unordered_map<const EvaluationNode*, std::vector<const Variable*>> _projectRefs;
+ opt::unordered_set<const EvaluationNode*, EvalNodeHash, EvalNodeCompare> _seenProjects;
+ opt::unordered_set<ABT::reference_type, RefHash> _inlinedDefs;
+ opt::unordered_map<ABT::reference_type, ABT::reference_type, RefHash> _staleDefs;
+ // We collect old ABTs in order to avoid the ABA problem.
+ std::vector<ABT> _staleABTs;
+
+ bool _inRefBlock{false};
+ size_t _inCostlyCtx{0};
+ bool _changed{false};
+
+ // Optionally collect projection names from erased Eval nodes.
+ ProjectionNameSet* _erasedProjNames;
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/rewrites/path.cpp b/src/mongo/db/query/optimizer/rewrites/path.cpp
new file mode 100644
index 00000000000..1378dadb6ae
--- /dev/null
+++ b/src/mongo/db/query/optimizer/rewrites/path.cpp
@@ -0,0 +1,346 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/rewrites/path.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer {
+ABT::reference_type PathFusion::follow(ABT::reference_type n) {
+ if (auto var = n.cast<Variable>(); var) {
+ auto def = _env.getDefinition(var);
+ if (!def.definition.empty() && !def.definition.is<Source>()) {
+ return follow(def.definition);
+ }
+ }
+
+ return n;
+}
+
+bool PathFusion::fuse(ABT& lhs, const ABT& rhs) {
+ if (auto rhsComposeM = rhs.cast<PathComposeM>(); rhsComposeM != nullptr) {
+ if (_info[rhsComposeM->getPath1().cast<PathSyntaxSort>()]._isConst) {
+ if (fuse(lhs, rhsComposeM->getPath1())) {
+ return true;
+ }
+ if (_info[rhsComposeM->getPath2().cast<PathSyntaxSort>()]._isConst &&
+ fuse(lhs, rhsComposeM->getPath2())) {
+ return true;
+ }
+ }
+ }
+
+ if (auto lhsGet = lhs.cast<PathGet>(); lhsGet != nullptr) {
+ if (auto rhsField = rhs.cast<PathField>();
+ rhsField != nullptr && lhsGet->name() == rhsField->name()) {
+ return fuse(lhsGet->getPath(), rhsField->getPath());
+ }
+
+ if (auto rhsKeep = rhs.cast<PathKeep>(); rhsKeep != nullptr) {
+ if (rhsKeep->getNames().count(lhsGet->name()) > 0) {
+ return true;
+ }
+ }
+ }
+
+ if (auto lhsTraverse = lhs.cast<PathTraverse>(); lhsTraverse != nullptr) {
+ if (auto rhsTraverse = rhs.cast<PathTraverse>(); rhsTraverse != nullptr) {
+ return fuse(lhsTraverse->getPath(), rhsTraverse->getPath());
+ }
+
+ auto rhsType = _info[rhs.cast<PathSyntaxSort>()]._type;
+ if (rhsType != Type::unknown && rhsType != Type::array) {
+ auto result = std::exchange(lhsTraverse->getPath(), make<Blackhole>());
+ std::swap(lhs, result);
+ return fuse(lhs, rhs);
+ }
+ }
+
+ if (lhs.is<PathIdentity>()) {
+ lhs = rhs;
+ return true;
+ }
+
+ if (rhs.is<PathLambda>()) {
+ lhs = make<PathComposeM>(std::move(lhs), rhs);
+ return true;
+ }
+
+ if (auto rhsConst = rhs.cast<PathConstant>(); rhsConst != nullptr) {
+ if (auto lhsCmp = lhs.cast<PathCompare>(); lhsCmp != nullptr) {
+ auto result = make<PathConstant>(make<BinaryOp>(
+ lhsCmp->op(),
+ make<BinaryOp>(Operations::Cmp3w, rhsConst->getConstant(), lhsCmp->getVal()),
+ Constant::int64(0)));
+
+ std::swap(lhs, result);
+ return true;
+ }
+
+ lhs = make<PathComposeM>(rhs, std::move(lhs));
+ return true;
+ }
+
+ return false;
+}
+
+bool PathFusion::optimize(ABT& root) {
+ _changed = false;
+ algebra::transport<true>(root, *this);
+
+ if (_changed) {
+ _env.rebuild(root);
+ }
+
+ return _changed;
+}
+
+void PathFusion::transport(ABT& n, const PathConstant& path, ABT& c) {
+ CollectedInfo ci;
+ if (auto exprC = path.getConstant().cast<Constant>(); exprC) {
+ // Let's see what we can determine from the constant expression
+ auto [tag, val] = exprC->get();
+ if (sbe::value::isObject(tag)) {
+ ci._type = Type::object;
+ } else if (sbe::value::isArray(tag)) {
+ ci._type = Type::array;
+ } else if (tag == sbe::value::TypeTags::Boolean) {
+ ci._type = Type::boolean;
+ } else if (tag == sbe::value::TypeTags::Nothing) {
+ ci._type = Type::nothing;
+ } else {
+ ci._type = Type::any;
+ }
+ }
+
+ ci._isConst = true;
+ _info[&path] = ci;
+}
+
+void PathFusion::transport(ABT& n, const PathCompare& path, ABT& c) {
+ CollectedInfo ci;
+
+ // TODO - follow up on Nothing and 3 value logic. Assume plain boolean for now.
+ ci._type = Type::boolean;
+ ci._isConst = _info[path.getVal().cast<PathSyntaxSort>()]._isConst;
+ _info[&path] = ci;
+}
+
+void PathFusion::transport(ABT& n, const PathGet& get, ABT& path) {
+ // Get "a" Const <c> -> Const <c>
+ if (auto constPath = path.cast<PathConstant>(); constPath) {
+ // Pull out the constant path
+ auto result = std::exchange(path, make<Blackhole>());
+
+ // And swap it for the current node
+ std::swap(n, result);
+
+ _changed = true;
+ } else {
+ auto it = _info.find(path.cast<PathSyntaxSort>());
+ uassert(6624129, "expected to find path", it != _info.end());
+
+ // Simply move the info from the child.
+ _info[&get] = it->second;
+ }
+}
+
+void PathFusion::transport(ABT& n, const PathField& field, ABT& path) {
+ auto it = _info.find(path.cast<PathSyntaxSort>());
+ uassert(6624130, "expected to find path", it != _info.end());
+
+ CollectedInfo ci;
+ if (it->second._type == Type::unknown) {
+ // We don't know anything about the child.
+ ci._type = Type::unknown;
+ } else if (it->second._type == Type::nothing) {
+ // We are setting a field to nothing (aka drop) hence we do not know what the result could
+ // be (i.e. it all depends on the input).
+ ci._type = Type::unknown;
+ } else {
+ // The child produces bona fide value hence this will become an object.
+ ci._type = Type::object;
+ }
+
+ ci._isConst = it->second._isConst;
+ _info[&field] = ci;
+}
+
+void PathFusion::transport(ABT& n, const PathTraverse& path, ABT& inner) {
+ // Traverse is completely dependent on its input and we cannot say anything about it.
+ CollectedInfo ci;
+ ci._type = Type::unknown;
+ ci._isConst = false;
+ _info[&path] = ci;
+}
+
+void PathFusion::transport(ABT& n, const PathComposeM& path, ABT& p1, ABT& p2) {
+ if (auto p1Const = p1.cast<PathConstant>(); p1Const != nullptr) {
+ switch (_kindCtx.back()) {
+ case Kind::filter:
+ n = make<PathConstant>(make<EvalFilter>(p2, p1Const->getConstant()));
+ _changed = true;
+ return;
+
+ case Kind::project:
+ n = make<PathConstant>(make<EvalPath>(p2, p1Const->getConstant()));
+ _changed = true;
+ return;
+
+ default:
+ MONGO_UNREACHABLE;
+ }
+ }
+
+ if (auto p1Get = p1.cast<PathGet>(); p1Get != nullptr && p1Get->getPath().is<PathIdentity>()) {
+ // TODO: handle chain of Gets.
+ n = make<PathGet>(p1Get->name(), std::move(p2));
+ _changed = true;
+ return;
+ }
+
+ if (p1.is<PathIdentity>()) {
+ // Id * p2 -> p2
+ auto result = std::exchange(p2, make<Blackhole>());
+ std::swap(n, result);
+ _changed = true;
+ return;
+ } else if (p2.is<PathIdentity>()) {
+ // p1 * Id -> p1
+ auto result = std::exchange(p1, make<Blackhole>());
+ std::swap(n, result);
+ _changed = true;
+ return;
+ } else if (_redundant.erase(p1.cast<PathSyntaxSort>())) {
+ auto result = std::exchange(p2, make<Blackhole>());
+ std::swap(n, result);
+ _changed = true;
+ return;
+ } else if (_redundant.erase(p2.cast<PathSyntaxSort>())) {
+ auto result = std::exchange(p1, make<Blackhole>());
+ std::swap(n, result);
+ _changed = true;
+ return;
+ }
+
+ auto p1InfoIt = _info.find(p1.cast<PathSyntaxSort>());
+ auto p2InfoIt = _info.find(p2.cast<PathSyntaxSort>());
+
+ uassert(6624131, "info must be defined", p1InfoIt != _info.end() && p2InfoIt != _info.end());
+
+ if (p1.is<PathDefault>() && p2InfoIt->second.isNotNothing()) {
+ // Default * Const e -> e (provided we can prove e is not Nothing and we can do that only
+ // when e is Constant expression)
+ auto result = std::exchange(p2, make<Blackhole>());
+ std::swap(n, result);
+ _changed = true;
+ } else if (p2.is<PathDefault>() && p1InfoIt->second.isNotNothing()) {
+ // Const e * Default -> e (provided we can prove e is not Nothing and we can do that only
+ // when e is Constant expression)
+ auto result = std::exchange(p1, make<Blackhole>());
+ std::swap(n, result);
+ _changed = true;
+ } else if (p2InfoIt->second._type == Type::object) {
+ auto left = collectComposed(p1);
+ for (auto l : left) {
+ if (l.is<PathObj>()) {
+ _redundant.emplace(l.cast<PathSyntaxSort>());
+ _changed = true;
+ }
+ }
+ _info[&path] = p2InfoIt->second;
+ } else {
+ _info[&path] = p2InfoIt->second;
+ }
+}
+
+void PathFusion::transport(ABT& n, const EvalPath& eval, ABT& path, ABT& input) {
+ auto realInput = follow(input);
+ // If we are evaluating const path then we can simply replace the whole expression with the
+ // result.
+ if (auto constPath = path.cast<PathConstant>(); constPath) {
+ // Pull out the constant out of the path
+ auto result = std::exchange(constPath->getConstant(), make<Blackhole>());
+
+ // And swap it for the current node
+ std::swap(n, result);
+
+ _changed = true;
+ } else if (auto evalInput = realInput.cast<EvalPath>(); evalInput) {
+ // An input to 'this' EvalPath expression is another EvalPath so we may try to fuse the
+ // paths.
+ if (fuse(n.cast<EvalPath>()->getPath(), evalInput->getPath())) {
+ // We have fused paths so replace the input (by making a copy).
+ input = evalInput->getInput();
+
+ _changed = true;
+ } else if (auto evalImmediateInput = input.cast<EvalPath>();
+ evalImmediateInput != nullptr) {
+ // Compose immediate EvalPath input.
+ n = make<EvalPath>(
+ make<PathComposeM>(std::move(evalImmediateInput->getPath()), std::move(path)),
+ std::move(evalImmediateInput->getInput()));
+
+ _changed = true;
+ }
+ }
+ _kindCtx.pop_back();
+}
+
+void PathFusion::transport(ABT& n, const EvalFilter& eval, ABT& path, ABT& input) {
+ auto realInput = follow(input);
+ // If we are evaluating const path then we can simply replace the whole expression with the
+ // result.
+ if (auto constPath = path.cast<PathConstant>(); constPath) {
+ // Pull out the constant out of the path
+ auto result = std::exchange(constPath->getConstant(), make<Blackhole>());
+
+ // And swap it for the current node
+ std::swap(n, result);
+
+ _changed = true;
+ } else if (auto evalImmediateInput = input.cast<EvalFilter>(); evalImmediateInput != nullptr) {
+ // Compose immediate EvalFilter input.
+ n = make<EvalFilter>(
+ make<PathComposeM>(std::move(evalImmediateInput->getPath()), std::move(path)),
+ std::move(evalImmediateInput->getInput()));
+
+ _changed = true;
+ } else if (auto evalInput = realInput.cast<EvalPath>(); evalInput) {
+ // An input to 'this' EvalFilter expression is another EvalPath so we may try to fuse the
+ // paths.
+ if (fuse(n.cast<EvalFilter>()->getPath(), evalInput->getPath())) {
+ // We have fused paths so replace the input (by making a copy).
+ input = evalInput->getInput();
+
+ _changed = true;
+ }
+ }
+ _kindCtx.pop_back();
+}
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/rewrites/path.h b/src/mongo/db/query/optimizer/rewrites/path.h
new file mode 100644
index 00000000000..6e53fe5a7b5
--- /dev/null
+++ b/src/mongo/db/query/optimizer/rewrites/path.h
@@ -0,0 +1,94 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/reference_tracker.h"
+
+namespace mongo::optimizer {
+class PathFusion {
+ enum class Type { unknown, nothing, object, array, boolean, any };
+ enum class Kind { project, filter };
+
+ struct CollectedInfo {
+ bool isNotNothing() const {
+ return _type != Type::unknown && _type != Type::nothing;
+ }
+
+ Type _type{Type::unknown};
+
+ // Is the result of the path independent of its input (e.g. can be if the paths terminates
+ // with PathConst, but not necessarily with PathIdentity).
+ bool _isConst{false};
+ };
+
+public:
+ PathFusion(VariableEnvironment& env) : _env(env) {}
+
+ template <typename T, typename... Ts>
+ void transport(ABT&, const T& op, Ts&&...) {
+ if constexpr (std::is_base_of_v<PathSyntaxSort, T>) {
+ _info[&op] = CollectedInfo{};
+ }
+ }
+
+ void transport(ABT& n, const PathConstant&, ABT& c);
+ void transport(ABT& n, const PathCompare&, ABT& c);
+ void transport(ABT& n, const PathGet&, ABT& path);
+ void transport(ABT& n, const PathField&, ABT& path);
+ void transport(ABT& n, const PathTraverse&, ABT& inner);
+ void transport(ABT& n, const PathComposeM&, ABT& p1, ABT& p2);
+
+ void prepare(ABT& n, const EvalPath& eval) {
+ _kindCtx.push_back(Kind::project);
+ }
+ void transport(ABT& n, const EvalPath& eval, ABT& path, ABT& input);
+ void prepare(ABT& n, const EvalFilter& eval) {
+ _kindCtx.push_back(Kind::filter);
+ }
+ void transport(ABT& n, const EvalFilter& eval, ABT& path, ABT& input);
+
+ bool optimize(ABT& root);
+
+private:
+ ABT::reference_type follow(ABT::reference_type n);
+ ABT::reference_type follow(const ABT& n) {
+ return follow(n.ref());
+ }
+ bool fuse(ABT& lhs, const ABT& rhs);
+
+ VariableEnvironment& _env;
+ opt::unordered_map<const PathSyntaxSort*, CollectedInfo> _info;
+ opt::unordered_set<const PathSyntaxSort*> _redundant;
+
+ // A stack of context (either project or filter path)
+ std::vector<Kind> _kindCtx;
+ bool _changed{false};
+};
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/rewrites/path_lower.cpp b/src/mongo/db/query/optimizer/rewrites/path_lower.cpp
new file mode 100644
index 00000000000..c5eb294c6c8
--- /dev/null
+++ b/src/mongo/db/query/optimizer/rewrites/path_lower.cpp
@@ -0,0 +1,400 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/rewrites/path_lower.h"
+
+
+namespace mongo::optimizer {
+bool EvalPathLowering::optimize(ABT& n) {
+ _changed = false;
+
+ algebra::transport<true>(n, *this);
+
+ if (_changed) {
+ _env.rebuild(n);
+ }
+
+ return _changed;
+}
+
+void EvalPathLowering::transport(ABT& n, const PathConstant&, ABT& c) {
+ n = make<LambdaAbstraction>(_prefixId.getNextId("_"), std::exchange(c, make<Blackhole>()));
+ _changed = true;
+}
+
+void EvalPathLowering::transport(ABT& n, const PathIdentity&) {
+ const std::string& name = _prefixId.getNextId("x");
+
+ n = make<LambdaAbstraction>(name, make<Variable>(name));
+ _changed = true;
+}
+
+void EvalPathLowering::transport(ABT& n, const PathLambda&, ABT& lam) {
+ n = std::exchange(lam, make<Blackhole>());
+ _changed = true;
+}
+
+void EvalPathLowering::transport(ABT& n, const PathDefault&, ABT& c) {
+ // if (exists(x), x, c)
+ const std::string& name = _prefixId.getNextId("valDefault");
+
+ n = make<LambdaAbstraction>(
+ name,
+ make<If>(make<FunctionCall>("exists", makeSeq(make<Variable>(name))),
+ make<Variable>(name),
+ std::exchange(c, make<Blackhole>())));
+ _changed = true;
+}
+
+void EvalPathLowering::transport(ABT& n, const PathCompare&, ABT& c) {
+ uasserted(6624132, "cannot lower compare in projection");
+}
+
+void EvalPathLowering::transport(ABT& n, const PathGet& p, ABT& inner) {
+ const std::string& name = _prefixId.getNextId("inputGet");
+
+ n = make<LambdaAbstraction>(
+ name,
+ make<LambdaApplication>(
+ std::exchange(inner, make<Blackhole>()),
+ make<FunctionCall>("getField",
+ makeSeq(make<Variable>(name), Constant::str(p.name())))));
+ _changed = true;
+}
+
+void EvalPathLowering::transport(ABT& n, const PathDrop& drop) {
+ // if (isObject(x), dropFields(x,...) , x)
+ // Alternatively, we can implement a special builtin function that does the comparison and drop.
+ const std::string& name = _prefixId.getNextId("valDrop");
+
+ std::vector<ABT> params;
+ params.emplace_back(make<Variable>(name));
+ for (const auto& fieldName : drop.getNames()) {
+ params.emplace_back(Constant::str(fieldName));
+ }
+
+ n = make<LambdaAbstraction>(
+ name,
+ make<If>(make<FunctionCall>("isObject", makeSeq(make<Variable>(name))),
+ make<FunctionCall>("dropFields", std::move(params)),
+ make<Variable>(name)));
+ _changed = true;
+}
+
+void EvalPathLowering::transport(ABT& n, const PathKeep& keep) {
+ // if (isObject(x), keepFields(x,...) , x)
+ // Alternatively, we can implement a special builtin function that does the comparison and drop.
+ const std::string& name = _prefixId.getNextId("valKeep");
+
+ std::vector<ABT> params;
+ params.emplace_back(make<Variable>(name));
+ for (const auto& fieldName : keep.getNames()) {
+ params.emplace_back(Constant::str(fieldName));
+ }
+
+ n = make<LambdaAbstraction>(
+ name,
+ make<If>(make<FunctionCall>("isObject", makeSeq(make<Variable>(name))),
+ make<FunctionCall>("keepFields", std::move(params)),
+ make<Variable>(name)));
+ _changed = true;
+}
+
+void EvalPathLowering::transport(ABT& n, const PathObj&) {
+ // if (isObject(x), x, Nothing)
+ const std::string& name = _prefixId.getNextId("valObj");
+
+ n = make<LambdaAbstraction>(
+ name,
+ make<If>(make<FunctionCall>("isObject", makeSeq(make<Variable>(name))),
+ make<Variable>(name),
+ Constant::nothing()));
+ _changed = true;
+}
+
+void EvalPathLowering::transport(ABT& n, const PathArr&) {
+ // if (isArray(x), x, Nothing)
+ const std::string& name = _prefixId.getNextId("valArr");
+
+ n = make<LambdaAbstraction>(
+ name,
+ make<If>(make<FunctionCall>("isArray", makeSeq(make<Variable>(name))),
+ make<Variable>(name),
+ Constant::nothing()));
+
+ _changed = true;
+}
+
+void EvalPathLowering::transport(ABT& n, const PathTraverse&, ABT& inner) {
+ const std::string& name = _prefixId.getNextId("valTraverse");
+
+ n = make<LambdaAbstraction>(
+ name,
+ make<FunctionCall>("traverseP",
+ makeSeq(make<Variable>(name), std::exchange(inner, make<Blackhole>()))));
+ _changed = true;
+}
+
+void EvalPathLowering::transport(ABT& n, const PathField& p, ABT& inner) {
+ const std::string& name = _prefixId.getNextId("inputField");
+ const std::string& val = _prefixId.getNextId("valField");
+
+ n = make<LambdaAbstraction>(
+ name,
+ make<Let>(
+ val,
+ make<LambdaApplication>(
+ std::exchange(inner, make<Blackhole>()),
+ make<FunctionCall>("getField",
+ makeSeq(make<Variable>(name), Constant::str(p.name())))),
+ make<If>(
+ make<BinaryOp>(Operations::Or,
+ make<FunctionCall>("exists", makeSeq(make<Variable>(val))),
+ make<FunctionCall>("isObject", makeSeq(make<Variable>(name)))),
+ make<FunctionCall>(
+ "setField",
+ makeSeq(make<Variable>(name), Constant::str(p.name()), make<Variable>(val))),
+ make<Variable>(name))));
+
+ _changed = true;
+}
+
+void EvalPathLowering::transport(ABT& n, const PathComposeM&, ABT& p1, ABT& p2) {
+ // p1 * p2 -> (p2 (p1 input))
+ const std::string& name = _prefixId.getNextId("inputComposeM");
+
+ n = make<LambdaAbstraction>(
+ name,
+ make<LambdaApplication>(
+ std::exchange(p2, make<Blackhole>()),
+ make<LambdaApplication>(std::exchange(p1, make<Blackhole>()), make<Variable>(name))));
+
+ _changed = true;
+}
+
+void EvalPathLowering::transport(ABT& n, const PathComposeA&, ABT& p1, ABT& p2) {
+ uasserted(6624133, "cannot lower additive composite in projection");
+}
+
+void EvalPathLowering::transport(ABT& n, const EvalPath&, ABT& path, ABT& input) {
+ // In order to completely dissolve EvalPath the incoming path must be lowered to an expression
+ // (lambda).
+ uassert(6624134, "Incomplete evalpath lowering", path.is<LambdaAbstraction>());
+
+ n = make<LambdaApplication>(std::exchange(path, make<Blackhole>()),
+ std::exchange(input, make<Blackhole>()));
+
+ _changed = true;
+}
+
+bool EvalFilterLowering::optimize(ABT& n) {
+ _changed = false;
+
+ algebra::transport<true>(n, *this);
+
+ if (_changed) {
+ _env.rebuild(n);
+ }
+
+ return _changed;
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathConstant&, ABT& c) {
+ n = make<LambdaAbstraction>(_prefixId.getNextId("_"), std::exchange(c, make<Blackhole>()));
+ _changed = true;
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathIdentity&) {
+ n = make<LambdaAbstraction>(_prefixId.getNextId("_"), Constant::boolean(true));
+ _changed = true;
+
+ // TODO - do we need an identity element for the additive composition? i.e. false constant
+ // Or should Identity be left undefined and removed by the PathFuse?
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathLambda&, ABT& lam) {
+ n = std::exchange(lam, make<Blackhole>());
+ _changed = true;
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathDefault&, ABT& c) {
+ uasserted(6624135, "cannot lower default in filter");
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathCompare& cmp, ABT& c) {
+ const std::string& name = _prefixId.getNextId("valCmp");
+
+ if (cmp.op() == Operations::Eq) {
+ n = make<LambdaAbstraction>(
+ name,
+ make<BinaryOp>(cmp.op(), make<Variable>(name), std::exchange(c, make<Blackhole>())));
+ } else {
+ n = make<LambdaAbstraction>(
+ name,
+ make<BinaryOp>(cmp.op(),
+ make<BinaryOp>(Operations::Cmp3w,
+ make<Variable>(name),
+ std::exchange(c, make<Blackhole>())),
+ Constant::int64(0)));
+ }
+
+ _changed = true;
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathGet& p, ABT& inner) {
+ const std::string& name = _prefixId.getNextId("inputGet");
+
+ int idx;
+ bool isNumber = NumberParser{}(p.name(), &idx).isOK();
+ n = make<LambdaAbstraction>(
+ name,
+ make<LambdaApplication>(
+ std::exchange(inner, make<Blackhole>()),
+ make<FunctionCall>(isNumber ? "getFieldOrElement" : "getField",
+ makeSeq(make<Variable>(name), Constant::str(p.name())))));
+ _changed = true;
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathDrop& drop) {
+ uasserted(6624136, "cannot lower drop in filter");
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathKeep& keep) {
+ uasserted(6624137, "cannot lower keep in filter");
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathObj&) {
+ const std::string& name = _prefixId.getNextId("valObj");
+ n = make<LambdaAbstraction>(name,
+ make<FunctionCall>("isObject", makeSeq(make<Variable>(name))));
+ _changed = true;
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathArr&) {
+ const std::string& name = _prefixId.getNextId("valArr");
+ n = make<LambdaAbstraction>(name, make<FunctionCall>("isArray", makeSeq(make<Variable>(name))));
+ _changed = true;
+}
+
+void EvalFilterLowering::prepare(ABT& n, const PathTraverse& t) {
+ int idx;
+ // This is a bad hack that detect if a child is number path element
+ if (auto child = t.getPath().cast<PathGet>();
+ child && NumberParser{}(child->name(), &idx).isOK()) {
+ _traverseStack.emplace_back(n.ref());
+ }
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathTraverse&, ABT& inner) {
+ const std::string& name = _prefixId.getNextId("valTraverse");
+
+ ABT numberPath = Constant::boolean(false);
+ if (!_traverseStack.empty() && _traverseStack.back() == n.ref()) {
+ numberPath = Constant::boolean(true);
+ _traverseStack.pop_back();
+ }
+ n = make<LambdaAbstraction>(name,
+ make<FunctionCall>("traverseF",
+ makeSeq(make<Variable>(name),
+ std::exchange(inner, make<Blackhole>()),
+ std::move(numberPath))));
+
+ _changed = true;
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathField& p, ABT& inner) {
+ uasserted(6624140, "cannot lower arr in filter");
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathComposeM&, ABT& p1, ABT& p2) {
+ const std::string& name = _prefixId.getNextId("inputComposeM");
+
+ n = make<LambdaAbstraction>(
+ name,
+ make<If>(
+ make<FunctionCall>("fillEmpty",
+ makeSeq(make<LambdaApplication>(std::exchange(p1, make<Blackhole>()),
+ make<Variable>(name)),
+ Constant::boolean(false))),
+ make<LambdaApplication>(std::exchange(p2, make<Blackhole>()), make<Variable>(name)),
+ Constant::boolean(false)));
+
+ _changed = true;
+}
+
+void EvalFilterLowering::transport(ABT& n, const PathComposeA&, ABT& p1, ABT& p2) {
+ const std::string& name = _prefixId.getNextId("inputComposeA");
+
+ n = make<LambdaAbstraction>(
+ name,
+ make<If>(
+ make<FunctionCall>("fillEmpty",
+ makeSeq(make<LambdaApplication>(std::exchange(p1, make<Blackhole>()),
+ make<Variable>(name)),
+ Constant::boolean(false))),
+ Constant::boolean(true),
+ make<LambdaApplication>(std::exchange(p2, make<Blackhole>()), make<Variable>(name))));
+
+ _changed = true;
+}
+
+void EvalFilterLowering::transport(ABT& n, const EvalFilter&, ABT& path, ABT& input) {
+ // In order to completely dissolve EvalFilter the incoming path must be lowered to an expression
+ // (lambda).
+ uassert(6624141, "Incomplete evalfilter lowering", path.is<LambdaAbstraction>());
+
+ n = make<LambdaApplication>(std::exchange(path, make<Blackhole>()),
+ std::exchange(input, make<Blackhole>()));
+
+ _changed = true;
+}
+
+void PathLowering::transport(ABT& n, const EvalPath&, ABT&, ABT&) {
+ _changed = _changed || _project.optimize(n);
+}
+
+void PathLowering::transport(ABT& n, const EvalFilter&, ABT&, ABT&) {
+ _changed = _changed || _filter.optimize(n);
+}
+
+bool PathLowering::optimize(ABT& n) {
+ _changed = false;
+
+ algebra::transport<true>(n, *this);
+
+ // TODO investigate why we crash when this is removed. It should not be needed here.
+ if (_changed) {
+ _env.rebuild(n);
+ }
+
+ return _changed;
+}
+
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/rewrites/path_lower.h b/src/mongo/db/query/optimizer/rewrites/path_lower.h
new file mode 100644
index 00000000000..59df0f50050
--- /dev/null
+++ b/src/mongo/db/query/optimizer/rewrites/path_lower.h
@@ -0,0 +1,158 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/reference_tracker.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+
+namespace mongo::optimizer {
+/**
+ * This class lowers projection paths (aka EvalPath) to simple expressions.
+ */
+class EvalPathLowering {
+public:
+ EvalPathLowering(PrefixId& prefixId, VariableEnvironment& env)
+ : _prefixId(prefixId), _env(env) {}
+
+ // The default noop transport.
+ template <typename T, typename... Ts>
+ void transport(ABT&, const T&, Ts&&...) {
+ static_assert(!std::is_base_of_v<PathSyntaxSort, T>,
+ "Path elements must define their transport");
+ }
+
+ void transport(ABT& n, const PathConstant&, ABT& c);
+ void transport(ABT& n, const PathIdentity&);
+ void transport(ABT& n, const PathLambda&, ABT& lam);
+ void transport(ABT& n, const PathDefault&, ABT& c);
+ void transport(ABT& n, const PathCompare&, ABT& c);
+ void transport(ABT& n, const PathDrop&);
+ void transport(ABT& n, const PathKeep&);
+ void transport(ABT& n, const PathObj&);
+ void transport(ABT& n, const PathArr&);
+
+ void transport(ABT& n, const PathTraverse&, ABT& inner);
+
+ void transport(ABT& n, const PathGet&, ABT& inner);
+ void transport(ABT& n, const PathField&, ABT& inner);
+
+ void transport(ABT& n, const PathComposeM&, ABT& p1, ABT& p2);
+ void transport(ABT& n, const PathComposeA&, ABT& p1, ABT& p2);
+
+ void transport(ABT& n, const EvalPath&, ABT& path, ABT& input);
+
+ // The tree is passed in as NON-const reference as we will be updating it.
+ // Returns true if the tree changed.
+ bool optimize(ABT& n);
+
+private:
+ // We don't own these.
+ PrefixId& _prefixId;
+ VariableEnvironment& _env;
+
+ bool _changed{false};
+};
+
+/**
+ * This class lowers match/filter paths (aka EvalFilter) to simple expressions.
+ */
+class EvalFilterLowering {
+public:
+ EvalFilterLowering(PrefixId& prefixId, VariableEnvironment& env)
+ : _prefixId(prefixId), _env(env) {}
+
+ // The default noop transport.
+ template <typename T, typename... Ts>
+ void transport(ABT&, const T&, Ts&&...) {
+ static_assert(!std::is_base_of_v<PathSyntaxSort, T>,
+ "Path elements must define their transport");
+ }
+
+ void transport(ABT& n, const PathConstant&, ABT& c);
+ void transport(ABT& n, const PathIdentity&);
+ void transport(ABT& n, const PathLambda&, ABT& lam);
+ void transport(ABT& n, const PathDefault&, ABT& c);
+ void transport(ABT& n, const PathCompare&, ABT& c);
+ void transport(ABT& n, const PathDrop&);
+ void transport(ABT& n, const PathKeep&);
+ void transport(ABT& n, const PathObj&);
+ void transport(ABT& n, const PathArr&);
+
+ void prepare(ABT& n, const PathTraverse& t);
+ void transport(ABT& n, const PathTraverse&, ABT& inner);
+
+ void transport(ABT& n, const PathGet&, ABT& inner);
+ void transport(ABT& n, const PathField&, ABT& inner);
+
+ void transport(ABT& n, const PathComposeM&, ABT& p1, ABT& p2);
+ void transport(ABT& n, const PathComposeA&, ABT& p1, ABT& p2);
+
+ void transport(ABT& n, const EvalFilter&, ABT& path, ABT& input);
+
+ // The tree is passed in as NON-const reference as we will be updating it.
+ // Returns true if the tree changed.
+ bool optimize(ABT& n);
+
+private:
+ // We don't own these.
+ PrefixId& _prefixId;
+ VariableEnvironment& _env;
+
+ std::vector<ABT::reference_type> _traverseStack;
+
+ bool _changed{false};
+};
+
+class PathLowering {
+public:
+ PathLowering(PrefixId& prefixId, VariableEnvironment& env)
+ : _prefixId(prefixId), _env(env), _project(_prefixId, _env), _filter(_prefixId, _env) {}
+
+ // The default noop transport.
+ template <typename T, typename... Ts>
+ void transport(ABT&, const T&, Ts&&...) {}
+
+ void transport(ABT& n, const EvalPath&, ABT&, ABT&);
+ void transport(ABT& n, const EvalFilter&, ABT&, ABT&);
+
+ bool optimize(ABT& n);
+
+private:
+ // We don't own these.
+ PrefixId& _prefixId;
+ VariableEnvironment& _env;
+
+ EvalPathLowering _project;
+ EvalFilterLowering _filter;
+
+ bool _changed{false};
+};
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/rewrites/path_optimizer_test.cpp b/src/mongo/db/query/optimizer/rewrites/path_optimizer_test.cpp
new file mode 100644
index 00000000000..23f2c1e8899
--- /dev/null
+++ b/src/mongo/db/query/optimizer/rewrites/path_optimizer_test.cpp
@@ -0,0 +1,881 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/explain.h"
+#include "mongo/db/query/optimizer/rewrites/const_eval.h"
+#include "mongo/db/query/optimizer/rewrites/path.h"
+#include "mongo/db/query/optimizer/rewrites/path_lower.h"
+#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo::optimizer {
+namespace {
+TEST(Path, Const) {
+ auto tree = make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("ptest"));
+ auto env = VariableEnvironment::build(tree);
+
+ auto fusor = PathFusion(env);
+ fusor.optimize(tree);
+
+ // The result must be Constant.
+ auto result = tree.cast<Constant>();
+ ASSERT(result != nullptr);
+
+ // And the value must be 2
+ ASSERT_EQ(result->getValueInt64(), 2);
+}
+
+TEST(Path, GetConst) {
+ // Get "any" Const 2
+ auto tree = make<EvalPath>(make<PathGet>("any", make<PathConstant>(Constant::int64(2))),
+ make<Variable>("ptest"));
+ auto env = VariableEnvironment::build(tree);
+
+ auto fusor = PathFusion(env);
+ fusor.optimize(tree);
+
+ // The result must be Constant.
+ auto result = tree.cast<Constant>();
+ ASSERT(result != nullptr);
+
+ // And the value must be 2
+ ASSERT_EQ(result->getValueInt64(), 2);
+}
+
+TEST(Path, Fuse1) {
+ // Field "a" Const 2
+ auto field = make<EvalPath>(make<PathField>("a", make<PathConstant>(Constant::int64(2))),
+ make<Variable>("root"));
+
+ // Get "a" Id
+ auto get = make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("x"));
+
+ // let x = (Field "a" Const 2 | root)
+ // in (Get "a" Id | x)
+ auto tree = make<Let>("x", std::move(field), std::move(get));
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathFusion{env}.optimize(tree)) {
+ changed = true;
+ }
+
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+
+ } while (changed);
+
+ // The result must be Constant.
+ auto result = tree.cast<Constant>();
+ ASSERT(result != nullptr);
+
+ // And the value must be 2
+ ASSERT_EQ(result->getValueInt64(), 2);
+}
+
+TEST(Path, Fuse2) {
+ auto scanNode = make<ScanNode>("root", "test");
+
+ // Field "a" Const 2
+ auto field = make<EvalPath>(make<PathField>("a", make<PathConstant>(Constant::int64(2))),
+ make<Variable>("root"));
+
+ auto project1 = make<EvaluationNode>("x", std::move(field), std::move(scanNode));
+
+ // Get "a" Id
+ auto get = make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("x"));
+ auto project2 = make<EvaluationNode>("y", std::move(get), std::move(project1));
+
+ auto tree = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"y"}},
+ std::move(project2));
+
+ auto env = VariableEnvironment::build(tree);
+ {
+ ProjectionNameSet projSet = env.topLevelProjections();
+ ProjectionNameSet expSet = {"x", "y", "root"};
+ ASSERT(expSet == projSet);
+ ASSERT(!env.hasFreeVariables());
+ }
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathFusion{env}.optimize(tree)) {
+ changed = true;
+ }
+
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+
+ } while (changed);
+
+ // After rewrites for x projection disappear from the tree.
+ {
+ ProjectionNameSet projSet = env.topLevelProjections();
+ ProjectionNameSet expSet = {"root", "y"};
+ ASSERT(expSet == projSet);
+ ASSERT(!env.hasFreeVariables());
+ }
+}
+
+TEST(Path, Fuse3) {
+ auto scanNode = make<ScanNode>("root", "test");
+
+ auto project0 = make<EvaluationNode>(
+ "z",
+ make<EvalPath>(make<PathGet>("z", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ // Field "a" Const Var "z"
+ auto field = make<EvalPath>(make<PathField>("a", make<PathConstant>(make<Variable>("z"))),
+ make<Variable>("root"));
+ auto project1 = make<EvaluationNode>("x", std::move(field), std::move(project0));
+
+ // Get "a" Traverse Const 2
+ auto get = make<EvalPath>(
+ make<PathGet>("a", make<PathTraverse>(make<PathConstant>(Constant::int64(2)))),
+ make<Variable>("x"));
+ auto project2 = make<EvaluationNode>("y", std::move(get), std::move(project1));
+
+ auto tree = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"y"}},
+ std::move(project2));
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " y\n"
+ " RefBlock: \n"
+ " Variable [y]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [y]\n"
+ " EvalPath []\n"
+ " PathGet [a]\n"
+ " PathTraverse []\n"
+ " PathConstant []\n"
+ " Const [2]\n"
+ " Variable [x]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [x]\n"
+ " EvalPath []\n"
+ " PathField [a]\n"
+ " PathConstant []\n"
+ " Variable [z]\n"
+ " Variable [root]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [z]\n"
+ " EvalPath []\n"
+ " PathGet [z]\n"
+ " PathIdentity []\n"
+ " Variable [root]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ tree);
+
+ auto env = VariableEnvironment::build(tree);
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathFusion{env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " y\n"
+ " RefBlock: \n"
+ " Variable [y]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [y]\n"
+ " EvalPath []\n"
+ " PathGet [z]\n"
+ " PathTraverse []\n"
+ " PathConstant []\n"
+ " Const [2]\n"
+ " Variable [root]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ tree);
+}
+
+TEST(Path, Fuse4) {
+ auto scanNode = make<ScanNode>("root", "test");
+
+ auto project0 = make<EvaluationNode>(
+ "z",
+ make<EvalPath>(make<PathGet>("z", make<PathIdentity>()), make<Variable>("root")),
+ std::move(scanNode));
+
+ auto project01 = make<EvaluationNode>(
+ "z1",
+ make<EvalPath>(make<PathGet>("z1", make<PathIdentity>()), make<Variable>("root")),
+ std::move(project0));
+
+ auto project02 = make<EvaluationNode>(
+ "z2",
+ make<EvalPath>(make<PathGet>("z2", make<PathIdentity>()), make<Variable>("root")),
+ std::move(project01));
+
+ // Field "a" Const Var "z" * Field "b" Const Var "z1"
+ auto field = make<EvalPath>(
+ make<PathComposeM>(
+ make<PathField>("c", make<PathConstant>(make<Variable>("z2"))),
+ make<PathComposeM>(make<PathField>("a", make<PathConstant>(make<Variable>("z"))),
+ make<PathField>("b", make<PathConstant>(make<Variable>("z1"))))),
+ make<Variable>("root"));
+ auto project1 = make<EvaluationNode>("x", std::move(field), std::move(project02));
+
+ // Get "a" Traverse Const 2
+ auto get = make<EvalPath>(
+ make<PathGet>("a", make<PathTraverse>(make<PathConstant>(Constant::int64(2)))),
+ make<Variable>("x"));
+ auto project2 = make<EvaluationNode>("y", std::move(get), std::move(project1));
+
+ auto tree = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"x", "y"}},
+ std::move(project2));
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " x\n"
+ " y\n"
+ " RefBlock: \n"
+ " Variable [x]\n"
+ " Variable [y]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [y]\n"
+ " EvalPath []\n"
+ " PathGet [a]\n"
+ " PathTraverse []\n"
+ " PathConstant []\n"
+ " Const [2]\n"
+ " Variable [x]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [x]\n"
+ " EvalPath []\n"
+ " PathComposeM []\n"
+ " PathField [c]\n"
+ " PathConstant []\n"
+ " Variable [z2]\n"
+ " PathComposeM []\n"
+ " PathField [a]\n"
+ " PathConstant []\n"
+ " Variable [z]\n"
+ " PathField [b]\n"
+ " PathConstant []\n"
+ " Variable [z1]\n"
+ " Variable [root]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [z2]\n"
+ " EvalPath []\n"
+ " PathGet [z2]\n"
+ " PathIdentity []\n"
+ " Variable [root]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [z1]\n"
+ " EvalPath []\n"
+ " PathGet [z1]\n"
+ " PathIdentity []\n"
+ " Variable [root]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [z]\n"
+ " EvalPath []\n"
+ " PathGet [z]\n"
+ " PathIdentity []\n"
+ " Variable [root]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ tree);
+
+ auto env = VariableEnvironment::build(tree);
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathFusion{env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " x\n"
+ " y\n"
+ " RefBlock: \n"
+ " Variable [x]\n"
+ " Variable [y]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [y]\n"
+ " EvalPath []\n"
+ " PathTraverse []\n"
+ " PathConstant []\n"
+ " Const [2]\n"
+ " Variable [z]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [x]\n"
+ " EvalPath []\n"
+ " PathComposeM []\n"
+ " PathField [c]\n"
+ " PathConstant []\n"
+ " EvalPath []\n"
+ " PathGet [z2]\n"
+ " PathIdentity []\n"
+ " Variable [root]\n"
+ " PathComposeM []\n"
+ " PathField [a]\n"
+ " PathConstant []\n"
+ " Variable [z]\n"
+ " PathField [b]\n"
+ " PathConstant []\n"
+ " EvalPath []\n"
+ " PathGet [z1]\n"
+ " PathIdentity []\n"
+ " Variable [root]\n"
+ " Variable [root]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [z]\n"
+ " EvalPath []\n"
+ " PathGet [z]\n"
+ " PathIdentity []\n"
+ " Variable [root]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ tree);
+}
+
+TEST(Path, Fuse5) {
+ auto scanNode = make<ScanNode>("root", "test");
+
+ auto project = make<EvaluationNode>(
+ "x",
+ make<EvalPath>(make<PathKeep>(PathKeep::NameSet{"a", "b", "c"}), make<Variable>("root")),
+ std::move(scanNode));
+
+ // Get "a" Traverse Compare= 2
+ auto filter = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathGet>(
+ "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2)))),
+ make<Variable>("x")),
+ std::move(project));
+
+ auto tree = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"x"}},
+ std::move(filter));
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " x\n"
+ " RefBlock: \n"
+ " Variable [x]\n"
+ " Filter []\n"
+ " EvalFilter []\n"
+ " PathGet [a]\n"
+ " PathTraverse []\n"
+ " PathCompare [Eq]\n"
+ " Const [2]\n"
+ " Variable [x]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [x]\n"
+ " EvalPath []\n"
+ " PathKeep [a, b, c]\n"
+ " Variable [root]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ tree);
+
+ auto env = VariableEnvironment::build(tree);
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathFusion{env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ // The filter now refers directly to the root projection.
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " x\n"
+ " RefBlock: \n"
+ " Variable [x]\n"
+ " Filter []\n"
+ " EvalFilter []\n"
+ " PathGet [a]\n"
+ " PathTraverse []\n"
+ " PathCompare [Eq]\n"
+ " Const [2]\n"
+ " Variable [root]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [x]\n"
+ " EvalPath []\n"
+ " PathKeep [a, b, c]\n"
+ " Variable [root]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ tree);
+}
+
+TEST(Path, Lower1) {
+ PrefixId prefixId;
+
+ auto tree = make<EvalPath>(make<PathIdentity>(), make<Variable>("foo"));
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ ASSERT(tree.is<Variable>());
+ ASSERT_EQ(tree.cast<Variable>()->name(), "foo");
+}
+
+TEST(Path, Lower2) {
+ PrefixId prefixId;
+
+ auto tree = make<EvalPath>(make<PathConstant>(Constant::int64(10)), make<Variable>("foo"));
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ ASSERT(tree.is<Constant>());
+ ASSERT_EQ(tree.cast<Constant>()->getValueInt64(), 10);
+}
+
+TEST(Path, Lower3) {
+ PrefixId prefixId;
+
+ auto tree = make<EvalPath>(
+ make<PathLambda>(make<LambdaAbstraction>(
+ "x", make<BinaryOp>(Operations::Add, make<Variable>("x"), Constant::int64(1)))),
+ Constant::int64(9));
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ ASSERT(tree.is<Constant>());
+ ASSERT_EQ(tree.cast<Constant>()->getValueInt64(), 10);
+}
+
+TEST(Path, Lower4) {
+ PrefixId prefixId;
+
+ auto tree = make<EvalPath>(
+ make<PathGet>(
+ "fieldA",
+ make<PathGet>("fieldB",
+ /*make<PathIdentity>()*/ make<PathConstant>(Constant::int64(100)))),
+ make<Variable>("rootObj"));
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ ASSERT(tree.is<Constant>());
+ ASSERT_EQ(tree.cast<Constant>()->getValueInt64(), 100);
+}
+
+TEST(Path, Lower5) {
+ PrefixId prefixId;
+
+ auto tree = make<EvalPath>(
+ make<PathGet>(
+ "fieldA",
+ make<PathGet>(
+ "fieldB",
+ make<PathLambda>(make<LambdaAbstraction>(
+ "x",
+ make<BinaryOp>(Operations::Add, make<Variable>("x"), Constant::int64(1)))))),
+ make<Variable>("rootObj"));
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ ASSERT_EXPLAIN(
+ "BinaryOp [Add]\n"
+ " FunctionCall [getField]\n"
+ " FunctionCall [getField]\n"
+ " Variable [rootObj]\n"
+ " Const [\"fieldA\"]\n"
+ " Const [\"fieldB\"]\n"
+ " Const [1]\n",
+ tree);
+}
+
+TEST(Path, ProjElim1) {
+ PrefixId prefixId;
+
+ auto scanNode = make<ScanNode>("root", "test");
+
+ auto expr1 = make<FunctionCall>("anyFunctionWillDo", makeSeq(make<Variable>("root")));
+ auto project1 = make<EvaluationNode>("x", std::move(expr1), std::move(scanNode));
+
+ auto expr2 = make<Variable>("x");
+ auto project2 = make<EvaluationNode>("y", std::move(expr2), std::move(project1));
+
+ auto tree = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"y"}},
+ std::move(project2));
+
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " y\n"
+ " RefBlock: \n"
+ " Variable [y]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [y]\n"
+ " FunctionCall [anyFunctionWillDo]\n"
+ " Variable [root]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ tree);
+}
+
+TEST(Path, ProjElim2) {
+ PrefixId prefixId;
+
+ auto scanNode = make<ScanNode>("root", "test");
+
+ auto expr1 = make<FunctionCall>("anyFunctionWillDo", makeSeq(make<Variable>("root")));
+ auto project1 = make<EvaluationNode>("x", std::move(expr1), std::move(scanNode));
+
+ auto expr2 = make<Variable>("x");
+ auto project2 = make<EvaluationNode>("y", std::move(expr2), std::move(project1));
+
+ auto tree = make<RootNode>(properties::ProjectionRequirement{{}}, std::move(project2));
+
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " RefBlock: \n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ tree);
+}
+
+TEST(Path, ProjElim3) {
+ auto node = make<ScanNode>("root", "test");
+ std::string var = "root";
+ for (int i = 0; i < 100; ++i) {
+ std::string newVar = "p" + std::to_string(i);
+ node = make<EvaluationNode>(
+ newVar,
+ // make<FunctionCall>("anyFunctionWillDo", makeSeq(make<Variable>(var))),
+ make<Variable>(var),
+ std::move(node));
+ var = newVar;
+ }
+
+ auto tree = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{var}},
+ std::move(node));
+
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ ASSERT_EXPLAIN(
+ "Root []\n"
+ " projections: \n"
+ " p99\n"
+ " RefBlock: \n"
+ " Variable [p99]\n"
+ " Evaluation []\n"
+ " BindBlock:\n"
+ " [p99]\n"
+ " Variable [root]\n"
+ " Scan [test]\n"
+ " BindBlock:\n"
+ " [root]\n"
+ " Source []\n",
+ tree);
+}
+
+TEST(Path, Lower6) {
+ PrefixId prefixId;
+
+ auto tree = make<EvalPath>(
+ make<PathGet>("fieldA", make<PathGet>("fieldB", make<PathDefault>(Constant::int64(0)))),
+ make<Variable>("rootObj"));
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ ASSERT_EXPLAIN(
+ "Let [valDefault_0]\n"
+ " FunctionCall [getField]\n"
+ " FunctionCall [getField]\n"
+ " Variable [rootObj]\n"
+ " Const [\"fieldA\"]\n"
+ " Const [\"fieldB\"]\n"
+ " If []\n"
+ " FunctionCall [exists]\n"
+ " Variable [valDefault_0]\n"
+ " Variable [valDefault_0]\n"
+ " Const [0]\n",
+ tree);
+}
+
+TEST(Path, Lower7) {
+ PrefixId prefixId;
+
+ auto tree = make<EvalPath>(make<PathGet>("fieldA",
+ make<PathTraverse>(make<PathGet>(
+ "fieldB", make<PathDefault>(Constant::int64(0))))),
+ make<Variable>("rootObj"));
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ // Add some asserts on the shape of the tree or something.
+}
+
+TEST(Path, Lower8) {
+ PrefixId prefixId;
+
+ auto tree = make<EvalPath>(
+ make<PathComposeM>(make<PathIdentity>(), make<PathConstant>(Constant::int64(100))),
+ make<Variable>("rootObj"));
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ ASSERT(tree.is<Constant>());
+ ASSERT_EQ(tree.cast<Constant>()->getValueInt64(), 100);
+}
+
+TEST(Path, Lower9) {
+ PrefixId prefixId;
+
+ auto tree = make<EvalPath>(make<PathComposeM>(make<PathGet>("fieldA", make<PathIdentity>()),
+ make<PathConstant>(Constant::int64(100))),
+ make<Variable>("rootObj"));
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ ASSERT(tree.is<Constant>());
+ ASSERT_EQ(tree.cast<Constant>()->getValueInt64(), 100);
+}
+
+TEST(Path, Lower10) {
+ PrefixId prefixId;
+
+ auto tree = make<EvalPath>(
+ make<PathField>(
+ "fieldA",
+ make<PathTraverse>(make<PathField>("fieldB", make<PathDefault>(Constant::int64(0))))),
+ make<Variable>("rootObj"));
+ auto env = VariableEnvironment::build(tree);
+
+ // Run rewriters while things change
+ bool changed = false;
+ do {
+ changed = false;
+ if (PathLowering{prefixId, env}.optimize(tree)) {
+ changed = true;
+ }
+ if (ConstEval{env}.optimize(tree)) {
+ changed = true;
+ }
+ } while (changed);
+
+ // Add some asserts on the shape of the tree or something.
+}
+
+} // namespace
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/syntax/expr.cpp b/src/mongo/db/query/optimizer/syntax/expr.cpp
new file mode 100644
index 00000000000..1dcd85de61a
--- /dev/null
+++ b/src/mongo/db/query/optimizer/syntax/expr.cpp
@@ -0,0 +1,128 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/syntax/expr.h"
+#include "mongo/db/query/optimizer/node.h"
+
+namespace mongo::optimizer {
+
+using namespace sbe::value;
+
+Constant::Constant(TypeTags tag, sbe::value::Value val) : _tag(tag), _val(val) {}
+
+Constant::Constant(const Constant& other) {
+ auto [tag, val] = copyValue(other._tag, other._val);
+ _tag = tag;
+ _val = val;
+}
+
+Constant::Constant(Constant&& other) noexcept {
+ _tag = other._tag;
+ _val = other._val;
+
+ other._tag = TypeTags::Nothing;
+ other._val = 0;
+}
+
+ABT Constant::str(std::string str) {
+ // Views are non-owning so we have to make a copy.
+ auto [tag, val] = makeNewString(str);
+ return make<Constant>(tag, val);
+}
+
+ABT Constant::int32(int32_t valueInt32) {
+ return make<Constant>(TypeTags::NumberInt32, bitcastFrom<int32_t>(valueInt32));
+}
+
+ABT Constant::int64(int64_t valueInt64) {
+ return make<Constant>(TypeTags::NumberInt64, bitcastFrom<int64_t>(valueInt64));
+}
+
+ABT Constant::fromDouble(double value) {
+ return make<Constant>(TypeTags::NumberDouble, bitcastFrom<double>(value));
+}
+
+ABT Constant::emptyObject() {
+ auto [tag, val] = makeNewObject();
+ return make<Constant>(tag, val);
+}
+
+ABT Constant::emptyArray() {
+ auto [tag, val] = makeNewArray();
+ return make<Constant>(tag, val);
+}
+
+ABT Constant::nothing() {
+ return make<Constant>(TypeTags::Nothing, 0);
+}
+
+ABT Constant::null() {
+ return make<Constant>(TypeTags::Null, 0);
+}
+
+ABT Constant::boolean(bool b) {
+ return make<Constant>(TypeTags::Boolean, bitcastFrom<bool>(b));
+}
+
+ABT Constant::minKey() {
+ return make<Constant>(TypeTags::MinKey, 0);
+}
+
+ABT Constant::maxKey() {
+ return make<Constant>(TypeTags::MaxKey, 0);
+}
+
+Constant::~Constant() {
+ releaseValue(_tag, _val);
+}
+
+bool Constant::operator==(const Constant& other) const {
+ const auto [compareTag, compareVal] = compareValue(_tag, _val, other._tag, other._val);
+ return sbe::value::bitcastTo<int32_t>(compareVal) == 0;
+}
+
+bool Constant::isValueInt64() const {
+ return _tag == TypeTags::NumberInt64;
+}
+
+int64_t Constant::getValueInt64() const {
+ uassert(6624057, "Constant value type is not int64_t", isValueInt64());
+ return bitcastTo<int64_t>(_val);
+}
+
+bool Constant::isValueInt32() const {
+ return _tag == TypeTags::NumberInt32;
+}
+
+int32_t Constant::getValueInt32() const {
+ uassert(6624354, "Constant value type is not int32_t", isValueInt32());
+ return bitcastTo<int32_t>(_val);
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/syntax/expr.h b/src/mongo/db/query/optimizer/syntax/expr.h
new file mode 100644
index 00000000000..a795625d59c
--- /dev/null
+++ b/src/mongo/db/query/optimizer/syntax/expr.h
@@ -0,0 +1,350 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <ostream>
+
+#include "mongo/db/exec/sbe/values/value.h"
+#include "mongo/db/query/optimizer/syntax/syntax.h"
+
+namespace mongo::optimizer {
+
+class Constant final : public Operator<Constant, 0>, public ExpressionSyntaxSort {
+public:
+ Constant(sbe::value::TypeTags tag, sbe::value::Value val);
+
+ static ABT str(std::string str);
+
+ static ABT int32(int32_t valueInt32);
+ static ABT int64(int64_t valueInt64);
+
+ static ABT fromDouble(double value);
+
+ static ABT emptyObject();
+ static ABT emptyArray();
+
+ static ABT nothing();
+ static ABT null();
+
+ static ABT boolean(bool b);
+
+ static ABT minKey();
+ static ABT maxKey();
+
+ ~Constant();
+
+ Constant(const Constant& other);
+ Constant(Constant&& other) noexcept;
+
+ bool operator==(const Constant& other) const;
+
+ auto get() const {
+ return std::pair{_tag, _val};
+ }
+
+ bool isValueInt64() const;
+ int64_t getValueInt64() const;
+
+ bool isValueInt32() const;
+ int32_t getValueInt32() const;
+
+ bool isNumber() const {
+ return sbe::value::isNumber(_tag);
+ }
+
+ bool isNothing() const {
+ return _tag == sbe::value::TypeTags::Nothing;
+ }
+
+private:
+ sbe::value::TypeTags _tag;
+ sbe::value::Value _val;
+};
+
+class Variable final : public Operator<Variable, 0>, public ExpressionSyntaxSort {
+ std::string _name;
+
+public:
+ Variable(std::string inName) : _name(std::move(inName)) {}
+
+ bool operator==(const Variable& other) const {
+ return _name == other._name;
+ }
+
+ auto& name() const {
+ return _name;
+ }
+};
+
+class UnaryOp final : public Operator<UnaryOp, 1>, public ExpressionSyntaxSort {
+ using Base = Operator<UnaryOp, 1>;
+ Operations _op;
+
+public:
+ UnaryOp(Operations inOp, ABT inExpr) : Base(std::move(inExpr)), _op(inOp) {
+ assertExprSort(getChild());
+ }
+
+ bool operator==(const UnaryOp& other) const {
+ return _op == other._op && getChild() == other.getChild();
+ }
+
+ auto op() const {
+ return _op;
+ }
+
+ const ABT& getChild() const {
+ return get<0>();
+ }
+};
+
+class BinaryOp final : public Operator<BinaryOp, 2>, public ExpressionSyntaxSort {
+ using Base = Operator<BinaryOp, 2>;
+ Operations _op;
+
+public:
+ BinaryOp(Operations inOp, ABT inLhs, ABT inRhs)
+ : Base(std::move(inLhs), std::move(inRhs)), _op(inOp) {
+ assertExprSort(getLeftChild());
+ assertExprSort(getRightChild());
+ }
+
+ bool operator==(const BinaryOp& other) const {
+ return _op == other._op && getLeftChild() == other.getLeftChild() &&
+ getRightChild() == other.getRightChild();
+ }
+
+ auto op() const {
+ return _op;
+ }
+
+ const ABT& getLeftChild() const {
+ return get<0>();
+ }
+
+ const ABT& getRightChild() const {
+ return get<1>();
+ }
+};
+
+class If final : public Operator<If, 3>, public ExpressionSyntaxSort {
+ using Base = Operator<If, 3>;
+
+public:
+ If(ABT inCond, ABT inThen, ABT inElse)
+ : Base(std::move(inCond), std::move(inThen), std::move(inElse)) {
+ assertExprSort(getCondChild());
+ assertExprSort(getThenChild());
+ assertExprSort(getElseChild());
+ }
+
+ bool operator==(const If& other) const {
+ return getCondChild() == other.getCondChild() && getThenChild() == other.getThenChild() &&
+ getElseChild() == other.getElseChild();
+ }
+
+ const ABT& getCondChild() const {
+ return get<0>();
+ }
+
+ const ABT& getThenChild() const {
+ return get<1>();
+ }
+
+ const ABT& getElseChild() const {
+ return get<2>();
+ }
+};
+
+class Let final : public Operator<Let, 2>, public ExpressionSyntaxSort {
+ using Base = Operator<Let, 2>;
+
+ std::string _varName;
+
+public:
+ Let(std::string var, ABT inBind, ABT inExpr)
+ : Base(std::move(inBind), std::move(inExpr)), _varName(std::move(var)) {
+ assertExprSort(bind());
+ assertExprSort(in());
+ }
+
+ bool operator==(const Let& other) const {
+ return _varName == other._varName && bind() == other.bind() && in() == other.in();
+ }
+
+ auto& varName() const {
+ return _varName;
+ }
+
+ const ABT& bind() const {
+ return get<0>();
+ }
+
+ const ABT& in() const {
+ return get<1>();
+ }
+};
+
+class LambdaAbstraction final : public Operator<LambdaAbstraction, 1>, public ExpressionSyntaxSort {
+ using Base = Operator<LambdaAbstraction, 1>;
+
+ std::string _varName;
+
+public:
+ LambdaAbstraction(std::string var, ABT inBody)
+ : Base(std::move(inBody)), _varName(std::move(var)) {
+ assertExprSort(getBody());
+ }
+
+ bool operator==(const LambdaAbstraction& other) const {
+ return _varName == other._varName && getBody() == other.getBody();
+ }
+
+ auto& varName() const {
+ return _varName;
+ }
+
+ const ABT& getBody() const {
+ return get<0>();
+ }
+
+ ABT& getBody() {
+ return get<0>();
+ }
+};
+
+class LambdaApplication final : public Operator<LambdaApplication, 2>, public ExpressionSyntaxSort {
+ using Base = Operator<LambdaApplication, 2>;
+
+public:
+ LambdaApplication(ABT inLambda, ABT inArgument)
+ : Base(std::move(inLambda), std::move(inArgument)) {
+ assertExprSort(getLambda());
+ assertExprSort(getArgument());
+ }
+
+ bool operator==(const LambdaApplication& other) const {
+ return getLambda() == other.getLambda() && getArgument() == other.getArgument();
+ }
+
+ const ABT& getLambda() const {
+ return get<0>();
+ }
+
+ const ABT& getArgument() const {
+ return get<1>();
+ }
+};
+
+class FunctionCall final : public OperatorDynamicHomogenous<FunctionCall>,
+ public ExpressionSyntaxSort {
+ using Base = OperatorDynamicHomogenous<FunctionCall>;
+ std::string _name;
+
+public:
+ FunctionCall(std::string inName, ABTVector inArgs)
+ : Base(std::move(inArgs)), _name(std::move(inName)) {
+ for (auto& a : nodes()) {
+ assertExprSort(a);
+ }
+ }
+
+ bool operator==(const FunctionCall& other) const {
+ return _name == other._name && nodes() == other.nodes();
+ }
+
+ auto& name() const {
+ return _name;
+ }
+};
+
+class EvalPath final : public Operator<EvalPath, 2>, public ExpressionSyntaxSort {
+ using Base = Operator<EvalPath, 2>;
+
+public:
+ EvalPath(ABT inPath, ABT inInput) : Base(std::move(inPath), std::move(inInput)) {
+ assertPathSort(getPath());
+ assertExprSort(getInput());
+ }
+
+ bool operator==(const EvalPath& other) const {
+ return getPath() == other.getPath() && getInput() == other.getInput();
+ }
+
+ const ABT& getPath() const {
+ return get<0>();
+ }
+
+ ABT& getPath() {
+ return get<0>();
+ }
+
+ const ABT& getInput() const {
+ return get<1>();
+ }
+};
+
+class EvalFilter final : public Operator<EvalFilter, 2>, public ExpressionSyntaxSort {
+ using Base = Operator<EvalFilter, 2>;
+
+public:
+ EvalFilter(ABT inPath, ABT inInput) : Base(std::move(inPath), std::move(inInput)) {
+ assertPathSort(getPath());
+ assertExprSort(getInput());
+ }
+
+ bool operator==(const EvalFilter& other) const {
+ return getPath() == other.getPath() && getInput() == other.getInput();
+ }
+
+ const ABT& getPath() const {
+ return get<0>();
+ }
+
+ ABT& getPath() {
+ return get<0>();
+ }
+
+ const ABT& getInput() const {
+ return get<1>();
+ }
+};
+
+/**
+ * This class represents a source of values originating from relational nodes.
+ */
+class Source final : public Operator<Source, 0>, public ExpressionSyntaxSort {
+public:
+ bool operator==(const Source& other) const {
+ return true;
+ }
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/syntax/path.h b/src/mongo/db/query/optimizer/syntax/path.h
new file mode 100644
index 00000000000..4f6ec8775b4
--- /dev/null
+++ b/src/mongo/db/query/optimizer/syntax/path.h
@@ -0,0 +1,349 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <ostream>
+#include <unordered_set>
+
+#include "mongo/db/query/optimizer/syntax/syntax.h"
+
+
+namespace mongo::optimizer {
+
+/**
+ * A constant path element - any input value is disregarded and replaced by the result of (constant)
+ * expression.
+ *
+ * It could also be expressed as lambda that ignores its input: \ _ . c
+ */
+class PathConstant final : public Operator<PathConstant, 1>, public PathSyntaxSort {
+ using Base = Operator<PathConstant, 1>;
+
+public:
+ PathConstant(ABT inConstant) : Base(std::move(inConstant)) {
+ assertExprSort(getConstant());
+ }
+
+ bool operator==(const PathConstant& other) const {
+ return getConstant() == other.getConstant();
+ }
+
+ const ABT& getConstant() const {
+ return get<0>();
+ }
+
+ ABT& getConstant() {
+ return get<0>();
+ }
+};
+
+/**
+ * A lambda path element - the expression must be a single argument lambda. The lambda is applied
+ * with the input value.
+ */
+class PathLambda final : public Operator<PathLambda, 1>, public PathSyntaxSort {
+ using Base = Operator<PathLambda, 1>;
+
+public:
+ PathLambda(ABT inLambda) : Base(std::move(inLambda)) {
+ assertExprSort(getLambda());
+ }
+
+ bool operator==(const PathLambda& other) const {
+ return getLambda() == other.getLambda();
+ }
+
+ const ABT& getLambda() const {
+ return get<0>();
+ }
+};
+
+/**
+ * An identity path element - the input is not disturbed at all.
+ *
+ * It could also be expressed as lambda : \ x . x
+ */
+class PathIdentity final : public Operator<PathIdentity, 0>, public PathSyntaxSort {
+public:
+ bool operator==(const PathIdentity& other) const {
+ return true;
+ }
+};
+
+/**
+ * A default path element - If input is Nothing then return the result of expression (assumed to
+ * return non-Nothing) otherwise return the input undisturbed.
+ */
+class PathDefault final : public Operator<PathDefault, 1>, public PathSyntaxSort {
+ using Base = Operator<PathDefault, 1>;
+
+public:
+ PathDefault(ABT inDefault) : Base(std::move(inDefault)) {
+ assertExprSort(getDefault());
+ }
+
+ bool operator==(const PathDefault& other) const {
+ return getDefault() == other.getDefault();
+ }
+
+ const ABT& getDefault() const {
+ return get<0>();
+ }
+};
+
+/**
+ * A comparison path element - the input value is compared to the result of (constant) expression.
+ * The actual semantics (return value) depends on what component is evaluating the paths (i.e.
+ * filter or project).
+ */
+class PathCompare : public Operator<PathCompare, 1>, public PathSyntaxSort {
+ using Base = Operator<PathCompare, 1>;
+
+ Operations _cmp;
+
+public:
+ PathCompare(Operations inCmp, ABT inVal) : Base(std::move(inVal)), _cmp(inCmp) {
+ assertExprSort(getVal());
+ }
+
+ bool operator==(const PathCompare& other) const {
+ return _cmp == other._cmp && getVal() == other.getVal();
+ }
+
+ auto op() const {
+ return _cmp;
+ }
+
+ const ABT& getVal() const {
+ return get<0>();
+ }
+
+ ABT& getVal() {
+ return get<0>();
+ }
+};
+
+/**
+ * A drop path element - drops fields from the input if it is an object otherwise returns it
+ * undisturbed.
+ */
+class PathDrop final : public Operator<PathDrop, 0>, public PathSyntaxSort {
+public:
+ using NameSet = opt::unordered_set<std::string>;
+
+ PathDrop(NameSet inNames) : _names(std::move(inNames)) {}
+
+ bool operator==(const PathDrop& other) const {
+ return _names == other._names;
+ }
+
+ const NameSet& getNames() const {
+ return _names;
+ }
+
+private:
+ const NameSet _names;
+};
+
+/**
+ * A keep path element - keeps fields from the input if it is an object otherwise returns it
+ * undisturbed.
+ */
+class PathKeep final : public Operator<PathKeep, 0>, public PathSyntaxSort {
+public:
+ using NameSet = opt::unordered_set<std::string>;
+
+ PathKeep(NameSet inNames) : _names(std::move(inNames)) {}
+
+ bool operator==(const PathKeep other) const {
+ return _names == other._names;
+ }
+
+ const NameSet& getNames() const {
+ return _names;
+ }
+
+private:
+ const NameSet _names;
+};
+
+/**
+ * Returns input undisturbed if it is an object otherwise return Nothing.
+ */
+class PathObj final : public Operator<PathObj, 0>, public PathSyntaxSort {
+public:
+ bool operator==(const PathObj& other) const {
+ return true;
+ }
+};
+
+/**
+ * Returns input undisturbed if it is an array otherwise return Nothing.
+ */
+class PathArr final : public Operator<PathArr, 0>, public PathSyntaxSort {
+public:
+ bool operator==(const PathArr& other) const {
+ return true;
+ }
+};
+
+/**
+ * A traverse path element - apply the inner path to every element of an array.
+ */
+class PathTraverse final : public Operator<PathTraverse, 1>, public PathSyntaxSort {
+ using Base = Operator<PathTraverse, 1>;
+
+public:
+ PathTraverse(ABT inPath) : Base(std::move(inPath)) {
+ assertPathSort(getPath());
+ }
+
+ bool operator==(const PathTraverse& other) const {
+ return getPath() == other.getPath();
+ }
+
+ const ABT& getPath() const {
+ return get<0>();
+ }
+
+ ABT& getPath() {
+ return get<0>();
+ }
+};
+
+/**
+ * A field path element - apply the inner path to an object field.
+ */
+class PathField final : public Operator<PathField, 1>, public PathSyntaxSort {
+ using Base = Operator<PathField, 1>;
+ std::string _name;
+
+public:
+ PathField(std::string inName, ABT inPath) : Base(std::move(inPath)), _name(std::move(inName)) {
+ assertPathSort(getPath());
+ }
+
+ bool operator==(const PathField& other) const {
+ return _name == other._name && getPath() == other.getPath();
+ }
+
+ auto& name() const {
+ return _name;
+ }
+
+ const ABT& getPath() const {
+ return get<0>();
+ }
+
+ ABT& getPath() {
+ return get<0>();
+ }
+};
+
+/**
+ * A get path element - similar to the path element.
+ */
+class PathGet final : public Operator<PathGet, 1>, public PathSyntaxSort {
+ using Base = Operator<PathGet, 1>;
+ std::string _name;
+
+public:
+ PathGet(std::string inName, ABT inPath) : Base(std::move(inPath)), _name(std::move(inName)) {
+ assertPathSort(getPath());
+ }
+
+ bool operator==(const PathGet& other) const {
+ return _name == other._name && getPath() == other.getPath();
+ }
+
+ auto& name() const {
+ return _name;
+ }
+
+ const ABT& getPath() const {
+ return get<0>();
+ }
+
+ ABT& getPath() {
+ return get<0>();
+ }
+};
+
+/**
+ * A multiplicative composition path element.
+ */
+class PathComposeM final : public Operator<PathComposeM, 2>, public PathSyntaxSort {
+ using Base = Operator<PathComposeM, 2>;
+
+public:
+ PathComposeM(ABT inPath1, ABT inPath2) : Base(std::move(inPath1), std::move(inPath2)) {
+ assertPathSort(getPath1());
+ assertPathSort(getPath2());
+ }
+
+ bool operator==(const PathComposeM& other) const {
+ return getPath1() == other.getPath1() && getPath2() == other.getPath2();
+ }
+
+ const ABT& getPath1() const {
+ return get<0>();
+ }
+
+ const ABT& getPath2() const {
+ return get<1>();
+ }
+};
+
+/**
+ * An additive composition path element.
+ */
+class PathComposeA final : public Operator<PathComposeA, 2>, public PathSyntaxSort {
+ using Base = Operator<PathComposeA, 2>;
+
+public:
+ PathComposeA(ABT inPath1, ABT inPath2) : Base(std::move(inPath1), std::move(inPath2)) {
+ assertPathSort(getPath1());
+ assertPathSort(getPath2());
+ }
+
+ bool operator==(const PathComposeA& other) const {
+ return getPath1() == other.getPath1() && getPath2() == other.getPath2();
+ }
+
+ const ABT& getPath1() const {
+ return get<0>();
+ }
+
+ const ABT& getPath2() const {
+ return get<1>();
+ }
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/syntax/syntax.h b/src/mongo/db/query/optimizer/syntax/syntax.h
new file mode 100644
index 00000000000..cfdd70ce4dc
--- /dev/null
+++ b/src/mongo/db/query/optimizer/syntax/syntax.h
@@ -0,0 +1,263 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <string>
+
+#include "mongo/db/query/optimizer/algebra/operator.h"
+#include "mongo/db/query/optimizer/algebra/polyvalue.h"
+#include "mongo/db/query/optimizer/syntax/syntax_fwd_declare.h"
+#include "mongo/db/query/optimizer/utils/printable_enum.h"
+#include "mongo/util/assert_util.h"
+
+namespace mongo::optimizer {
+
+using ABT = algebra::PolyValue<Blackhole,
+ Constant, // expressions
+ Variable,
+ UnaryOp,
+ BinaryOp,
+ If,
+ Let,
+ LambdaAbstraction,
+ LambdaApplication,
+ FunctionCall,
+ EvalPath,
+ EvalFilter,
+ Source,
+ PathConstant, // path elements
+ PathLambda,
+ PathIdentity,
+ PathDefault,
+ PathCompare,
+ PathDrop,
+ PathKeep,
+ PathObj,
+ PathArr,
+ PathTraverse,
+ PathField,
+ PathGet,
+ PathComposeM,
+ PathComposeA,
+ ScanNode, // nodes
+ PhysicalScanNode,
+ ValueScanNode,
+ CoScanNode,
+ IndexScanNode,
+ SeekNode,
+ MemoLogicalDelegatorNode,
+ MemoPhysicalDelegatorNode,
+ FilterNode,
+ EvaluationNode,
+ SargableNode,
+ RIDIntersectNode,
+ BinaryJoinNode,
+ HashJoinNode,
+ MergeJoinNode,
+ UnionNode,
+ GroupByNode,
+ UnwindNode,
+ UniqueNode,
+ CollationNode,
+ LimitSkipNode,
+ ExchangeNode,
+ RootNode,
+ References, // utilities
+ ExpressionBinder>;
+
+template <typename Derived, size_t Arity>
+using Operator = algebra::OpSpecificArity<ABT, Derived, Arity>;
+
+template <typename Derived, size_t Arity>
+using OperatorDynamic = algebra::OpSpecificDynamicArity<ABT, Derived, Arity>;
+
+template <typename Derived>
+using OperatorDynamicHomogenous = OperatorDynamic<Derived, 0>;
+
+using ABTVector = std::vector<ABT>;
+
+template <typename T, typename... Args>
+inline auto make(Args&&... args) {
+ return ABT::make<T>(std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+inline auto makeSeq(Args&&... args) {
+ ABTVector seq;
+ (seq.emplace_back(std::forward<Args>(args)), ...);
+ return seq;
+}
+
+class ExpressionSyntaxSort {};
+
+class PathSyntaxSort {};
+
+inline void assertExprSort(const ABT& e) {
+ if (!e.is<ExpressionSyntaxSort>()) {
+ uasserted(6624058, "expression syntax sort expected");
+ }
+}
+
+inline void assertPathSort(const ABT& e) {
+ if (!e.is<PathSyntaxSort>()) {
+ uasserted(6624059, "path syntax sort expected");
+ }
+}
+
+inline bool operator!=(const ABT& left, const ABT& right) {
+ return !(left == right);
+}
+
+#define PATHSYNTAX_OPNAMES(F) \
+ /* comparison operations */ \
+ F(Eq) \
+ F(Neq) \
+ F(Gt) \
+ F(Gte) \
+ F(Lt) \
+ F(Lte) \
+ F(Cmp3w) \
+ \
+ /* binary operations */ \
+ F(Add) \
+ F(Sub) \
+ F(Mult) \
+ F(Div) \
+ \
+ /* unary operations */ \
+ F(Neg) \
+ \
+ /* logical operations */ \
+ F(And) \
+ F(Or) \
+ F(Not)
+
+MAKE_PRINTABLE_ENUM(Operations, PATHSYNTAX_OPNAMES);
+MAKE_PRINTABLE_ENUM_STRING_ARRAY(OperationsEnum, Operations, PATHSYNTAX_OPNAMES);
+#undef PATHSYNTAX_OPNAMES
+
+inline constexpr bool isUnaryOp(Operations op) {
+ return op == Operations::Neg || op == Operations::Not;
+}
+
+inline constexpr bool isBinaryOp(Operations op) {
+ return !isUnaryOp(op);
+}
+
+/**
+ * This is a special inert ABT node. It is used by rewriters to preserve structural properties of
+ * nodes during in-place rewriting.
+ */
+class Blackhole final : public Operator<Blackhole, 0> {
+public:
+ bool operator==(const Blackhole& other) const {
+ return true;
+ }
+};
+
+/**
+ * This is a helper structure that represents Node internal references. Some relational nodes
+ * implicitly reference named projections from its children.
+ *
+ * Canonical examples are: GROUP BY "a", ORDER BY "b", etc.
+ *
+ * We want to capture these references. The rule of ABTs says that the ONLY way to reference a named
+ * entity is through the Variable class. The uniformity of the approach makes life much easier for
+ * the optimizer developers.
+ * On the other hand using Variables everywhere makes writing code more verbose, hence this helper.
+ */
+class References final : public OperatorDynamicHomogenous<References> {
+ using Base = OperatorDynamicHomogenous<References>;
+
+public:
+ /*
+ * Construct Variable objects out of provided vector of strings.
+ */
+ References(const std::vector<std::string>& names) : Base(ABTVector{}) {
+ // Construct actual Variable objects from names and make them the children of this object.
+ for (const auto& name : names) {
+ nodes().emplace_back(make<Variable>(name));
+ }
+ }
+
+ /*
+ * Alternatively, construct references out of provided ABTs. This may be useful when the
+ * internal references are more complex then a simple string. We may consider e.g. GROUP BY
+ * (a+b).
+ */
+ References(ABTVector refs) : Base(std::move(refs)) {
+ for (const auto& node : nodes()) {
+ assertExprSort(node);
+ }
+ }
+
+ bool operator==(const References& other) const {
+ return nodes() == other.nodes();
+ }
+};
+
+/**
+ * This class represents a unified way of binding identifiers to expressions. Every ABT node that
+ * introduces a new identifier must use this binder (i.e. all relational nodes adding new
+ * projections and expression nodes adding new local variables).
+ */
+class ExpressionBinder : public OperatorDynamicHomogenous<ExpressionBinder> {
+ using Base = OperatorDynamicHomogenous<ExpressionBinder>;
+ std::vector<std::string> _names;
+
+public:
+ ExpressionBinder(std::string name, ABT expr) : Base(makeSeq(std::move(expr))) {
+ _names.emplace_back(std::move(name));
+ for (const auto& node : nodes()) {
+ assertExprSort(node);
+ }
+ }
+
+ ExpressionBinder(std::vector<std::string> names, ABTVector exprs)
+ : Base(std::move(exprs)), _names(std::move(names)) {
+ for (const auto& node : nodes()) {
+ assertExprSort(node);
+ }
+ }
+
+ bool operator==(const ExpressionBinder& other) const {
+ return _names == other._names && exprs() == other.exprs();
+ }
+
+ const std::vector<std::string>& names() const {
+ return _names;
+ }
+
+ const ABTVector& exprs() const {
+ return nodes();
+ }
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/syntax/syntax_fwd_declare.h b/src/mongo/db/query/optimizer/syntax/syntax_fwd_declare.h
new file mode 100644
index 00000000000..b21781959e4
--- /dev/null
+++ b/src/mongo/db/query/optimizer/syntax/syntax_fwd_declare.h
@@ -0,0 +1,101 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+namespace mongo::optimizer {
+
+/**
+ * Expressions
+ */
+class Blackhole;
+class Constant;
+class Variable;
+class UnaryOp;
+class BinaryOp;
+class If;
+class Let;
+class LambdaAbstraction;
+class LambdaApplication;
+class FunctionCall;
+class EvalPath;
+class EvalFilter;
+class Source;
+
+/**
+ * Path elements
+ */
+class PathConstant;
+class PathLambda;
+class PathIdentity;
+class PathDefault;
+class PathCompare;
+class PathDrop;
+class PathKeep;
+class PathObj;
+class PathArr;
+class PathTraverse;
+class PathField;
+class PathGet;
+class PathComposeM;
+class PathComposeA;
+
+/**
+ * Nodes
+ */
+class ScanNode;
+class PhysicalScanNode;
+class ValueScanNode;
+class CoScanNode;
+class IndexScanNode;
+class SeekNode;
+class MemoLogicalDelegatorNode;
+class MemoPhysicalDelegatorNode;
+class FilterNode;
+class EvaluationNode;
+class SargableNode;
+class RIDIntersectNode;
+class BinaryJoinNode;
+class HashJoinNode;
+class MergeJoinNode;
+class UnionNode;
+class GroupByNode;
+class UnwindNode;
+class UniqueNode;
+class CollationNode;
+class LimitSkipNode;
+class ExchangeNode;
+class RootNode;
+
+/**
+ * Utility
+ */
+class References;
+class ExpressionBinder;
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/abt_hash.cpp b/src/mongo/db/query/optimizer/utils/abt_hash.cpp
new file mode 100644
index 00000000000..9edb830f893
--- /dev/null
+++ b/src/mongo/db/query/optimizer/utils/abt_hash.cpp
@@ -0,0 +1,473 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/utils/abt_hash.h"
+
+#include "mongo/db/query/optimizer/node.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer {
+
+static size_t computeCollationHash(const properties::CollationRequirement& prop) {
+ size_t collationHash = 17;
+ for (const auto& entry : prop.getCollationSpec()) {
+ updateHash(collationHash, std::hash<ProjectionName>()(entry.first));
+ updateHash(collationHash, std::hash<CollationOp>()(entry.second));
+ }
+ return collationHash;
+}
+
+static size_t computeLimitSkipHash(const properties::LimitSkipRequirement& prop) {
+ size_t limitSkipHash = 17;
+ updateHash(limitSkipHash, std::hash<int64_t>()(prop.getLimit()));
+ updateHash(limitSkipHash, std::hash<int64_t>()(prop.getSkip()));
+ return limitSkipHash;
+}
+
+static size_t computePropertyProjectionsHash(const ProjectionNameVector& projections) {
+ size_t resultHash = 17;
+ for (const ProjectionName& projection : projections) {
+ updateHashUnordered(resultHash, std::hash<ProjectionName>()(projection));
+ }
+ return resultHash;
+}
+
+static size_t computeProjectionRequirementHash(const properties::ProjectionRequirement& prop) {
+ return computePropertyProjectionsHash(prop.getProjections().getVector());
+}
+
+static size_t computeDistributionHash(const properties::DistributionRequirement& prop) {
+ size_t resultHash = 17;
+ const auto& distribAndProjections = prop.getDistributionAndProjections();
+ updateHash(resultHash, std::hash<DistributionType>()(distribAndProjections._type));
+ updateHash(resultHash, computePropertyProjectionsHash(distribAndProjections._projectionNames));
+ return resultHash;
+}
+
+static void updateBoundHash(size_t& result, const BoundRequirement& bound) {
+ updateHash(result, std::hash<bool>()(bound.isInclusive()));
+ if (!bound.isInfinite()) {
+ updateHash(result, ABTHashGenerator::generate(bound.getBound()));
+ }
+};
+
+template <class T>
+class IntervalHasher {
+public:
+ size_t computeHash(const IntervalRequirement& req) {
+ size_t result = 17;
+ updateBoundHash(result, req.getLowBound());
+ updateBoundHash(result, req.getHighBound());
+ return 17;
+ }
+
+ size_t computeHash(const MultiKeyIntervalRequirement& req) {
+ size_t result = 19;
+ for (const auto& interval : req) {
+ updateHash(result, computeHash(interval));
+ }
+ return result;
+ }
+
+ size_t transport(const typename T::Atom& node) {
+ return computeHash(node.getExpr());
+ }
+
+ size_t transport(const typename T::Conjunction& node, const std::vector<size_t> childResults) {
+ size_t result = 31;
+ for (const size_t childResult : childResults) {
+ updateHash(result, childResult);
+ }
+ return result;
+ }
+
+ size_t transport(const typename T::Disjunction& node, const std::vector<size_t> childResults) {
+ size_t result = 29;
+ for (const size_t childResult : childResults) {
+ updateHash(result, childResult);
+ }
+ return result;
+ }
+
+ size_t compute(const typename T::Node& intervals) {
+ return algebra::transport<false>(intervals, *this);
+ }
+};
+
+static size_t computePartialSchemaReqHash(const PartialSchemaRequirements& reqMap) {
+ size_t result = 17;
+
+ IntervalHasher<IntervalReqExpr> intervalHasher;
+ for (const auto& [key, req] : reqMap) {
+ updateHash(result, std::hash<ProjectionName>()(key._projectionName));
+ updateHash(result, ABTHashGenerator::generate(key._path));
+ updateHash(result, std::hash<ProjectionName>()(req.getBoundProjectionName()));
+ updateHash(result, intervalHasher.compute(req.getIntervals()));
+ }
+ return result;
+}
+
+static size_t computeCandidateIndexMapHash(const CandidateIndexMap& map) {
+ size_t result = 17;
+
+ IntervalHasher<MultiKeyIntervalReqExpr> intervalHasher;
+ for (const auto& [indexDefName, candidateIndexEntry] : map) {
+ updateHash(result, std::hash<std::string>()(indexDefName));
+
+ {
+ const auto& fieldProjectionMap = candidateIndexEntry._fieldProjectionMap;
+ updateHash(result, std::hash<ProjectionName>()(fieldProjectionMap._ridProjection));
+ updateHash(result, std::hash<ProjectionName>()(fieldProjectionMap._rootProjection));
+ for (const auto& fieldProjectionMapEntry : fieldProjectionMap._fieldProjections) {
+ updateHash(result, std::hash<FieldNameType>()(fieldProjectionMapEntry.first));
+ updateHash(result, std::hash<ProjectionName>()(fieldProjectionMapEntry.second));
+ }
+ }
+ updateHash(result, intervalHasher.compute(candidateIndexEntry._intervals));
+ }
+
+ return result;
+}
+
+/**
+ * Hasher for ABT nodes. Used in conjunction with memo.
+ */
+class ABTHashTransporter {
+public:
+ /**
+ * Nodes
+ */
+ template <typename T, typename... Ts>
+ size_t transport(const T& /*node*/, Ts&&...) {
+ // Physical nodes do not currently need to implement hash.
+ static_assert(!canBeLogicalNode<T>(), "Logical node must implement its hash.");
+ uasserted(6624142, "must implement custom hash");
+ }
+
+ size_t transport(const References& references, std::vector<size_t> inResults) {
+ return computeHashSeq<1>(computeVectorHash(inResults));
+ }
+
+ size_t transport(const ExpressionBinder& binders, std::vector<size_t> inResults) {
+ return computeHashSeq<2>(computeVectorHash(binders.names()), computeVectorHash(inResults));
+ }
+
+ size_t transport(const ScanNode& node, size_t bindResult) {
+ return computeHashSeq<3>(std::hash<std::string>()(node.getScanDefName()), bindResult);
+ }
+
+ size_t transport(const ValueScanNode& node, size_t bindResult) {
+ return computeHashSeq<46>(std::hash<size_t>()(node.getArraySize()),
+ ABTHashGenerator::generate(node.getValueArray()),
+ bindResult);
+ }
+
+ size_t transport(const MemoLogicalDelegatorNode& node) {
+ return computeHashSeq<4>(std::hash<GroupIdType>()(node.getGroupId()));
+ }
+
+ size_t transport(const FilterNode& node, size_t childResult, size_t filterResult) {
+ return computeHashSeq<5>(filterResult, childResult);
+ }
+
+ size_t transport(const EvaluationNode& node, size_t childResult, size_t projectionResult) {
+ return computeHashSeq<6>(projectionResult, childResult);
+ }
+
+ size_t transport(const SargableNode& node,
+ size_t childResult,
+ size_t /*bindResult*/,
+ size_t /*refResult*/) {
+ return computeHashSeq<44>(computePartialSchemaReqHash(node.getReqMap()),
+ computeCandidateIndexMapHash(node.getCandidateIndexMap()),
+ std::hash<IndexReqTarget>()(node.getTarget()),
+ childResult);
+ }
+
+ size_t transport(const RIDIntersectNode& node,
+ size_t leftChildResult,
+ size_t rightChildResult) {
+ // Specifically always including children.
+ return computeHashSeq<45>(std::hash<ProjectionName>()(node.getScanProjectionName()),
+ std::hash<bool>()(node.hasLeftIntervals()),
+ std::hash<bool>()(node.hasRightIntervals()),
+ leftChildResult,
+ rightChildResult);
+ }
+
+ size_t transport(const BinaryJoinNode& node,
+ size_t leftChildResult,
+ size_t rightChildResult,
+ size_t filterResult) {
+ // Specifically always including children.
+ return computeHashSeq<7>(filterResult, leftChildResult, rightChildResult);
+ }
+
+ size_t transport(const UnionNode& node,
+ std::vector<size_t> childResults,
+ size_t bindResult,
+ size_t refsResult) {
+ // Specifically always including children.
+ return computeHashSeq<9>(bindResult, refsResult, computeVectorHash(childResults));
+ }
+
+ size_t transport(const GroupByNode& node,
+ size_t childResult,
+ size_t bindAggResult,
+ size_t refsAggResult,
+ size_t bindGbResult,
+ size_t refsGbResult) {
+ return computeHashSeq<10>(bindAggResult,
+ refsAggResult,
+ bindGbResult,
+ refsGbResult,
+ std::hash<GroupNodeType>()(node.getType()),
+ childResult);
+ }
+
+ size_t transport(const UnwindNode& node,
+ size_t childResult,
+ size_t bindResult,
+ size_t refsResult) {
+ return computeHashSeq<11>(
+ std::hash<bool>()(node.getRetainNonArrays()), bindResult, refsResult, childResult);
+ }
+
+ size_t transport(const CollationNode& node, size_t childResult, size_t /*refsResult*/) {
+ return computeHashSeq<13>(computeCollationHash(node.getProperty()), childResult);
+ }
+
+ size_t transport(const LimitSkipNode& node, size_t childResult) {
+ return computeHashSeq<14>(computeLimitSkipHash(node.getProperty()), childResult);
+ }
+
+ size_t transport(const ExchangeNode& node, size_t childResult, size_t /*refsResult*/) {
+ return computeHashSeq<43>(computeDistributionHash(node.getProperty()), childResult);
+ }
+
+ size_t transport(const RootNode& node, size_t childResult, size_t /*refsResult*/) {
+ return computeHashSeq<15>(computeProjectionRequirementHash(node.getProperty()),
+ childResult);
+ }
+
+ /**
+ * Expressions
+ */
+ size_t transport(const Blackhole& expr) {
+ return computeHashSeq<16>();
+ }
+
+ size_t transport(const Constant& expr) {
+ auto [tag, val] = expr.get();
+ return computeHashSeq<17>(sbe::value::hashValue(tag, val));
+ }
+
+ size_t transport(const Variable& expr) {
+ return computeHashSeq<18>(std::hash<std::string>()(expr.name()));
+ }
+
+ size_t transport(const UnaryOp& expr, size_t inResult) {
+ return computeHashSeq<19>(std::hash<Operations>()(expr.op()), inResult);
+ }
+
+ size_t transport(const BinaryOp& expr, size_t leftResult, size_t rightResult) {
+ return computeHashSeq<20>(std::hash<Operations>()(expr.op()), leftResult, rightResult);
+ }
+
+ size_t transport(const If& expr, size_t condResult, size_t thenResult, size_t elseResult) {
+ return computeHashSeq<21>(condResult, thenResult, elseResult);
+ }
+
+ size_t transport(const Let& expr, size_t bindResult, size_t exprResult) {
+ return computeHashSeq<22>(std::hash<std::string>()(expr.varName()), bindResult, exprResult);
+ }
+
+ size_t transport(const LambdaAbstraction& expr, size_t inResult) {
+ return computeHashSeq<23>(std::hash<std::string>()(expr.varName()), inResult);
+ }
+
+ size_t transport(const LambdaApplication& expr, size_t lambdaResult, size_t argumentResult) {
+ return computeHashSeq<24>(lambdaResult, argumentResult);
+ }
+
+ size_t transport(const FunctionCall& expr, std::vector<size_t> argResults) {
+ return computeHashSeq<25>(std::hash<std::string>()(expr.name()),
+ computeVectorHash(argResults));
+ }
+
+ size_t transport(const EvalPath& expr, size_t pathResult, size_t inputResult) {
+ return computeHashSeq<26>(pathResult, inputResult);
+ }
+
+ size_t transport(const EvalFilter& expr, size_t pathResult, size_t inputResult) {
+ return computeHashSeq<27>(pathResult, inputResult);
+ }
+
+ size_t transport(const Source& expr) {
+ return computeHashSeq<28>();
+ }
+
+ /**
+ * Paths
+ */
+ size_t transport(const PathConstant& path, size_t inResult) {
+ return computeHashSeq<29>(inResult);
+ }
+
+ size_t transport(const PathLambda& path, size_t inResult) {
+ return computeHashSeq<30>(inResult);
+ }
+
+ size_t transport(const PathIdentity& path) {
+ return computeHashSeq<31>();
+ }
+
+ size_t transport(const PathDefault& path, size_t inResult) {
+ return computeHashSeq<32>(inResult);
+ }
+
+ size_t transport(const PathCompare& path, size_t valueResult) {
+ return computeHashSeq<33>(std::hash<Operations>()(path.op()), valueResult);
+ }
+
+ size_t transport(const PathDrop& path) {
+ size_t namesHash = 17;
+ for (const std::string& name : path.getNames()) {
+ updateHash(namesHash, std::hash<std::string>()(name));
+ }
+ return computeHashSeq<34>(namesHash);
+ }
+
+ size_t transport(const PathKeep& path) {
+ size_t namesHash = 17;
+ for (const std::string& name : path.getNames()) {
+ updateHash(namesHash, std::hash<std::string>()(name));
+ }
+ return computeHashSeq<35>(namesHash);
+ }
+
+ size_t transport(const PathObj& path) {
+ return computeHashSeq<36>();
+ }
+
+ size_t transport(const PathArr& path) {
+ return computeHashSeq<37>();
+ }
+
+ size_t transport(const PathTraverse& path, size_t inResult) {
+ return computeHashSeq<38>(inResult);
+ }
+
+ size_t transport(const PathField& path, size_t inResult) {
+ return computeHashSeq<39>(std::hash<std::string>()(path.name()), inResult);
+ }
+
+ size_t transport(const PathGet& path, size_t inResult) {
+ return computeHashSeq<40>(std::hash<std::string>()(path.name()), inResult);
+ }
+
+ size_t transport(const PathComposeM& path, size_t leftResult, size_t rightResult) {
+ return computeHashSeq<41>(leftResult, rightResult);
+ }
+
+ size_t transport(const PathComposeA& path, size_t leftResult, size_t rightResult) {
+ return computeHashSeq<42>(leftResult, rightResult);
+ }
+
+ size_t generate(const ABT& node) {
+ return algebra::transport<false>(node, *this);
+ }
+
+ size_t generate(const ABT::reference_type& nodeRef) {
+ return algebra::transport<false>(nodeRef, *this);
+ }
+};
+
+size_t ABTHashGenerator::generate(const ABT& node) {
+ ABTHashTransporter gen;
+ return gen.generate(node);
+}
+
+size_t ABTHashGenerator::generate(const ABT::reference_type& nodeRef) {
+ ABTHashTransporter gen;
+ return gen.generate(nodeRef);
+}
+
+class PhysPropsHasher {
+public:
+ size_t operator()(const properties::PhysProperty&,
+ const properties::CollationRequirement& prop) {
+ return computeHashSeq<1>(computeCollationHash(prop));
+ }
+
+ size_t operator()(const properties::PhysProperty&,
+ const properties::LimitSkipRequirement& prop) {
+ return computeHashSeq<2>(computeLimitSkipHash(prop));
+ }
+
+ size_t operator()(const properties::PhysProperty&,
+ const properties::ProjectionRequirement& prop) {
+ return computeHashSeq<3>(computeProjectionRequirementHash(prop));
+ }
+
+ size_t operator()(const properties::PhysProperty&,
+ const properties::DistributionRequirement& prop) {
+ return computeHashSeq<4>(computeDistributionHash(prop));
+ }
+
+ size_t operator()(const properties::PhysProperty&,
+ const properties::IndexingRequirement& prop) {
+ return computeHashSeq<5>(std::hash<IndexReqTarget>()(prop.getIndexReqTarget()),
+ std::hash<bool>()(prop.getNeedsRID()),
+ std::hash<bool>()(prop.getDedupRID()));
+ }
+
+ size_t operator()(const properties::PhysProperty&, const properties::RepetitionEstimate& prop) {
+ return computeHashSeq<6>(std::hash<CEType>()(prop.getEstimate()));
+ }
+
+ size_t operator()(const properties::PhysProperty&, const properties::LimitEstimate& prop) {
+ return computeHashSeq<7>(std::hash<CEType>()(prop.getEstimate()));
+ }
+
+ static size_t computeHash(const properties::PhysProps& props) {
+ PhysPropsHasher visitor;
+ size_t result = 17;
+ for (const auto& prop : props) {
+ updateHashUnordered(result, prop.second.visit(visitor));
+ }
+ return result;
+ }
+};
+
+size_t ABTHashGenerator::generateForPhysProps(const properties::PhysProps& props) {
+ return PhysPropsHasher::computeHash(props);
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/abt_hash.h b/src/mongo/db/query/optimizer/utils/abt_hash.h
new file mode 100644
index 00000000000..1f097bed3a1
--- /dev/null
+++ b/src/mongo/db/query/optimizer/utils/abt_hash.h
@@ -0,0 +1,45 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/props.h"
+#include "mongo/db/query/optimizer/syntax/syntax.h"
+
+namespace mongo::optimizer {
+
+class ABTHashGenerator {
+public:
+ static size_t generate(const ABT& node);
+ static size_t generate(const ABT::reference_type& nodeRef);
+
+ static size_t generateForPhysProps(const properties::PhysProps& props);
+};
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/interval_utils.cpp b/src/mongo/db/query/optimizer/utils/interval_utils.cpp
new file mode 100644
index 00000000000..5a0f11e3299
--- /dev/null
+++ b/src/mongo/db/query/optimizer/utils/interval_utils.cpp
@@ -0,0 +1,330 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/utils/interval_utils.h"
+
+#include "mongo/db/exec/sbe/values/value.h"
+#include "mongo/db/query/optimizer/node.h"
+#include "mongo/db/query/optimizer/rewrites/const_eval.h"
+
+
+namespace mongo::optimizer {
+
+void combineIntervalsDNF(const bool intersect,
+ IntervalReqExpr::Node& target,
+ const IntervalReqExpr::Node& source) {
+ if (target == source) {
+ // Intervals are the same. Leave target unchanged.
+ return;
+ }
+
+ if (isIntervalReqFullyOpenDNF(target)) {
+ // Intersecting with fully open interval is redundant.
+ // Unioning with fully open interval results in a fully-open interval.
+ if (intersect) {
+ target = source;
+ }
+ return;
+ }
+
+ if (isIntervalReqFullyOpenDNF(source)) {
+ // Intersecting with fully open interval is redundant.
+ // Unioning with fully open interval results in a fully-open interval.
+ if (!intersect) {
+ target = source;
+ }
+ return;
+ }
+
+ IntervalReqExpr::NodeVector newDisjunction;
+ // Integrate both compound bounds.
+ if (intersect) {
+ // Intersection is analogous to polynomial multiplication. Using '.' to denote intersection
+ // and '+' to denote union. (a.b + c.d) . (e+f) = a.b.e + c.d.e + a.b.f + c.d.f
+ // TODO: in certain cases we can simplify further. For example if we only have scalars, we
+ // can simplify (-inf, 10) ^ (5, +inf) to (5, 10), but this does not work with arrays.
+
+ for (const auto& sourceConjunction : source.cast<IntervalReqExpr::Disjunction>()->nodes()) {
+ const auto& sourceConjunctionIntervals =
+ sourceConjunction.cast<IntervalReqExpr::Conjunction>()->nodes();
+ for (const auto& targetConjunction :
+ target.cast<IntervalReqExpr::Disjunction>()->nodes()) {
+ // TODO: handle case with targetConjunct fully open
+ // TODO: handle case with targetConjunct half-open and sourceConjuct equality.
+ // TODO: handle case with both targetConjunct and sourceConjuct equalities
+ // (different consts).
+
+ auto newConjunctionIntervals =
+ targetConjunction.cast<IntervalReqExpr::Conjunction>()->nodes();
+ std::copy(sourceConjunctionIntervals.cbegin(),
+ sourceConjunctionIntervals.cend(),
+ std::back_inserter(newConjunctionIntervals));
+ newDisjunction.emplace_back(IntervalReqExpr::make<IntervalReqExpr::Conjunction>(
+ std::move(newConjunctionIntervals)));
+ }
+ }
+ } else {
+ // Unioning is analogous to polynomial addition.
+ // (a.b + c.d) + (e+f) = a.b + c.d + e + f
+ newDisjunction = target.cast<IntervalReqExpr::Disjunction>()->nodes();
+ for (const auto& sourceConjunction : source.cast<IntervalReqExpr::Disjunction>()->nodes()) {
+ newDisjunction.push_back(sourceConjunction);
+ }
+ }
+ target = IntervalReqExpr::make<IntervalReqExpr::Disjunction>(std::move(newDisjunction));
+}
+
+std::vector<IntervalRequirement> intersectIntervals(const IntervalRequirement& i1,
+ const IntervalRequirement& i2) {
+ // Handle trivial cases of intersection.
+ if (i1.isFullyOpen()) {
+ return {i2};
+ }
+ if (i2.isFullyOpen()) {
+ return {i1};
+ }
+
+ const ABT low1 =
+ i1.getLowBound().isInfinite() ? Constant::minKey() : i1.getLowBound().getBound();
+ const ABT high1 =
+ i1.getHighBound().isInfinite() ? Constant::maxKey() : i1.getHighBound().getBound();
+ const ABT low2 =
+ i2.getLowBound().isInfinite() ? Constant::minKey() : i2.getLowBound().getBound();
+ const ABT high2 =
+ i2.getHighBound().isInfinite() ? Constant::maxKey() : i2.getHighBound().getBound();
+
+ const auto foldFn = [](ABT expr) {
+ // Performs constant folding.
+ VariableEnvironment env = VariableEnvironment::build(expr);
+ ConstEval instance(env);
+ instance.optimize(expr);
+ return expr;
+ };
+ const auto minMaxFn = [](const Operations op, const ABT& v1, const ABT& v2) {
+ // Encodes max(v1, v2).
+ return make<If>(make<BinaryOp>(op, v1, v2), v1, v2);
+ };
+ const auto minMaxFn1 = [](const Operations op, const ABT& v1, const ABT& v2, const ABT& v3) {
+ // Encodes v1 op v2 ? v3 : v2
+ return make<If>(make<BinaryOp>(op, v1, v2), v3, v2);
+ };
+
+ // In the simplest case our bound is (max(low1, low2), min(high1, high2)) if none of the bounds
+ // are inclusive.
+ const ABT maxLow = foldFn(minMaxFn(Operations::Gte, low1, low2));
+ const ABT minHigh = foldFn(minMaxFn(Operations::Lte, high1, high2));
+ if (foldFn(make<BinaryOp>(Operations::Gt, maxLow, minHigh)) == Constant::boolean(true)) {
+ // Low bound is greater than high bound.
+ return {};
+ }
+
+ const bool low1Inc = i1.getLowBound().isInclusive();
+ const bool high1Inc = i1.getHighBound().isInclusive();
+ const bool low2Inc = i2.getLowBound().isInclusive();
+ const bool high2Inc = i2.getHighBound().isInclusive();
+
+ // We form a "main" result interval which is closed on any side with "agreement" between the two
+ // intervals. For example [low1, high1] ^ [low2, high2) -> [max(low1, low2), min(high1, high2))
+ BoundRequirement lowBoundMain = (maxLow == Constant::minKey())
+ ? BoundRequirement::makeInfinite()
+ : BoundRequirement{low1Inc && low2Inc, maxLow};
+ BoundRequirement highBoundMain = (minHigh == Constant::maxKey())
+ ? BoundRequirement::makeInfinite()
+ : BoundRequirement{high1Inc && high2Inc, minHigh};
+
+ const bool boundsEqual =
+ foldFn(make<BinaryOp>(Operations::Eq, maxLow, minHigh)) == Constant::boolean(true);
+ if (boundsEqual) {
+ if (low1Inc && high1Inc && low2Inc && high2Inc) {
+ // Point interval.
+ return {{std::move(lowBoundMain), std::move(highBoundMain)}};
+ }
+ if ((!low1Inc && !low2Inc) || (!high1Inc && !high2Inc)) {
+ // Fully open on both sides.
+ return {};
+ }
+ }
+ if (low1Inc == low2Inc && high1Inc == high2Inc) {
+ // Inclusion matches on both sides.
+ return {{std::move(lowBoundMain), std::move(highBoundMain)}};
+ }
+
+ // At this point we have intervals without inclusion agreement, for example
+ // [low1, high1) ^ (low2, high2]. We have the main result which in this case is the open
+ // (max(low1, low2), min(high1, high2)). Then we add an extra closed interval for each side with
+ // disagreement. For example for the lower sides we add: [low2 >= low1 ? MaxKey : low1,
+ // min(max(low1, low2), min(high1, high2)] This is a closed interval which would reduce to
+ // [max(low1, low2), max(low1, low2)] if low1 < low2. If low2 >= low1 the interval reduces to an
+ // empty one [MaxKey, min(max(low1, low2), min(high1, high2)] which will return no results from
+ // an index scan. We do not know that in general if we do not have constants (we cannot fold).
+ //
+ // If we can fold the extra interval, we exploit the fact that (max(low1, low2),
+ // min(high1, high2)) U [max(low1, low2), max(low1, low2)] is [max(low1, low2), min(high1,
+ // high2)) (observe left side is now closed). Then we create a similar auxiliary interval for
+ // the right side if there is disagreement on the inclusion. Finally, we attempt to fold both
+ // intervals. Should we conclude definitively that they are point intervals, we update the
+ // inclusion of the main interval for the respective side.
+
+ std::vector<IntervalRequirement> result;
+ const auto addAuxInterval = [&](ABT low, ABT high, BoundRequirement& bound) {
+ IntervalRequirement interval{{true, low}, {true, high}};
+
+ const ABT comparison = foldFn(make<BinaryOp>(Operations::Lte, low, high));
+ if (comparison == Constant::boolean(true)) {
+ if (interval.isEquality()) {
+ // We can determine the two bounds are equal.
+ bound.setInclusive(true);
+ } else {
+ result.push_back(std::move(interval));
+ }
+ } else if (!comparison.is<Constant>()) {
+ // We cannot determine statically how the two bounds compare.
+ result.push_back(std::move(interval));
+ }
+ };
+
+ if (low1Inc != low2Inc) {
+ const ABT low = foldFn(minMaxFn1(
+ Operations::Gte, low1Inc ? low2 : low1, low1Inc ? low1 : low2, Constant::maxKey()));
+ const ABT high = foldFn(minMaxFn(Operations::Lte, maxLow, minHigh));
+ addAuxInterval(std::move(low), std::move(high), lowBoundMain);
+ }
+
+ if (high1Inc != high2Inc) {
+ const ABT low = foldFn(minMaxFn(Operations::Gte, maxLow, minHigh));
+ const ABT high = foldFn(minMaxFn1(Operations::Lte,
+ high1Inc ? high2 : high1,
+ high1Inc ? high1 : high2,
+ Constant::minKey()));
+ addAuxInterval(std::move(low), std::move(high), highBoundMain);
+ }
+
+ if (!boundsEqual || (lowBoundMain.isInclusive() && highBoundMain.isInclusive())) {
+ // We add the main interval to the result as long as it is a valid point interval, or the
+ // bounds are not equal.
+ result.emplace_back(std::move(lowBoundMain), std::move(highBoundMain));
+ }
+ return result;
+}
+
+boost::optional<IntervalReqExpr::Node> intersectDNFIntervals(
+ const IntervalReqExpr::Node& intervalDNF) {
+ IntervalReqExpr::NodeVector disjuncts;
+
+ for (const auto& disjunct : intervalDNF.cast<IntervalReqExpr::Disjunction>()->nodes()) {
+ const auto& conjuncts = disjunct.cast<IntervalReqExpr::Conjunction>()->nodes();
+ uassert(6624149, "Empty disjunct in interval DNF.", !conjuncts.empty());
+
+ std::vector<IntervalRequirement> intersectedIntervalDisjunction;
+ bool isEmpty = false;
+ bool isFirst = true;
+
+ for (const auto& conjunct : conjuncts) {
+ const auto& interval = conjunct.cast<IntervalReqExpr::Atom>()->getExpr();
+ if (isFirst) {
+ isFirst = false;
+ intersectedIntervalDisjunction = {interval};
+ } else {
+ std::vector<IntervalRequirement> newResult;
+ for (const auto& intersectedInterval : intersectedIntervalDisjunction) {
+ auto intersectionResult = intersectIntervals(intersectedInterval, interval);
+ newResult.insert(
+ newResult.end(), intersectionResult.cbegin(), intersectionResult.cend());
+ }
+ if (newResult.empty()) {
+ // The intersection is empty, there is no need to process the remaining
+ // conjuncts
+ isEmpty = true;
+ break;
+ }
+ std::swap(intersectedIntervalDisjunction, newResult);
+ }
+ }
+ if (isEmpty) {
+ continue; // The whole conjunct is false (empty interval), skip it.
+ }
+
+ for (const auto& interval : intersectedIntervalDisjunction) {
+ auto conjunction =
+ IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::makeSeq(
+ IntervalReqExpr::make<IntervalReqExpr::Atom>(std::move(interval))));
+ disjuncts.emplace_back(conjunction);
+ }
+ }
+
+ if (disjuncts.empty()) {
+ return {};
+ }
+ return IntervalReqExpr::make<IntervalReqExpr::Disjunction>(std::move(disjuncts));
+}
+
+bool combineMultiKeyIntervalsDNF(MultiKeyIntervalReqExpr::Node& targetIntervals,
+ const IntervalReqExpr::Node& sourceIntervals) {
+ MultiKeyIntervalReqExpr::NodeVector newDisjunction;
+
+ for (const auto& sourceConjunction :
+ sourceIntervals.cast<IntervalReqExpr::Disjunction>()->nodes()) {
+ for (const auto& targetConjunction :
+ targetIntervals.cast<MultiKeyIntervalReqExpr::Disjunction>()->nodes()) {
+ MultiKeyIntervalReqExpr::NodeVector newConjunction;
+
+ for (const auto& sourceConjunct :
+ sourceConjunction.cast<IntervalReqExpr::Conjunction>()->nodes()) {
+ const auto& sourceInterval =
+ sourceConjunct.cast<IntervalReqExpr::Atom>()->getExpr();
+ for (const auto& targetConjunct :
+ targetConjunction.cast<MultiKeyIntervalReqExpr::Conjunction>()->nodes()) {
+ const auto& targetInterval =
+ targetConjunct.cast<MultiKeyIntervalReqExpr::Atom>()->getExpr();
+ if (!targetInterval.empty() && !targetInterval.back().isEquality() &&
+ !sourceInterval.isFullyOpen()) {
+ // We do not have an equality prefix. Reject.
+ return {};
+ }
+
+ auto newInterval = targetInterval;
+ newInterval.push_back(sourceInterval);
+ newConjunction.emplace_back(
+ MultiKeyIntervalReqExpr::make<MultiKeyIntervalReqExpr::Atom>(
+ std::move(newInterval)));
+ }
+ }
+
+ newDisjunction.emplace_back(
+ MultiKeyIntervalReqExpr::make<MultiKeyIntervalReqExpr::Conjunction>(
+ std::move(newConjunction)));
+ }
+ }
+
+ targetIntervals = MultiKeyIntervalReqExpr::make<MultiKeyIntervalReqExpr::Disjunction>(
+ std::move(newDisjunction));
+ return true;
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/interval_utils.h b/src/mongo/db/query/optimizer/utils/interval_utils.h
new file mode 100644
index 00000000000..4186e0aa276
--- /dev/null
+++ b/src/mongo/db/query/optimizer/utils/interval_utils.h
@@ -0,0 +1,68 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/index_bounds.h"
+
+namespace mongo::optimizer {
+
+/**
+ * Intersects or unions two intervals without simplification which might depend on multi-keyness.
+ * Currently assumes intervals are in DNF.
+ * TODO: handle generic interval expressions (not necessarily DNF).
+ */
+void combineIntervalsDNF(bool intersect,
+ IntervalReqExpr::Node& target,
+ const IntervalReqExpr::Node& source);
+
+/**
+ * Intersect all intervals within each conjunction of intervals in a disjunction of intervals.
+ * Notice that all intervals reference the same path (which is an index field).
+ * Return a DNF of the intersected intervals, where there is at most one interval inside each
+ * conjunct. If the resulting interval is empty, return boost::none.
+ * The intervals themselves can contain Constants, Variables, or arbitrary arithmetic expressions.
+ * TODO: handle generic interval expressions (not necessarily DNF).
+ */
+boost::optional<IntervalReqExpr::Node> intersectDNFIntervals(
+ const IntervalReqExpr::Node& intervalDNF);
+
+/**
+ * Combines a source interval over a single path with a target multi-component interval. The
+ * multi-component interval is extended to contain an extra field. The resulting multi-component
+ * interval defined the boundaries over the index component used by the index access execution
+ * operator. If we fail to combine, the target multi-key interval is left unchanged.
+ * Currently we only support a single "equality prefix": 0+ equalities followed by at most
+ * inequality, and trailing open intervals.
+ * TODO: support Recursive Index Navigation.
+ */
+bool combineMultiKeyIntervalsDNF(MultiKeyIntervalReqExpr::Node& targetIntervals,
+ const IntervalReqExpr::Node& sourceIntervals);
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/memo_utils.cpp b/src/mongo/db/query/optimizer/utils/memo_utils.cpp
new file mode 100644
index 00000000000..9dfa3e595a2
--- /dev/null
+++ b/src/mongo/db/query/optimizer/utils/memo_utils.cpp
@@ -0,0 +1,176 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/utils/memo_utils.h"
+
+#include "mongo/db/query/optimizer/cascades/memo.h"
+
+
+namespace mongo::optimizer {
+
+ABT wrapConstFilter(ABT node) {
+ return make<FilterNode>(Constant::boolean(true), std::move(node));
+}
+
+ABT unwrapConstFilter(ABT node) {
+ if (auto nodePtr = node.cast<FilterNode>();
+ nodePtr != nullptr && nodePtr->getFilter() == Constant::boolean(true)) {
+ return nodePtr->getChild();
+ }
+ return node;
+}
+
+class MemoLatestPlanExtractor {
+public:
+ explicit MemoLatestPlanExtractor(const cascades::Memo& memo) : _memo(memo) {}
+
+ /**
+ * Logical delegator node.
+ */
+ void transport(ABT& n,
+ const MemoLogicalDelegatorNode& node,
+ opt::unordered_set<GroupIdType>& visitedGroups) {
+ n = extractLatest(node.getGroupId(), visitedGroups);
+ }
+
+ /**
+ * Other ABT types.
+ */
+ template <typename T, typename... Ts>
+ void transport(ABT& /*n*/,
+ const T& /*node*/,
+ opt::unordered_set<GroupIdType>& visitedGroups,
+ Ts&&...) {
+ // noop
+ }
+
+ ABT extractLatest(const GroupIdType groupId, opt::unordered_set<GroupIdType>& visitedGroups) {
+ const cascades::Group& group = _memo.getGroup(groupId);
+ if (!visitedGroups.insert(groupId).second) {
+ const GroupIdType scanGroupId =
+ properties::getPropertyConst<properties::IndexingAvailability>(
+ group._logicalProperties)
+ .getScanGroupId();
+ uassert(
+ 6624357, "Visited the same non-scan group more than once", groupId == scanGroupId);
+ }
+
+ ABT rootNode = group._logicalNodes.getVector().back();
+ algebra::transport<true>(rootNode, *this, visitedGroups);
+ return rootNode;
+ }
+
+private:
+ const cascades::Memo& _memo;
+};
+
+ABT extractLatestPlan(const cascades::Memo& memo, const GroupIdType rootGroupId) {
+ MemoLatestPlanExtractor extractor(memo);
+ opt::unordered_set<GroupIdType> visitedGroups;
+ return extractor.extractLatest(rootGroupId, visitedGroups);
+}
+
+class MemoPhysicalPlanExtractor {
+public:
+ explicit MemoPhysicalPlanExtractor(const cascades::Memo& memo,
+ const Metadata& metadata,
+ NodeToGroupPropsMap& nodeToGroupPropsMap)
+ : _memo(memo),
+ _metadata(metadata),
+ _nodeToGroupPropsMap(nodeToGroupPropsMap),
+ _planNodeId(0) {}
+
+ /**
+ * Physical delegator node.
+ */
+ void transport(ABT& n, const MemoPhysicalDelegatorNode& node, const MemoPhysicalNodeId /*id*/) {
+ n = extract(node.getNodeId());
+ }
+
+ /**
+ * Other ABT types.
+ */
+ template <typename T, typename... Ts>
+ void transport(ABT& /*n*/, const T& node, MemoPhysicalNodeId id, Ts&&...) {
+ if constexpr (std::is_base_of_v<Node, T>) {
+ using namespace properties;
+
+ const cascades::Group& group = _memo.getGroup(id._groupId);
+ const auto& physicalResult = group._physicalNodes.at(id._index);
+ const auto& nodeInfo = *physicalResult._nodeInfo;
+
+ LogicalProps logicalProps = group._logicalProperties;
+ PhysProps physProps = physicalResult._physProps;
+ if (!_metadata.isParallelExecution()) {
+ // Do not display availability and requirement if under centralized setting.
+ removeProperty<DistributionAvailability>(logicalProps);
+ removeProperty<DistributionRequirement>(physProps);
+ }
+
+ _nodeToGroupPropsMap.emplace(&node,
+ NodeProps{_planNodeId++,
+ id,
+ std::move(logicalProps),
+ std::move(physProps),
+ nodeInfo._cost,
+ nodeInfo._localCost,
+ nodeInfo._adjustedCE});
+ }
+ }
+
+ ABT extract(const MemoPhysicalNodeId nodeId) {
+ const auto& result = _memo.getGroup(nodeId._groupId)._physicalNodes.at(nodeId._index);
+ uassert(6624143,
+ "Physical delegator must be pointing to an optimized result.",
+ result._nodeInfo.has_value());
+ ABT node = result._nodeInfo->_node;
+
+ algebra::transport<true>(node, *this, nodeId);
+ return node;
+ }
+
+private:
+ // We don't own this.
+ const cascades::Memo& _memo;
+ const Metadata& _metadata;
+ NodeToGroupPropsMap& _nodeToGroupPropsMap;
+
+ int32_t _planNodeId;
+};
+
+std::pair<ABT, NodeToGroupPropsMap> extractPhysicalPlan(const MemoPhysicalNodeId id,
+ const Metadata& metadata,
+ const cascades::Memo& memo) {
+ NodeToGroupPropsMap resultMap;
+ MemoPhysicalPlanExtractor extractor(memo, metadata, resultMap);
+ ABT resultNode = extractor.extract(id);
+ return {std::move(resultNode), std::move(resultMap)};
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/memo_utils.h b/src/mongo/db/query/optimizer/utils/memo_utils.h
new file mode 100644
index 00000000000..af28a2c3ef8
--- /dev/null
+++ b/src/mongo/db/query/optimizer/utils/memo_utils.h
@@ -0,0 +1,75 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/cascades/memo.h"
+#include "mongo/db/query/optimizer/metadata.h"
+#include "mongo/db/query/optimizer/node_defs.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+
+namespace mongo::optimizer {
+
+ABT wrapConstFilter(ABT node);
+ABT unwrapConstFilter(ABT node);
+
+template <class ToAddType, class ToRemoveType>
+static void addRemoveProjectionsToProperties(properties::PhysProps& properties,
+ const ToAddType& toAdd,
+ const ToRemoveType& toRemove) {
+ ProjectionNameOrderPreservingSet& projections =
+ properties::getProperty<properties::ProjectionRequirement>(properties).getProjections();
+ for (const auto& varName : toRemove) {
+ projections.erase(varName);
+ }
+ for (const auto& varName : toAdd) {
+ projections.emplace_back(varName);
+ }
+}
+
+template <class ToAddType>
+static void addProjectionsToProperties(properties::PhysProps& properties, const ToAddType& toAdd) {
+ addRemoveProjectionsToProperties(properties, toAdd, ToAddType{});
+}
+
+/**
+ * Extracts the "latest" logical plan. Starting from the root group, we follow the last logical
+ * nodes.
+ */
+ABT extractLatestPlan(const cascades::Memo& memo, GroupIdType rootGroupId);
+
+/**
+ * Extracts a complete physical plan by inlining references to MemoPhysicalPlanNode.
+ */
+std::pair<ABT, NodeToGroupPropsMap> extractPhysicalPlan(MemoPhysicalNodeId id,
+ const Metadata& metadata,
+ const cascades::Memo& memo);
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/printable_enum.h b/src/mongo/db/query/optimizer/utils/printable_enum.h
new file mode 100644
index 00000000000..873187f5a25
--- /dev/null
+++ b/src/mongo/db/query/optimizer/utils/printable_enum.h
@@ -0,0 +1,48 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+namespace mongo::optimizer {
+
+// Declares a helper macro to print an enmumerator in an enum.
+#define PRINTABLE_ENUMERATOR(value) value,
+// Declares a helper macro to print an entry in a string array for an enum.
+#define PRINTABLE_ENUM_STRING(value) #value,
+
+// Makes an enum class with a name based on a LIST macro of enumerators.
+#define MAKE_PRINTABLE_ENUM(name, LIST) enum class name { LIST(PRINTABLE_ENUMERATOR) };
+
+// Makes an array of enum names with a name based on a LIST macro of enumerators.
+#define MAKE_PRINTABLE_ENUM_STRING_ARRAY(nspace, name, LIST) \
+ namespace nspace { \
+ constexpr const char* toString[] = {LIST(PRINTABLE_ENUM_STRING)}; \
+ } // namespace nspace
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp b/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp
new file mode 100644
index 00000000000..2f5252e573f
--- /dev/null
+++ b/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp
@@ -0,0 +1,157 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
+
+#include "mongo/db/pipeline/abt/abt_document_source_visitor.h"
+#include "mongo/db/pipeline/expression_context_for_test.h"
+#include "mongo/db/query/optimizer/explain.h"
+#include "mongo/db/query/optimizer/metadata.h"
+#include "mongo/db/query/optimizer/node.h"
+#include "mongo/unittest/temp_dir.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo::optimizer {
+
+static constexpr bool kDebugAsserts = false;
+
+void maybePrintABT(const ABT& abt) {
+ // Always print using the supported versions to make sure we don't crash.
+ const std::string strV1 = ExplainGenerator::explain(abt);
+ const std::string strV2 = ExplainGenerator::explainV2(abt);
+ auto [tag, val] = ExplainGenerator::explainBSON(abt);
+ sbe::value::ValueGuard vg(tag, val);
+
+ if constexpr (kDebugAsserts) {
+ std::cout << "V1: " << strV1 << "\n";
+ std::cout << "V2: " << strV2 << "\n";
+ std::cout << "BSON: " << ExplainGenerator::printBSON(tag, val) << "\n";
+ }
+}
+
+ABT makeIndexPath(FieldPathType fieldPath, bool isMultiKey) {
+ ABT result = make<PathIdentity>();
+
+ for (size_t i = fieldPath.size(); i-- > 0;) {
+ if (isMultiKey) {
+ result = make<PathTraverse>(std::move(result));
+ }
+ result = make<PathGet>(std::move(fieldPath.at(i)), std::move(result));
+ }
+
+ return result;
+}
+
+ABT makeIndexPath(FieldNameType fieldName) {
+ return makeIndexPath(FieldPathType{std::move(fieldName)});
+}
+
+ABT makeNonMultikeyIndexPath(FieldNameType fieldName) {
+ return makeIndexPath(FieldPathType{std::move(fieldName)}, false /*isMultiKey*/);
+}
+
+IndexDefinition makeIndexDefinition(FieldNameType fieldName, CollationOp op, bool isMultiKey) {
+ IndexCollationSpec idxCollSpec{
+ IndexCollationEntry((isMultiKey ? makeIndexPath(std::move(fieldName))
+ : makeNonMultikeyIndexPath(std::move(fieldName))),
+ op)};
+ return IndexDefinition{std::move(idxCollSpec), isMultiKey};
+}
+
+IndexDefinition makeCompositeIndexDefinition(std::vector<TestIndexField> indexFields,
+ bool isMultiKey) {
+ IndexCollationSpec idxCollSpec;
+ for (auto& idxField : indexFields) {
+ idxCollSpec.emplace_back((idxField.isMultiKey
+ ? makeIndexPath(std::move(idxField.fieldName))
+ : makeNonMultikeyIndexPath(std::move(idxField.fieldName))),
+ idxField.op);
+ }
+ return IndexDefinition{std::move(idxCollSpec), isMultiKey};
+}
+
+std::unique_ptr<mongo::Pipeline, mongo::PipelineDeleter> parsePipeline(
+ const NamespaceString& nss,
+ const std::string& inputPipeline,
+ OperationContextNoop& opCtx,
+ const std::vector<ExpressionContext::ResolvedNamespace>& involvedNss) {
+ const BSONObj inputBson = fromjson("{pipeline: " + inputPipeline + "}");
+
+ std::vector<BSONObj> rawPipeline;
+ for (auto&& stageElem : inputBson["pipeline"].Array()) {
+ ASSERT_EQUALS(stageElem.type(), BSONType::Object);
+ rawPipeline.push_back(stageElem.embeddedObject());
+ }
+
+ AggregateCommandRequest request(nss, rawPipeline);
+ boost::intrusive_ptr<ExpressionContextForTest> ctx(
+ new ExpressionContextForTest(&opCtx, request));
+
+ // Setup the resolved namespaces for other involved collections.
+ for (const auto& resolvedNss : involvedNss) {
+ ctx->setResolvedNamespace(resolvedNss.ns, resolvedNss);
+ }
+
+ unittest::TempDir tempDir("ABTPipelineTest");
+ ctx->tempDir = tempDir.path();
+
+ return Pipeline::parse(request.getPipeline(), ctx);
+}
+
+ABT translatePipeline(const Metadata& metadata,
+ const std::string& pipelineStr,
+ ProjectionName scanProjName,
+ std::string scanDefName,
+ PrefixId& prefixId,
+ const std::vector<ExpressionContext::ResolvedNamespace>& involvedNss) {
+ OperationContextNoop opCtx;
+ auto pipeline =
+ parsePipeline(NamespaceString("a." + scanDefName), pipelineStr, opCtx, involvedNss);
+ return translatePipelineToABT(metadata,
+ *pipeline.get(),
+ scanProjName,
+ make<ScanNode>(scanProjName, std::move(scanDefName)),
+ prefixId);
+}
+
+ABT translatePipeline(Metadata& metadata,
+ const std::string& pipelineStr,
+ std::string scanDefName,
+ PrefixId& prefixId) {
+ return translatePipeline(
+ metadata, pipelineStr, prefixId.getNextId("scan"), scanDefName, prefixId);
+}
+
+ABT translatePipeline(const std::string& pipelineStr, std::string scanDefName) {
+ PrefixId prefixId;
+ Metadata metadata{{}};
+ return translatePipeline(metadata, pipelineStr, std::move(scanDefName), prefixId);
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/unit_test_utils.h b/src/mongo/db/query/optimizer/utils/unit_test_utils.h
new file mode 100644
index 00000000000..402a81618d8
--- /dev/null
+++ b/src/mongo/db/query/optimizer/utils/unit_test_utils.h
@@ -0,0 +1,99 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/operation_context_noop.h"
+#include "mongo/db/pipeline/expression_context_for_test.h"
+#include "mongo/db/pipeline/pipeline.h"
+#include "mongo/db/query/optimizer/defs.h"
+#include "mongo/db/query/optimizer/metadata.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+namespace mongo::optimizer {
+
+void maybePrintABT(const ABT& abt);
+
+#define ASSERT_EXPLAIN(expected, abt) \
+ maybePrintABT(abt); \
+ ASSERT_EQ(expected, ExplainGenerator::explain(abt))
+
+#define ASSERT_EXPLAIN_V2(expected, abt) \
+ maybePrintABT(abt); \
+ ASSERT_EQ(expected, ExplainGenerator::explainV2(abt))
+
+#define ASSERT_EXPLAIN_BSON(expected, abt) \
+ maybePrintABT(abt); \
+ ASSERT_EQ(expected, ExplainGenerator::explainBSON(abt))
+
+#define ASSERT_EXPLAIN_PROPS_V2(expected, phaseManager) \
+ ASSERT_EQ(expected, \
+ ExplainGenerator::explainV2( \
+ make<MemoPhysicalDelegatorNode>(phaseManager.getPhysicalNodeId()), \
+ true /*displayPhysicalProperties*/, \
+ &phaseManager.getMemo()))
+
+#define ASSERT_EXPLAIN_MEMO(expected, memo) ASSERT_EQ(expected, ExplainGenerator::explainMemo(memo))
+
+#define ASSERT_BETWEEN(a, b, value) \
+ ASSERT_LTE(a, value); \
+ ASSERT_GTE(b, value);
+
+struct TestIndexField {
+ FieldNameType fieldName;
+ CollationOp op;
+ bool isMultiKey;
+};
+
+ABT makeIndexPath(FieldPathType fieldPath, bool isMultiKey = true);
+
+ABT makeIndexPath(FieldNameType fieldName);
+ABT makeNonMultikeyIndexPath(FieldNameType fieldName);
+
+IndexDefinition makeIndexDefinition(FieldNameType fieldName,
+ CollationOp op,
+ bool isMultiKey = true);
+IndexDefinition makeCompositeIndexDefinition(std::vector<TestIndexField> indexFields,
+ bool isMultiKey = true);
+
+ABT translatePipeline(const Metadata& metadata,
+ const std::string& pipelineStr,
+ ProjectionName scanProjName,
+ std::string scanDefName,
+ PrefixId& prefixId,
+ const std::vector<ExpressionContext::ResolvedNamespace>& involvedNss = {});
+
+ABT translatePipeline(Metadata& metadata,
+ const std::string& pipelineStr,
+ std::string scanDefName,
+ PrefixId& prefixId);
+
+ABT translatePipeline(const std::string& pipelineStr, std::string scanDefName = "collection");
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/utils.cpp b/src/mongo/db/query/optimizer/utils/utils.cpp
new file mode 100644
index 00000000000..5bf7d93db29
--- /dev/null
+++ b/src/mongo/db/query/optimizer/utils/utils.cpp
@@ -0,0 +1,1583 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/utils/utils.h"
+
+#include "boost/none.hpp"
+#include "mongo/db/query/optimizer/metadata.h"
+#include "mongo/db/query/optimizer/reference_tracker.h"
+#include "mongo/db/query/optimizer/syntax/path.h"
+#include "mongo/db/query/optimizer/utils/interval_utils.h"
+#include "mongo/db/storage/storage_parameters_gen.h"
+#include "mongo/util/assert_util.h"
+
+namespace mongo::optimizer {
+
+size_t roundUpToNextPow2(const size_t v, const size_t maxPower) {
+ if (v == 0) {
+ return 0;
+ }
+
+ size_t result = 1;
+ for (size_t power = 0; result < v && power < maxPower; result <<= 1, power++) {
+ }
+ return result;
+}
+
+std::vector<ABT::reference_type> collectComposed(const ABT& n) {
+ if (auto comp = n.cast<PathComposeM>(); comp) {
+ auto lhs = collectComposed(comp->getPath1());
+ auto rhs = collectComposed(comp->getPath2());
+ lhs.insert(lhs.end(), rhs.begin(), rhs.end());
+
+ return lhs;
+ }
+
+ return {n.ref()};
+}
+
+FieldNameType getSimpleField(const ABT& node) {
+ const PathGet* pathGet = node.cast<PathGet>();
+ return pathGet != nullptr ? pathGet->name() : "";
+}
+
+std::string PrefixId::getNextId(const std::string& key) {
+ std::ostringstream os;
+ os << key << "_" << _idCounterPerKey[key]++;
+ return os.str();
+}
+
+ProjectionNameOrderedSet convertToOrderedSet(ProjectionNameSet unordered) {
+ ProjectionNameOrderedSet ordered;
+ for (const ProjectionName& projection : unordered) {
+ ordered.emplace(projection);
+ }
+ return ordered;
+}
+
+opt::unordered_set<FieldNameType> toUnorderedFieldNameSet(std::set<FieldNameType> set) {
+ opt::unordered_set<FieldNameType> result;
+ for (auto&& fieldName : set) {
+ result.emplace(std::move(fieldName));
+ }
+ return result;
+}
+
+void combineLimitSkipProperties(properties::LimitSkipRequirement& aboveProp,
+ const properties::LimitSkipRequirement& belowProp) {
+ using namespace properties;
+
+ const int64_t newAbsLimit = std::min<int64_t>(
+ aboveProp.hasLimit() ? (belowProp.getSkip() + aboveProp.getAbsoluteLimit())
+ : LimitSkipRequirement::kMaxVal,
+ std::max<int64_t>(0,
+ belowProp.hasLimit()
+ ? (belowProp.getAbsoluteLimit() - aboveProp.getSkip())
+ : LimitSkipRequirement::kMaxVal));
+
+ const int64_t newLimit = (newAbsLimit == LimitSkipRequirement::kMaxVal)
+ ? LimitSkipRequirement::kMaxVal
+ : (newAbsLimit - belowProp.getSkip());
+ const int64_t newSkip = (newLimit == 0) ? 0 : belowProp.getSkip();
+ aboveProp = {newLimit, newSkip};
+}
+
+/**
+ * Used to track references originating from a set of properties.
+ */
+class PropertiesAffectedColumnsExtractor {
+public:
+ template <class T>
+ void operator()(const properties::PhysProperty&, const T& prop) {
+ for (const ProjectionName& projection : prop.getAffectedProjectionNames()) {
+ _projections.insert(projection);
+ }
+ }
+
+ static ProjectionNameSet extract(const properties::PhysProps& properties) {
+ PropertiesAffectedColumnsExtractor extractor;
+ for (const auto& entry : properties) {
+ entry.second.visit(extractor);
+ }
+ return extractor._projections;
+ }
+
+private:
+ ProjectionNameSet _projections;
+};
+
+ProjectionNameSet extractReferencedColumns(const properties::PhysProps& properties) {
+ return PropertiesAffectedColumnsExtractor::extract(properties);
+}
+
+bool areMultiKeyIntervalsEqualities(const MultiKeyIntervalRequirement& intervals) {
+ for (const auto& interval : intervals) {
+ if (!interval.isEquality()) {
+ return false;
+ }
+ }
+ return true;
+}
+
+CollationSplitResult splitCollationSpec(const ProjectionCollationSpec& collationSpec,
+ const ProjectionNameSet& leftProjections,
+ const ProjectionNameSet& rightProjections) {
+ bool leftSide = true;
+ ProjectionCollationSpec leftCollationSpec;
+ ProjectionCollationSpec rightCollationSpec;
+
+ for (const auto& collationEntry : collationSpec) {
+ const ProjectionName& projectionName = collationEntry.first;
+
+ if (leftProjections.count(projectionName) > 0) {
+ if (!leftSide) {
+ // Left and right projections must complement and form prefix and suffix.
+ return {};
+ }
+ leftCollationSpec.push_back(collationEntry);
+ } else if (rightProjections.count(projectionName) > 0) {
+ if (leftSide) {
+ leftSide = false;
+ }
+ rightCollationSpec.push_back(collationEntry);
+ } else {
+ uasserted(6624146,
+ "Collation projection must appear in either the left or the right "
+ "child projections");
+ return {};
+ }
+ }
+
+ return {true /*validSplit*/, std::move(leftCollationSpec), std::move(rightCollationSpec)};
+}
+/**
+ * Helper class used to extract variable references from a node.
+ */
+class NodeVariableTracker {
+public:
+ template <typename T, typename... Ts>
+ VariableNameSetType walk(const T&, Ts&&...) {
+ static_assert(!std::is_base_of_v<Node, T>, "Nodes must implement variable tracking");
+
+ // Default case: no variables.
+ return {};
+ }
+
+ VariableNameSetType walk(const ScanNode& /*node*/, const ABT& /*binds*/) {
+ return {};
+ }
+
+ VariableNameSetType walk(const ValueScanNode& /*node*/, const ABT& /*binds*/) {
+ return {};
+ }
+
+ VariableNameSetType walk(const PhysicalScanNode& /*node*/, const ABT& /*binds*/) {
+ return {};
+ }
+
+ VariableNameSetType walk(const CoScanNode& /*node*/) {
+ return {};
+ }
+
+ VariableNameSetType walk(const IndexScanNode& /*node*/, const ABT& /*binds*/) {
+ return {};
+ }
+
+ VariableNameSetType walk(const SeekNode& /*node*/, const ABT& /*binds*/, const ABT& refs) {
+ return extractFromABT(refs);
+ }
+
+ VariableNameSetType walk(const MemoLogicalDelegatorNode& /*node*/) {
+ return {};
+ }
+
+ VariableNameSetType walk(const MemoPhysicalDelegatorNode& /*node*/) {
+ return {};
+ }
+
+ VariableNameSetType walk(const FilterNode& /*node*/, const ABT& /*child*/, const ABT& expr) {
+ return extractFromABT(expr);
+ }
+
+ VariableNameSetType walk(const EvaluationNode& /*node*/,
+ const ABT& /*child*/,
+ const ABT& expr) {
+ return extractFromABT(expr);
+ }
+
+ VariableNameSetType walk(const SargableNode& /*node*/,
+ const ABT& /*child*/,
+ const ABT& /*binds*/,
+ const ABT& refs) {
+ return extractFromABT(refs);
+ }
+
+ VariableNameSetType walk(const RIDIntersectNode& /*node*/,
+ const ABT& /*leftChild*/,
+ const ABT& /*rightChild*/) {
+ return {};
+ }
+
+ VariableNameSetType walk(const BinaryJoinNode& /*node*/,
+ const ABT& /*leftChild*/,
+ const ABT& /*rightChild*/,
+ const ABT& expr) {
+ return extractFromABT(expr);
+ }
+
+ VariableNameSetType walk(const HashJoinNode& /*node*/,
+ const ABT& /*leftChild*/,
+ const ABT& /*rightChild*/,
+ const ABT& refs) {
+ return extractFromABT(refs);
+ }
+
+ VariableNameSetType walk(const MergeJoinNode& /*node*/,
+ const ABT& /*leftChild*/,
+ const ABT& /*rightChild*/,
+ const ABT& refs) {
+ return extractFromABT(refs);
+ }
+
+ VariableNameSetType walk(const UnionNode& /*node*/,
+ const ABTVector& /*children*/,
+ const ABT& /*binder*/,
+ const ABT& refs) {
+ return extractFromABT(refs);
+ }
+
+ VariableNameSetType walk(const GroupByNode& /*node*/,
+ const ABT& /*child*/,
+ const ABT& /*aggBinder*/,
+ const ABT& aggRefs,
+ const ABT& /*groupbyBinder*/,
+ const ABT& groupbyRefs) {
+ VariableNameSetType result;
+ extractFromABT(result, aggRefs);
+ extractFromABT(result, groupbyRefs);
+ return result;
+ }
+
+ VariableNameSetType walk(const UnwindNode& /*node*/,
+ const ABT& /*child*/,
+ const ABT& /*binds*/,
+ const ABT& refs) {
+ return extractFromABT(refs);
+ }
+
+ VariableNameSetType walk(const UniqueNode& /*node*/, const ABT& /*child*/, const ABT& refs) {
+ return extractFromABT(refs);
+ }
+
+ VariableNameSetType walk(const CollationNode& /*node*/, const ABT& /*child*/, const ABT& refs) {
+ return extractFromABT(refs);
+ }
+
+ VariableNameSetType walk(const LimitSkipNode& /*node*/, const ABT& /*child*/) {
+ return {};
+ }
+
+ VariableNameSetType walk(const ExchangeNode& /*node*/, const ABT& /*child*/, const ABT& refs) {
+ return extractFromABT(refs);
+ }
+
+ VariableNameSetType walk(const RootNode& /*node*/, const ABT& /*child*/, const ABT& refs) {
+ return extractFromABT(refs);
+ }
+
+ static VariableNameSetType collect(const ABT& n) {
+ NodeVariableTracker tracker;
+ return algebra::walk<false>(n, tracker);
+ }
+
+private:
+ void extractFromABT(VariableNameSetType& vars, const ABT& v) {
+ const auto& result = VariableEnvironment::getVariables(v);
+ for (const Variable* var : result._variables) {
+ if (result._definedVars.count(var->name()) == 0) {
+ // We are interested in either free variables, or variables defined on other nodes.
+ vars.insert(var->name());
+ }
+ }
+ }
+
+ VariableNameSetType extractFromABT(const ABT& v) {
+ VariableNameSetType result;
+ extractFromABT(result, v);
+ return result;
+ }
+};
+
+VariableNameSetType collectVariableReferences(const ABT& n) {
+ return NodeVariableTracker::collect(n);
+}
+
+PartialSchemaReqConversion::PartialSchemaReqConversion()
+ : _success(false),
+ _bound(),
+ _reqMap(),
+ _hasIntersected(false),
+ _hasTraversed(false),
+ _hasEmptyInterval(false) {}
+
+PartialSchemaReqConversion::PartialSchemaReqConversion(PartialSchemaRequirements reqMap)
+ : _success(true),
+ _bound(),
+ _reqMap(std::move(reqMap)),
+ _hasIntersected(false),
+ _hasTraversed(false),
+ _hasEmptyInterval(false) {}
+
+PartialSchemaReqConversion::PartialSchemaReqConversion(ABT bound)
+ : _success(true),
+ _bound(std::move(bound)),
+ _reqMap(),
+ _hasIntersected(false),
+ _hasTraversed(false),
+ _hasEmptyInterval(false) {}
+
+/**
+ * Helper class that builds PartialSchemaRequirements property from an EvalFilter or an EvalPath.
+ */
+class PartialSchemaReqConverter {
+public:
+ PartialSchemaReqConverter() = default;
+
+ PartialSchemaReqConversion handleEvalPathAndEvalFilter(PartialSchemaReqConversion pathResult,
+ PartialSchemaReqConversion inputResult) {
+ if (!pathResult._success || !inputResult._success) {
+ return {};
+ }
+ if (pathResult._bound.has_value() || !inputResult._bound.has_value() ||
+ !inputResult._reqMap.empty()) {
+ return {};
+ }
+
+ if (auto boundPtr = inputResult._bound->cast<Variable>(); boundPtr != nullptr) {
+ const ProjectionName& boundVarName = boundPtr->name();
+ PartialSchemaRequirements newMap;
+
+ for (auto& [key, req] : pathResult._reqMap) {
+ if (!key._projectionName.empty()) {
+ return {};
+ }
+ newMap.emplace(PartialSchemaKey{boundVarName, key._path}, std::move(req));
+ }
+
+ PartialSchemaReqConversion result{std::move(newMap)};
+ result._hasEmptyInterval = pathResult._hasEmptyInterval;
+ return result;
+ }
+
+ return {};
+ }
+
+ PartialSchemaReqConversion transport(const ABT& n,
+ const EvalPath& evalPath,
+ PartialSchemaReqConversion pathResult,
+ PartialSchemaReqConversion inputResult) {
+ return handleEvalPathAndEvalFilter(std::move(pathResult), std::move(inputResult));
+ }
+
+ PartialSchemaReqConversion transport(const ABT& n,
+ const EvalFilter& evalFilter,
+ PartialSchemaReqConversion pathResult,
+ PartialSchemaReqConversion inputResult) {
+ return handleEvalPathAndEvalFilter(std::move(pathResult), std::move(inputResult));
+ }
+
+ static PartialSchemaReqConversion handleComposition(const bool isMultiplicative,
+ PartialSchemaReqConversion leftResult,
+ PartialSchemaReqConversion rightResult) {
+ if (!leftResult._success || !rightResult._success) {
+ return {};
+ }
+ if (leftResult._bound.has_value() || rightResult._bound.has_value()) {
+ return {};
+ }
+
+ auto& leftReq = leftResult._reqMap;
+ auto& rightReq = rightResult._reqMap;
+ if (isMultiplicative) {
+ {
+ ProjectionRenames projectionRenames;
+ if (!intersectPartialSchemaReq(leftReq, rightReq, projectionRenames)) {
+ return {};
+ }
+ if (!projectionRenames.empty()) {
+ return {};
+ }
+ }
+
+ if (!leftResult._hasTraversed && !rightResult._hasTraversed) {
+ // Intersect intervals only if we have not traversed. E.g. (-inf, 90) ^ (70, +inf)
+ // becomes (70, 90).
+ for (auto& [key, req] : leftReq) {
+ auto newIntervals = intersectDNFIntervals(req.getIntervals());
+ if (newIntervals) {
+ req.getIntervals() = std::move(newIntervals.get());
+ } else {
+ leftResult._hasEmptyInterval = true;
+ break;
+ }
+ }
+ } else if (leftReq.size() > 1) {
+ // Reject if we have traversed, and composed requirements are more than one.
+ return {};
+ }
+
+ leftResult._hasIntersected = true;
+ return leftResult;
+ }
+
+ // We can only perform additive composition (OR) if we have a single matching key on both
+ // sides.
+ if (leftReq.size() != 1 || rightReq.size() != 1) {
+ return {};
+ }
+
+ auto leftEntry = leftReq.begin();
+ auto rightEntry = rightReq.begin();
+ if (!(leftEntry->first == rightEntry->first)) {
+ return {};
+ }
+
+ combineIntervalsDNF(false /*intersect*/,
+ leftEntry->second.getIntervals(),
+ rightEntry->second.getIntervals());
+ return leftResult;
+ }
+
+ PartialSchemaReqConversion transport(const ABT& n,
+ const PathComposeM& pathComposeM,
+ PartialSchemaReqConversion leftResult,
+ PartialSchemaReqConversion rightResult) {
+ return handleComposition(
+ true /*isMultiplicative*/, std::move(leftResult), std::move(rightResult));
+ }
+
+ PartialSchemaReqConversion transport(const ABT& n,
+ const PathComposeA& pathComposeA,
+ PartialSchemaReqConversion leftResult,
+ PartialSchemaReqConversion rightResult) {
+ return handleComposition(
+ false /*isMultiplicative*/, std::move(leftResult), std::move(rightResult));
+ }
+
+ template <class T>
+ static PartialSchemaReqConversion handleGetAndTraverse(const ABT& n,
+ PartialSchemaReqConversion inputResult) {
+ if (!inputResult._success) {
+ return {};
+ }
+ if (inputResult._bound.has_value()) {
+ return {};
+ }
+
+ // New map has keys with appended paths.
+ PartialSchemaRequirements newMap;
+
+ for (auto& entry : inputResult._reqMap) {
+ const ProjectionName& projectionName = entry.first._projectionName;
+ if (!projectionName.empty()) {
+ return {};
+ }
+
+ ABT path = entry.first._path;
+
+ // Updated key path to be now rooted at n, with existing key path as child.
+ ABT appendedPath = n;
+ std::swap(appendedPath.cast<T>()->getPath(), path);
+ std::swap(path, appendedPath);
+
+ newMap.emplace(PartialSchemaKey{projectionName, std::move(path)},
+ std::move(entry.second));
+ }
+
+ inputResult._reqMap = std::move(newMap);
+ return inputResult;
+ }
+
+ PartialSchemaReqConversion transport(const ABT& n,
+ const PathGet& pathGet,
+ PartialSchemaReqConversion inputResult) {
+ return handleGetAndTraverse<PathGet>(n, std::move(inputResult));
+ }
+
+ PartialSchemaReqConversion transport(const ABT& n,
+ const PathTraverse& pathTraverse,
+ PartialSchemaReqConversion inputResult) {
+ if (inputResult._reqMap.size() > 1) {
+ // Cannot append traverse if we have more than one requirement.
+ return {};
+ }
+
+ PartialSchemaReqConversion result =
+ handleGetAndTraverse<PathTraverse>(n, std::move(inputResult));
+ result._hasTraversed = true;
+ return result;
+ }
+
+ PartialSchemaReqConversion transport(const ABT& n,
+ const PathCompare& pathCompare,
+ PartialSchemaReqConversion inputResult) {
+ if (!inputResult._success) {
+ return {};
+ }
+ if (!inputResult._bound.has_value() || !inputResult._reqMap.empty()) {
+ return {};
+ }
+
+ const auto& bound = inputResult._bound;
+ bool lowBoundInclusive = false;
+ boost::optional<ABT> lowBound;
+ bool highBoundInclusive = false;
+ boost::optional<ABT> highBound;
+
+ const Operations op = pathCompare.op();
+ switch (op) {
+ case Operations::Eq:
+ lowBoundInclusive = true;
+ lowBound = bound;
+ highBoundInclusive = true;
+ highBound = bound;
+ break;
+
+ case Operations::Lt:
+ case Operations::Lte:
+ lowBoundInclusive = false;
+ highBoundInclusive = op == Operations::Lte;
+ highBound = bound;
+ break;
+
+ case Operations::Gt:
+ case Operations::Gte:
+ lowBoundInclusive = op == Operations::Gte;
+ lowBound = bound;
+ highBoundInclusive = false;
+ break;
+
+ default:
+ // TODO handle other comparisons?
+ return {};
+ }
+
+ auto intervalExpr = IntervalReqExpr::makeSingularDNF(IntervalRequirement{
+ {lowBoundInclusive, std::move(lowBound)}, {highBoundInclusive, std::move(highBound)}});
+ return {PartialSchemaRequirements{
+ {PartialSchemaKey{},
+ PartialSchemaRequirement{"" /*boundProjectionName*/, std::move(intervalExpr)}}}};
+ }
+
+ PartialSchemaReqConversion transport(const ABT& n, const PathIdentity& pathIdentity) {
+ return {PartialSchemaRequirements{{{}, {}}}};
+ }
+
+ template <typename T, typename... Ts>
+ PartialSchemaReqConversion transport(const ABT& n, const T& node, Ts&&...) {
+ if constexpr (std::is_base_of_v<ExpressionSyntaxSort, T>) {
+ // We allow expressions to participate in bounds.
+ return {n};
+ }
+ // General case. Reject conversion.
+ return {};
+ }
+
+ PartialSchemaReqConversion convert(const ABT& input) {
+ return algebra::transport<true>(input, *this);
+ }
+};
+
+PartialSchemaReqConversion convertExprToPartialSchemaReq(const ABT& expr) {
+ PartialSchemaReqConverter converter;
+ PartialSchemaReqConversion result = converter.convert(expr);
+ if (result._reqMap.empty()) {
+ result._success = false;
+ return result;
+ }
+
+ for (const auto& entry : result._reqMap) {
+ if (entry.first.emptyPath() && isIntervalReqFullyOpenDNF(entry.second.getIntervals())) {
+ // We need to determine either path or interval (or both).
+ result._success = false;
+ return result;
+ }
+ }
+ return result;
+}
+
+static bool intersectPartialSchemaReq(PartialSchemaRequirements& reqMap,
+ PartialSchemaKey key,
+ PartialSchemaRequirement req,
+ ProjectionRenames& projectionRenames) {
+ for (;;) {
+ bool merged = false;
+ PartialSchemaKey newKey;
+ PartialSchemaRequirement newReq;
+
+ const bool reqHasBoundProj = req.hasBoundProjectionName();
+ {
+ // Look for exact match on the path, and if found combine intervals.
+ auto it = reqMap.find(key);
+ if (it != reqMap.cend()) {
+ PartialSchemaRequirement& existingReq = it->second;
+ if (reqHasBoundProj) {
+ if (existingReq.hasBoundProjectionName()) {
+ projectionRenames.emplace(req.getBoundProjectionName(),
+ existingReq.getBoundProjectionName());
+ } else {
+ existingReq.setBoundProjectionName(req.getBoundProjectionName());
+ }
+ }
+ combineIntervalsDNF(
+ true /*intersect*/, existingReq.getIntervals(), req.getIntervals());
+ return true;
+ }
+ }
+
+ for (auto it = reqMap.begin(); it != reqMap.cend();) {
+ const auto& [existingKey, existingReq] = *it;
+ uassert(6624150,
+ "Existing key referring to new requirement.",
+ !reqHasBoundProj ||
+ existingKey._projectionName != req.getBoundProjectionName());
+
+ if (existingReq.hasBoundProjectionName() &&
+ key._projectionName == existingReq.getBoundProjectionName()) {
+ // The new key is referring to a projection the existing requirement binds.
+ if (reqHasBoundProj) {
+ return false;
+ }
+
+ newKey = existingKey;
+ newReq = req;
+
+ PathAppender appender(key._path);
+ appender.append(newKey._path);
+
+ combineIntervalsDNF(
+ true /*intersect*/, newReq.getIntervals(), existingReq.getIntervals());
+
+ if (key._path.is<PathIdentity>()) {
+ newReq.setBoundProjectionName(existingReq.getBoundProjectionName());
+ reqMap.erase(it++);
+ } else if (!isIntervalReqFullyOpenDNF(existingReq.getIntervals())) {
+ return false;
+ }
+
+ merged = true;
+ break;
+ } else {
+ it++;
+ continue;
+ }
+ }
+
+ if (merged) {
+ key = std::move(newKey);
+ req = std::move(newReq);
+ } else {
+ reqMap[key] = req;
+ break;
+ }
+ };
+
+ return true;
+}
+
+bool intersectPartialSchemaReq(PartialSchemaRequirements& target,
+ const PartialSchemaRequirements& source,
+ ProjectionRenames& projectionRenames) {
+ for (const auto& [key, req] : source) {
+ if (!intersectPartialSchemaReq(target, key, req, projectionRenames)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+std::string encodeIndexKeyName(const size_t indexField) {
+ std::ostringstream os;
+ os << kIndexKeyPrefix << " " << indexField;
+ return os.str();
+}
+
+size_t decodeIndexKeyName(const std::string& fieldName) {
+ std::istringstream is(fieldName);
+
+ std::string prefix;
+ is >> prefix;
+ uassert(6624151, "Invalid index key prefix", prefix == kIndexKeyPrefix);
+
+ int key;
+ is >> key;
+ return key;
+}
+
+/**
+ * Checks if one index path is a prefix of another. Considers only Get, Traverse, and Id.
+ * Return the suffix that doesn't match.
+ */
+class PathSuffixExtactor {
+public:
+ using ResultType = boost::optional<ABT::reference_type>;
+
+ /**
+ * 'n' - The complete index path being compared to, can be modified if needed.
+ * 'node' - Same as 'n' but cast to a specific type by the caller in order to invoke the
+ * correct operator.
+ * 'other' - The query that is checked if it is a prefix of the index.
+ */
+ ResultType operator()(const ABT& n, const PathGet& node, const ABT& other) {
+ if (auto otherGet = other.cast<PathGet>();
+ otherGet != nullptr && otherGet->name() == node.name()) {
+ if (auto otherChildTraverse = otherGet->getPath().cast<PathTraverse>();
+ otherChildTraverse != nullptr && !node.getPath().is<PathTraverse>()) {
+ // If a query path has a Traverse, but the index path doesn't, the query can
+ // still be evaluated by this index. Skip the Traverse node, and continue matching.
+ return node.getPath().visit(*this, otherChildTraverse->getPath());
+ } else {
+ return node.getPath().visit(*this, otherGet->getPath());
+ }
+ }
+ return {};
+ }
+
+ ResultType operator()(const ABT& n, const PathTraverse& node, const ABT& other) {
+ if (auto otherTraverse = other.cast<PathTraverse>(); otherTraverse != nullptr) {
+ return node.getPath().visit(*this, otherTraverse->getPath());
+ }
+ return {};
+ }
+
+ ResultType operator()(const ABT& n, const PathIdentity& node, const ABT& other) {
+ return {other.ref()};
+ }
+
+ template <typename T, typename... Ts>
+ ResultType operator()(const ABT& /*n*/, const T& /*node*/, Ts&&...) {
+ uasserted(6624152, "Unexpected node type");
+ }
+
+ static ResultType check(const ABT& node, const ABT& candidatePrefix) {
+ PathSuffixExtactor instance;
+ return candidatePrefix.visit(instance, node);
+ }
+};
+
+/**
+ * Check if a path contains a Traverse node over the last path component.
+ */
+class PathTraverseChecker {
+public:
+ PathTraverseChecker() {}
+
+ /**
+ * PathIdentity is always the last node in a path.
+ * Return true if either 'node' is a PathTraverse just before the last node.
+ * If the input Traverse node is somewhere in the middle of the path, return
+ * the result of its child.
+ */
+ bool transport(const ABT& /*n*/, const PathTraverse& node, bool childResult) {
+ return node.getPath().is<PathIdentity>() || childResult;
+ }
+
+ bool transport(const ABT& /*n*/, const PathGet& node, bool childResult) {
+ return childResult;
+ }
+
+ bool transport(const ABT& /*n*/, const PathIdentity& node) {
+ return false;
+ }
+
+ template <typename T, typename... Ts>
+ bool transport(const ABT& /*n*/, const T& /*node*/, Ts&&...) {
+ uasserted(6624153, "Index paths only consist of Get, Traverse, and Id nodes.");
+ return false;
+ }
+
+ /**
+ * Return true if the node before the last one in 'path' is a PathTraverse.
+ * Return false otherwise.
+ */
+ bool check(const ABT& path) {
+ return algebra::transport<true>(path, *this);
+ }
+};
+
+void findMatchingSchemaRequirement(const PartialSchemaKey& indexKey,
+ const PartialSchemaRequirements& reqMap,
+ PartialSchemaKeySet& keySet,
+ PartialSchemaRequirement& req,
+ const bool setIntervalsAndBoundProj) {
+ for (const auto& [queryKey, queryReq] : reqMap) {
+ const auto pathSuffixResult = PathSuffixExtactor::check(queryKey._path, indexKey._path);
+ if (pathSuffixResult.has_value() && pathSuffixResult->is<PathIdentity>()) {
+ keySet.insert(queryKey);
+
+ if (setIntervalsAndBoundProj) {
+ // Combine all matching requirements into a single entry.
+ if (queryReq.hasBoundProjectionName()) {
+ uassert(6624154, "Unexpected bound projection", !req.hasBoundProjectionName());
+ req.setBoundProjectionName(queryReq.getBoundProjectionName());
+ }
+ if (!isIntervalReqFullyOpenDNF(queryReq.getIntervals())) {
+ uassert(6624350,
+ "Unexpected intervals",
+ isIntervalReqFullyOpenDNF(req.getIntervals()));
+ req.getIntervals() = queryReq.getIntervals();
+ }
+ }
+ }
+ }
+}
+
+CandidateIndexMap computeCandidateIndexMap(PrefixId& prefixId,
+ const ProjectionName& scanProjectionName,
+ const PartialSchemaRequirements& reqMap,
+ const ScanDefinition& scanDef,
+ bool& hasEmptyInterval) {
+ CandidateIndexMap result;
+ hasEmptyInterval = false;
+
+ for (const auto& [indexDefName, indexDef] : scanDef.getIndexDefs()) {
+ FieldProjectionMap indexProjectionMap;
+ auto intervals = MultiKeyIntervalReqExpr::makeSingularDNF(); // Singular empty interval.
+ PartialSchemaRequirements residualRequirements;
+ ProjectionNameSet residualRequirementsTempProjections;
+ ResidualKeyMap residualKeyMap;
+ opt::unordered_set<size_t> fieldsToCollate;
+
+ PartialSchemaKeySet unsatisfiedKeys;
+ for (const auto& [key, req] : reqMap) {
+ unsatisfiedKeys.insert(key);
+ }
+
+ // True if the paths from partial schema requirements form a strict prefix of the index
+ // collation.
+ bool isPrefix = true;
+ // If we formed bounds using at least one requirement (as opposed to having only residual
+ // requirements).
+ bool hasExactMatch = false;
+ bool indexSuitable = true;
+
+ const IndexCollationSpec& indexCollationSpec = indexDef.getCollationSpec();
+ for (size_t indexField = 0; indexField < indexCollationSpec.size(); indexField++) {
+ const auto& indexCollationEntry = indexCollationSpec.at(indexField);
+ PartialSchemaKey indexKey{scanProjectionName, indexCollationEntry._path};
+
+ PartialSchemaKeySet keySet;
+ PartialSchemaRequirement req;
+ findMatchingSchemaRequirement(indexKey, reqMap, keySet, req);
+ if (!keySet.empty() && isPrefix) {
+ hasExactMatch = true;
+ const PartialSchemaKey& queryKey = *keySet.begin();
+ for (const auto& key : keySet) {
+ unsatisfiedKeys.erase(key);
+ }
+
+ // The interval derived from the query (not from the index).
+ const auto& requiredInterval = req.getIntervals();
+
+ PathTraverseChecker pathChecker;
+ const bool indexPathContainsArrays = pathChecker.check(indexCollationEntry._path);
+ bool combineSuccess = false;
+ if (indexPathContainsArrays) {
+ combineSuccess = combineMultiKeyIntervalsDNF(intervals, requiredInterval);
+ } else {
+ auto intersectedIntervals = intersectDNFIntervals(requiredInterval);
+ if (intersectedIntervals.has_value()) {
+ combineSuccess =
+ combineMultiKeyIntervalsDNF(intervals, intersectedIntervals.get());
+ } else {
+ if (indexDef.getPartialReqMap().empty()) {
+ hasEmptyInterval = true;
+ return CandidateIndexMap();
+ } else {
+ // This is a partial index, so skip the empty interval, but consider the
+ // remaining indexes.
+ indexSuitable = false;
+ break;
+ }
+ }
+ }
+
+ if (!combineSuccess) {
+ if (!combineMultiKeyIntervalsDNF(intervals,
+ IntervalReqExpr::makeSingularDNF())) {
+ uasserted(6624155, "Cannot combine with an open interval");
+ }
+
+ // Move interval into residual requirements.
+ if (req.hasBoundProjectionName()) {
+ PartialSchemaKey residualKey{req.getBoundProjectionName(),
+ make<PathIdentity>()};
+ residualRequirements.emplace(
+ residualKey, PartialSchemaRequirement{"", req.getIntervals()});
+ residualKeyMap.emplace(std::move(residualKey), queryKey);
+ residualRequirementsTempProjections.insert(req.getBoundProjectionName());
+ } else {
+ ProjectionName tempProj = prefixId.getNextId("evalTemp");
+ PartialSchemaKey residualKey{tempProj, make<PathIdentity>()};
+ residualRequirements.emplace(residualKey, req);
+ residualKeyMap.emplace(std::move(residualKey), queryKey);
+ residualRequirementsTempProjections.insert(tempProj);
+
+ // Include bounds projection into index spec.
+ if (!indexProjectionMap._fieldProjections
+ .emplace(encodeIndexKeyName(indexField), std::move(tempProj))
+ .second) {
+ uasserted(6624156, "Duplicate field name");
+ }
+ }
+ }
+
+ // Include bounds projection into index spec.
+ if (req.hasBoundProjectionName() &&
+ !indexProjectionMap._fieldProjections
+ .emplace(encodeIndexKeyName(indexField), req.getBoundProjectionName())
+ .second) {
+ uasserted(6624157, "Duplicate field name");
+ }
+
+ if (auto singularInterval = IntervalReqExpr::getSingularDNF(requiredInterval);
+ !singularInterval || !singularInterval->isEquality()) {
+ // We only care about collation of for non-equality intervals.
+ // Equivalently, it is sufficient for singular intervals to be clustered.
+ fieldsToCollate.insert(indexField);
+ }
+ } else {
+ bool foundPathPrefix = false;
+ for (const auto& queryKey : unsatisfiedKeys) {
+ const auto pathPrefixResult =
+ PathSuffixExtactor::check(queryKey._path, indexCollationEntry._path);
+ if (pathPrefixResult.has_value()) {
+ ProjectionName tempProj = prefixId.getNextId("evalTemp");
+ PartialSchemaKey residualKey{tempProj, pathPrefixResult.get()};
+ residualRequirements.emplace(residualKey, reqMap.at(queryKey));
+ residualKeyMap.emplace(std::move(residualKey), queryKey);
+ residualRequirementsTempProjections.insert(tempProj);
+
+ // Include bounds projection into index spec.
+ if (!indexProjectionMap._fieldProjections
+ .emplace(encodeIndexKeyName(indexField), std::move(tempProj))
+ .second) {
+ uasserted(6624158, "Duplicate field name");
+ }
+
+ if (!combineMultiKeyIntervalsDNF(intervals,
+ IntervalReqExpr::makeSingularDNF())) {
+ uasserted(6624159, "Cannot combine with an open interval");
+ }
+
+ foundPathPrefix = true;
+ unsatisfiedKeys.erase(queryKey);
+ break;
+ }
+ }
+
+ if (!foundPathPrefix) {
+ isPrefix = false;
+ if (!combineMultiKeyIntervalsDNF(intervals,
+ IntervalReqExpr::makeSingularDNF())) {
+ uasserted(6624160, "Cannot combine with an open interval");
+ }
+ }
+ }
+ }
+ if (!indexSuitable) {
+ continue;
+ }
+ if (!hasExactMatch) {
+ continue;
+ }
+ if (!unsatisfiedKeys.empty()) {
+ continue;
+ }
+
+ uassert(6624161, "Invalid map sizes", residualRequirements.size() == residualKeyMap.size());
+ result.emplace(indexDefName,
+ CandidateIndexEntry{std::move(indexProjectionMap),
+ std::move(intervals),
+ std::move(residualRequirements),
+ std::move(residualRequirementsTempProjections),
+ std::move(residualKeyMap),
+ std::move(fieldsToCollate)});
+ }
+
+ return result;
+}
+
+class PartialSchemaReqLowerTransport {
+public:
+ ABT transport(const IntervalReqExpr::Atom& node) {
+ const auto& interval = node.getExpr();
+ const auto& lowBound = interval.getLowBound();
+ const auto& highBound = interval.getHighBound();
+
+ if (interval.isEquality()) {
+ return make<PathCompare>(Operations::Eq, lowBound.getBound());
+ }
+
+ ABT result = make<PathIdentity>();
+ if (!lowBound.isInfinite()) {
+ maybeComposePath(
+ result,
+ make<PathCompare>(lowBound.isInclusive() ? Operations::Gte : Operations::Gt,
+ lowBound.getBound()));
+ }
+ if (!highBound.isInfinite()) {
+ maybeComposePath(
+ result,
+ make<PathCompare>(highBound.isInclusive() ? Operations::Lte : Operations::Lt,
+ highBound.getBound()));
+ }
+ return result;
+ }
+
+ template <class Element>
+ ABT composeChildren(ABTVector childResults) {
+ ABT result = make<PathIdentity>();
+ for (ABT& n : childResults) {
+ maybeComposePath<Element>(result, std::move(n));
+ }
+ return result;
+ }
+
+ ABT transport(const IntervalReqExpr::Conjunction& node, ABTVector childResults) {
+ return composeChildren<PathComposeM>(std::move(childResults));
+ }
+
+ ABT transport(const IntervalReqExpr::Disjunction& node, ABTVector childResults) {
+ return composeChildren<PathComposeA>(std::move(childResults));
+ }
+
+ ABT lower(const IntervalReqExpr::Node& intervals) {
+ return algebra::transport<false>(intervals, *this);
+ }
+};
+
+void lowerPartialSchemaRequirement(const PartialSchemaKey& key,
+ const PartialSchemaRequirement& req,
+ ABT& node,
+ const std::function<void(const ABT& node)>& visitor) {
+ PartialSchemaReqLowerTransport transport;
+ ABT path = transport.lower(req.getIntervals());
+ const bool pathIsId = path.is<PathIdentity>();
+
+ if (req.hasBoundProjectionName()) {
+ node = make<EvaluationNode>(req.getBoundProjectionName(),
+ make<EvalPath>(key._path, make<Variable>(key._projectionName)),
+ std::move(node));
+ visitor(node);
+
+ if (!pathIsId) {
+ node = make<FilterNode>(
+ make<EvalFilter>(std::move(path), make<Variable>(req.getBoundProjectionName())),
+ std::move(node));
+ visitor(node);
+ }
+ } else {
+ uassert(
+ 6624162, "If we do not have a bound projection, then we have a proper path", !pathIsId);
+
+ PathAppender appender(std::move(path));
+ path = key._path;
+ appender.append(path);
+
+ node =
+ make<FilterNode>(make<EvalFilter>(std::move(path), make<Variable>(key._projectionName)),
+ std::move(node));
+ visitor(node);
+ }
+}
+
+void lowerPartialSchemaRequirements(const CEType baseCE,
+ const CEType scanGroupCE,
+ ResidualRequirements& requirements,
+ ABT& physNode,
+ NodeCEMap& nodeCEMap) {
+ sortResidualRequirements(requirements);
+
+ CEType residualCE = baseCE;
+ for (const auto& [residualKey, residualReq, ce] : requirements) {
+ if (scanGroupCE > 0.0) {
+ residualCE *= ce / scanGroupCE;
+ }
+ lowerPartialSchemaRequirement(residualKey, residualReq, physNode, [&](const ABT& node) {
+ nodeCEMap.emplace(node.cast<Node>(), residualCE);
+ });
+ }
+}
+
+void computePhysicalScanParams(PrefixId& prefixId,
+ const PartialSchemaRequirements& reqMap,
+ const PartialSchemaKeyCE& partialSchemaKeyCEMap,
+ const ProjectionNameOrderPreservingSet& requiredProjections,
+ ResidualRequirements& residualRequirements,
+ ProjectionRenames& projectionRenames,
+ FieldProjectionMap& fieldProjectionMap,
+ bool& requiresRootProjection) {
+ for (const auto& [key, req] : reqMap) {
+ bool hasBoundProjection = req.hasBoundProjectionName();
+ if (hasBoundProjection && !requiredProjections.find(req.getBoundProjectionName()).second) {
+ // Bound projection is not required, pretend we don't bind.
+ hasBoundProjection = false;
+ }
+ if (!hasBoundProjection && isIntervalReqFullyOpenDNF(req.getIntervals())) {
+ // Redundant requirement.
+ continue;
+ }
+
+ const CEType keyCE = partialSchemaKeyCEMap.at(key);
+ if (auto pathGet = key._path.cast<PathGet>(); pathGet != nullptr) {
+ // Extract a new requirements path with removed simple paths.
+ // For example if we have a key Get "a" Traverse Compare = 0 we leave only
+ // Traverse Compare 0.
+ if (pathGet->getPath().is<PathIdentity>() && hasBoundProjection) {
+ const auto [it, inserted] = fieldProjectionMap._fieldProjections.emplace(
+ pathGet->name(), req.getBoundProjectionName());
+ if (!inserted) {
+ projectionRenames.emplace(req.getBoundProjectionName(), it->second);
+ }
+
+ if (!isIntervalReqFullyOpenDNF(req.getIntervals())) {
+ residualRequirements.emplace_back(
+ PartialSchemaKey{req.getBoundProjectionName(), make<PathIdentity>()},
+ PartialSchemaRequirement{"", req.getIntervals()},
+ keyCE);
+ }
+ } else {
+ ProjectionName tempProjName;
+ auto it = fieldProjectionMap._fieldProjections.find(pathGet->name());
+ if (it == fieldProjectionMap._fieldProjections.cend()) {
+ tempProjName = prefixId.getNextId("evalTemp");
+ fieldProjectionMap._fieldProjections.emplace(pathGet->name(), tempProjName);
+ } else {
+ tempProjName = it->second;
+ }
+
+ residualRequirements.emplace_back(
+ PartialSchemaKey{std::move(tempProjName), pathGet->getPath()}, req, keyCE);
+ }
+ } else {
+ // Move other conditions into the residual map.
+ requiresRootProjection = true;
+ residualRequirements.emplace_back(key, req, keyCE);
+ }
+ }
+}
+
+void sortResidualRequirements(ResidualRequirements& residualReq) {
+ // Sort residual requirements by estimated CE.
+ std::sort(
+ residualReq.begin(),
+ residualReq.end(),
+ [](const ResidualRequirement& x, const ResidualRequirement& y) { return x._ce < y._ce; });
+}
+
+void applyProjectionRenames(ProjectionRenames projectionRenames,
+ ABT& node,
+ const std::function<void(const ABT& node)>& visitor) {
+ for (auto&& [targetProjName, sourceProjName] : projectionRenames) {
+ node = make<EvaluationNode>(
+ std::move(targetProjName), make<Variable>(std::move(sourceProjName)), std::move(node));
+ visitor(node);
+ }
+}
+
+ABT lowerRIDIntersectGroupBy(PrefixId& prefixId,
+ const ProjectionName& ridProjName,
+ const CEType intersectedCE,
+ const CEType leftCE,
+ const CEType rightCE,
+ const properties::PhysProps& physProps,
+ const properties::PhysProps& leftPhysProps,
+ const properties::PhysProps& rightPhysProps,
+ ABT leftChild,
+ ABT rightChild,
+ NodeCEMap& nodeCEMap,
+ ChildPropsType& childProps) {
+ using namespace properties;
+
+ const auto& leftProjections =
+ getPropertyConst<ProjectionRequirement>(leftPhysProps).getProjections();
+
+ ABTVector aggExpressions;
+ ProjectionNameVector aggProjectionNames;
+
+ const ProjectionName sideIdProjectionName = prefixId.getNextId("sideId");
+ const ProjectionName sideSetProjectionName = prefixId.getNextId("sides");
+
+ aggExpressions.emplace_back(
+ make<FunctionCall>("$addToSet", makeSeq(make<Variable>(sideIdProjectionName))));
+ aggProjectionNames.push_back(sideSetProjectionName);
+
+ leftChild =
+ make<EvaluationNode>(sideIdProjectionName, Constant::int64(0), std::move(leftChild));
+ childProps.emplace_back(&leftChild.cast<EvaluationNode>()->getChild(), leftPhysProps);
+ nodeCEMap.emplace(leftChild.cast<Node>(), leftCE);
+
+ rightChild =
+ make<EvaluationNode>(sideIdProjectionName, Constant::int64(1), std::move(rightChild));
+ childProps.emplace_back(&rightChild.cast<EvaluationNode>()->getChild(), rightPhysProps);
+ nodeCEMap.emplace(rightChild.cast<Node>(), rightCE);
+
+ ProjectionNameVector sortedProjections =
+ getPropertyConst<ProjectionRequirement>(physProps).getProjections().getVector();
+ std::sort(sortedProjections.begin(), sortedProjections.end());
+
+ ProjectionNameVector unionProjections{ridProjName, sideIdProjectionName};
+ for (const ProjectionName& projectionName : sortedProjections) {
+ if (projectionName == ridProjName) {
+ continue;
+ }
+
+ ProjectionName tempProjectionName = prefixId.getNextId("unionTemp");
+ unionProjections.push_back(tempProjectionName);
+
+ if (leftProjections.find(projectionName).second) {
+ leftChild = make<EvaluationNode>(
+ tempProjectionName, make<Variable>(projectionName), std::move(leftChild));
+ nodeCEMap.emplace(leftChild.cast<Node>(), leftCE);
+
+ rightChild = make<EvaluationNode>(
+ tempProjectionName, Constant::nothing(), std::move(rightChild));
+ nodeCEMap.emplace(rightChild.cast<Node>(), rightCE);
+ } else {
+ leftChild =
+ make<EvaluationNode>(tempProjectionName, Constant::nothing(), std::move(leftChild));
+ nodeCEMap.emplace(leftChild.cast<Node>(), leftCE);
+
+ rightChild = make<EvaluationNode>(
+ tempProjectionName, make<Variable>(projectionName), std::move(rightChild));
+ nodeCEMap.emplace(rightChild.cast<Node>(), rightCE);
+ }
+
+ aggExpressions.emplace_back(
+ make<FunctionCall>("$max", makeSeq(make<Variable>(tempProjectionName))));
+ aggProjectionNames.push_back(projectionName);
+ }
+
+ ABT result = make<UnionNode>(std::move(unionProjections),
+ makeSeq(std::move(leftChild), std::move(rightChild)));
+ nodeCEMap.emplace(result.cast<Node>(), leftCE + rightCE);
+
+ result = make<GroupByNode>(ProjectionNameVector{ridProjName},
+ std::move(aggProjectionNames),
+ std::move(aggExpressions),
+ std::move(result));
+ nodeCEMap.emplace(result.cast<Node>(), intersectedCE);
+
+ result = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathCompare>(Operations::Eq, Constant::int64(2)),
+ make<FunctionCall>("getArraySize", makeSeq(make<Variable>(sideSetProjectionName)))),
+ std::move(result));
+ nodeCEMap.emplace(result.cast<Node>(), intersectedCE);
+
+ return result;
+}
+
+ABT lowerRIDIntersectHashJoin(PrefixId& prefixId,
+ const ProjectionName& ridProjName,
+ const CEType intersectedCE,
+ const CEType leftCE,
+ const CEType rightCE,
+ const properties::PhysProps& leftPhysProps,
+ const properties::PhysProps& rightPhysProps,
+ ABT leftChild,
+ ABT rightChild,
+ NodeCEMap& nodeCEMap,
+ ChildPropsType& childProps) {
+ using namespace properties;
+
+ ProjectionName rightRIDProjName = prefixId.getNextId("rid");
+ rightChild =
+ make<EvaluationNode>(rightRIDProjName, make<Variable>(ridProjName), std::move(rightChild));
+ ABT* rightChildPtr = &rightChild.cast<EvaluationNode>()->getChild();
+ nodeCEMap.emplace(rightChild.cast<Node>(), rightCE);
+
+ auto rightProjections =
+ getPropertyConst<ProjectionRequirement>(rightPhysProps).getProjections();
+ rightProjections.erase(ridProjName);
+ rightProjections.emplace_back(rightRIDProjName);
+ ProjectionNameVector sortedProjections = rightProjections.getVector();
+ std::sort(sortedProjections.begin(), sortedProjections.end());
+
+ // Use a union node to restrict the rid projection name coming from the right child in order
+ // to ensure we do not have the same rid from both children. This node is optimized away
+ // during lowering.
+ rightChild = make<UnionNode>(std::move(sortedProjections), makeSeq(std::move(rightChild)));
+ nodeCEMap.emplace(rightChild.cast<Node>(), rightCE);
+
+ ABT result = make<HashJoinNode>(JoinType::Inner,
+ ProjectionNameVector{ridProjName},
+ ProjectionNameVector{std::move(rightRIDProjName)},
+ std::move(leftChild),
+ std::move(rightChild));
+ nodeCEMap.emplace(result.cast<Node>(), intersectedCE);
+
+ childProps.emplace_back(&result.cast<HashJoinNode>()->getLeftChild(), leftPhysProps);
+ childProps.emplace_back(rightChildPtr, rightPhysProps);
+
+ return result;
+}
+
+ABT lowerRIDIntersectMergeJoin(PrefixId& prefixId,
+ const ProjectionName& ridProjName,
+ const CEType intersectedCE,
+ const CEType leftCE,
+ const CEType rightCE,
+ const properties::PhysProps& leftPhysProps,
+ const properties::PhysProps& rightPhysProps,
+ ABT leftChild,
+ ABT rightChild,
+ NodeCEMap& nodeCEMap,
+ ChildPropsType& childProps) {
+ using namespace properties;
+
+ ProjectionName rightRIDProjName = prefixId.getNextId("rid");
+ rightChild =
+ make<EvaluationNode>(rightRIDProjName, make<Variable>(ridProjName), std::move(rightChild));
+ ABT* rightChildPtr = &rightChild.cast<EvaluationNode>()->getChild();
+ nodeCEMap.emplace(rightChild.cast<Node>(), rightCE);
+
+ auto rightProjections =
+ getPropertyConst<ProjectionRequirement>(rightPhysProps).getProjections();
+ rightProjections.erase(ridProjName);
+ rightProjections.emplace_back(rightRIDProjName);
+ ProjectionNameVector sortedProjections = rightProjections.getVector();
+ std::sort(sortedProjections.begin(), sortedProjections.end());
+
+ // Use a union node to restrict the rid projection name coming from the right child in order
+ // to ensure we do not have the same rid from both children. This node is optimized away
+ // during lowering.
+ rightChild = make<UnionNode>(std::move(sortedProjections), makeSeq(std::move(rightChild)));
+ nodeCEMap.emplace(rightChild.cast<Node>(), rightCE);
+
+ ABT result = make<MergeJoinNode>(ProjectionNameVector{ridProjName},
+ ProjectionNameVector{std::move(rightRIDProjName)},
+ std::vector<CollationOp>{CollationOp::Ascending},
+ std::move(leftChild),
+ std::move(rightChild));
+ nodeCEMap.emplace(result.cast<Node>(), intersectedCE);
+
+ childProps.emplace_back(&result.cast<MergeJoinNode>()->getLeftChild(), leftPhysProps);
+ childProps.emplace_back(rightChildPtr, rightPhysProps);
+
+ return result;
+}
+
+class IntervalLowerTransport {
+public:
+ IntervalLowerTransport(PrefixId& prefixId,
+ const ProjectionName& ridProjName,
+ FieldProjectionMap indexProjectionMap,
+ const std::string& scanDefName,
+ const std::string& indexDefName,
+ const bool reverseOrder,
+ const CEType indexCE,
+ const CEType scanGroupCE,
+ NodeCEMap& nodeCEMap)
+ : _prefixId(prefixId),
+ _ridProjName(ridProjName),
+ _scanDefName(scanDefName),
+ _indexDefName(indexDefName),
+ _reverseOrder(reverseOrder),
+ _scanGroupCE(scanGroupCE),
+ _nodeCEMap(nodeCEMap) {
+ const SelectivityType indexSel = (scanGroupCE == 0.0) ? 0.0 : (indexCE / _scanGroupCE);
+ _estimateStack.push_back(indexSel);
+ _fpmStack.push_back(std::move(indexProjectionMap));
+ };
+
+ ABT transport(const MultiKeyIntervalReqExpr::Atom& node) {
+ ABT physicalIndexScan = make<IndexScanNode>(
+ _fpmStack.back(),
+ IndexSpecification{_scanDefName, _indexDefName, node.getExpr(), _reverseOrder});
+ _nodeCEMap.emplace(physicalIndexScan.cast<Node>(), _scanGroupCE * _estimateStack.back());
+ return physicalIndexScan;
+ }
+
+ template <bool isConjunction>
+ void prepare(const size_t childCount) {
+ // Here we are assuming each conjunction and disjunction contribute uniformly to the total
+ // selectivity.
+ // TODO: consider estimates per individual interval.
+
+ const SelectivityType parentSel = _estimateStack.back();
+ SelectivityType childSel = 0.0;
+ if constexpr (isConjunction) {
+ childSel = (parentSel == 0.0) ? 0.0 : std::pow(parentSel, 1.0 / childCount);
+ } else {
+ childSel = _estimateStack.back() / childCount;
+ }
+ _estimateStack.push_back(childSel);
+
+ FieldProjectionMap childMap = _fpmStack.back();
+ if (childMap._ridProjection.empty()) {
+ childMap._ridProjection = _ridProjName;
+ }
+ if (childCount > 1) {
+ for (auto& [fieldName, projectionName] : childMap._fieldProjections) {
+ projectionName = _prefixId.getNextId(isConjunction ? "conjunction" : "disjunction");
+ }
+ }
+ _fpmStack.push_back(std::move(childMap));
+ }
+
+ void prepare(const MultiKeyIntervalReqExpr::Conjunction& node) {
+ prepare<true /*isConjunction*/>(node.nodes().size());
+ }
+
+ template <bool isIntersect>
+ ABT implement(ABTVector inputs) {
+ _estimateStack.pop_back();
+ const CEType ce = _scanGroupCE * _estimateStack.back();
+
+ auto innerMap = std::move(_fpmStack.back());
+ _fpmStack.pop_back();
+ auto outerMap = _fpmStack.back();
+
+ const size_t inputSize = inputs.size();
+ if (inputSize == 1) {
+ return std::move(inputs.front());
+ }
+
+ ProjectionNameVector unionProjectionNames;
+ unionProjectionNames.push_back(innerMap._ridProjection);
+ for (const auto& [fieldName, projectionName] : innerMap._fieldProjections) {
+ unionProjectionNames.push_back(projectionName);
+ }
+
+ ProjectionNameVector aggProjectionNames;
+ for (const auto& [fieldName, projectionName] : outerMap._fieldProjections) {
+ aggProjectionNames.push_back(projectionName);
+ }
+
+ ABTVector aggExpressions;
+ for (const auto& [fieldName, projectionName] : innerMap._fieldProjections) {
+ aggExpressions.emplace_back(
+ make<FunctionCall>("$first", makeSeq(make<Variable>(projectionName))));
+ }
+
+ ProjectionName sideSetProjectionName;
+ if constexpr (isIntersect) {
+ const ProjectionName sideIdProjectionName = _prefixId.getNextId("sideId");
+ unionProjectionNames.push_back(sideIdProjectionName);
+ sideSetProjectionName = _prefixId.getNextId("sides");
+
+ for (size_t index = 0; index < inputSize; index++) {
+ ABT& input = inputs.at(index);
+ input = make<EvaluationNode>(
+ sideIdProjectionName, Constant::int64(index), std::move(input));
+ // Not relevant for cost.
+ _nodeCEMap.emplace(input.cast<Node>(), 0.0);
+ }
+
+ aggExpressions.emplace_back(
+ make<FunctionCall>("$addToSet", makeSeq(make<Variable>(sideIdProjectionName))));
+ aggProjectionNames.push_back(sideSetProjectionName);
+ }
+
+ ABT result = make<UnionNode>(std::move(unionProjectionNames), std::move(inputs));
+ _nodeCEMap.emplace(result.cast<Node>(), ce);
+
+ result = make<GroupByNode>(ProjectionNameVector{innerMap._ridProjection},
+ std::move(aggProjectionNames),
+ std::move(aggExpressions),
+ std::move(result));
+ _nodeCEMap.emplace(result.cast<Node>(), ce);
+
+ if constexpr (isIntersect) {
+ result = make<FilterNode>(
+ make<EvalFilter>(
+ make<PathCompare>(Operations::Eq, Constant::int64(inputSize)),
+ make<FunctionCall>("getArraySize",
+ makeSeq(make<Variable>(sideSetProjectionName)))),
+ std::move(result));
+ _nodeCEMap.emplace(result.cast<Node>(), ce);
+ }
+ return result;
+ }
+
+ ABT transport(const MultiKeyIntervalReqExpr::Conjunction& node, ABTVector childResults) {
+ return implement<true /*isIntersect*/>(std::move(childResults));
+ }
+
+ void prepare(const MultiKeyIntervalReqExpr::Disjunction& node) {
+ prepare<false /*isConjunction*/>(node.nodes().size());
+ }
+
+ ABT transport(const MultiKeyIntervalReqExpr::Disjunction& node, ABTVector childResults) {
+ return implement<false /*isIntersect*/>(std::move(childResults));
+ }
+
+ ABT lower(const MultiKeyIntervalReqExpr::Node& intervals) {
+ return algebra::transport<false>(intervals, *this);
+ }
+
+private:
+ PrefixId& _prefixId;
+ const ProjectionName& _ridProjName;
+ const std::string& _scanDefName;
+ const std::string& _indexDefName;
+ const bool _reverseOrder;
+ const CEType _scanGroupCE;
+ NodeCEMap& _nodeCEMap;
+
+ std::vector<SelectivityType> _estimateStack;
+ std::vector<FieldProjectionMap> _fpmStack;
+};
+
+ABT lowerIntervals(PrefixId& prefixId,
+ const ProjectionName& ridProjName,
+ FieldProjectionMap indexProjectionMap,
+ const std::string& scanDefName,
+ const std::string& indexDefName,
+ const MultiKeyIntervalReqExpr::Node& intervals,
+ const bool reverseOrder,
+ const CEType indexCE,
+ const CEType scanGroupCE,
+ NodeCEMap& nodeCEMap) {
+ IntervalLowerTransport lowerTransport(prefixId,
+ ridProjName,
+ std::move(indexProjectionMap),
+ scanDefName,
+ indexDefName,
+ reverseOrder,
+ indexCE,
+ scanGroupCE,
+ nodeCEMap);
+ return lowerTransport.lower(intervals);
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/utils.h b/src/mongo/db/query/optimizer/utils/utils.h
new file mode 100644
index 00000000000..371dac2a48d
--- /dev/null
+++ b/src/mongo/db/query/optimizer/utils/utils.h
@@ -0,0 +1,317 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/query/optimizer/defs.h"
+#include "mongo/db/query/optimizer/node.h"
+#include "mongo/db/query/optimizer/node_defs.h"
+#include "mongo/db/query/optimizer/props.h"
+
+namespace mongo::optimizer {
+
+inline void updateHash(size_t& result, const size_t hash) {
+ result = 31 * result + hash;
+}
+
+inline void updateHashUnordered(size_t& result, const size_t hash) {
+ result ^= hash;
+}
+
+template <class T, class T1 = std::conditional_t<std::is_arithmetic_v<T>, const T, const T&>>
+inline size_t computeVectorHash(const std::vector<T>& v) {
+ size_t result = 17;
+ for (T1 e : v) {
+ updateHash(result, std::hash<T>()(e));
+ }
+ return result;
+}
+
+template <int typeCode, typename... Args>
+inline size_t computeHashSeq(const Args&... seq) {
+ size_t result = 17 + typeCode;
+ (updateHash(result, seq), ...);
+ return result;
+}
+
+size_t roundUpToNextPow2(size_t v, size_t maxPower);
+
+std::vector<ABT::reference_type> collectComposed(const ABT& n);
+
+/**
+ * Returns the path represented by 'node' as a simple dotted string. Returns an empty string if
+ * 'node' is not a path.
+ */
+FieldNameType getSimpleField(const ABT& node);
+
+template <class Element = PathComposeM>
+inline void maybeComposePath(ABT& composition, ABT child) {
+ if (child.is<PathIdentity>()) {
+ return;
+ }
+ if (composition.is<PathIdentity>()) {
+ composition = std::move(child);
+ return;
+ }
+
+ composition = make<Element>(std::move(composition), std::move(child));
+}
+
+/**
+ * Used to vend out fresh ids for projection names.
+ */
+class PrefixId {
+public:
+ std::string getNextId(const std::string& key);
+
+private:
+ opt::unordered_map<std::string, int> _idCounterPerKey;
+};
+
+ProjectionNameOrderedSet convertToOrderedSet(ProjectionNameSet unordered);
+
+opt::unordered_set<FieldNameType> toUnorderedFieldNameSet(std::set<FieldNameType> set);
+
+
+void combineLimitSkipProperties(properties::LimitSkipRequirement& aboveProp,
+ const properties::LimitSkipRequirement& belowProp);
+
+/**
+ * Used to track references originating from a set of physical properties.
+ */
+ProjectionNameSet extractReferencedColumns(const properties::PhysProps& properties);
+
+/**
+ * Returns true if all components of the compound interval are equalities.
+ */
+bool areMultiKeyIntervalsEqualities(const MultiKeyIntervalRequirement& intervals);
+
+struct CollationSplitResult {
+ bool _validSplit = false;
+ ProjectionCollationSpec _leftCollation;
+ ProjectionCollationSpec _rightCollation;
+};
+
+/**
+ * Split a collation requirement between an outer (left) and inner (right) side. The outer side must
+ * be a prefix in the collation spec, and the right side a suffix.
+ */
+CollationSplitResult splitCollationSpec(const ProjectionCollationSpec& collationSpec,
+ const ProjectionNameSet& leftProjections,
+ const ProjectionNameSet& rightProjections);
+
+/**
+ * Used to extract variable references from a node.
+ */
+using VariableNameSetType = opt::unordered_set<std::string>;
+VariableNameSetType collectVariableReferences(const ABT& n);
+
+/**
+ * Appends a path to another path. Performs the append at PathIdentity elements.
+ */
+class PathAppender {
+public:
+ PathAppender(ABT toAppend) : _toAppend(std::move(toAppend)) {}
+
+ void transport(ABT& n, const PathIdentity& node) {
+ n = _toAppend;
+ }
+
+ template <typename T, typename... Ts>
+ void transport(ABT& /*n*/, const T& /*node*/, Ts&&...) {
+ // noop
+ }
+
+ void append(ABT& path) {
+ return algebra::transport<true>(path, *this);
+ }
+
+private:
+ ABT _toAppend;
+};
+
+struct PartialSchemaReqConversion {
+ PartialSchemaReqConversion();
+ PartialSchemaReqConversion(PartialSchemaRequirements reqMap);
+ PartialSchemaReqConversion(ABT bound);
+
+ // Is our current bottom-up conversion successful. If not shortcut to top.
+ bool _success;
+
+ // If set, contains a Constant or Variable bound of an (yet unknown) interval.
+ boost::optional<ABT> _bound;
+
+ // Requirements we have built so far.
+ PartialSchemaRequirements _reqMap;
+
+ // Have we added a PathComposeM.
+ bool _hasIntersected;
+
+ // Have we added a PathTraverse.
+ bool _hasTraversed;
+
+ // If we have determined that we have a contradiction.
+ bool _hasEmptyInterval;
+};
+
+/**
+ * Takes an expression that comes from an Filter or Evaluation node, and attempt to convert
+ * to a PartialSchemaReqConversion. This is done independent of the availability of indexes.
+ * Essentially this means to extract intervals over paths whenever possible.
+ */
+PartialSchemaReqConversion convertExprToPartialSchemaReq(const ABT& expr);
+
+bool intersectPartialSchemaReq(PartialSchemaRequirements& target,
+ const PartialSchemaRequirements& source,
+ ProjectionRenames& projectionRenames);
+
+
+/**
+ * Encode an index of an index field as a field name in order to use with a FieldProjectionMap.
+ */
+std::string encodeIndexKeyName(size_t indexField);
+
+/**
+ * Decode an field name as an index field.
+ */
+size_t decodeIndexKeyName(const std::string& fieldName);
+
+/**
+ * Given a partial schema key that specifies an index path, and a map of partial requirements
+ * created from sargable query conditions, return the partial requirement that matches the
+ * index path (and thus can be evaluated via this path).
+ */
+void findMatchingSchemaRequirement(const PartialSchemaKey& indexKey,
+ const PartialSchemaRequirements& reqMap,
+ PartialSchemaKeySet& keySet,
+ PartialSchemaRequirement& req,
+ bool setIntervalsAndBoundProj = true);
+
+/**
+ * Compute a mapping [indexName -> CandidateIndexEntry] that describes intervals that could be
+ * used for accessing each of the indexes in the map. The intervals themselves are derived from
+ * 'reqMap'.
+ * If the intersection of any of the interval requirements in 'reqMap' results in an empty
+ * interval, return an empty mappting and set 'hasEmptyInterval' to true.
+ * Otherwise return the computed mapping, and set 'hasEmptyInterval' to false.
+ */
+CandidateIndexMap computeCandidateIndexMap(PrefixId& prefixId,
+ const ProjectionName& scanProjectionName,
+ const PartialSchemaRequirements& reqMap,
+ const ScanDefinition& scanDef,
+ bool& hasEmptyInterval);
+
+/**
+ * Used to lower a Sargable node to a subtree consisting of functionally equivalent Filter and Eval
+ * nodes.
+ */
+void lowerPartialSchemaRequirement(const PartialSchemaKey& key,
+ const PartialSchemaRequirement& req,
+ ABT& node,
+ const std::function<void(const ABT& node)>& visitor =
+ [](const ABT&) {});
+
+void lowerPartialSchemaRequirements(CEType baseCE,
+ CEType scanGroupCE,
+ ResidualRequirements& requirements,
+ ABT& physNode,
+ NodeCEMap& nodeCEMap);
+
+void computePhysicalScanParams(PrefixId& prefixId,
+ const PartialSchemaRequirements& reqMap,
+ const PartialSchemaKeyCE& partialSchemaKeyCEMap,
+ const ProjectionNameOrderPreservingSet& requiredProjections,
+ ResidualRequirements& residualRequirements,
+ ProjectionRenames& projectionRenames,
+ FieldProjectionMap& fieldProjectionMap,
+ bool& requiresRootProjection);
+
+void sortResidualRequirements(ResidualRequirements& residualReq);
+
+void applyProjectionRenames(ProjectionRenames projectionRenames,
+ ABT& node,
+ const std::function<void(const ABT& node)>& visitor = [](const ABT&) {
+ });
+
+/**
+ * Implements an RID Intersect node using Union and GroupBy.
+ */
+ABT lowerRIDIntersectGroupBy(PrefixId& prefixId,
+ const ProjectionName& ridProjName,
+ CEType intersectedCE,
+ CEType leftCE,
+ CEType rightCE,
+ const properties::PhysProps& physProps,
+ const properties::PhysProps& leftPhysProps,
+ const properties::PhysProps& rightPhysProps,
+ ABT leftChild,
+ ABT rightChild,
+ NodeCEMap& nodeCEMap,
+ ChildPropsType& childProps);
+
+/**
+ * Implements an RID Intersect node using a HashJoin.
+ */
+ABT lowerRIDIntersectHashJoin(PrefixId& prefixId,
+ const ProjectionName& ridProjName,
+ CEType intersectedCE,
+ CEType leftCE,
+ CEType rightCE,
+ const properties::PhysProps& leftPhysProps,
+ const properties::PhysProps& rightPhysProps,
+ ABT leftChild,
+ ABT rightChild,
+ NodeCEMap& nodeCEMap,
+ ChildPropsType& childProps);
+
+ABT lowerRIDIntersectMergeJoin(PrefixId& prefixId,
+ const ProjectionName& ridProjName,
+ CEType intersectedCE,
+ CEType leftCE,
+ CEType rightCE,
+ const properties::PhysProps& leftPhysProps,
+ const properties::PhysProps& rightPhysProps,
+ ABT leftChild,
+ ABT rightChild,
+ NodeCEMap& nodeCEMap,
+ ChildPropsType& childProps);
+
+ABT lowerIntervals(PrefixId& prefixId,
+ const ProjectionName& ridProjName,
+ FieldProjectionMap indexProjectionMap,
+ const std::string& scanDefName,
+ const std::string& indexDefName,
+ const MultiKeyIntervalReqExpr::Node& intervals,
+ bool reverseOrder,
+ CEType indexCE,
+ CEType scanGroupCE,
+ NodeCEMap& nodeCEMap);
+
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/plan_executor_factory.cpp b/src/mongo/db/query/plan_executor_factory.cpp
index bd11182daa2..c72b768adb2 100644
--- a/src/mongo/db/query/plan_executor_factory.cpp
+++ b/src/mongo/db/query/plan_executor_factory.cpp
@@ -122,6 +122,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> make(
std::unique_ptr<CanonicalQuery> cq,
std::unique_ptr<QuerySolution> solution,
std::pair<std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData> root,
+ std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData,
const CollectionPtr* collection,
size_t plannerOptions,
NamespaceString nss,
@@ -140,6 +141,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> make(
return {{new PlanExecutorSBE(
opCtx,
std::move(cq),
+ std::move(optimizerData),
{makeVector<sbe::plan_ranker::CandidatePlan>(sbe::plan_ranker::CandidatePlan{
std::move(solution), std::move(rootStage), std::move(data)}),
0},
@@ -169,6 +171,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> make(
return {{new PlanExecutorSBE(opCtx,
std::move(cq),
+ {},
std::move(candidates),
*collection,
plannerOptions & QueryPlannerParams::RETURN_OWNED_DATA,
diff --git a/src/mongo/db/query/plan_executor_factory.h b/src/mongo/db/query/plan_executor_factory.h
index 380e21cb1e4..17694489add 100644
--- a/src/mongo/db/query/plan_executor_factory.h
+++ b/src/mongo/db/query/plan_executor_factory.h
@@ -35,6 +35,7 @@
#include "mongo/db/exec/working_set.h"
#include "mongo/db/pipeline/pipeline.h"
#include "mongo/db/pipeline/plan_executor_pipeline.h"
+#include "mongo/db/query/optimizer/explain_interface.h"
#include "mongo/db/query/plan_executor.h"
#include "mongo/db/query/plan_yield_policy_sbe.h"
#include "mongo/db/query/query_solution.h"
@@ -106,12 +107,14 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> make(
/**
* Constructs a PlanExecutor for the query 'cq' which will execute the SBE plan 'root'. A yield
* policy can optionally be provided if the plan should automatically yield during execution.
+ * "optimizerData" is used to print optimizer ABT plans, and may be empty.
*/
StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> make(
OperationContext* opCtx,
std::unique_ptr<CanonicalQuery> cq,
std::unique_ptr<QuerySolution> solution,
std::pair<std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData> root,
+ std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData,
const CollectionPtr* collection,
size_t plannerOptions,
NamespaceString nss,
diff --git a/src/mongo/db/query/plan_executor_sbe.cpp b/src/mongo/db/query/plan_executor_sbe.cpp
index 3b7a2db857f..e1dded76398 100644
--- a/src/mongo/db/query/plan_executor_sbe.cpp
+++ b/src/mongo/db/query/plan_executor_sbe.cpp
@@ -48,6 +48,7 @@ extern FailPoint planExecutorHangBeforeShouldWaitForInserts;
PlanExecutorSBE::PlanExecutorSBE(OperationContext* opCtx,
std::unique_ptr<CanonicalQuery> cq,
+ std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData,
sbe::CandidatePlans candidates,
const CollectionPtr& collection,
bool returnOwnedBson,
@@ -103,8 +104,7 @@ PlanExecutorSBE::PlanExecutorSBE(OperationContext* opCtx,
const auto isMultiPlan = candidates.plans.size() > 1;
- uassert(5088500, "Query does not have a valid CanonicalQuery", _cq);
- if (!_cq->getExpCtx()->explain) {
+ if (!_cq || !_cq->getExpCtx()->explain) {
// If we're not in explain mode, there is no need to keep rejected candidate plans around.
candidates.plans.clear();
} else {
@@ -115,6 +115,7 @@ PlanExecutorSBE::PlanExecutorSBE(OperationContext* opCtx,
_planExplainer = plan_explainer_factory::make(_root.get(),
&_rootData,
_solution.get(),
+ std::move(optimizerData),
std::move(candidates.plans),
isMultiPlan,
std::move(_rootData.debugInfo));
diff --git a/src/mongo/db/query/plan_executor_sbe.h b/src/mongo/db/query/plan_executor_sbe.h
index a0f7a62f57b..f20e399d36e 100644
--- a/src/mongo/db/query/plan_executor_sbe.h
+++ b/src/mongo/db/query/plan_executor_sbe.h
@@ -32,6 +32,7 @@
#include <queue>
#include "mongo/db/exec/sbe/stages/stages.h"
+#include "mongo/db/query/optimizer/explain_interface.h"
#include "mongo/db/query/plan_executor.h"
#include "mongo/db/query/plan_explainer_sbe.h"
#include "mongo/db/query/plan_yield_policy_sbe.h"
@@ -44,6 +45,7 @@ class PlanExecutorSBE final : public PlanExecutor {
public:
PlanExecutorSBE(OperationContext* opCtx,
std::unique_ptr<CanonicalQuery> cq,
+ std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData,
sbe::CandidatePlans candidates,
const CollectionPtr& collection,
bool returnOwnedBson,
diff --git a/src/mongo/db/query/plan_explainer_factory.cpp b/src/mongo/db/query/plan_explainer_factory.cpp
index ef4858e1f7c..eecbba3908a 100644
--- a/src/mongo/db/query/plan_explainer_factory.cpp
+++ b/src/mongo/db/query/plan_explainer_factory.cpp
@@ -47,25 +47,32 @@ std::unique_ptr<PlanExplainer> make(PlanStage* root, const PlanEnumeratorExplain
std::unique_ptr<PlanExplainer> make(sbe::PlanStage* root,
const stage_builder::PlanStageData* data,
const QuerySolution* solution) {
- return make(root, data, solution, {}, false);
+ return make(root, data, solution, {}, {}, false);
}
std::unique_ptr<PlanExplainer> make(sbe::PlanStage* root,
const stage_builder::PlanStageData* data,
const QuerySolution* solution,
+ std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData,
std::vector<sbe::plan_ranker::CandidatePlan> rejectedCandidates,
bool isMultiPlan) {
// Pre-compute Debugging info for explain use.
auto debugInfoSBE = std::make_unique<plan_cache_debug_info::DebugInfoSBE>(
plan_cache_util::buildDebugInfo(solution));
- return std::make_unique<PlanExplainerSBE>(
- root, data, solution, std::move(rejectedCandidates), isMultiPlan, std::move(debugInfoSBE));
+ return std::make_unique<PlanExplainerSBE>(root,
+ data,
+ solution,
+ std::move(optimizerData),
+ std::move(rejectedCandidates),
+ isMultiPlan,
+ std::move(debugInfoSBE));
}
std::unique_ptr<PlanExplainer> make(
sbe::PlanStage* root,
const stage_builder::PlanStageData* data,
const QuerySolution* solution,
+ std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData,
std::vector<sbe::plan_ranker::CandidatePlan> rejectedCandidates,
bool isMultiPlan,
std::unique_ptr<plan_cache_debug_info::DebugInfoSBE> debugInfoSBE) {
@@ -77,7 +84,12 @@ std::unique_ptr<PlanExplainer> make(
plan_cache_util::buildDebugInfo(solution));
}
- return std::make_unique<PlanExplainerSBE>(
- root, data, solution, std::move(rejectedCandidates), isMultiPlan, std::move(debugInfoSBE));
+ return std::make_unique<PlanExplainerSBE>(root,
+ data,
+ solution,
+ std::move(optimizerData),
+ std::move(rejectedCandidates),
+ isMultiPlan,
+ std::move(debugInfoSBE));
}
} // namespace mongo::plan_explainer_factory
diff --git a/src/mongo/db/query/plan_explainer_factory.h b/src/mongo/db/query/plan_explainer_factory.h
index f0165f2c1cb..ddeb3b94e91 100644
--- a/src/mongo/db/query/plan_explainer_factory.h
+++ b/src/mongo/db/query/plan_explainer_factory.h
@@ -31,6 +31,7 @@
#include "mongo/db/exec/plan_stage.h"
#include "mongo/db/exec/sbe/stages/stages.h"
+#include "mongo/db/query/optimizer/explain_interface.h"
#include "mongo/db/query/plan_enumerator_explain_info.h"
#include "mongo/db/query/plan_explainer.h"
#include "mongo/db/query/query_solution.h"
@@ -49,12 +50,14 @@ std::unique_ptr<PlanExplainer> make(sbe::PlanStage* root,
std::unique_ptr<PlanExplainer> make(sbe::PlanStage* root,
const stage_builder::PlanStageData* data,
const QuerySolution* solution,
+ std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData,
std::vector<sbe::plan_ranker::CandidatePlan> rejectedCandidates,
bool isMultiPlan);
std::unique_ptr<PlanExplainer> make(sbe::PlanStage* root,
const stage_builder::PlanStageData* data,
const QuerySolution* solution,
+ std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData,
std::vector<sbe::plan_ranker::CandidatePlan> rejectedCandidates,
bool isMultiPlan,
std::unique_ptr<plan_cache_debug_info::DebugInfoSBE> debugInfo);
diff --git a/src/mongo/db/query/plan_explainer_sbe.cpp b/src/mongo/db/query/plan_explainer_sbe.cpp
index 8bde2311ab6..00035f473f2 100644
--- a/src/mongo/db/query/plan_explainer_sbe.cpp
+++ b/src/mongo/db/query/plan_explainer_sbe.cpp
@@ -36,6 +36,8 @@
#include "mongo/db/exec/plan_stats_walker.h"
#include "mongo/db/fts/fts_query_impl.h"
#include "mongo/db/keypattern.h"
+#include "mongo/db/query/optimizer/explain_interface.h"
+#include "mongo/db/query/optimizer/node.h"
#include "mongo/db/query/plan_explainer_impl.h"
#include "mongo/db/query/plan_summary_stats_visitor.h"
#include "mongo/db/query/projection_ast_util.h"
@@ -305,12 +307,13 @@ PlanExplainer::PlanStatsDetails buildPlanStatsDetails(
const QuerySolution* solution,
const sbe::PlanStageStats* stats,
const boost::optional<BSONObj>& execPlanDebugInfo,
+ const boost::optional<BSONObj>& optimizerExplain,
ExplainOptions::Verbosity verbosity) {
BSONObjBuilder bob;
if (verbosity >= ExplainOptions::Verbosity::kExecStats) {
auto summary = collectExecutionStatsSummary(stats);
- if (verbosity >= ExplainOptions::Verbosity::kExecAllPlans) {
+ if (solution != nullptr && verbosity >= ExplainOptions::Verbosity::kExecAllPlans) {
summary.score = solution->score;
}
statsToBSON(stats, &bob, &bob);
@@ -323,9 +326,18 @@ PlanExplainer::PlanStatsDetails buildPlanStatsDetails(
return {bob.obj(), std::move(summary)};
}
- statsToBSON(solution->root(), &bob, &bob);
+ if (solution != nullptr) {
+ statsToBSON(solution->root(), &bob, &bob);
+ }
+
invariant(execPlanDebugInfo);
- return {BSON("queryPlan" << bob.obj() << "slotBasedPlan" << *execPlanDebugInfo), boost::none};
+ if (optimizerExplain) {
+ return {BSON("optimizerPlan" << *optimizerExplain << "slotBasedPlan" << *execPlanDebugInfo),
+ boost::none};
+ } else {
+ return {BSON("queryPlan" << bob.obj() << "slotBasedPlan" << *execPlanDebugInfo),
+ boost::none};
+ }
}
} // namespace
@@ -369,10 +381,12 @@ void PlanExplainerSBE::getSummaryStats(PlanSummaryStats* statsOut) const {
PlanExplainer::PlanStatsDetails PlanExplainerSBE::getWinningPlanStats(
ExplainOptions::Verbosity verbosity) const {
invariant(_root);
- invariant(_solution);
auto stats = _root->getStats(true /* includeDebugInfo */);
- return buildPlanStatsDetails(
- _solution, stats.get(), buildExecPlanDebugInfo(_root, _rootData), verbosity);
+ return buildPlanStatsDetails(_solution,
+ stats.get(),
+ buildExecPlanDebugInfo(_root, _rootData),
+ buildCascadesPlan(),
+ verbosity);
}
PlanExplainer::PlanStatsDetails PlanExplainerSBE::getWinningPlanTrialStats() const {
@@ -385,6 +399,7 @@ PlanExplainer::PlanStatsDetails PlanExplainerSBE::getWinningPlanTrialStats() con
// This parameter is not used in `buildPlanStatsDetails` if the last parameter is
// `ExplainOptions::Verbosity::kExecAllPlans`, as is the case here.
boost::none,
+ boost::none,
ExplainOptions::Verbosity::kExecAllPlans);
}
return getWinningPlanStats(ExplainOptions::Verbosity::kExecAllPlans);
@@ -405,7 +420,7 @@ std::vector<PlanExplainer::PlanStatsDetails> PlanExplainerSBE::getRejectedPlansS
auto stats = candidate.root->getStats(true /* includeDebugInfo */);
auto execPlanDebugInfo = buildExecPlanDebugInfo(candidate.root.get(), &candidate.data);
res.push_back(buildPlanStatsDetails(
- candidate.solution.get(), stats.get(), execPlanDebugInfo, verbosity));
+ candidate.solution.get(), stats.get(), execPlanDebugInfo, boost::none, verbosity));
}
return res;
}
@@ -418,7 +433,8 @@ std::vector<PlanExplainer::PlanStatsDetails> PlanExplainerSBE::getCachedPlanStat
auto&& stats = decision.getStats<mongo::sbe::PlanStageStats>();
if (verbosity >= ExplainOptions::Verbosity::kExecStats) {
for (auto&& planStats : stats.candidatePlanStats) {
- res.push_back(buildPlanStatsDetails(nullptr, planStats.get(), boost::none, verbosity));
+ res.push_back(buildPlanStatsDetails(
+ nullptr, planStats.get(), boost::none, boost::none, verbosity));
}
} else {
// At the "queryPlanner" verbosity we only need to provide details about the winning plan
@@ -429,4 +445,12 @@ std::vector<PlanExplainer::PlanStatsDetails> PlanExplainerSBE::getCachedPlanStat
return res;
}
+
+boost::optional<BSONObj> PlanExplainerSBE::buildCascadesPlan() const {
+ if (_optimizerData) {
+ return _optimizerData->explainBSON();
+ }
+ return {};
+}
+
} // namespace mongo
diff --git a/src/mongo/db/query/plan_explainer_sbe.h b/src/mongo/db/query/plan_explainer_sbe.h
index 070ed479647..1edaf3209ed 100644
--- a/src/mongo/db/query/plan_explainer_sbe.h
+++ b/src/mongo/db/query/plan_explainer_sbe.h
@@ -30,6 +30,7 @@
#pragma once
#include "mongo/db/exec/sbe/stages/stages.h"
+#include "mongo/db/query/optimizer/explain_interface.h"
#include "mongo/db/query/plan_cache_debug_info.h"
#include "mongo/db/query/plan_explainer.h"
#include "mongo/db/query/query_solution.h"
@@ -44,6 +45,7 @@ public:
PlanExplainerSBE(const sbe::PlanStage* root,
const stage_builder::PlanStageData* data,
const QuerySolution* solution,
+ std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData,
std::vector<sbe::plan_ranker::CandidatePlan> rejectedCandidates,
bool isMultiPlan,
std::unique_ptr<plan_cache_debug_info::DebugInfoSBE> debugInfo)
@@ -51,6 +53,7 @@ public:
_root{root},
_rootData{data},
_solution{solution},
+ _optimizerData(std::move(optimizerData)),
_rejectedCandidates{std::move(rejectedCandidates)},
_isMultiPlan{isMultiPlan},
_debugInfo{std::move(debugInfo)} {
@@ -81,11 +84,15 @@ private:
return boost::none;
}
+ boost::optional<BSONObj> buildCascadesPlan() const;
+
// These fields are are owned elsewhere (e.g. the PlanExecutor or CandidatePlan).
const sbe::PlanStage* _root{nullptr};
const stage_builder::PlanStageData* _rootData{nullptr};
const QuerySolution* _solution{nullptr};
+ const std::unique_ptr<optimizer::AbstractABTPrinter> _optimizerData;
+
const std::vector<sbe::plan_ranker::CandidatePlan> _rejectedCandidates;
const bool _isMultiPlan{false};
// Pre-computed debugging info so we don't necessarily have to collect them from QuerySolution.
diff --git a/src/mongo/db/query/query_feature_flags.idl b/src/mongo/db/query/query_feature_flags.idl
index a15f8fbd4c4..e1ba84203f6 100644
--- a/src/mongo/db/query/query_feature_flags.idl
+++ b/src/mongo/db/query/query_feature_flags.idl
@@ -110,11 +110,16 @@ feature_flags:
cpp_varname: gFeatureFlagPerShardCursor
default: false
+ featureFlagCommonQueryFramework:
+ description: "Feature flag for allowing use of Cascades-based query optimizer"
+ cpp_varname: gfeatureFlagCommonQueryFramework
+ default: false
+
featureFlagLastPointQuery:
description : "Feature flag for optimizing Last Point queries on time-series collections"
cpp_varname: gfeatureFlagLastPointQuery
default: false
-
+
featureFlagChangeStreamPreAndPostImagesTimeBasedRetentionPolicy:
description: "Feature flag to enable time based retention policy of point-in-time pre- and post-images of documents in change streams"
cpp_varname: gFeatureFlagChangeStreamPreAndPostImagesTimeBasedRetentionPolicy
diff --git a/src/mongo/db/query/query_knobs.idl b/src/mongo/db/query/query_knobs.idl
index 0b30f179764..52952861b6e 100644
--- a/src/mongo/db/query/query_knobs.idl
+++ b/src/mongo/db/query/query_knobs.idl
@@ -641,3 +641,66 @@ server_parameters:
cpp_varname: "internalEnableMultipleAutoGetCollections"
cpp_vartype: AtomicWord<bool>
default: false
+
+ internalQueryEnableSamplingCardinalityEstimator:
+ description: "Set to use the sampling-based method for estimating cardinality."
+ set_at: [ startup, runtime ]
+ cpp_varname: "internalQueryEnableSamplingCardinalityEstimator"
+ cpp_vartype: AtomicWord<bool>
+ default: true
+
+ internalQueryEnableCascadesOptimizer:
+ description: "Set to use the new optimizer path, must be used in conjunction with the feature flag."
+ set_at: [ startup, runtime ]
+ cpp_varname: "internalQueryEnableCascadesOptimizer"
+ cpp_vartype: AtomicWord<bool>
+ default: false
+
+ internalCascadesOptimizerDisableScan:
+ description: "Disable full collection scans."
+ set_at: [ startup, runtime ]
+ cpp_varname: "internalCascadesOptimizerDisableScan"
+ cpp_vartype: AtomicWord<bool>
+ default: false
+
+ internalCascadesOptimizerDisableIndexes:
+ description: "Disable index scan plans."
+ set_at: [ startup, runtime ]
+ cpp_varname: "internalCascadesOptimizerDisableIndexes"
+ cpp_vartype: AtomicWord<bool>
+ default: false
+
+ internalCascadesOptimizerDisableMergeJoinRIDIntersect:
+ description: "Disable index RID intersection via merge join."
+ set_at: [ startup, runtime ]
+ cpp_varname: "internalCascadesOptimizerDisableMergeJoinRIDIntersect"
+ cpp_vartype: AtomicWord<bool>
+ default: false
+
+ internalCascadesOptimizerDisableHashJoinRIDIntersect:
+ description: "Disable index RID intersection via hash join."
+ set_at: [ startup, runtime ]
+ cpp_varname: "internalCascadesOptimizerDisableHashJoinRIDIntersect"
+ cpp_vartype: AtomicWord<bool>
+ default: false
+
+ internalCascadesOptimizerDisableGroupByAndUnionRIDIntersect:
+ description: "Disable index RID intersection via group by and union."
+ set_at: [ startup, runtime ]
+ cpp_varname: "internalCascadesOptimizerDisableGroupByAndUnionRIDIntersect"
+ cpp_vartype: AtomicWord<bool>
+ default: false
+
+ internalCascadesOptimizerKeepRejectedPlans:
+ description: "Keep track of rejected plans in the memo."
+ set_at: [ startup, runtime ]
+ cpp_varname: "internalCascadesOptimizerKeepRejectedPlans"
+ cpp_vartype: AtomicWord<bool>
+ default: false
+
+ internalCascadesOptimizerDisableBranchAndBound:
+ description: "Disable cascades branch-and-bound strategy, and fully evaluate all plans."
+ set_at: [ startup, runtime ]
+ cpp_varname: "internalCascadesOptimizerDisableBranchAndBound"
+ cpp_vartype: AtomicWord<bool>
+ default: false
diff --git a/src/mongo/db/query/sbe_cached_solution_planner.cpp b/src/mongo/db/query/sbe_cached_solution_planner.cpp
index f9b7ba125ea..d0fd8db0b81 100644
--- a/src/mongo/db/query/sbe_cached_solution_planner.cpp
+++ b/src/mongo/db/query/sbe_cached_solution_planner.cpp
@@ -73,6 +73,7 @@ CandidatePlans CachedSolutionPlanner::plan(
candidate.root.get(),
&candidate.data,
candidate.solution.get(),
+ {}, /* optimizedData */
{}, /* rejectedCandidates */
false, /* isMultiPlan */
candidate.data.debugInfo