diff options
author | Svilen Mihaylov <svilen.mihaylov@mongodb.com> | 2022-01-31 21:05:27 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2022-01-31 21:48:46 +0000 |
commit | 50db8e9573e191ba2c193b4ef3dba6b5c6488f82 (patch) | |
tree | 1d211e40920b5952af569bb6e9fa7dd830d5bbaa | |
parent | b696e034fe97e7699dd45ac2595422e1d510ba2c (diff) | |
download | mongo-50db8e9573e191ba2c193b4ef3dba6b5c6488f82.tar.gz |
SERVER-62434 Implement query optimizer based on Path algebra and Cascades
162 files changed, 40729 insertions, 96 deletions
diff --git a/buildscripts/gdb/mongo_printers.py b/buildscripts/gdb/mongo_printers.py index 83289b33ffe..07ebc47bd64 100644 --- a/buildscripts/gdb/mongo_printers.py +++ b/buildscripts/gdb/mongo_printers.py @@ -815,6 +815,45 @@ class SbeCodeFragmentPrinter(object): instr_count if not error else '? (successfully parsed {})'.format(instr_count) +def eval_print_fn(val, print_fn_name): + """Evaluate a print function, and return the resulting string.""" + print_fn_symbol = gdb.lookup_symbol(print_fn_name)[0] + print_fn = print_fn_symbol.value() + # The generated output from explain contains the string "\n" (two characters) + # replace them with a single EOL character so that GDB prints multi-line + # explains nicely. + pp_result = print_fn(val) + pp_str = str(pp_result).replace("\"", "").replace("\\n", "\n") + return pp_str + + +class ABTPrinter(object): + """Pretty-printer for mongo::optimizer::ABT.""" + + def __init__(self, val): + """Initialize ABTPrinter.""" + self.val = val + + @staticmethod + def display_hint(): + """Display hint.""" + return 'ABT' + + def to_string(self): + """Return ABT for printing.""" + # Python will truncate/compress certain strings that contain many repeated characters. + # For an ABT, this is quite common when indenting nodes to represent children, so + # disable it for now. + prior_repeats = gdb.parameter("print repeats") + prior_elements = gdb.parameter("print elements") + gdb.execute("set print repeats 0") # for "<repeats N times>" + gdb.execute("set print elements 0") # for ... on long strings + res = eval_print_fn(self.val, "_printNode") + gdb.execute("set print repeats " + str(prior_repeats)) + gdb.execute("set print elements " + str(prior_elements)) + return res + + def build_pretty_printer(): """Build a pretty printer.""" pp = MongoPrettyPrinterCollection() @@ -836,6 +875,12 @@ def build_pretty_printer(): pp.add('__wt_txn', '__wt_txn', False, WtTxnPrinter) pp.add('__wt_update', '__wt_update', False, WtUpdateToBsonPrinter) pp.add('CodeFragment', 'mongo::sbe::vm::CodeFragment', False, SbeCodeFragmentPrinter) + + # TODO: enable with SERVER-62044. + # Optimizer/ABT related pretty printers that can be used only with a running process. + # abt_type = gdb.lookup_type("mongo::optimizer::ABT").strip_typedefs() + # pp.add("ABT", abt_type.name, True, ABTPrinter) + return pp diff --git a/buildscripts/resmokeconfig/fully_disabled_feature_flags.yml b/buildscripts/resmokeconfig/fully_disabled_feature_flags.yml index 59025aba509..c08125ad738 100644 --- a/buildscripts/resmokeconfig/fully_disabled_feature_flags.yml +++ b/buildscripts/resmokeconfig/fully_disabled_feature_flags.yml @@ -4,6 +4,7 @@ # by modifying their respective definitions in evergreen.yml. - featureFlagFryer +- featureFlagCommonQueryFramework # Disable featureFlagRequireTenantID until all paths pass tenant id to TenantNamespace # and TenantDatabase constructors. - featureFlagRequireTenantID diff --git a/buildscripts/resmokeconfig/suites/cqf.yml b/buildscripts/resmokeconfig/suites/cqf.yml new file mode 100644 index 00000000000..5c4415228b7 --- /dev/null +++ b/buildscripts/resmokeconfig/suites/cqf.yml @@ -0,0 +1,30 @@ +test_kind: js_test + +selector: + roots: + - jstests/cqf/**/*.js + +executor: + archive: + hooks: + - ValidateCollections + config: + shell_options: + crashOnInvalidBSONError: "" + objcheck: "" + eval: load("jstests/libs/override_methods/detect_spawning_own_mongod.js"); + hooks: + - class: ValidateCollections + shell_options: + global_vars: + TestData: + skipValidationOnNamespaceNotFound: false + - class: CleanEveryN + n: 20 + fixture: + class: MongoDFixture + mongod_options: + set_parameters: + enableTestCommands: 1 + featureFlagCommonQueryFramework: true + internalQueryEnableCascadesOptimizer: true diff --git a/buildscripts/resmokeconfig/suites/cqf_parallel.yml b/buildscripts/resmokeconfig/suites/cqf_parallel.yml new file mode 100644 index 00000000000..57d55f023a3 --- /dev/null +++ b/buildscripts/resmokeconfig/suites/cqf_parallel.yml @@ -0,0 +1,31 @@ +test_kind: js_test + +selector: + roots: + - jstests/cqf_parallel/**/*.js + +executor: + archive: + hooks: + - ValidateCollections + config: + shell_options: + crashOnInvalidBSONError: "" + objcheck: "" + eval: load("jstests/libs/override_methods/detect_spawning_own_mongod.js"); + hooks: + - class: ValidateCollections + shell_options: + global_vars: + TestData: + skipValidationOnNamespaceNotFound: false + - class: CleanEveryN + n: 20 + fixture: + class: MongoDFixture + mongod_options: + set_parameters: + enableTestCommands: 1 + featureFlagCommonQueryFramework: true + internalQueryEnableCascadesOptimizer: true + internalQueryDefaultDOP: 5 diff --git a/etc/evergreen.yml b/etc/evergreen.yml index 10557a651a7..48b467f05bb 100644 --- a/etc/evergreen.yml +++ b/etc/evergreen.yml @@ -7081,6 +7081,20 @@ tasks: vars: resmoke_jobs_max: 1 +- <<: *task_template + name: cqf + tags: [] + commands: + - func: "do setup" + - func: "run tests" + +- <<: *task_template + name: cqf_parallel + tags: [] + commands: + - func: "do setup" + - func: "run tests" + - name: shared_scons_cache_pruning tags: [] exec_timeout_secs: 7200 # 2 hour timeout for the task overall @@ -8918,6 +8932,8 @@ buildvariants: --excludeWithAnyTags=incompatible_with_windows_tls --excludeWithAnyTags=incompatible_with_shard_merge tasks: + - name: cqf + - name: cqf_parallel - name: compile_and_archive_dist_test_then_package_TG distros: - windows-vsCurrent-xlarge @@ -9790,6 +9806,8 @@ buildvariants: --runAllFeatureFlagTests --excludeWithAnyTags=incompatible_with_shard_merge tasks: &enterprise-rhel-80-64-bit-dynamic-all-feature-flags-tasks + - name: cqf + - name: cqf_parallel - name: compile_test_and_package_parallel_core_stream_TG distros: - rhel80-large @@ -11614,6 +11632,8 @@ buildvariants: --excludeWithAnyTags=incompatible_with_shard_merge separate_debug: off tasks: + - name: cqf + - name: cqf_parallel - name: compile_and_archive_dist_test_then_package_TG - name: compile_benchmarks - name: build_variant_gen @@ -11856,6 +11876,8 @@ buildvariants: scons_cache_scope: shared separate_debug: off tasks: + - name: cqf + - name: cqf_parallel - name: compile_and_archive_dist_test_then_package_TG - name: compile_benchmarks - name: build_variant_gen diff --git a/jstests/cqf/array_index.js b/jstests/cqf/array_index.js new file mode 100644 index 00000000000..5a40a3040fb --- /dev/null +++ b/jstests/cqf/array_index.js @@ -0,0 +1,33 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_array_index; +t.drop(); + +assert.commandWorked(t.insert({a: [1, 2, 3, 4]})); +assert.commandWorked(t.insert({a: [2, 3, 4]})); +assert.commandWorked(t.insert({a: [2]})); +assert.commandWorked(t.insert({a: 2})); +assert.commandWorked(t.insert({a: [1, 3]})); + +// Generate enough documents for index to be preferable. +for (let i = 0; i < 100; i++) { + assert.commandWorked(t.insert({a: i + 10})); +} + +assert.commandWorked(t.createIndex({a: 1})); + +let res = t.explain("executionStats").aggregate([{$match: {a: 2}}]); +assert.eq(4, res.executionStats.nReturned); +assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType); + +res = t.explain("executionStats").aggregate([{$match: {a: {$lt: 2}}}]); +assert.eq(2, res.executionStats.nReturned); +assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.child.nodeType); +}()); diff --git a/jstests/cqf/basic_agg.js b/jstests/cqf/basic_agg.js new file mode 100644 index 00000000000..3165b4403d0 --- /dev/null +++ b/jstests/cqf/basic_agg.js @@ -0,0 +1,42 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const coll = db.cqf_basic_index; +coll.drop(); + +assert.commandWorked( + coll.insert([{a: {b: 1}}, {a: {b: 2}}, {a: {b: 3}}, {a: {b: 4}}, {a: {b: 5}}])); + +const extraDocCount = 50; +// Add extra docs to make sure indexes can be picked. +for (let i = 0; i < extraDocCount; i++) { + assert.commandWorked(coll.insert({a: {b: i + 10}})); +} +assert.commandWorked(coll.createIndex({'a.b': 1})); + +let res = coll.explain("executionStats").aggregate([{$match: {'a.b': 2}}]); +assert.eq(1, res.executionStats.nReturned); +assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType); + +res = coll.explain("executionStats").aggregate([{$match: {'a.b': {$gt: 2}}}]); +assert.eq(3 + extraDocCount, res.executionStats.nReturned); +assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); + +res = coll.explain("executionStats").aggregate([{$match: {'a.b': {$gte: 2}}}]); +assert.eq(4 + extraDocCount, res.executionStats.nReturned); +assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); + +res = coll.explain("executionStats").aggregate([{$match: {'a.b': {$lt: 2}}}]); +assert.eq(1, res.executionStats.nReturned); +assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType); + +res = coll.explain("executionStats").aggregate([{$match: {'a.b': {$lte: 2}}}]); +assert.eq(2, res.executionStats.nReturned); +assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType); +}()); diff --git a/jstests/cqf/basic_find.js b/jstests/cqf/basic_find.js new file mode 100644 index 00000000000..37abd4b5eaa --- /dev/null +++ b/jstests/cqf/basic_find.js @@ -0,0 +1,42 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const coll = db.cqf_basic_find; +coll.drop(); + +assert.commandWorked( + coll.insert([{a: {b: 1}}, {a: {b: 2}}, {a: {b: 3}}, {a: {b: 4}}, {a: {b: 5}}])); + +const extraDocCount = 50; +// Add extra docs to make sure indexes can be picked. +for (let i = 0; i < extraDocCount; i++) { + assert.commandWorked(coll.insert({a: {b: i + 10}})); +} +assert.commandWorked(coll.createIndex({'a.b': 1})); + +let res = coll.explain("executionStats").find({'a.b': 2}).finish(); +assert.eq(1, res.executionStats.nReturned); +assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType); + +res = coll.explain("executionStats").find({'a.b': {$gt: 2}}).finish(); +assert.eq(3 + extraDocCount, res.executionStats.nReturned); +assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); + +res = coll.explain("executionStats").find({'a.b': {$gte: 2}}).finish(); +assert.eq(4 + extraDocCount, res.executionStats.nReturned); +assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); + +res = coll.explain("executionStats").find({'a.b': {$lt: 2}}).finish(); +assert.eq(1, res.executionStats.nReturned); +assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType); + +res = coll.explain("executionStats").find({'a.b': {$lte: 2}}).finish(); +assert.eq(2, res.executionStats.nReturned); +assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType); +}()); diff --git a/jstests/cqf/basic_unwind.js b/jstests/cqf/basic_unwind.js new file mode 100644 index 00000000000..c0faa5c0e0d --- /dev/null +++ b/jstests/cqf/basic_unwind.js @@ -0,0 +1,25 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const coll = db.cqf_basic_unwind; +coll.drop(); + +assert.commandWorked(coll.insert([ + {_id: 1}, + {_id: 2, x: null}, + {_id: 3, x: []}, + {_id: 4, x: [1, 2]}, + {_id: 5, x: [3]}, + {_id: 6, x: 4} +])); + +let res = coll.explain("executionStats").aggregate([{$unwind: '$x'}]); +assert.eq(4, res.executionStats.nReturned); +assert.eq("Unwind", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); +}()); diff --git a/jstests/cqf/chess.js b/jstests/cqf/chess.js new file mode 100644 index 00000000000..30124ea99b8 --- /dev/null +++ b/jstests/cqf/chess.js @@ -0,0 +1,107 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const coll = db.cqf_chess; +Random.srand(0); + +const players = [ + "penguingim1", "aladdin65", "aleksey472", "azuaga", "benpig", + "blackboarder", "bockosrb555", "bogdan_low_player", "charlytb", "chchbbuur", + "chessexplained", "cmcookiemonster", "crptone", "cselhu3", "darkzam", + "dmitri31", "dorado99", "ericrosen", "fast-tsunami", "flaneur" +]; +const sources = [1, 2, 3, 4, 5, 6, 7, 8]; +const variants = [1, 2, 3, 4, 5, 6, 7, 8]; +const results = [1, 2, 3, 4, 5, 6, 7, 8]; +const winColor = [true, false, null]; + +const nbGames = 1000; // 1000 * 1000; + +function intRandom(max) { + return Random.randInt(max); +} +function anyOf(as) { + return as [intRandom(as.length)]; +} + +coll.drop(); + +print(`Adding ${nbGames} games`); +const bulk = coll.initializeUnorderedBulkOp(); +for (let i = 0; i < nbGames; i++) { + const users = [anyOf(players), anyOf(players)]; + const winnerIndex = intRandom(2); + bulk.insert({ + users: users, + winner: users[winnerIndex], + loser: users[1 - winnerIndex], + winColor: anyOf(winColor), + avgRating: NumberInt(600 + intRandom(2400)), + source: NumberInt(anyOf(sources)), + variants: NumberInt(anyOf(variants)), + mode: !!intRandom(2), + turns: NumberInt(1 + intRandom(300)), + minutes: NumberInt(30 + intRandom(3600 * 3)), + clock: {init: NumberInt(0 + intRandom(10800)), inc: NumberInt(0 + intRandom(180))}, + result: anyOf(results), + date: new Date(Date.now() - intRandom(118719488)), + analysed: !!intRandom(2) + }); + if (i % 1000 == 0) { + print(`${i} / ${nbGames}`); + } +} +assert.commandWorked(bulk.execute()); + +const indexes = [ + {users: 1}, + {winner: 1}, + {loser: 1}, + {winColor: 1}, + {avgRating: 1}, + {source: 1}, + {variants: 1}, + {mode: 1}, + {turns: 1}, + {minutes: 1}, + {'clock.init': 1}, + {'clock.inc': 1}, + {result: 1}, + {date: 1}, + {analysed: 1} +]; + +print("Adding indexes"); +indexes.forEach(index => { + printjson(index); + coll.createIndex(index); +}); + +print("Searching"); + +const res = coll.explain("executionStats").aggregate([ + { + $match: { + avgRating: {$gt: 1000}, + turns: {$lt: 250}, + 'clock.init': {$gt: 1}, + minutes: {$gt: 2, $lt: 150} + } + }, + {$sort: {date: -1}}, + {$limit: 20} +]); + +// TODO: verify expected results. + +// Verify we are using the index on "minutes". +const indexNode = res.queryPlanner.winningPlan.optimizerPlan.child.child.leftChild; +assert.eq("IndexScan", indexNode.nodeType); +assert.eq("minutes_1", indexNode.indexDefName); +}()); diff --git a/jstests/cqf/empty_results.js b/jstests/cqf/empty_results.js new file mode 100644 index 00000000000..5eed556189c --- /dev/null +++ b/jstests/cqf/empty_results.js @@ -0,0 +1,18 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_empty_results; +t.drop(); + +assert.commandWorked(t.insert([{a: 1}, {a: 2}])); + +const res = t.explain("executionStats").aggregate([{$match: {'a': 2}}, {$limit: 1}, {$skip: 10}]); +assert.eq(0, res.executionStats.nReturned); +assert.eq("CoScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.child.nodeType); +}()); diff --git a/jstests/cqf/filter_order.js b/jstests/cqf/filter_order.js new file mode 100644 index 00000000000..e33e45c661e --- /dev/null +++ b/jstests/cqf/filter_order.js @@ -0,0 +1,22 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const coll = db.cqf_filter_order; +coll.drop(); + +const bulk = coll.initializeUnorderedBulkOp(); +for (let i = 0; i < 10000; i++) { + // "a" has the most ones, then "b", then "c". + bulk.insert({a: (i % 2), b: (i % 3), c: (i % 4)}); +} +assert.commandWorked(bulk.execute()); + +let res = coll.aggregate([{$match: {'a': {$eq: 1}, 'b': {$eq: 1}, 'c': {$eq: 1}}}]).toArray(); +// TODO: verify plan that predicate on "c" is applied first (most selective), then "b", then "a". +}()); diff --git a/jstests/cqf/find_sort.js b/jstests/cqf/find_sort.js new file mode 100644 index 00000000000..5ead920b48c --- /dev/null +++ b/jstests/cqf/find_sort.js @@ -0,0 +1,41 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const coll = db.cqf_find_sort; +coll.drop(); + +const bulk = coll.initializeUnorderedBulkOp(); +const nDocs = 10000; +let numResults = 0; + +Random.srand(0); +for (let i = 0; i < nDocs; i++) { + const va = 100.0 * Random.rand(); + const vb = 100.0 * Random.rand(); + if (va < 5.0 && vb < 5.0) { + numResults++; + } + bulk.insert({a: va, b: vb}); +} +assert.gt(numResults, 0); + +assert.commandWorked(bulk.execute()); + +assert.commandWorked(coll.createIndex({a: 1, b: 1})); + +const res = coll.explain("executionStats") + .find({a: {$lt: 5}, b: {$lt: 5}}, {a: 1, b: 1}) + .sort({b: 1}) + .finish(); +assert.eq(numResults, res.executionStats.nReturned); + +const indexScanNode = res.queryPlanner.winningPlan.optimizerPlan.child.child.child.leftChild.child; +assert.eq("IndexScan", indexScanNode.nodeType); +assert.eq(5, indexScanNode.interval[0].highBound.bound.value); +}()); diff --git a/jstests/cqf/group.js b/jstests/cqf/group.js new file mode 100644 index 00000000000..4af1ad6b021 --- /dev/null +++ b/jstests/cqf/group.js @@ -0,0 +1,27 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const coll = db.cqf_group; +coll.drop(); + +assert.commandWorked(coll.insert([ + {a: 1, b: 1, c: 1}, + {a: 1, b: 2, c: 2}, + {a: 1, b: 2, c: 3}, + {a: 2, b: 1, c: 4}, + {a: 2, b: 1, c: 5}, + {a: 2, b: 2, c: 6}, +])); + +const res = coll.explain("executionStats").aggregate([ + {$group: {_id: {a: '$a', b: '$b'}, sum: {$sum: '$c'}, avg: {$avg: '$c'}}} +]); +assert.eq("GroupBy", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); +assert.eq(4, res.executionStats.nReturned); +}()); diff --git a/jstests/cqf/index_intersect.js b/jstests/cqf/index_intersect.js new file mode 100644 index 00000000000..66ad1935996 --- /dev/null +++ b/jstests/cqf/index_intersect.js @@ -0,0 +1,47 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_index_intersect; +t.drop(); + +const nMatches = 60; + +assert.commandWorked(t.insert({a: 1, b: 1, c: 1})); +assert.commandWorked(t.insert({a: 3, b: 2, c: 1})); +for (let i = 0; i < nMatches; i++) { + assert.commandWorked(t.insert({a: 3, b: 3, c: i})); +} +assert.commandWorked(t.insert({a: 4, b: 3, c: 2})); +assert.commandWorked(t.insert({a: 5, b: 5, c: 2})); + +for (let i = 1; i < nMatches + 100; i++) { + assert.commandWorked(t.insert({a: i + nMatches, b: i + nMatches, c: i + nMatches})); +} + +assert.commandWorked(t.createIndex({'a': 1})); +assert.commandWorked(t.createIndex({'b': 1})); + +let res = t.explain("executionStats").aggregate([{$match: {'a': 3, 'b': 3}}]); +assert.eq(nMatches, res.executionStats.nReturned); + +// Verify we can place a MergeJoin +let joinNode = res.queryPlanner.winningPlan.optimizerPlan.child.leftChild; +assert.eq("MergeJoin", joinNode.nodeType); +assert.eq("IndexScan", joinNode.leftChild.nodeType); +assert.eq("IndexScan", joinNode.rightChild.children[0].child.nodeType); + +// One side is not equality, and we use a HashJoin. +res = t.explain("executionStats").aggregate([{$match: {'a': {$lte: 3}, 'b': 3}}]); +assert.eq(nMatches, res.executionStats.nReturned); + +joinNode = res.queryPlanner.winningPlan.optimizerPlan.child.leftChild; +assert.eq("HashJoin", joinNode.nodeType); +assert.eq("IndexScan", joinNode.leftChild.nodeType); +assert.eq("IndexScan", joinNode.rightChild.children[0].child.nodeType); +}()); diff --git a/jstests/cqf/index_intersect1.js b/jstests/cqf/index_intersect1.js new file mode 100644 index 00000000000..fcf0036c974 --- /dev/null +++ b/jstests/cqf/index_intersect1.js @@ -0,0 +1,35 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_index_intersect1; +t.drop(); + +assert.commandWorked(t.insert({a: 50})); +assert.commandWorked(t.insert({a: 70})); +assert.commandWorked(t.insert({a: 90})); +assert.commandWorked(t.insert({a: 110})); +assert.commandWorked(t.insert({a: 130})); + +// Generate enough documents for index to be preferable. +for (let i = 0; i < 100; i++) { + assert.commandWorked(t.insert({a: 200 + i})); +} + +assert.commandWorked(t.createIndex({'a': 1})); + +let res = t.explain("executionStats").aggregate([{$match: {'a': {$gt: 60, $lt: 100}}}]); +assert.eq(2, res.executionStats.nReturned); +assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType); + +// Should get a covered plan. +res = t.explain("executionStats") + .aggregate([{$project: {'_id': 0, 'a': 1}}, {$match: {'a': {$gt: 60, $lt: 100}}}]); +assert.eq(2, res.executionStats.nReturned); +assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); +}());
\ No newline at end of file diff --git a/jstests/cqf/no_collection.js b/jstests/cqf/no_collection.js new file mode 100644 index 00000000000..3c7ecae4c32 --- /dev/null +++ b/jstests/cqf/no_collection.js @@ -0,0 +1,15 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +let t = db.cqf_no_collection; +t.drop(); + +const res = t.explain("executionStats").aggregate([{$match: {'a': 2}}]); +assert.eq(0, res.executionStats.nReturned); +}());
\ No newline at end of file diff --git a/jstests/cqf/nonselective_index.js b/jstests/cqf/nonselective_index.js new file mode 100644 index 00000000000..f951ae7dc40 --- /dev/null +++ b/jstests/cqf/nonselective_index.js @@ -0,0 +1,30 @@ +/** + * Tests scenario related to SERVER-13065. + */ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_nonselective_index; +t.drop(); + +const bulk = t.initializeUnorderedBulkOp(); +const nDocs = 1000; +for (let i = 0; i < nDocs; i++) { + bulk.insert({a: i}); +} +assert.commandWorked(bulk.execute()); + +assert.commandWorked(t.createIndex({a: 1})); + +// We pick collection scan since the query is not selective. +const res = t.explain("executionStats").aggregate([{$match: {a: {$gte: 0}}}]); +assert.eq(nDocs, res.executionStats.nReturned); + +assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); +}());
\ No newline at end of file diff --git a/jstests/cqf/object_elemMatch.js b/jstests/cqf/object_elemMatch.js new file mode 100644 index 00000000000..f402c590658 --- /dev/null +++ b/jstests/cqf/object_elemMatch.js @@ -0,0 +1,33 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_object_elemMatch; + +t.drop(); +assert.commandWorked(t.insert({a: [{a: 1, b: 1}, {a: 1, b: 2}]})); +assert.commandWorked(t.insert({a: [{a: 2, b: 1}, {a: 2, b: 2}]})); +assert.commandWorked(t.insert({a: {a: 2, b: 1}})); +assert.commandWorked(t.insert({a: [{b: [1, 2], c: [3, 4]}]})); + +{ + // Object elemMatch. Currently we do not support index here. + const res = t.explain("executionStats").aggregate([{$match: {a: {$elemMatch: {a: 2, b: 1}}}}]); + assert.eq(1, res.executionStats.nReturned); + assert.eq("PhysicalScan", + res.queryPlanner.winningPlan.optimizerPlan.child.child.child.nodeType); +} + +{ + // Should not be getting any results. + const res = t.explain("executionStats").aggregate([ + {$match: {a: {$elemMatch: {b: {$elemMatch: {}}, c: {$elemMatch: {}}}}}} + ]); + assert.eq(0, res.executionStats.nReturned); +} +}()); diff --git a/jstests/cqf/partial_index.js b/jstests/cqf/partial_index.js new file mode 100644 index 00000000000..d8196c8cea8 --- /dev/null +++ b/jstests/cqf/partial_index.js @@ -0,0 +1,34 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_partial_index; +t.drop(); + +assert.commandWorked(t.insert({a: 1, b: 1, c: 1})); +assert.commandWorked(t.insert({a: 3, b: 2, c: 1})); +assert.commandWorked(t.insert({a: 3, b: 3, c: 1})); +assert.commandWorked(t.insert({a: 3, b: 3, c: 2})); +assert.commandWorked(t.insert({a: 4, b: 3, c: 2})); +assert.commandWorked(t.insert({a: 5, b: 5, c: 2})); + +for (let i = 0; i < 40; i++) { + assert.commandWorked(t.insert({a: i + 10, b: i + 10, c: i + 10})); +} + +assert.commandWorked(t.createIndex({'a': 1}, {partialFilterExpression: {'b': 2}})); +// assert.commandWorked(t.createIndex({'a': 1})); + +// TODO: verify with explain the plan should use the index. +let res = t.aggregate([{$match: {'a': 3, 'b': 2}}]).toArray(); +assert.eq(1, res.length); + +// TODO: verify with explain the plan should not use the index. +res = t.aggregate([{$match: {'a': 3, 'b': 3}}]).toArray(); +assert.eq(2, res.length); +}());
\ No newline at end of file diff --git a/jstests/cqf/residual_pred_costing.js b/jstests/cqf/residual_pred_costing.js new file mode 100644 index 00000000000..07bc7211836 --- /dev/null +++ b/jstests/cqf/residual_pred_costing.js @@ -0,0 +1,35 @@ +/** + * Tests scenario related to SERVER-21697. + */ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_residual_pred_costing; +t.drop(); + +const bulk = t.initializeUnorderedBulkOp(); +const nDocs = 2000; +for (let i = 0; i < nDocs; i++) { + bulk.insert({a: i % 10, b: i % 10, c: i % 10, d: i % 10}); +} +assert.commandWorked(bulk.execute()); + +assert.commandWorked(t.createIndex({a: 1, b: 1, c: 1, d: 1})); +assert.commandWorked(t.createIndex({a: 1, b: 1, d: 1})); +assert.commandWorked(t.createIndex({a: 1, d: 1})); + +let res = t.explain("executionStats") + .aggregate([{$match: {a: {$eq: 0}, b: {$eq: 0}, c: {$eq: 0}}}, {$sort: {d: 1}}]); +assert.eq(nDocs * 0.1, res.executionStats.nReturned); + +// Demonstrate we can pick the indexing covering most fields. +const indexNode = res.queryPlanner.winningPlan.optimizerPlan.child.leftChild; +assert.eq("IndexScan", indexNode.nodeType); +assert.eq("a_1_b_1_c_1_d_1", indexNode.indexDefName); +}()); diff --git a/jstests/cqf/sampling.js b/jstests/cqf/sampling.js new file mode 100644 index 00000000000..37dd0ae0e44 --- /dev/null +++ b/jstests/cqf/sampling.js @@ -0,0 +1,31 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const coll = db.cqf_sampling; +coll.drop(); + +const bulk = coll.initializeUnorderedBulkOp(); +const nDocs = 10000; + +Random.srand(0); +for (let i = 0; i < nDocs; i++) { + const valA = 10.0 * Random.rand(); + const valB = 10.0 * Random.rand(); + bulk.insert({a: valA, b: valB}); +} +assert.commandWorked(bulk.execute()); + +const res = coll.explain().aggregate([{$match: {'a': {$lt: 2}}}]); +assert(res.queryPlanner.winningPlan.optimizerPlan.hasOwnProperty("properties")); +const props = res.queryPlanner.winningPlan.optimizerPlan.properties; + +// Verify the winning plan cardinality is within roughly 25% of the expected documents. +assert.lt(nDocs * 0.2 * 0.75, props.adjustedCE); +assert.gt(nDocs * 0.2 * 1.25, props.adjustedCE); +}()); diff --git a/jstests/cqf/selective_index.js b/jstests/cqf/selective_index.js new file mode 100644 index 00000000000..722f04e75c7 --- /dev/null +++ b/jstests/cqf/selective_index.js @@ -0,0 +1,34 @@ +/** + * Tests scenario related to SERVER-20616. + */ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_selective_index; +t.drop(); + +const bulk = t.initializeUnorderedBulkOp(); +const nDocs = 1000; +for (let i = 0; i < nDocs; i++) { + bulk.insert({a: i % 10, b: i}); +} +assert.commandWorked(bulk.execute()); + +assert.commandWorked(t.createIndex({a: 1})); +assert.commandWorked(t.createIndex({b: 1})); + +// Predicate on "b" is more selective than the one on "a": 0.1% vs 10%. +const res = t.explain("executionStats").aggregate([{$match: {a: {$eq: 0}, b: {$eq: 0}}}]); +assert.eq(1, res.executionStats.nReturned); + +// Demonstrate we can pick index on "b". +const indexNode = res.queryPlanner.winningPlan.optimizerPlan.child.leftChild; +assert.eq("IndexScan", indexNode.nodeType); +assert.eq("b_1", indexNode.indexDefName); +}());
\ No newline at end of file diff --git a/jstests/cqf/sort.js b/jstests/cqf/sort.js new file mode 100644 index 00000000000..1a9f4582262 --- /dev/null +++ b/jstests/cqf/sort.js @@ -0,0 +1,22 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_basic_unwind1; +t.drop(); + +assert.commandWorked(t.insert({_id: 1})); +assert.commandWorked(t.insert({_id: 2, x: null})); +assert.commandWorked(t.insert({_id: 3, x: []})); +assert.commandWorked(t.insert({_id: 4, x: [1, 2]})); +assert.commandWorked(t.insert({_id: 5, x: [10]})); +assert.commandWorked(t.insert({_id: 6, x: 4})); + +const res = t.aggregate([{$unwind: '$x'}, {$sort: {'x': 1}}]).toArray(); +assert.eq(4, res.length); +}());
\ No newline at end of file diff --git a/jstests/cqf/sort_match.js b/jstests/cqf/sort_match.js new file mode 100644 index 00000000000..54a22a64071 --- /dev/null +++ b/jstests/cqf/sort_match.js @@ -0,0 +1,33 @@ +/** + * Tests scenario related to SERVER-12923. + */ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_sort_match; +t.drop(); + +const bulk = t.initializeUnorderedBulkOp(); +const nDocs = 1000; +for (let i = 0; i < nDocs; i++) { + bulk.insert({a: i, b: i % 10}); +} +assert.commandWorked(bulk.execute()); + +assert.commandWorked(t.createIndex({a: 1})); +assert.commandWorked(t.createIndex({b: 1})); + +let res = t.explain("executionStats").aggregate([{$sort: {b: 1}}, {$match: {a: {$eq: 0}}}]); +assert.eq(1, res.executionStats.nReturned); + +// Index on "a" is preferred. +const indexNode = res.queryPlanner.winningPlan.optimizerPlan.child.child.leftChild; +assert.eq("IndexScan", indexNode.nodeType); +assert.eq("a_1", indexNode.indexDefName); +}());
\ No newline at end of file diff --git a/jstests/cqf/sort_project.js b/jstests/cqf/sort_project.js new file mode 100644 index 00000000000..49beb912191 --- /dev/null +++ b/jstests/cqf/sort_project.js @@ -0,0 +1,74 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +var coll = db.cqf_testCovIndxScan; + +coll.drop(); + +coll.createIndex({f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1}); +coll.getIndexes(); + +coll.insertMany([ + {f_0: 2, f_1: 8, f_2: 2, f_3: 0, f_4: 2}, {f_0: 7, f_1: 9, f_2: 8, f_3: 3, f_4: 3}, + {f_0: 6, f_1: 6, f_2: 2, f_3: 8, f_4: 3}, {f_0: 9, f_1: 2, f_2: 3, f_3: 5, f_4: 7}, + {f_0: 7, f_1: 8, f_2: 8, f_3: 2, f_4: 9}, {f_0: 7, f_1: 1, f_2: 7, f_3: 3, f_4: 1}, + {f_0: 7, f_1: 3, f_2: 4, f_3: 0, f_4: 7}, {f_0: 8, f_1: 4, f_2: 5, f_3: 6, f_4: 0}, + {f_0: 5, f_1: 2, f_2: 0, f_3: 7, f_4: 0}, {f_0: 0, f_1: 2, f_2: 1, f_3: 9, f_4: 2}, + {f_0: 6, f_1: 0, f_2: 5, f_3: 9, f_4: 1}, {f_0: 0, f_1: 1, f_2: 6, f_3: 8, f_4: 6}, + {f_0: 6, f_1: 5, f_2: 3, f_3: 8, f_4: 5}, {f_0: 2, f_1: 9, f_2: 7, f_3: 2, f_4: 3}, + {f_0: 0, f_1: 6, f_2: 9, f_3: 6, f_4: 8}, {f_0: 5, f_1: 7, f_2: 8, f_3: 1, f_4: 4}, + {f_0: 8, f_1: 5, f_2: 1, f_3: 4, f_4: 6}, {f_0: 6, f_1: 2, f_2: 8, f_3: 4, f_4: 3}, + {f_0: 1, f_1: 6, f_2: 2, f_3: 0, f_4: 3}, {f_0: 1, f_1: 8, f_2: 2, f_3: 5, f_4: 2} +]); + +const nDocs = 20; + +var pln0 = [{'$project': {_id: 0, f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1}}]; + +var pln1 = [{'$sort': {f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1}}]; + +var pln2 = [ + {'$project': {_id: 0, f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1}}, + {'$sort': {f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1}} +]; + +var pln3 = [ + {'$sort': {f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1}}, + {'$project': {_id: 0, f_0: 1, f_1: 1, f_2: 1, f_3: 1, f_4: 1}} +]; + +{ + // Covered plan. Still chooses collection scan because there is no field size/count statistics. + // Also an index scan on all fields is not cheaper than a collection scan. + let res = coll.explain("executionStats").aggregate(pln0); + assert.eq(nDocs, res.executionStats.nReturned); + assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); +} + +{ + // Covered plan. + let res = coll.explain("executionStats").aggregate(pln1); + assert.eq(nDocs, res.executionStats.nReturned); + assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType); +} + +{ + // Covered plan. + let res = coll.explain("executionStats").aggregate(pln2); + assert.eq(nDocs, res.executionStats.nReturned); + assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); +} + +{ + // Covered plan. + let res = coll.explain("executionStats").aggregate(pln3); + assert.eq(nDocs, res.executionStats.nReturned); + assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); +} +}()); diff --git a/jstests/cqf/type_bracket.js b/jstests/cqf/type_bracket.js new file mode 100644 index 00000000000..1cacba0df2e --- /dev/null +++ b/jstests/cqf/type_bracket.js @@ -0,0 +1,62 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_type_bracket; +t.drop(); + +// Generate enough documents for index to be preferable if it exists. +for (let i = 0; i < 100; i++) { + assert.commandWorked(t.insert({a: i})); + assert.commandWorked(t.insert({a: i.toString()})); +} + +{ + const res = t.explain("executionStats").aggregate([{$match: {a: {$lt: "2"}}}]); + assert.eq(12, res.executionStats.nReturned); + assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); +} +{ + const res = t.explain("executionStats").aggregate([{$match: {a: {$gt: "95"}}}]); + assert.eq(4, res.executionStats.nReturned); + assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); +} +{ + const res = t.explain("executionStats").aggregate([{$match: {a: {$lt: 2}}}]); + assert.eq(2, res.executionStats.nReturned); + assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); +} +{ + const res = t.explain("executionStats").aggregate([{$match: {a: {$gt: 95}}}]); + assert.eq(4, res.executionStats.nReturned); + assert.eq("PhysicalScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.nodeType); +} + +assert.commandWorked(t.createIndex({a: 1})); + +{ + const res = t.explain("executionStats").aggregate([{$match: {a: {$lt: "2"}}}]); + assert.eq(12, res.executionStats.nReturned); + assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType); +} +{ + const res = t.explain("executionStats").aggregate([{$match: {a: {$gt: "95"}}}]); + assert.eq(4, res.executionStats.nReturned); + assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType); +} +{ + const res = t.explain("executionStats").aggregate([{$match: {a: {$lt: 2}}}]); + assert.eq(2, res.executionStats.nReturned); + assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType); +} +{ + const res = t.explain("executionStats").aggregate([{$match: {a: {$gt: 95}}}]); + assert.eq(4, res.executionStats.nReturned); + assert.eq("IndexScan", res.queryPlanner.winningPlan.optimizerPlan.child.leftChild.nodeType); +} +}());
\ No newline at end of file diff --git a/jstests/cqf/type_predicate.js b/jstests/cqf/type_predicate.js new file mode 100644 index 00000000000..eb8de44b3f6 --- /dev/null +++ b/jstests/cqf/type_predicate.js @@ -0,0 +1,26 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_type_predicate; +t.drop(); + +for (let i = 0; i < 10; i++) { + assert.commandWorked(t.insert({a: i})); + assert.commandWorked(t.insert({a: i.toString()})); +} + +{ + const res = t.explain("executionStats").aggregate([{$match: {a: {$type: "string"}}}]); + assert.eq(10, res.executionStats.nReturned); +} +{ + const res = t.explain("executionStats").aggregate([{$match: {a: {$type: "double"}}}]); + assert.eq(10, res.executionStats.nReturned); +} +}());
\ No newline at end of file diff --git a/jstests/cqf/unionWith.js b/jstests/cqf/unionWith.js new file mode 100644 index 00000000000..63dedc9d750 --- /dev/null +++ b/jstests/cqf/unionWith.js @@ -0,0 +1,54 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +load("jstests/aggregation/extras/utils.js"); + +const collA = db.collA; +collA.drop(); + +const collB = db.collB; +collB.drop(); + +assert.commandWorked(collA.insert({_id: 0, a: 1})); +assert.commandWorked(collB.insert({_id: 0, a: 2})); + +let res = collA.aggregate([{$unionWith: "collB"}]).toArray(); +assert.eq(2, res.length); +assert.eq([{_id: 0, a: 1}, {_id: 0, a: 2}], res); + +// Test a filter after the union which can be pushed down to each branch. +res = collA.aggregate([{$unionWith: "collB"}, {$match: {a: {$lt: 2}}}]).toArray(); +assert.eq(1, res.length); +assert.eq([{_id: 0, a: 1}], res); + +// Test a non-simple inner pipeline. +res = collA.aggregate([{$unionWith: {coll: "collB", pipeline: [{$match: {a: 2}}]}}]).toArray(); +assert.eq(2, res.length); +assert.eq([{_id: 0, a: 1}, {_id: 0, a: 2}], res); + +// Test a union with non-existent collection. +res = collA.aggregate([{$unionWith: "non_existent"}]).toArray(); +assert.eq(1, res.length); +assert.eq([{_id: 0, a: 1}], res); + +// Test union alongside projections. This is meant to test the pipeline translation logic that adds +// a projection to the inner pipeline when necessary. +res = collA.aggregate([{$project: {_id: 0, a: 1}}, {$unionWith: "collB"}]).toArray(); +assert.eq(2, res.length); +assert.eq([{a: 1}, {_id: 0, a: 2}], res); + +res = collA.aggregate([{$unionWith: {coll: "collB", pipeline: [{$project: {_id: 0, a: 1}}]}}]) + .toArray(); +assert.eq(2, res.length); +assert.eq([{_id: 0, a: 1}, {a: 2}], res); + +res = collA.aggregate([{$unionWith: "collB"}, {$project: {_id: 0, a: 1}}]).toArray(); +assert.eq(2, res.length); +assert.eq([{a: 1}, {a: 2}], res); +}()); diff --git a/jstests/cqf/value_elemMatch.js b/jstests/cqf/value_elemMatch.js new file mode 100644 index 00000000000..4bb46e6f1a7 --- /dev/null +++ b/jstests/cqf/value_elemMatch.js @@ -0,0 +1,52 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_value_elemMatch; +t.drop(); + +assert.commandWorked(t.insert({a: [1, 2, 3, 4, 5, 6]})); +assert.commandWorked(t.insert({a: [5, 6, 7, 8, 9]})); +assert.commandWorked(t.insert({a: [1, 2, 3]})); +assert.commandWorked(t.insert({a: []})); +assert.commandWorked(t.insert({a: [1]})); +assert.commandWorked(t.insert({a: [10]})); +assert.commandWorked(t.insert({a: 5})); +assert.commandWorked(t.insert({a: 6})); + +// Generate enough documents for index to be preferable. +const nDocs = 400; +for (let i = 0; i < nDocs; i++) { + assert.commandWorked(t.insert({a: i + 10})); +} + +assert.commandWorked(t.createIndex({a: 1})); + +{ + // Value elemMatch. Demonstrate we can use an index. + const res = + t.explain("executionStats").aggregate([{$match: {a: {$elemMatch: {$gte: 5, $lte: 6}}}}]); + assert.eq(2, res.executionStats.nReturned); + assert.eq("IndexScan", + res.queryPlanner.winningPlan.optimizerPlan.child.child.leftChild.child.nodeType); +} +{ + const res = + t.explain("executionStats").aggregate([{$match: {a: {$elemMatch: {$lt: 11, $gt: 9}}}}]); + assert.eq(1, res.executionStats.nReturned); + assert.eq("IndexScan", + res.queryPlanner.winningPlan.optimizerPlan.child.child.leftChild.child.nodeType); +} +{ + // Contradiction. + const res = + t.explain("executionStats").aggregate([{$match: {a: {$elemMatch: {$lt: 5, $gt: 6}}}}]); + assert.eq(0, res.executionStats.nReturned); + assert.eq("CoScan", res.queryPlanner.winningPlan.optimizerPlan.child.child.child.nodeType); +} +}()); diff --git a/jstests/cqf_parallel/basic_exchange.js b/jstests/cqf_parallel/basic_exchange.js new file mode 100644 index 00000000000..3be8768b0de --- /dev/null +++ b/jstests/cqf_parallel/basic_exchange.js @@ -0,0 +1,22 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_exchange; +t.drop(); + +assert.commandWorked(t.insert({a: {b: 1}})); +assert.commandWorked(t.insert({a: {b: 2}})); +assert.commandWorked(t.insert({a: {b: 3}})); +assert.commandWorked(t.insert({a: {b: 4}})); +assert.commandWorked(t.insert({a: {b: 5}})); + +const res = t.explain("executionStats").aggregate([{$match: {'a.b': 2}}]); +assert.eq(1, res.executionStats.nReturned); +assert.eq("Exchange", res.queryPlanner.winningPlan.optimizerPlan.child.nodeType); +}()); diff --git a/jstests/cqf_parallel/groupby.js b/jstests/cqf_parallel/groupby.js new file mode 100644 index 00000000000..91fa3fc80fa --- /dev/null +++ b/jstests/cqf_parallel/groupby.js @@ -0,0 +1,37 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const t = db.cqf_exchange; +t.drop(); + +assert.commandWorked(t.insert({a: 1})); +assert.commandWorked(t.insert({a: 2})); +assert.commandWorked(t.insert({a: 3})); +assert.commandWorked(t.insert({a: 4})); +assert.commandWorked(t.insert({a: 5})); + +// Demonstrate local-global optimization. +const res = t.explain("executionStats").aggregate([{$group: {_id: "$a", cnt: {$sum: 1}}}]); +assert.eq(5, res.executionStats.nReturned); + +assert.eq("Exchange", res.queryPlanner.winningPlan.optimizerPlan.child.nodeType); +assert.eq("Centralized", res.queryPlanner.winningPlan.optimizerPlan.child.distribution.type); + +assert.eq("GroupBy", res.queryPlanner.winningPlan.optimizerPlan.child.child.child.nodeType); + +assert.eq("Exchange", res.queryPlanner.winningPlan.optimizerPlan.child.child.child.child.nodeType); +assert.eq("HashPartitioning", + res.queryPlanner.winningPlan.optimizerPlan.child.child.child.child.distribution.type); + +assert.eq("GroupBy", + res.queryPlanner.winningPlan.optimizerPlan.child.child.child.child.child.nodeType); +assert.eq("UnknownPartitioning", + res.queryPlanner.winningPlan.optimizerPlan.child.child.child.child.child.properties + .physicalProperties.distribution.type); +}());
\ No newline at end of file diff --git a/jstests/cqf_parallel/index.js b/jstests/cqf_parallel/index.js new file mode 100644 index 00000000000..c155e9b9f61 --- /dev/null +++ b/jstests/cqf_parallel/index.js @@ -0,0 +1,25 @@ +(function() { +"use strict"; + +load("jstests/libs/optimizer_utils.js"); // For checkCascadesOptimizerEnabled. +if (!checkCascadesOptimizerEnabled(db)) { + jsTestLog("Skipping test because the optimizer is not enabled"); + return; +} + +const coll = db.cqf_parallel_index; +coll.drop(); + +const bulk = coll.initializeUnorderedBulkOp(); +for (let i = 0; i < 1000; i++) { + bulk.insert({a: i}); +} +assert.commandWorked(bulk.execute()); + +assert.commandWorked(coll.createIndex({a: 1})); + +let res = coll.explain("executionStats").aggregate([{$match: {a: {$lt: 10}}}]); +assert.eq(10, res.executionStats.nReturned); +assert.eq("IndexScan", + res.queryPlanner.winningPlan.optimizerPlan.child.child.leftChild.child.nodeType); +}()); diff --git a/jstests/libs/optimizer_utils.js b/jstests/libs/optimizer_utils.js new file mode 100644 index 00000000000..ff2a179388a --- /dev/null +++ b/jstests/libs/optimizer_utils.js @@ -0,0 +1,8 @@ +/* + * Utility for checking if the query optimizer is enabled. + */ +function checkCascadesOptimizerEnabled(theDB) { + const param = theDB.adminCommand({getParameter: 1, featureFlagCommonQueryFramework: 1}); + return param.hasOwnProperty("featureFlagCommonQueryFramework") && + param.featureFlagCommonQueryFramework.value; +} diff --git a/src/mongo/db/commands/SConscript b/src/mongo/db/commands/SConscript index a5c8c6ce453..7551a05e8a6 100644 --- a/src/mongo/db/commands/SConscript +++ b/src/mongo/db/commands/SConscript @@ -317,6 +317,7 @@ env.Library( target="standalone", source=[ "count_cmd.cpp", + "cqf/cqf_aggregate.cpp", "create_command.cpp", "create_indexes.cpp", "current_op.cpp", @@ -364,13 +365,16 @@ env.Library( '$BUILD_DIR/mongo/db/concurrency/lock_manager', '$BUILD_DIR/mongo/db/concurrency/write_conflict_exception', '$BUILD_DIR/mongo/db/curop_failpoint_helpers', + '$BUILD_DIR/mongo/db/exec/sbe/query_sbe_abt', '$BUILD_DIR/mongo/db/index_builds_coordinator_interface', '$BUILD_DIR/mongo/db/index_commands_idl', '$BUILD_DIR/mongo/db/ops/write_ops_exec', '$BUILD_DIR/mongo/db/pipeline/aggregation_request_helper', '$BUILD_DIR/mongo/db/pipeline/process_interface/mongo_process_interface', + '$BUILD_DIR/mongo/db/query/ce/query_ce', '$BUILD_DIR/mongo/db/query/command_request_response', '$BUILD_DIR/mongo/db/query/cursor_response_idl', + '$BUILD_DIR/mongo/db/query/optimizer/optimizer', '$BUILD_DIR/mongo/db/query_exec', '$BUILD_DIR/mongo/db/repl/replica_set_messages', '$BUILD_DIR/mongo/db/repl/tenant_migration_access_blocker', @@ -766,4 +770,4 @@ env.CppUnitTest( "servers", "standalone", ], -) +)
\ No newline at end of file diff --git a/src/mongo/db/commands/cqf/cqf_aggregate.cpp b/src/mongo/db/commands/cqf/cqf_aggregate.cpp new file mode 100644 index 00000000000..f092233de11 --- /dev/null +++ b/src/mongo/db/commands/cqf/cqf_aggregate.cpp @@ -0,0 +1,431 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/commands/cqf/cqf_aggregate.h" + +#include "mongo/db/exec/sbe/abt/abt_lower.h" +#include "mongo/db/pipeline/abt/abt_document_source_visitor.h" +#include "mongo/db/pipeline/abt/match_expression_visitor.h" +#include "mongo/db/query/ce/ce_sampling.h" +#include "mongo/db/query/optimizer/cascades/ce_heuristic.h" +#include "mongo/db/query/optimizer/cascades/cost_derivation.h" +#include "mongo/db/query/optimizer/explain.h" +#include "mongo/db/query/optimizer/node.h" +#include "mongo/db/query/optimizer/opt_phase_manager.h" +#include "mongo/db/query/plan_executor_factory.h" +#include "mongo/db/query/query_knobs_gen.h" +#include "mongo/db/query/query_planner_params.h" +#include "mongo/db/query/sbe_stage_builder.h" +#include "mongo/db/query/yield_policy_callbacks_impl.h" + +namespace mongo { + +using namespace optimizer; + +static opt::unordered_map<std::string, optimizer::IndexDefinition> buildIndexSpecsOptimizer( + boost::intrusive_ptr<ExpressionContext> expCtx, + OperationContext* opCtx, + const CollectionPtr& collection, + const optimizer::ProjectionName& scanProjName, + const DisableIndexOptions disableIndexOptions) { + using namespace optimizer; + + if (disableIndexOptions == DisableIndexOptions::DisableAll) { + return {}; + } + + const IndexCatalog& indexCatalog = *collection->getIndexCatalog(); + opt::unordered_map<std::string, IndexDefinition> result; + auto indexIterator = indexCatalog.getIndexIterator(opCtx, false /*includeUnfinished*/); + + while (indexIterator->more()) { + const IndexCatalogEntry& catalogEntry = *indexIterator->next(); + + const bool isMultiKey = catalogEntry.isMultikey(opCtx, collection); + const MultikeyPaths& multiKeyPaths = catalogEntry.getMultikeyPaths(opCtx, collection); + uassert(6624251, "Multikey paths cannot be empty.", !multiKeyPaths.empty()); + + const IndexDescriptor& descriptor = *catalogEntry.descriptor(); + if (descriptor.hidden() || descriptor.isSparse() || + descriptor.getIndexType() != IndexType::INDEX_BTREE) { + // Not supported for now. + continue; + } + + // SBE version is base 0. + const int64_t version = static_cast<int>(descriptor.version()) - 1; + + uint32_t orderingBits = 0; + { + const Ordering ordering = catalogEntry.ordering(); + for (int i = 0; i < descriptor.getNumFields(); i++) { + if ((ordering.get(i) == 1)) { + orderingBits |= (1ull << i); + } + } + } + + IndexCollationSpec indexCollationSpec; + bool useIndex = true; + size_t elementIdx = 0; + for (const auto& element : descriptor.keyPattern()) { + FieldPathType fieldPath; + FieldPath path(element.fieldName()); + + for (size_t i = 0; i < path.getPathLength(); i++) { + const std::string& fieldName = path.getFieldName(i).toString(); + if (fieldName == "$**") { + // TODO: For now disallow wildcard indexes. + useIndex = false; + break; + } + fieldPath.emplace_back(fieldName); + } + if (!useIndex) { + break; + } + + const int direction = element.numberInt(); + if (direction != -1 && direction != 1) { + // Invalid value? + useIndex = false; + break; + } + + const CollationOp collationOp = + (direction == 1) ? CollationOp::Ascending : CollationOp::Descending; + + // Construct an ABT path for each index component (field path). + const MultikeyComponents& elementMultiKeyInfo = multiKeyPaths[elementIdx]; + ABT abtPath = make<PathIdentity>(); + for (size_t i = fieldPath.size(); i-- > 0;) { + if (isMultiKey && elementMultiKeyInfo.find(i) != elementMultiKeyInfo.cend()) { + // This is a multikey element of the path. + abtPath = make<PathTraverse>(std::move(abtPath)); + } + abtPath = make<PathGet>(fieldPath.at(i), std::move(abtPath)); + } + indexCollationSpec.emplace_back(std::move(abtPath), collationOp); + ++elementIdx; + } + if (!useIndex) { + continue; + } + + PartialSchemaRequirements partialIndexReqMap; + if (descriptor.isPartial() && + disableIndexOptions != DisableIndexOptions::DisablePartialOnly) { + auto expr = MatchExpressionParser::parseAndNormalize( + descriptor.partialFilterExpression(), + expCtx, + ExtensionsCallbackNoop(), + MatchExpressionParser::kBanAllSpecialFeatures); + + ABT exprABT = generateMatchExpression(expr.get(), false /*allowAggExpression*/, "", ""); + exprABT = make<EvalFilter>(std::move(exprABT), make<Variable>(scanProjName)); + + // TODO: simplify expression. + + PartialSchemaReqConversion conversion = convertExprToPartialSchemaReq(exprABT); + if (!conversion._success || conversion._hasEmptyInterval) { + // Unsatisfiable partial index filter? + continue; + } + partialIndexReqMap = std::move(conversion._reqMap); + } + + // For now we assume distribution is Centralized. + result.emplace(descriptor.indexName(), + IndexDefinition(std::move(indexCollationSpec), + version, + orderingBits, + isMultiKey, + DistributionType::Centralized, + std::move(partialIndexReqMap))); + } + + return result; +} + +static QueryHints getHintsFromQueryKnobs() { + QueryHints hints; + + hints._disableScan = internalCascadesOptimizerDisableScan.load(); + hints._disableIndexes = internalCascadesOptimizerDisableIndexes.load() + ? DisableIndexOptions::DisableAll + : DisableIndexOptions::Enabled; + hints._disableHashJoinRIDIntersect = + internalCascadesOptimizerDisableHashJoinRIDIntersect.load(); + hints._disableMergeJoinRIDIntersect = + internalCascadesOptimizerDisableMergeJoinRIDIntersect.load(); + hints._disableGroupByAndUnionRIDIntersect = + internalCascadesOptimizerDisableGroupByAndUnionRIDIntersect.load(); + hints._keepRejectedPlans = internalCascadesOptimizerKeepRejectedPlans.load(); + hints._disableBranchAndBound = internalCascadesOptimizerDisableBranchAndBound.load(); + + return hints; +} + +static std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> optimizeAndCreateExecutor( + OptPhaseManager& phaseManager, + ABT abtTree, + OperationContext* opCtx, + boost::intrusive_ptr<ExpressionContext> expCtx, + const NamespaceString& nss, + const CollectionPtr& collection) { + + const bool optimizationResult = phaseManager.optimize(abtTree); + uassert(6624252, "Optimization failed", optimizationResult); + + // TODO: SERVER-62648. std::cerr is used for debugging. Consider structured logging. + std::cerr << "********* Optimizer Stats *********\n"; + { + const auto& memo = phaseManager.getMemo(); + std::cerr << "Memo groups: " << memo.getGroupCount() << "\n"; + std::cerr << "Memo logical nodes: " << memo.getLogicalNodeCount() << "\n"; + std::cerr << "Memo phys. nodes: " << memo.getPhysicalNodeCount() << "\n"; + + const auto& memoStats = memo.getStats(); + std::cerr << "Memo integrations: " << memoStats._numIntegrations << "\n"; + std::cerr << "Phys. plans explored: " << memoStats._physPlanExplorationCount << "\n"; + std::cerr << "Phys. memo checks: " << memoStats._physMemoCheckCount << "\n"; + } + std::cerr << "********* Optimizer Stats *********\n"; + + std::cerr << "********* Optimized ABT *********\n"; + std::cerr << ExplainGenerator::explainV2( + make<MemoPhysicalDelegatorNode>(phaseManager.getPhysicalNodeId()), + true /*displayPhysicalProperties*/, + &phaseManager.getMemo()); + std::cerr << "********* Optimized ABT *********\n"; + + auto env = VariableEnvironment::build(abtTree); + SlotVarMap slotMap; + sbe::value::SlotIdGenerator ids; + SBENodeLowering g{env, + slotMap, + ids, + phaseManager.getMetadata(), + phaseManager.getNodeToGroupPropsMap(), + phaseManager.getRIDProjections()}; + auto sbePlan = g.optimize(abtTree); + + uassert(6624253, "Lowering failed: did not produce a plan.", sbePlan != nullptr); + uassert(6624254, "Lowering failed: did not produce any output slots.", !slotMap.empty()); + + { + std::cerr << "********* SBE *********\n"; + sbe::DebugPrinter p; + std::cerr << p.print(*sbePlan.get()) << "\n"; + std::cerr << "********* SBE *********\n"; + } + + stage_builder::PlanStageData data{std::make_unique<sbe::RuntimeEnvironment>()}; + data.outputs.set(stage_builder::PlanStageSlots::kResult, slotMap.begin()->second); + + sbePlan->attachToOperationContext(opCtx); + if (expCtx->explain || expCtx->mayDbProfile) { + sbePlan->markShouldCollectTimingInfo(); + } + + auto yieldPolicy = + std::make_unique<PlanYieldPolicySBE>(PlanYieldPolicy::YieldPolicy::YIELD_AUTO, + opCtx->getServiceContext()->getFastClockSource(), + internalQueryExecYieldIterations.load(), + Milliseconds{internalQueryExecYieldPeriodMS.load()}, + nullptr, + std::make_unique<YieldPolicyCallbacksImpl>(nss), + false /*useExperimentalCommitTxnBehavior*/); + + auto planExec = uassertStatusOK(plan_executor_factory::make( + opCtx, + nullptr /*cq*/, + nullptr /*solution*/, + {std::move(sbePlan), std::move(data)}, + std::make_unique<ABTPrinter>(std::move(abtTree), phaseManager.getNodeToGroupPropsMap()), + &collection, + QueryPlannerParams::Options::DEFAULT, + nss, + std::move(yieldPolicy))); + return planExec; +} + +static void populateAdditionalScanDefs(OperationContext* opCtx, + boost::intrusive_ptr<ExpressionContext> expCtx, + const Pipeline& pipeline, + const size_t numberOfPartitions, + PrefixId& prefixId, + opt::unordered_map<std::string, ScanDefinition>& scanDefs, + const DisableIndexOptions disableIndexOptions) { + for (const auto& involvedNss : pipeline.getInvolvedCollections()) { + // TODO handle views? + AutoGetCollectionForReadCommandMaybeLockFree ctx( + opCtx, involvedNss, AutoGetCollectionViewMode::kViewsForbidden); + const CollectionPtr& collection = ctx ? ctx.getCollection() : CollectionPtr::null; + const bool collectionExists = collection != nullptr; + const std::string uuidStr = + collectionExists ? collection->uuid().toString() : "<missing_uuid>"; + + const std::string collNameStr = involvedNss.coll().toString(); + // TODO: We cannot add the uuidStr suffix because the pipeline translation does not have + // access to the metadata so it generates a scan over just the collection name. + const std::string scanDefName = collNameStr; + + opt::unordered_map<std::string, optimizer::IndexDefinition> indexDefs; + const ProjectionName& scanProjName = prefixId.getNextId("scan"); + if (collectionExists) { + // TODO: add locks on used indexes? + indexDefs = buildIndexSpecsOptimizer( + expCtx, opCtx, collection, scanProjName, disableIndexOptions); + } + + // For now handle only local parallelism (no over-the-network exchanges). + DistributionAndPaths distribution{(numberOfPartitions == 1) + ? DistributionType::Centralized + : DistributionType::UnknownPartitioning}; + + const CEType collectionCE = collectionExists ? collection->numRecords(opCtx) : -1.0; + scanDefs[scanDefName] = + ScanDefinition({{"type", "mongod"}, + {"database", involvedNss.db().toString()}, + {"uuid", uuidStr}, + {ScanNode::kDefaultCollectionNameSpec, collNameStr}}, + std::move(indexDefs), + std::move(distribution), + collectionExists, + collectionCE); + } +} + +std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> getSBEExecutorViaCascadesOptimizer( + OperationContext* opCtx, + boost::intrusive_ptr<ExpressionContext> expCtx, + const NamespaceString& nss, + const CollectionPtr& collection, + const Pipeline& pipeline) { + const bool collectionExists = collection != nullptr; + const std::string uuidStr = collectionExists ? collection->uuid().toString() : "<missing_uuid>"; + const std::string collNameStr = nss.coll().toString(); + const std::string scanDefName = collNameStr + "_" + uuidStr; + + QueryHints queryHints = getHintsFromQueryKnobs(); + + PrefixId prefixId; + const ProjectionName& scanProjName = prefixId.getNextId("scan"); + + // Add the base collection metadata. + opt::unordered_map<std::string, optimizer::IndexDefinition> indexDefs; + if (collectionExists) { + // TODO: add locks on used indexes? + indexDefs = buildIndexSpecsOptimizer( + expCtx, opCtx, collection, scanProjName, queryHints._disableIndexes); + } + + const size_t numberOfPartitions = internalQueryDefaultDOP.load(); + // For now handle only local parallelism (no over-the-network exchanges). + DistributionAndPaths distribution{(numberOfPartitions == 1) + ? DistributionType::Centralized + : DistributionType::UnknownPartitioning}; + + opt::unordered_map<std::string, ScanDefinition> scanDefs; + const int64_t numRecords = collectionExists ? collection->numRecords(opCtx) : -1; + scanDefs.emplace(scanDefName, + ScanDefinition({{"type", "mongod"}, + {"database", nss.db().toString()}, + {"uuid", uuidStr}, + {ScanNode::kDefaultCollectionNameSpec, collNameStr}}, + std::move(indexDefs), + std::move(distribution), + collectionExists, + static_cast<CEType>(numRecords))); + + // Add a scan definition for all involved collections. Note that the base namespace has already + // been accounted for above and isn't included here. + populateAdditionalScanDefs(opCtx, + expCtx, + pipeline, + numberOfPartitions, + prefixId, + scanDefs, + queryHints._disableIndexes); + + Metadata metadata(std::move(scanDefs), numberOfPartitions); + + ABT abtTree = collectionExists ? make<ScanNode>(scanProjName, scanDefName) + : make<ValueScanNode>(ProjectionNameVector{scanProjName}); + abtTree = + translatePipelineToABT(metadata, pipeline, scanProjName, std::move(abtTree), prefixId); + + std::cerr << "******* Translated ABT **********\n"; + std::cerr << ExplainGenerator::explainV2(abtTree) << std::endl; + std::cerr << "******* Translated ABT **********\n"; + + if (collectionExists && numRecords > 0 && + internalQueryEnableSamplingCardinalityEstimator.load()) { + Metadata metadataForSampling = metadata; + // Do not use indexes for sampling. + for (auto& entry : metadataForSampling._scanDefs) { + entry.second.getIndexDefs().clear(); + } + + // TODO: consider a limited rewrite set. + OptPhaseManager phaseManagerForSampling(OptPhaseManager::getAllRewritesSet(), + prefixId, + false /*requireRID*/, + std::move(metadataForSampling), + std::make_unique<HeuristicCE>(), + std::make_unique<DefaultCosting>(), + DebugInfo::kDefaultForProd); + + OptPhaseManager phaseManager{ + OptPhaseManager::getAllRewritesSet(), + prefixId, + false /*requireRID*/, + std::move(metadata), + std::make_unique<CESamplingTransport>(opCtx, phaseManagerForSampling, numRecords), + std::make_unique<DefaultCosting>(), + DebugInfo::kDefaultForProd}; + phaseManager.getHints() = queryHints; + + return optimizeAndCreateExecutor( + phaseManager, std::move(abtTree), opCtx, expCtx, nss, collection); + } + + // Use heuristics. + OptPhaseManager phaseManager{OptPhaseManager::getAllRewritesSet(), + prefixId, + std::move(metadata), + DebugInfo::kDefaultForProd}; + phaseManager.getHints() = queryHints; + + return optimizeAndCreateExecutor( + phaseManager, std::move(abtTree), opCtx, expCtx, nss, collection); +} + +} // namespace mongo diff --git a/src/mongo/db/commands/cqf/cqf_aggregate.h b/src/mongo/db/commands/cqf/cqf_aggregate.h new file mode 100644 index 00000000000..b98716400f2 --- /dev/null +++ b/src/mongo/db/commands/cqf/cqf_aggregate.h @@ -0,0 +1,44 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/catalog/collection.h" +#include "mongo/db/query/plan_executor.h" + +namespace mongo { + +std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> getSBEExecutorViaCascadesOptimizer( + OperationContext* opCtx, + boost::intrusive_ptr<ExpressionContext> expCtx, + const NamespaceString& nss, + const CollectionPtr& collection, + const Pipeline& pipeline); + +} // namespace mongo diff --git a/src/mongo/db/commands/find_cmd.cpp b/src/mongo/db/commands/find_cmd.cpp index aaf86a480ef..d1fbe6fa8c7 100644 --- a/src/mongo/db/commands/find_cmd.cpp +++ b/src/mongo/db/commands/find_cmd.cpp @@ -51,6 +51,7 @@ #include "mongo/db/query/find.h" #include "mongo/db/query/find_common.h" #include "mongo/db/query/get_executor.h" +#include "mongo/db/query/query_knobs_gen.h" #include "mongo/db/repl/replication_coordinator.h" #include "mongo/db/service_context.h" #include "mongo/db/stats/counters.h" @@ -317,7 +318,12 @@ public: extensionsCallback, MatchExpressionParser::kAllowAllSpecialFeatures)); - if (ctx->getView()) { + // If we are running a query against a view, or if we are trying to test the new + // optimizer, redirect this query through the aggregation system. + if (ctx->getView() || + (feature_flags::gfeatureFlagCommonQueryFramework.isEnabled( + serverGlobalParams.featureCompatibility) && + internalQueryEnableCascadesOptimizer.load())) { // Relinquish locks. The aggregation command will re-acquire them. ctx.reset(); @@ -521,7 +527,12 @@ public: extensionsCallback, MatchExpressionParser::kAllowAllSpecialFeatures)); - if (ctx->getView()) { + // If we are running a query against a view, or if we are trying to test the new + // optimizer, redirect this query through the aggregation system. + if (ctx->getView() || + (feature_flags::gfeatureFlagCommonQueryFramework.isEnabled( + serverGlobalParams.featureCompatibility) && + internalQueryEnableCascadesOptimizer.load())) { // Relinquish locks. The aggregation command will re-acquire them. ctx.reset(); diff --git a/src/mongo/db/commands/run_aggregate.cpp b/src/mongo/db/commands/run_aggregate.cpp index 2bb3d6f9d3f..b1b80cc7827 100644 --- a/src/mongo/db/commands/run_aggregate.cpp +++ b/src/mongo/db/commands/run_aggregate.cpp @@ -41,6 +41,7 @@ #include "mongo/db/auth/authorization_session.h" #include "mongo/db/catalog/database.h" #include "mongo/db/catalog/database_holder.h" +#include "mongo/db/commands/cqf/cqf_aggregate.h" #include "mongo/db/curop.h" #include "mongo/db/cursor_manager.h" #include "mongo/db/db_raii.h" @@ -66,6 +67,7 @@ #include "mongo/db/query/get_executor.h" #include "mongo/db/query/plan_executor_factory.h" #include "mongo/db/query/plan_summary_stats.h" +#include "mongo/db/query/query_feature_flags_gen.h" #include "mongo/db/query/query_planner_common.h" #include "mongo/db/read_concern.h" #include "mongo/db/repl/oplog.h" @@ -129,7 +131,6 @@ bool handleCursorCommand(OperationContext* opCtx, request.getCursor().getBatchSize().value_or(aggregation_request_helper::kDefaultBatchSize); if (cursors.size() > 1) { - uassert( ErrorCodes::BadValue, "the exchange initial batch size must be zero", batchSize == 0); @@ -535,6 +536,74 @@ void performValidationChecks(const OperationContext* opCtx, aggregation_request_helper::validateRequestForAPIVersion(opCtx, request); } +std::vector<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> createLegacyExecutor( + std::unique_ptr<Pipeline, PipelineDeleter> pipeline, + const LiteParsedPipeline& liteParsedPipeline, + const NamespaceString& nss, + const MultiCollection& collections, + const AggregateCommandRequest& request, + CurOp* curOp, + const std::function<void(void)>& resetContextFn) { + const auto expCtx = pipeline->getContext(); + // Check if the pipeline has a $geoNear stage, as it will be ripped away during the build query + // executor phase below (to be replaced with a $geoNearCursorStage later during the executor + // attach phase). + auto hasGeoNearStage = !pipeline->getSources().empty() && + dynamic_cast<DocumentSourceGeoNear*>(pipeline->peekFront()); + + // Prepare a PlanExecutor to provide input into the pipeline, if needed. + auto attachExecutorCallback = + PipelineD::buildInnerQueryExecutor(collections, nss, &request, pipeline.get()); + + std::vector<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> execs; + if (canOptimizeAwayPipeline(pipeline.get(), + attachExecutorCallback.second.get(), + request, + hasGeoNearStage, + liteParsedPipeline.hasChangeStream())) { + // Mark that this query does not use DocumentSource. + curOp->debug().documentSourceUsed = false; + + // This pipeline is currently empty, but once completed it will have only one source, + // which is a DocumentSourceCursor. Instead of creating a whole pipeline to do nothing + // more than forward the results of its cursor document source, we can use the + // PlanExecutor by itself. The resulting cursor will look like what the client would + // have gotten from find command. + execs.emplace_back(std::move(attachExecutorCallback.second)); + } else { + // Mark that this query uses DocumentSource. + curOp->debug().documentSourceUsed = true; + + // Complete creation of the initial $cursor stage, if needed. + PipelineD::attachInnerQueryExecutorToPipeline(collections, + attachExecutorCallback.first, + std::move(attachExecutorCallback.second), + pipeline.get()); + + auto pipelines = createExchangePipelinesIfNeeded( + expCtx->opCtx, expCtx, request, std::move(pipeline), expCtx->uuid); + for (auto&& pipelineIt : pipelines) { + // There are separate ExpressionContexts for each exchange pipeline, so make sure to + // pass the pipeline's ExpressionContext to the plan executor factory. + auto pipelineExpCtx = pipelineIt->getContext(); + + execs.emplace_back( + plan_executor_factory::make(std::move(pipelineExpCtx), + std::move(pipelineIt), + aggregation_request_helper::getResumableScanType( + request, liteParsedPipeline.hasChangeStream()))); + } + + // With the pipelines created, we can relinquish locks as they will manage the locks + // internally further on. We still need to keep the lock for an optimized away pipeline + // though, as we will be changing its lock policy to 'kLockExternally' (see details + // below), and in order to execute the initial getNext() call in 'handleCursorCommand', + // we need to hold the collection lock. + resetContextFn(); + } + return execs; +} + } // namespace Status runAggregate(OperationContext* opCtx, @@ -604,6 +673,7 @@ Status runAggregate(OperationContext* opCtx, std::vector<unique_ptr<PlanExecutor, PlanExecutor::Deleter>> execs; boost::intrusive_ptr<ExpressionContext> expCtx; auto curOp = CurOp::get(opCtx); + { // If we are in a transaction, check whether the parsed pipeline supports being in // a transaction and if the transaction's read concern is supported. @@ -812,59 +882,29 @@ Status runAggregate(OperationContext* opCtx, constexpr bool alreadyOptimized = true; pipeline->validateCommon(alreadyOptimized); - // Check if the pipeline has a $geoNear stage, as it will be ripped away during the build - // query executor phase below (to be replaced with a $geoNearCursorStage later during the - // executor attach phase). - auto hasGeoNearStage = !pipeline->getSources().empty() && - dynamic_cast<DocumentSourceGeoNear*>(pipeline->peekFront()); - - // Prepare a PlanExecutor to provide input into the pipeline, if needed. - auto attachExecutorCallback = - PipelineD::buildInnerQueryExecutor(collections, nss, &request, pipeline.get()); - - if (canOptimizeAwayPipeline(pipeline.get(), - attachExecutorCallback.second.get(), - request, - hasGeoNearStage, - liteParsedPipeline.hasChangeStream())) { - // This pipeline is currently empty, but once completed it will have only one source, - // which is a DocumentSourceCursor. Instead of creating a whole pipeline to do nothing - // more than forward the results of its cursor document source, we can use the - // PlanExecutor by itself. The resulting cursor will look like what the client would - // have gotten from find command. - execs.emplace_back(std::move(attachExecutorCallback.second)); - // Mark that this query does not use DocumentSource. - curOp->debug().documentSourceUsed = false; + if (feature_flags::gfeatureFlagCommonQueryFramework.isEnabled( + serverGlobalParams.featureCompatibility) && + internalQueryEnableCascadesOptimizer.load()) { + uassert(6624344, + "Exchanging is not supported in the Cascades optimizer", + !request.getExchange().has_value()); + + auto timeBegin = Date_t::now(); + execs.emplace_back(getSBEExecutorViaCascadesOptimizer( + opCtx, expCtx, nss, collections.getMainCollection(), *pipeline)); + auto elapsed = + (Date_t::now().toMillisSinceEpoch() - timeBegin.toMillisSinceEpoch()) / 1000.0; + std::cerr << "Optimization took: " << elapsed << " s.\n"; } else { - // Mark that this query uses DocumentSource. - curOp->debug().documentSourceUsed = true; - // Complete creation of the initial $cursor stage, if needed. - PipelineD::attachInnerQueryExecutorToPipeline(collections, - attachExecutorCallback.first, - std::move(attachExecutorCallback.second), - pipeline.get()); - - auto pipelines = - createExchangePipelinesIfNeeded(opCtx, expCtx, request, std::move(pipeline), uuid); - for (auto&& pipelineIt : pipelines) { - // There are separate ExpressionContexts for each exchange pipeline, so make sure to - // pass the pipeline's ExpressionContext to the plan executor factory. - auto pipelineExpCtx = pipelineIt->getContext(); - - execs.emplace_back(plan_executor_factory::make( - std::move(pipelineExpCtx), - std::move(pipelineIt), - aggregation_request_helper::getResumableScanType( - request, liteParsedPipeline.hasChangeStream()))); - } - - // With the pipelines created, we can relinquish locks as they will manage the locks - // internally further on. We still need to keep the lock for an optimized away pipeline - // though, as we will be changing its lock policy to 'kLockExternally' (see details - // below), and in order to execute the initial getNext() call in 'handleCursorCommand', - // we need to hold the collection lock. - resetContext(); + execs = createLegacyExecutor(std::move(pipeline), + liteParsedPipeline, + nss, + collections, + request, + curOp, + resetContext); } + tassert(6624353, "No executors", !execs.empty()); { auto planSummary = execs[0]->getPlanExplainer().getPlanSummary(); diff --git a/src/mongo/db/exec/sbe/SConscript b/src/mongo/db/exec/sbe/SConscript index 0e220aafed2..c9395ca88e6 100644 --- a/src/mongo/db/exec/sbe/SConscript +++ b/src/mongo/db/exec/sbe/SConscript @@ -114,6 +114,18 @@ env.Library( ) env.Library( + target='query_sbe_abt', + source=[ + 'abt/abt_lower.cpp', + ], + LIBDEPS=[ + '$BUILD_DIR/mongo/db/query/optimizer/optimizer', + 'query_sbe', + 'query_sbe_storage' + ] + ) + +env.Library( target='sbe_plan_stage_test', source=[ 'sbe_plan_stage_test.cpp', @@ -191,3 +203,20 @@ env.CppUnitTest( 'sbe_plan_stage_test', ], ) + +env.CppUnitTest( + target='sbe_abt_test', + source=[ + 'abt/sbe_abt_diff_test.cpp', + 'abt/sbe_abt_test.cpp', + 'abt/sbe_abt_test_util.cpp', + ], + LIBDEPS=[ + "$BUILD_DIR/mongo/db/auth/authmocks", + '$BUILD_DIR/mongo/db/query/query_test_service_context', + '$BUILD_DIR/mongo/db/query_exec', + '$BUILD_DIR/mongo/db/service_context_test_fixture', + '$BUILD_DIR/mongo/unittest/unittest', + 'query_sbe_abt', + ], +) diff --git a/src/mongo/db/exec/sbe/abt/abt_lower.cpp b/src/mongo/db/exec/sbe/abt/abt_lower.cpp new file mode 100644 index 00000000000..737f94064da --- /dev/null +++ b/src/mongo/db/exec/sbe/abt/abt_lower.cpp @@ -0,0 +1,1014 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/exec/sbe/abt/abt_lower.h" +#include "mongo/db/exec/sbe/stages/bson_scan.h" +#include "mongo/db/exec/sbe/stages/co_scan.h" +#include "mongo/db/exec/sbe/stages/exchange.h" +#include "mongo/db/exec/sbe/stages/filter.h" +#include "mongo/db/exec/sbe/stages/hash_agg.h" +#include "mongo/db/exec/sbe/stages/hash_join.h" +#include "mongo/db/exec/sbe/stages/ix_scan.h" +#include "mongo/db/exec/sbe/stages/limit_skip.h" +#include "mongo/db/exec/sbe/stages/loop_join.h" +#include "mongo/db/exec/sbe/stages/merge_join.h" +#include "mongo/db/exec/sbe/stages/project.h" +#include "mongo/db/exec/sbe/stages/scan.h" +#include "mongo/db/exec/sbe/stages/sort.h" +#include "mongo/db/exec/sbe/stages/union.h" +#include "mongo/db/exec/sbe/stages/unique.h" +#include "mongo/db/exec/sbe/stages/unwind.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer { + +static sbe::EExpression::Vector toInlinedVector( + std::vector<std::unique_ptr<sbe::EExpression>> args) { + sbe::EExpression::Vector inlined; + for (auto&& arg : args) { + inlined.emplace_back(std::move(arg)); + } + return inlined; +} + +std::unique_ptr<sbe::EExpression> SBEExpressionLowering::optimize(const ABT& n) { + return algebra::transport<false>(n, *this); +} + +std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(const Constant& c) { + auto [tag, val] = c.get(); + auto [copyTag, copyVal] = sbe::value::copyValue(tag, val); + sbe::value::ValueGuard guard(copyTag, copyVal); + + auto result = sbe::makeE<sbe::EConstant>(copyTag, copyVal); + + guard.reset(); + return result; +} + +std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(const Source&) { + uasserted(6624202, "not yet implemented"); + return nullptr; +} + +std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(const Variable& var) { + auto def = _env.getDefinition(&var); + + if (!def.definedBy.empty()) { + if (auto let = def.definedBy.cast<Let>(); let) { + auto it = _letMap.find(let); + uassert(6624203, "incorrect let map", it != _letMap.end()); + + return sbe::makeE<sbe::EVariable>(it->second, 0, _env.isLastRef(&var)); + } else if (auto lam = def.definedBy.cast<LambdaAbstraction>(); lam) { + // This is a lambda parameter. + auto it = _lambdaMap.find(lam); + uassert(6624204, "incorrect lambda map", it != _lambdaMap.end()); + + return sbe::makeE<sbe::EVariable>(it->second, 0, _env.isLastRef(&var)); + } + } + if (auto it = _slotMap.find(var.name()); it != _slotMap.end()) { + // Found the slot. + return sbe::makeE<sbe::EVariable>(it->second); + } + uasserted(6624205, str::stream() << "undefined variable: " << var.name()); + return nullptr; +} + +std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport( + const BinaryOp& op, + std::unique_ptr<sbe::EExpression> lhs, + std::unique_ptr<sbe::EExpression> rhs) { + + sbe::EPrimBinary::Op sbeOp = [](const auto abtOp) { + switch (abtOp) { + case Operations::Eq: + return sbe::EPrimBinary::eq; + case Operations::Neq: + return sbe::EPrimBinary::neq; + case Operations::Gt: + return sbe::EPrimBinary::greater; + case Operations::Gte: + return sbe::EPrimBinary::greaterEq; + case Operations::Lt: + return sbe::EPrimBinary::less; + case Operations::Lte: + return sbe::EPrimBinary::lessEq; + case Operations::Add: + return sbe::EPrimBinary::add; + case Operations::Sub: + return sbe::EPrimBinary::sub; + case Operations::And: + return sbe::EPrimBinary::logicAnd; + case Operations::Or: + return sbe::EPrimBinary::logicOr; + case Operations::Cmp3w: + return sbe::EPrimBinary::cmp3w; + case Operations::Div: + return sbe::EPrimBinary::div; + case Operations::Mult: + return sbe::EPrimBinary::mul; + default: + MONGO_UNREACHABLE; + } + }(op.op()); + + return sbe::makeE<sbe::EPrimBinary>(sbeOp, std::move(lhs), std::move(rhs)); +} + +std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport( + const UnaryOp& op, std::unique_ptr<sbe::EExpression> arg) { + + sbe::EPrimUnary::Op sbeOp = [](const auto abtOp) { + switch (abtOp) { + case Operations::Neg: + return sbe::EPrimUnary::negate; + case Operations::Not: + return sbe::EPrimUnary::logicNot; + default: + MONGO_UNREACHABLE; + } + }(op.op()); + + return sbe::makeE<sbe::EPrimUnary>(sbeOp, std::move(arg)); +} + +std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport( + const If&, + std::unique_ptr<sbe::EExpression> cond, + std::unique_ptr<sbe::EExpression> thenBranch, + std::unique_ptr<sbe::EExpression> elseBranch) { + return sbe::makeE<sbe::EIf>(std::move(cond), std::move(thenBranch), std::move(elseBranch)); +} + +void SBEExpressionLowering::prepare(const Let& let) { + _letMap[&let] = ++_frameCounter; +} + +std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport( + const Let& let, std::unique_ptr<sbe::EExpression> bind, std::unique_ptr<sbe::EExpression> in) { + auto it = _letMap.find(&let); + uassert(6624206, "incorrect let map", it != _letMap.end()); + auto frameId = it->second; + _letMap.erase(it); + + // ABT let binds only a single variable. When we extend it to support multiple binds then we + // have to revisit how we map variable names to sbe slot ids. + return sbe::makeE<sbe::ELocalBind>(frameId, sbe::makeEs(std::move(bind)), std::move(in)); +} + +void SBEExpressionLowering::prepare(const LambdaAbstraction& lam) { + _lambdaMap[&lam] = ++_frameCounter; +} + +std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport( + const LambdaAbstraction& lam, std::unique_ptr<sbe::EExpression> body) { + auto it = _lambdaMap.find(&lam); + uassert(6624207, "incorrect lambda map", it != _lambdaMap.end()); + auto frameId = it->second; + _lambdaMap.erase(it); + + return sbe::makeE<sbe::ELocalLambda>(frameId, std::move(body)); +} + +std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport( + const LambdaApplication&, + std::unique_ptr<sbe::EExpression> lam, + std::unique_ptr<sbe::EExpression> arg) { + // lambda applications are not directly supported by SBE (yet) and must not be present. + uasserted(6624208, "lambda application is not implemented"); + return nullptr; +} + +std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport( + const FunctionCall& fn, std::vector<std::unique_ptr<sbe::EExpression>> args) { + auto name = fn.name(); + + if (name == "typeMatch") { + uassert(6624209, "Invalid number of typeMatch arguments", fn.nodes().size() == 2); + if (const auto* constPtr = fn.nodes().at(1).cast<Constant>(); + constPtr != nullptr && constPtr->isValueInt64()) { + return sbe::makeE<sbe::ETypeMatch>(std::move(args.at(0)), constPtr->getValueInt64()); + } else { + uasserted(6624210, "Second argument of typeMatch must be an integer constant"); + } + } + + // TODO - this is an open question how to do the name mappings. + if (name == "$sum") { + name = "sum"; + } else if (name == "$first") { + name = "first"; + } else if (name == "$last") { + name = "last"; + } else if (name == "$min") { + name = "min"; + } else if (name == "$max") { + name = "max"; + } else if (name == "$addToSet") { + name = "addToSet"; + } + + return sbe::makeE<sbe::EFunction>(name, toInlinedVector(std::move(args))); +} + +sbe::value::SlotVector SBENodeLowering::convertProjectionsToSlots( + const ProjectionNameVector& projectionNames) { + sbe::value::SlotVector result; + for (const ProjectionName& projectionName : projectionNames) { + auto it = _slotMap.find(projectionName); + uassert(6624211, + str::stream() << "undefined variable: " << projectionName, + it != _slotMap.end()); + result.push_back(it->second); + } + return result; +} + +sbe::value::SlotVector SBENodeLowering::convertRequiredProjectionsToSlots( + const NodeProps& props, const bool addRIDProjection, const ProjectionNameVector& toExclude) { + using namespace properties; + + const PhysProps& physProps = props._physicalProps; + auto projections = getPropertyConst<ProjectionRequirement>(physProps).getProjections(); + + if (addRIDProjection && hasProperty<IndexingRequirement>(physProps) && + getPropertyConst<IndexingRequirement>(physProps).getNeedsRID()) { + const auto& scanDefName = + getPropertyConst<IndexingAvailability>(props._logicalProps).getScanDefName(); + projections.emplace_back(_ridProjections.at(scanDefName)); + } + + for (const ProjectionName& projName : toExclude) { + projections.erase(projName); + } + + return convertProjectionsToSlots(projections.getVector()); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::optimize(const ABT& n) { + return generateInternal(n); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::generateInternal(const ABT& n) { + return algebra::walk<false>(n, *this); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const RootNode& n, + const ABT& child, + const ABT& refs) { + + auto input = generateInternal(child); + + auto output = refs.cast<References>(); + uassert(6624212, "refs expected", output); + + SlotVarMap finalMap; + for (auto& o : output->nodes()) { + auto var = o.cast<Variable>(); + uassert(6624213, "var expected", var); + if (auto it = _slotMap.find(var->name()); it != _slotMap.end()) { + finalMap.emplace(var->name(), it->second); + } + } + + _slotMap = finalMap; + + return input; +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const EvaluationNode& n, + const ABT& child, + const ABT& binds) { + auto input = generateInternal(child); + + if (auto varPtr = n.getProjection().cast<Variable>(); varPtr != nullptr) { + // Evaluation node is only renaming a variable. Do not place a project stage. + _slotMap.emplace(n.getProjectionName(), _slotMap.at(varPtr->name())); + return input; + } + + auto binder = binds.cast<ExpressionBinder>(); + uassert(6624214, "binder expected", binder); + + auto& names = binder->names(); + auto& exprs = binder->exprs(); + + sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> projects; + + for (size_t idx = 0; idx < exprs.size(); ++idx) { + auto expr = SBEExpressionLowering{_env, _slotMap}.optimize(exprs[idx]); + auto slot = _slotIdGenerator.generate(); + + _slotMap.emplace(names[idx], slot); + projects.emplace(slot, std::move(expr)); + } + + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + return sbe::makeS<sbe::ProjectStage>(std::move(input), std::move(projects), planNodeId); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const FilterNode& n, + const ABT& child, + const ABT& filter) { + auto input = generateInternal(child); + auto expr = SBEExpressionLowering{_env, _slotMap}.optimize(filter); + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + + // Check if the filter expression is 'constant' (i.e. does not depend on any variables) + // and create FilterStage<true>. + if (_env.getVariables(filter)._variables.empty()) { + return sbe::makeS<sbe::FilterStage<true>>(std::move(input), std::move(expr), planNodeId); + } else { + return sbe::makeS<sbe::FilterStage<false>>(std::move(input), std::move(expr), planNodeId); + } +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const LimitSkipNode& n, const ABT& child) { + auto input = generateInternal(child); + + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + return sbe::makeS<sbe::LimitSkipStage>( + std::move(input), n.getProperty().getLimit(), n.getProperty().getSkip(), planNodeId); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const ExchangeNode& n, + const ABT& child, + const ABT& refs) { + using namespace std::literals; + // TODO: Implement all types of distributions. + using namespace properties; + + // The DOP is obtained from the child (number of producers). + const auto& childProps = _nodeToGroupPropsMap.at(n.getChild().cast<Node>())._physicalProps; + const auto& childDistribution = getPropertyConst<DistributionRequirement>(childProps); + uassert(6624330, + "Parent and child distributions are the same", + !(childDistribution == n.getProperty())); + + const size_t localDOP = + (childDistribution.getDistributionAndProjections()._type == DistributionType::Centralized) + ? 1 + : _metadata._numberOfPartitions; + uassert(6624215, "invalid DOP", localDOP >= 1); + + auto input = generateInternal(child); + + // Initialized to arbitrary placeholder + sbe::ExchangePolicy localPolicy{}; + std::unique_ptr<sbe::EExpression> partitionExpr; + + const auto& distribAndProjections = n.getProperty().getDistributionAndProjections(); + switch (distribAndProjections._type) { + case DistributionType::Centralized: + case DistributionType::Replicated: + localPolicy = sbe::ExchangePolicy::broadcast; + break; + + case DistributionType::RoundRobin: + localPolicy = sbe::ExchangePolicy::roundrobin; + break; + + case DistributionType::RangePartitioning: + localPolicy = sbe::ExchangePolicy::rangepartition; + break; + + case DistributionType::HashPartitioning: { + localPolicy = sbe::ExchangePolicy::hashpartition; + std::vector<std::unique_ptr<sbe::EExpression>> args; + for (const ProjectionName& proj : distribAndProjections._projectionNames) { + auto it = _slotMap.find(proj); + uassert(6624216, str::stream() << "undefined var: " << proj, it != _slotMap.end()); + + args.emplace_back(sbe::makeE<sbe::EVariable>(it->second)); + } + partitionExpr = sbe::makeE<sbe::EFunction>("hash"_sd, toInlinedVector(std::move(args))); + break; + } + + case DistributionType::UnknownPartitioning: + uasserted(6624217, "Cannot partition into unknown distribution"); + + default: + MONGO_UNREACHABLE; + } + + const auto& nodeProps = _nodeToGroupPropsMap.at(&n); + auto fields = convertRequiredProjectionsToSlots(nodeProps, true /*addRIDProjection*/); + + return sbe::makeS<sbe::ExchangeConsumer>(std::move(input), + localDOP, + std::move(fields), + localPolicy, + std::move(partitionExpr), + nullptr, + nodeProps._planNodeId); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const CollationNode& n, + const ABT& child, + const ABT& refs) { + auto input = generateInternal(child); + + sbe::value::SlotVector orderBySlots; + std::vector<sbe::value::SortDirection> directions; + ProjectionNameVector collationProjections; + for (const auto& entry : n.getProperty().getCollationSpec()) { + collationProjections.push_back(entry.first); + auto it = _slotMap.find(entry.first); + + uassert(6624219, + str::stream() << "undefined orderBy var: " << entry.first, + it != _slotMap.end()); + orderBySlots.push_back(it->second); + + switch (entry.second) { + case CollationOp::Ascending: + case CollationOp::Clustered: + // TODO: is there a more efficient way to compute clustered collation op than sort? + directions.push_back(sbe::value::SortDirection::Ascending); + break; + + case CollationOp::Descending: + directions.push_back(sbe::value::SortDirection::Descending); + break; + + default: + MONGO_UNREACHABLE; + } + } + + const auto& nodeProps = _nodeToGroupPropsMap.at(&n); + const auto& physProps = nodeProps._physicalProps; + + size_t limit = std::numeric_limits<std::size_t>::max(); + if (properties::hasProperty<properties::LimitSkipRequirement>(physProps)) { + const auto& limitSkipReq = + properties::getPropertyConst<properties::LimitSkipRequirement>(physProps); + uassert(6624221, "We should not have skip set here", limitSkipReq.getSkip() == 0); + limit = limitSkipReq.getLimit(); + } + + // TODO: obtain defaults for these. + const size_t memoryLimit = 100 * (1ul << 20); // 100MB + const bool allowDiskUse = false; + + auto vals = convertRequiredProjectionsToSlots( + nodeProps, true /*addRIDProjection*/, collationProjections); + return sbe::makeS<sbe::SortStage>(std::move(input), + std::move(orderBySlots), + std::move(directions), + std::move(vals), + limit, + memoryLimit, + allowDiskUse, + nodeProps._planNodeId); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const UniqueNode& n, + const ABT& child, + const ABT& refs) { + auto input = generateInternal(child); + + sbe::value::SlotVector keySlots; + for (const ProjectionName& projectionName : n.getProjections()) { + auto it = _slotMap.find(projectionName); + uassert(6624222, + str::stream() << "undefined variable: " << projectionName, + it != _slotMap.end()); + keySlots.push_back(it->second); + } + + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + return sbe::makeS<sbe::UniqueStage>(std::move(input), std::move(keySlots), planNodeId); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const GroupByNode& n, + const ABT& child, + const ABT& aggBinds, + const ABT& aggRefs, + const ABT& gbBind, + const ABT& gbRefs) { + + auto input = generateInternal(child); + + // Ideally, we should make a distinction between gbBind and gbRefs; i.e. internal references + // used by the hash agg to determinte the group by values from its input and group by values as + // outputted by the hash agg after the grouping. However, SBE hash agg uses the same slot it to + // represent both so that distinction is kind of moot. + sbe::value::SlotVector gbs; + auto gbCols = gbRefs.cast<References>(); + uassert(6624223, "refs expected", gbCols); + for (auto& o : gbCols->nodes()) { + auto var = o.cast<Variable>(); + uassert(6624224, "var expected", var); + auto it = _slotMap.find(var->name()); + uassert(6624225, str::stream() << "undefined var: " << var->name(), it != _slotMap.end()); + gbs.push_back(it->second); + } + + // Similar considerations apply to the agg expressions as to the group by columns. + auto binderAgg = aggBinds.cast<ExpressionBinder>(); + uassert(6624226, "binder expected", binderAgg); + auto refsAgg = aggRefs.cast<References>(); + uassert(6624227, "refs expected", refsAgg); + + auto& names = binderAgg->names(); + auto& exprs = refsAgg->nodes(); + + sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> aggs; + + for (size_t idx = 0; idx < exprs.size(); ++idx) { + auto expr = SBEExpressionLowering{_env, _slotMap}.optimize(exprs[idx]); + auto slot = _slotIdGenerator.generate(); + + _slotMap.emplace(names[idx], slot); + aggs.emplace(slot, std::move(expr)); + } + + // TODO: use collator slot. + boost::optional<sbe::value::SlotId> collatorSlot; + // Unused + sbe::value::SlotVector seekKeysSlots; + + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + return sbe::makeS<sbe::HashAggStage>(std::move(input), + std::move(gbs), + std::move(aggs), + std::move(seekKeysSlots), + true /*optimizedClose*/, + collatorSlot, + false /*allowDiskUse*/, + planNodeId); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const BinaryJoinNode& n, + const ABT& leftChild, + const ABT& rightChild, + const ABT& filter) { + auto outerStage = generateInternal(leftChild); + auto innerStage = generateInternal(rightChild); + + // List of correlated projections (bound in outer side and referred to in the inner side). + sbe::value::SlotVector correlatedSlots; + for (const ProjectionName& projectionName : n.getCorrelatedProjectionNames()) { + correlatedSlots.push_back(_slotMap.at(projectionName)); + } + + const auto& leftChildProps = _nodeToGroupPropsMap.at(n.getLeftChild().cast<Node>()); + auto expr = SBEExpressionLowering{_env, _slotMap}.optimize(filter); + + auto outerProjects = + convertRequiredProjectionsToSlots(leftChildProps, true /*addRIDProjection*/); + + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + return sbe::makeS<sbe::LoopJoinStage>(std::move(outerStage), + std::move(innerStage), + std::move(outerProjects), + std::move(correlatedSlots), + std::move(expr), + planNodeId); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const HashJoinNode& n, + const ABT& leftChild, + const ABT& rightChild, + const ABT& refs) { + auto outerStage = generateInternal(leftChild); + auto innerStage = generateInternal(rightChild); + + uassert(6624228, "Only inner joins supported for now", n.getJoinType() == JoinType::Inner); + + const auto& leftProps = _nodeToGroupPropsMap.at(n.getLeftChild().cast<Node>()); + const auto& rightProps = _nodeToGroupPropsMap.at(n.getRightChild().cast<Node>()); + + // Add RID projection only from outer side. + auto outerKeys = convertProjectionsToSlots(n.getLeftKeys()); + auto outerProjects = + convertRequiredProjectionsToSlots(leftProps, true /*addRIDProjection*/, n.getLeftKeys()); + auto innerKeys = convertProjectionsToSlots(n.getRightKeys()); + auto innerProjects = + convertRequiredProjectionsToSlots(rightProps, false /*addRIDProjection*/, n.getRightKeys()); + + // TODO: use collator slot. + boost::optional<sbe::value::SlotId> collatorSlot; + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + return sbe::makeS<sbe::HashJoinStage>(std::move(outerStage), + std::move(innerStage), + std::move(outerKeys), + std::move(outerProjects), + std::move(innerKeys), + std::move(innerProjects), + collatorSlot, + planNodeId); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const MergeJoinNode& n, + const ABT& leftChild, + const ABT& rightChild, + const ABT& refs) { + auto outerStage = generateInternal(leftChild); + auto innerStage = generateInternal(rightChild); + + const auto& leftProps = _nodeToGroupPropsMap.at(n.getLeftChild().cast<Node>()); + const auto& rightProps = _nodeToGroupPropsMap.at(n.getRightChild().cast<Node>()); + + std::vector<sbe::value::SortDirection> sortDirs; + for (const CollationOp op : n.getCollation()) { + switch (op) { + case CollationOp::Ascending: + case CollationOp::Clustered: + sortDirs.push_back(sbe::value::SortDirection::Ascending); + break; + + case CollationOp::Descending: + sortDirs.push_back(sbe::value::SortDirection::Descending); + break; + + default: + MONGO_UNREACHABLE; + } + } + + // Add RID projection only from outer side. + auto outerKeys = convertProjectionsToSlots(n.getLeftKeys()); + auto outerProjects = + convertRequiredProjectionsToSlots(leftProps, true /*addRIDProjection*/, n.getLeftKeys()); + auto innerKeys = convertProjectionsToSlots(n.getRightKeys()); + auto innerProjects = + convertRequiredProjectionsToSlots(rightProps, false /*addRIDProjection*/, n.getRightKeys()); + + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + return sbe::makeS<sbe::MergeJoinStage>(std::move(outerStage), + std::move(innerStage), + std::move(outerKeys), + std::move(outerProjects), + std::move(innerKeys), + std::move(innerProjects), + std::move(sortDirs), + planNodeId); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const UnionNode& n, + const ABTVector& children, + const ABT& binder, + const ABT& refs) { + auto unionBinder = binder.cast<ExpressionBinder>(); + uassert(6624229, "binder expected", unionBinder); + const auto& names = unionBinder->names(); + + sbe::PlanStage::Vector loweredChildren; + std::vector<sbe::value::SlotVector> inputVals; + + for (const ABT& child : children) { + // Use a fresh map to prevent same projections for every child being overwritten. + SlotVarMap localMap; + SBENodeLowering localLowering(_env, + localMap, + _slotIdGenerator, + _metadata, + _nodeToGroupPropsMap, + _ridProjections, + _randomScan); + auto loweredChild = localLowering.optimize(child); + + if (children.size() == 1) { + // Union with one child is used to restrict projections. Do not place a union stage. + for (const auto& name : names) { + _slotMap.emplace(name, localMap.at(name)); + } + return loweredChild; + } + loweredChildren.push_back(std::move(loweredChild)); + + sbe::value::SlotVector childSlots; + for (const auto& name : names) { + childSlots.push_back(localMap.at(name)); + } + inputVals.emplace_back(std::move(childSlots)); + } + + sbe::value::SlotVector outputVals; + for (const auto& name : names) { + const auto outputSlot = _slotIdGenerator.generate(); + _slotMap.emplace(name, outputSlot); + outputVals.push_back(outputSlot); + } + + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + return sbe::makeS<sbe::UnionStage>( + std::move(loweredChildren), std::move(inputVals), std::move(outputVals), planNodeId); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const UnwindNode& n, + const ABT& child, + const ABT& pidBind, + const ABT& refs) { + auto input = generateInternal(child); + + auto it = _slotMap.find(n.getProjectionName()); + uassert(6624230, + str::stream() << "undefined unwind variable: " << n.getProjectionName(), + it != _slotMap.end()); + + auto inputSlot = it->second; + auto outputSlot = _slotIdGenerator.generate(); + auto outputPidSlot = _slotIdGenerator.generate(); + + _slotMap[n.getProjectionName()] = outputSlot; + _slotMap[n.getPIDProjectionName()] = outputPidSlot; + + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + return sbe::makeS<sbe::UnwindStage>( + std::move(input), inputSlot, outputSlot, outputPidSlot, n.getRetainNonArrays(), planNodeId); +} + +void SBENodeLowering::generateSlots(const FieldProjectionMap& fieldProjectionMap, + boost::optional<sbe::value::SlotId>& ridSlot, + boost::optional<sbe::value::SlotId>& rootSlot, + std::vector<std::string>& fields, + sbe::value::SlotVector& vars) { + if (!fieldProjectionMap._ridProjection.empty()) { + ridSlot = _slotIdGenerator.generate(); + _slotMap.emplace(fieldProjectionMap._ridProjection, ridSlot.value()); + } + if (!fieldProjectionMap._rootProjection.empty()) { + rootSlot = _slotIdGenerator.generate(); + _slotMap.emplace(fieldProjectionMap._rootProjection, rootSlot.value()); + } + for (const auto& [fieldName, projectionName] : fieldProjectionMap._fieldProjections) { + vars.push_back(_slotIdGenerator.generate()); + _slotMap.emplace(projectionName, vars.back()); + fields.push_back(fieldName); + } +} + +static NamespaceStringOrUUID parseFromScanDef(const ScanDefinition& def) { + const auto& dbName = def.getOptionsMap().at("database"); + const auto& uuidStr = def.getOptionsMap().at("uuid"); + return {dbName, UUID::parse(uuidStr).getValue()}; +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::lowerScanNode( + const Node& n, + const std::string& scanDefName, + const FieldProjectionMap& fieldProjectionMap, + const bool useParallelScan) { + const ScanDefinition& def = _metadata._scanDefs.at(scanDefName); + uassert(6624231, "Collection must exist to lower Scan", def.exists()); + auto& typeSpec = def.getOptionsMap().at("type"); + + boost::optional<sbe::value::SlotId> ridSlot; + boost::optional<sbe::value::SlotId> rootSlot; + std::vector<std::string> fields; + sbe::value::SlotVector vars; + generateSlots(fieldProjectionMap, ridSlot, rootSlot, fields, vars); + + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + if (typeSpec == "mongod") { + NamespaceStringOrUUID nss = parseFromScanDef(def); + + // Unused. + boost::optional<sbe::value::SlotId> seekKeySlot; + + sbe::ScanCallbacks callbacks({}, {}, {}); + if (useParallelScan) { + return sbe::makeS<sbe::ParallelScanStage>(nss.uuid().get(), + rootSlot, + ridSlot, + boost::none, + boost::none, + boost::none, + boost::none, + fields, + vars, + nullptr /*yieldPolicy*/, + planNodeId, + callbacks); + } else { + return sbe::makeS<sbe::ScanStage>(nss.uuid().get(), + rootSlot, + ridSlot, + boost::none, + boost::none, + boost::none, + boost::none, + boost::none, + fields, + vars, + seekKeySlot, + true /*forward*/, + nullptr /*yieldPolicy*/, + planNodeId, + callbacks, + _randomScan); + } + } else { + uasserted(6624355, "Unknown scan type."); + } + return nullptr; +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const ScanNode& n, const ABT& /*binds*/) { + FieldProjectionMap fieldProjectionMap; + fieldProjectionMap._rootProjection = n.getProjectionName(); + return lowerScanNode(n, n.getScanDefName(), fieldProjectionMap, false /*useParallelScan*/); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const PhysicalScanNode& n, + const ABT& /*binds*/) { + return lowerScanNode(n, n.getScanDefName(), n.getFieldProjectionMap(), n.useParallelScan()); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const CoScanNode& n) { + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + return sbe::makeS<sbe::CoScanStage>(planNodeId); +} + +std::unique_ptr<sbe::EExpression> SBENodeLowering::convertBoundsToExpr( + const bool isLower, + const IndexDefinition& indexDef, + const MultiKeyIntervalRequirement& interval) { + std::vector<std::unique_ptr<sbe::EExpression>> ksFnArgs; + ksFnArgs.emplace_back( + sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::NumberInt64, + sbe::value::bitcastFrom<int64_t>(indexDef.getVersion()))); + + // TODO: ordering is unsigned int32?? + ksFnArgs.emplace_back( + sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::NumberInt32, + sbe::value::bitcastFrom<uint32_t>(indexDef.getOrdering()))); + + auto exprLower = SBEExpressionLowering{_env, _slotMap}; + bool inclusive = true; + bool fullyInfinite = true; + for (const auto& entry : interval) { + const BoundRequirement& entryBound = isLower ? entry.getLowBound() : entry.getHighBound(); + const bool isInfinite = entryBound.isInfinite(); + if (!isInfinite) { + fullyInfinite = false; + if (!entryBound.isInclusive()) { + inclusive = false; + } + } + + ABT bound = isInfinite ? (isLower ? Constant::minKey() : Constant::maxKey()) + : entryBound.getBound(); + auto boundExpr = exprLower.optimize(std::move(bound)); + ksFnArgs.emplace_back(std::move(boundExpr)); + } + if (fullyInfinite && !isLower) { + // We can skip if fully infinite only for upper bound. For lower bound we need to generate + // minkeys. + return nullptr; + }; + + ksFnArgs.emplace_back(sbe::makeE<sbe::EConstant>( + sbe::value::TypeTags::NumberInt64, + sbe::value::bitcastFrom<int64_t>((isLower == inclusive) ? 1 : 2))); + return sbe::makeE<sbe::EFunction>("ks", toInlinedVector(std::move(ksFnArgs))); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const IndexScanNode& n, const ABT&) { + const auto& fieldProjectionMap = n.getFieldProjectionMap(); + const auto& indexSpec = n.getIndexSpecification(); + const auto& interval = indexSpec.getInterval(); + + const std::string& indexDefName = n.getIndexSpecification().getIndexDefName(); + const ScanDefinition& scanDef = _metadata._scanDefs.at(indexSpec.getScanDefName()); + uassert(6624232, "Collection must exist to lower IndexScan", scanDef.exists()); + const IndexDefinition& indexDef = scanDef.getIndexDefs().at(indexDefName); + + NamespaceStringOrUUID nss = parseFromScanDef(scanDef); + + boost::optional<sbe::value::SlotId> ridSlot; + boost::optional<sbe::value::SlotId> rootSlot; + std::vector<std::string> fields; + sbe::value::SlotVector vars; + generateSlots(fieldProjectionMap, ridSlot, rootSlot, fields, vars); + uassert(6624233, "Cannot deliver root projection in this context", !rootSlot.has_value()); + + sbe::IndexKeysInclusionSet indexKeysToInclude; + for (const std::string& fieldName : fields) { + indexKeysToInclude.set(decodeIndexKeyName(fieldName), true); + } + + auto lowerBoundExpr = convertBoundsToExpr(true /*isLower*/, indexDef, interval); + auto upperBoundExpr = convertBoundsToExpr(false /*isLower*/, indexDef, interval); + const bool hasLowerBound = lowerBoundExpr != nullptr; + const bool hasUpperBound = upperBoundExpr != nullptr; + uassert(6624234, "Invalid bounds combination", hasLowerBound || !hasUpperBound); + + boost::optional<sbe::value::SlotId> seekKeySlotLower; + boost::optional<sbe::value::SlotId> seekKeySlotUpper; + sbe::value::SlotVector correlatedSlotsForJoin; + + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + auto projectForKeyStringBounds = sbe::makeS<sbe::LimitSkipStage>( + sbe::makeS<sbe::CoScanStage>(planNodeId), 1, boost::none, planNodeId); + if (hasLowerBound) { + seekKeySlotLower = _slotIdGenerator.generate(); + correlatedSlotsForJoin.push_back(seekKeySlotLower.value()); + projectForKeyStringBounds = sbe::makeProjectStage(std::move(projectForKeyStringBounds), + planNodeId, + seekKeySlotLower.value(), + std::move(lowerBoundExpr)); + } + if (hasUpperBound) { + seekKeySlotUpper = _slotIdGenerator.generate(); + correlatedSlotsForJoin.push_back(seekKeySlotUpper.value()); + projectForKeyStringBounds = sbe::makeProjectStage(std::move(projectForKeyStringBounds), + planNodeId, + seekKeySlotUpper.value(), + std::move(upperBoundExpr)); + } + + // Unused. + boost::optional<sbe::value::SlotId> resultSlot; + + auto result = sbe::makeS<sbe::IndexScanStage>(nss.uuid().get(), + indexDefName, + !indexSpec.isReverseOrder(), + resultSlot, + ridSlot, + boost::none, + indexKeysToInclude, + vars, + seekKeySlotLower, + seekKeySlotUpper, + nullptr /*yieldPolicy*/, + planNodeId); + + return sbe::makeS<sbe::LoopJoinStage>(std::move(projectForKeyStringBounds), + std::move(result), + sbe::makeSV(), + std::move(correlatedSlotsForJoin), + nullptr, + planNodeId); +} + +std::unique_ptr<sbe::PlanStage> SBENodeLowering::walk(const SeekNode& n, + const ABT& /*binds*/, + const ABT& /*refs*/) { + const ScanDefinition& def = _metadata._scanDefs.at(n.getScanDefName()); + uassert(6624235, "Collection must exist to lower Seek", def.exists()); + + auto& typeSpec = def.getOptionsMap().at("type"); + uassert(6624236, "SeekNode only supports mongod collections", typeSpec == "mongod"); + NamespaceStringOrUUID nss = parseFromScanDef(def); + + boost::optional<sbe::value::SlotId> ridSlot; + boost::optional<sbe::value::SlotId> rootSlot; + std::vector<std::string> fields; + sbe::value::SlotVector vars; + generateSlots(n.getFieldProjectionMap(), ridSlot, rootSlot, fields, vars); + + boost::optional<sbe::value::SlotId> seekKeySlot = _slotMap.at(n.getRIDProjectionName()); + + sbe::ScanCallbacks callbacks({}, {}, {}); + const PlanNodeId planNodeId = _nodeToGroupPropsMap.at(&n)._planNodeId; + return sbe::makeS<sbe::ScanStage>(nss.uuid().get(), + rootSlot, + ridSlot, + boost::none, + boost::none, + boost::none, + boost::none, + boost::none, + fields, + vars, + seekKeySlot, + true /*forward*/, + nullptr /*yieldPolicy*/, + planNodeId, + callbacks); +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/exec/sbe/abt/abt_lower.h b/src/mongo/db/exec/sbe/abt/abt_lower.h new file mode 100644 index 00000000000..17365f7925f --- /dev/null +++ b/src/mongo/db/exec/sbe/abt/abt_lower.h @@ -0,0 +1,204 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/exec/sbe/expressions/expression.h" +#include "mongo/db/exec/sbe/stages/stages.h" +#include "mongo/db/query/optimizer/node_defs.h" +#include "mongo/db/query/optimizer/reference_tracker.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer { +using SlotVarMap = stdx::unordered_map<std::string, sbe::value::SlotId>; + +class SBEExpressionLowering { +public: + SBEExpressionLowering(const VariableEnvironment& env, SlotVarMap& slotMap) + : _env(env), _slotMap(slotMap) {} + + // The default noop transport. + template <typename T, typename... Ts> + std::unique_ptr<sbe::EExpression> transport(const T&, Ts&&...) { + uasserted(6624237, "abt tree is not lowered correctly"); + return nullptr; + } + + std::unique_ptr<sbe::EExpression> transport(const Constant&); + std::unique_ptr<sbe::EExpression> transport(const Variable& var); + std::unique_ptr<sbe::EExpression> transport(const Source&); + std::unique_ptr<sbe::EExpression> transport(const BinaryOp& op, + std::unique_ptr<sbe::EExpression> lhs, + std::unique_ptr<sbe::EExpression> rhs); + std::unique_ptr<sbe::EExpression> transport(const UnaryOp& op, + std::unique_ptr<sbe::EExpression> arg); + std::unique_ptr<sbe::EExpression> transport(const If&, + std::unique_ptr<sbe::EExpression> cond, + std::unique_ptr<sbe::EExpression> thenBranch, + std::unique_ptr<sbe::EExpression> elseBranch); + + void prepare(const Let& let); + std::unique_ptr<sbe::EExpression> transport(const Let& let, + std::unique_ptr<sbe::EExpression> bind, + std::unique_ptr<sbe::EExpression> in); + void prepare(const LambdaAbstraction& lam); + std::unique_ptr<sbe::EExpression> transport(const LambdaAbstraction& lam, + std::unique_ptr<sbe::EExpression> body); + std::unique_ptr<sbe::EExpression> transport(const LambdaApplication&, + std::unique_ptr<sbe::EExpression> lam, + std::unique_ptr<sbe::EExpression> arg); + std::unique_ptr<sbe::EExpression> transport( + const FunctionCall& fn, std::vector<std::unique_ptr<sbe::EExpression>> args); + + std::unique_ptr<sbe::EExpression> optimize(const ABT& n); + +private: + const VariableEnvironment& _env; + SlotVarMap& _slotMap; + + sbe::FrameId _frameCounter{100}; + stdx::unordered_map<const Let*, sbe::FrameId> _letMap; + stdx::unordered_map<const LambdaAbstraction*, sbe::FrameId> _lambdaMap; +}; + +class SBENodeLowering { +public: + SBENodeLowering(const VariableEnvironment& env, + SlotVarMap& slotMap, + sbe::value::SlotIdGenerator& ids, + const Metadata& metadata, + const NodeToGroupPropsMap& nodeToGroupPropsMap, + const opt::unordered_map<std::string, ProjectionName>& ridProjections, + const bool randomScan = false) + : _env(env), + _slotMap(slotMap), + _slotIdGenerator(ids), + _metadata(metadata), + _nodeToGroupPropsMap(nodeToGroupPropsMap), + _ridProjections(ridProjections), + _randomScan(randomScan) {} + + // The default noop transport. + template <typename T, typename... Ts> + std::unique_ptr<sbe::PlanStage> walk(const T&, Ts&&...) { + if constexpr (std::is_base_of_v<LogicalNode, T>) { + uasserted(6624238, "A physical plan should not contain exclusively logical nodes."); + } + return nullptr; + } + + std::unique_ptr<sbe::PlanStage> walk(const RootNode& n, const ABT& child, const ABT& refs); + std::unique_ptr<sbe::PlanStage> walk(const EvaluationNode& n, + const ABT& child, + const ABT& binds); + + std::unique_ptr<sbe::PlanStage> walk(const FilterNode& n, const ABT& child, const ABT& filter); + + std::unique_ptr<sbe::PlanStage> walk(const LimitSkipNode& n, const ABT& child); + std::unique_ptr<sbe::PlanStage> walk(const ExchangeNode& n, const ABT& child, const ABT& refs); + std::unique_ptr<sbe::PlanStage> walk(const CollationNode& n, const ABT& child, const ABT& refs); + + std::unique_ptr<sbe::PlanStage> walk(const UniqueNode& n, const ABT& child, const ABT& refs); + + std::unique_ptr<sbe::PlanStage> walk(const GroupByNode& n, + const ABT& child, + const ABT& aggBinds, + const ABT& aggRefs, + const ABT& gbBind, + const ABT& gbRefs); + + std::unique_ptr<sbe::PlanStage> walk(const BinaryJoinNode& n, + const ABT& leftChild, + const ABT& rightChild, + const ABT& filter); + std::unique_ptr<sbe::PlanStage> walk(const HashJoinNode& n, + const ABT& leftChild, + const ABT& rightChild, + const ABT& refs); + std::unique_ptr<sbe::PlanStage> walk(const MergeJoinNode& n, + const ABT& leftChild, + const ABT& rightChild, + const ABT& refs); + + std::unique_ptr<sbe::PlanStage> walk(const UnionNode& n, + const ABTVector& children, + const ABT& binder, + const ABT& refs); + + std::unique_ptr<sbe::PlanStage> walk(const UnwindNode& n, + const ABT& child, + const ABT& pidBind, + const ABT& refs); + + std::unique_ptr<sbe::PlanStage> walk(const ScanNode& n, const ABT& /*binds*/); + std::unique_ptr<sbe::PlanStage> walk(const PhysicalScanNode& n, const ABT& /*binds*/); + std::unique_ptr<sbe::PlanStage> walk(const CoScanNode& n); + + std::unique_ptr<sbe::PlanStage> walk(const IndexScanNode& n, const ABT& /*binds*/); + std::unique_ptr<sbe::PlanStage> walk(const SeekNode& n, + const ABT& /*binds*/, + const ABT& /*refs*/); + + std::unique_ptr<sbe::PlanStage> optimize(const ABT& n); + +private: + std::unique_ptr<sbe::PlanStage> lowerScanNode(const Node& n, + const std::string& scanDefName, + const FieldProjectionMap& fieldProjectionMap, + bool useParallelScan); + void generateSlots(const FieldProjectionMap& fieldProjectionMap, + boost::optional<sbe::value::SlotId>& ridSlot, + boost::optional<sbe::value::SlotId>& rootSlot, + std::vector<std::string>& fields, + sbe::value::SlotVector& vars); + + sbe::value::SlotVector convertProjectionsToSlots(const ProjectionNameVector& projectionNames); + sbe::value::SlotVector convertRequiredProjectionsToSlots( + const NodeProps& props, bool addRIDProjection, const ProjectionNameVector& toExclude = {}); + + std::unique_ptr<sbe::EExpression> convertBoundsToExpr( + bool isLower, const IndexDefinition& indexDef, const MultiKeyIntervalRequirement& interval); + + std::unique_ptr<sbe::PlanStage> generateInternal(const ABT& n); + + const VariableEnvironment& _env; + SlotVarMap& _slotMap; + sbe::value::SlotIdGenerator& _slotIdGenerator; + + const Metadata& _metadata; + const NodeToGroupPropsMap& _nodeToGroupPropsMap; + const opt::unordered_map<std::string, ProjectionName>& _ridProjections; + + // If true, will create scan nodes using a random cursor to support sampling. + // Currently only supported for single-threaded (non parallel-scanned) mongod collections. + // TODO: handle cases where we have more than one collection scan. + const bool _randomScan; +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_diff_test.cpp b/src/mongo/db/exec/sbe/abt/sbe_abt_diff_test.cpp new file mode 100644 index 00000000000..3758bf6ac18 --- /dev/null +++ b/src/mongo/db/exec/sbe/abt/sbe_abt_diff_test.cpp @@ -0,0 +1,369 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/concurrency/lock_state.h" +#include "mongo/db/exec/sbe/abt/abt_lower.h" +#include "mongo/db/exec/sbe/abt/sbe_abt_test_util.h" +#include "mongo/db/pipeline/abt/abt_document_source_visitor.h" +#include "mongo/db/pipeline/document_source_queue.h" +#include "mongo/db/query/optimizer/explain.h" +#include "mongo/db/query/optimizer/opt_phase_manager.h" +#include "mongo/db/query/plan_executor.h" +#include "mongo/db/query/plan_executor_factory.h" +#include "mongo/unittest/temp_dir.h" + +namespace mongo::optimizer { +namespace { + +ABT createValueArray(const std::vector<std::string>& jsonVector) { + const auto [tag, val] = sbe::value::makeNewArray(); + auto outerArrayPtr = sbe::value::getArrayView(val); + + for (const std::string& s : jsonVector) { + const auto [tag1, val1] = sbe::value::makeNewArray(); + auto innerArrayPtr = sbe::value::getArrayView(val1); + + const BSONObj& bsonObj = fromjson(s); + const auto [tag2, val2] = + sbe::value::copyValue(sbe::value::TypeTags::bsonObject, + sbe::value::bitcastFrom<const char*>(bsonObj.objdata())); + innerArrayPtr->push_back(tag2, val2); + + outerArrayPtr->push_back(tag1, val1); + } + + return make<Constant>(tag, val); +} + +using ContextFn = std::function<ServiceContext::UniqueOperationContext()>; +using ResultSet = std::vector<BSONObj>; + +ResultSet runSBEAST(const ContextFn& fn, + const std::string& pipelineStr, + const std::vector<std::string>& jsonVector) { + PrefixId prefixId; + Metadata metadata{{}}; + + auto pipeline = parsePipeline(pipelineStr, NamespaceString("test")); + + ABT tree = createValueArray(jsonVector); + + const ProjectionName scanProjName = prefixId.getNextId("scan"); + tree = translatePipelineToABT( + metadata, + *pipeline.get(), + scanProjName, + make<ValueScanNode>(ProjectionNameVector{scanProjName}, std::move(tree)), + prefixId); + + std::cerr << "********* Translated ABT *********\n"; + std::cerr << ExplainGenerator::explainV2(tree); + std::cerr << "********* Translated ABT *********\n"; + + OptPhaseManager phaseManager( + OptPhaseManager::getAllRewritesSet(), prefixId, {{}}, DebugInfo::kDefaultForTests); + ASSERT_TRUE(phaseManager.optimize(tree)); + + std::cerr << "********* Optimized ABT *********\n"; + std::cerr << ExplainGenerator::explainV2(tree); + std::cerr << "********* Optimized ABT *********\n"; + + SlotVarMap map; + sbe::value::SlotIdGenerator ids; + + auto env = VariableEnvironment::build(tree); + SBENodeLowering g{env, + map, + ids, + phaseManager.getMetadata(), + phaseManager.getNodeToGroupPropsMap(), + phaseManager.getRIDProjections()}; + auto sbePlan = g.optimize(tree); + uassert(6624249, "Cannot optimize SBE plan", sbePlan != nullptr); + + sbe::CompileCtx ctx(std::make_unique<sbe::RuntimeEnvironment>()); + sbePlan->prepare(ctx); + + std::vector<sbe::value::SlotAccessor*> accessors; + for (auto& [name, slot] : map) { + accessors.emplace_back(sbePlan->getAccessor(ctx, slot)); + } + // For now assert we only have one final projection. + ASSERT_EQ(1, accessors.size()); + + sbePlan->attachToOperationContext(fn().get()); + sbePlan->open(false); + + ResultSet results; + while (sbePlan->getNext() != sbe::PlanState::IS_EOF) { + if (results.size() > 1000) { + uasserted(6624250, "Too many results!"); + } + + std::ostringstream os; + os << accessors.at(0)->getViewOfValue(); + results.push_back(fromjson(os.str())); + }; + sbePlan->close(); + + return results; +} + +ResultSet runPipeline(const ContextFn& fn, + const std::string& pipelineStr, + const std::vector<std::string>& jsonVector) { + NamespaceString nss("test"); + std::unique_ptr<mongo::Pipeline, mongo::PipelineDeleter> pipeline = + parsePipeline(pipelineStr, nss); + + const auto queueStage = DocumentSourceQueue::create(pipeline->getContext()); + for (const std::string& s : jsonVector) { + BSONObj bsonObj = fromjson(s); + queueStage->emplace_back(Document{bsonObj}); + } + + pipeline->addInitialSource(queueStage); + + boost::intrusive_ptr<ExpressionContext> expCtx; + expCtx.reset(new ExpressionContext(fn().get(), nullptr, nss)); + + std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> planExec = + plan_executor_factory::make(expCtx, std::move(pipeline)); + + ResultSet results; + bool done = false; + while (!done) { + BSONObj outObj; + auto result = planExec->getNext(&outObj, nullptr); + switch (result) { + case PlanExecutor::ADVANCED: + results.push_back(outObj); + break; + case PlanExecutor::IS_EOF: + done = true; + break; + } + } + + return results; +} + +bool compareBSONObj(const BSONObj& actual, const BSONObj& expected, const bool preserveFieldOrder) { + BSONObj::ComparisonRulesSet rules = BSONObj::ComparisonRules::kConsiderFieldName; + if (!preserveFieldOrder) { + rules |= BSONObj::ComparisonRules::kIgnoreFieldOrder; + } + return actual.woCompare(expected, BSONObj(), rules) == 0; +} + +bool compareResults(const ResultSet& expected, + const ResultSet& actual, + const bool preserveFieldOrder) { + if (expected.size() != actual.size()) { + std::cout << "Different result size: expected: " << expected.size() + << " vs actual: " << actual.size() << "\n"; + if (!expected.empty()) { + std::cout << "First expected result: " << expected.front() << "\n"; + } + if (!actual.empty()) { + std::cout << "First actual result: " << actual.front() << "\n"; + } + return false; + } + + for (size_t i = 0; i < expected.size(); i++) { + if (!compareBSONObj(actual.at(i), expected.at(i), preserveFieldOrder)) { + std::cout << "Result at position " << i << "/" << expected.size() + << " mismatch: expected: " << expected.at(i) << " vs actual: " << actual.at(i) + << "\n"; + return false; + } + } + + return true; +} + +bool compareSBEABTAgainstExpected(const ContextFn& fn, + const std::string& pipelineStr, + const std::vector<std::string>& jsonVector, + const ResultSet& expected) { + const ResultSet& actual = runSBEAST(fn, pipelineStr, jsonVector); + return compareResults(expected, actual, true /*preserveFieldOrder*/); +} + +bool comparePipelineAgainstExpected(const ContextFn& fn, + const std::string& pipelineStr, + const std::vector<std::string>& jsonVector, + const ResultSet& expected) { + const ResultSet& actual = runPipeline(fn, pipelineStr, jsonVector); + return compareResults(expected, actual, true /*preserveFieldOrder*/); +} + +bool compareSBEABTAgainstPipeline(const ContextFn& fn, + const std::string& pipelineStr, + const std::vector<std::string>& jsonVector, + const bool preserveFieldOrder = true) { + const ResultSet& pipelineResults = runPipeline(fn, pipelineStr, jsonVector); + const ResultSet& sbeResults = runSBEAST(fn, pipelineStr, jsonVector); + + std::cout << "Pipeline: " << pipelineStr << ", input size: " << jsonVector.size() << "\n"; + const bool result = compareResults(pipelineResults, sbeResults, preserveFieldOrder); + if (result) { + std::cout << "Success. Result count: " << pipelineResults.size() << "\n"; + constexpr size_t maxResults = 1; + for (size_t i = 0; i < std::min(pipelineResults.size(), maxResults); i++) { + std::cout << "Result " << (i + 1) << "/" << pipelineResults.size() + << ": expected (pipeline): " << pipelineResults.at(i) + << " vs actual (SBE): " << sbeResults.at(i) << "\n"; + } + } + + return result; +} + +ResultSet toResultSet(const std::vector<std::string>& jsonVector) { + ResultSet results; + for (const std::string& jsonStr : jsonVector) { + results.emplace_back(fromjson(jsonStr)); + } + return results; +} + +class TestObserver : public ServiceContext::ClientObserver { +public: + TestObserver() = default; + ~TestObserver() = default; + + void onCreateClient(Client* client) final {} + + void onDestroyClient(Client* client) final {} + + void onCreateOperationContext(OperationContext* opCtx) override { + opCtx->setLockState(std::make_unique<LockerImpl>()); + } + + void onDestroyOperationContext(OperationContext* opCtx) final {} +}; + +const ServiceContext::ConstructorActionRegisterer clientObserverRegisterer{ + "TestObserver", + [](ServiceContext* service) { + service->registerClientObserver(std::make_unique<TestObserver>()); + }, + [](ServiceContext* serviceContext) {}}; + +TEST_F(NodeSBE, DiffTestBasic) { + const auto contextFn = [this]() { return makeOperationContext(); }; + const auto compare = [&contextFn](const std::string& pipelineStr, + const std::vector<std::string>& jsonVector) { + return compareSBEABTAgainstPipeline( + contextFn, pipelineStr, jsonVector, true /*preserveFieldOrder*/); + }; + + ASSERT_TRUE(compareSBEABTAgainstExpected( + contextFn, "[]", {"{a:1, b:2, c:3}"}, toResultSet({"{ a: 1, b: 2, c: 3 }"}))); + ASSERT_TRUE(compareSBEABTAgainstExpected(contextFn, + "[{$addFields: {c: {$literal: 3}}}]", + {"{a:1, b:2}"}, + toResultSet({"{ a: 1, b: 2, c: 3 }"}))); + + ASSERT_TRUE(comparePipelineAgainstExpected( + contextFn, "[]", {"{a:1, b:2, c:3}"}, toResultSet({"{ a: 1, b: 2, c: 3 }"}))); + ASSERT_TRUE(comparePipelineAgainstExpected(contextFn, + "[{$addFields: {c: {$literal: 3}}}]", + {"{a:1, b:2}"}, + toResultSet({"{ a: 1, b: 2, c: 3 }"}))); + + ASSERT_TRUE(compare("[]", {"{a:1, b:2, c:3}"})); + ASSERT_TRUE(compare("[{$addFields: {c: {$literal: 3}}}]", {"{a:1, b:2}"})); +} + +TEST_F(NodeSBE, DiffTest) { + const auto contextFn = [this]() { return makeOperationContext(); }; + const auto compare = [&contextFn](const std::string& pipelineStr, + const std::vector<std::string>& jsonVector) { + return compareSBEABTAgainstPipeline( + contextFn, pipelineStr, jsonVector, true /*preserveFieldOrder*/); + }; + + // Consider checking if compare() works first. + const auto compareUnordered = [&contextFn](const std::string& pipelineStr, + const std::vector<std::string>& jsonVector) { + return compareSBEABTAgainstPipeline( + contextFn, pipelineStr, jsonVector, false /*preserveFieldOrder*/); + }; + + ASSERT_TRUE(compare("[]", {})); + + ASSERT_TRUE(compare("[{$project: {a: 1, b: 1}}]", {"{a: 10, b: 20, c: 30}"})); + ASSERT_TRUE(compare("[{$match: {a: 2}}]", {"{a: [1, 2, 3, 4]}"})); + ASSERT_TRUE(compare("[{$match: {a: 5}}]", {"{a: [1, 2, 3, 4]}"})); + ASSERT_TRUE(compare("[{$match: {a: {$gte: 3}}}]", {"{a: [1, 2, 3, 4]}"})); + ASSERT_TRUE(compare("[{$match: {a: {$gte: 30}}}]", {"{a: [1, 2, 3, 4]}"})); + ASSERT_TRUE( + compare("[{$match: {a: {$elemMatch: {$gte: 2, $lte: 3}}}}]", {"{a: [1, 2, 3, 4]}"})); + ASSERT_TRUE( + compare("[{$match: {a: {$elemMatch: {$gte: 20, $lte: 30}}}}]", {"{a: [1, 2, 3, 4]}"})); + + ASSERT_TRUE(compare("[{$project: {'a.b': '$c'}}]", {"{a: {d: 1}, c: 2}"})); + ASSERT_TRUE(compare("[{$project: {'a.b': '$c'}}]", {"{a: [{d: 1}, {d: 2}, {b: 10}], c: 2}"})); + + ASSERT_TRUE(compareUnordered("[{$project: {'a.b': '$c', c: 1}}]", {"{a: {d: 1}, c: 2}"})); + ASSERT_TRUE(compareUnordered("[{$project: {'a.b': '$c', 'a.d': 1, c: 1}}]", + {"{a: [{d: 1}, {d: 2}, {b: 10}], c: 2}"})); + + ASSERT_TRUE( + compare("[{$project: {a: {$filter: {input: '$b', as: 'num', cond: {$and: [{$gte: ['$$num', " + "2]}, {$lte: ['$$num', 3]}]}}}}}]", + {"{b: [1, 2, 3, 4]}"})); + ASSERT_TRUE( + compare("[{$project: {a: {$filter: {input: '$b', as: 'num', cond: {$and: [{$gte: ['$$num', " + "3]}, {$lte: ['$$num', 2]}]}}}}}]", + {"{b: [1, 2, 3, 4]}"})); + + ASSERT_TRUE(compare("[{$unwind: {path: '$a'}}]", {"{a: [1, 2, 3, 4]}"})); + ASSERT_TRUE(compare("[{$unwind: {path: '$a.b'}}]", {"{a: {b: [1, 2, 3, 4]}}"})); + + ASSERT_TRUE(compare("[{$match:{'a.b.c':'aaa'}}]", {"{a: {b: {c: 'aaa'}}}"})); + ASSERT_TRUE( + compare("[{$match:{'a.b.c':'aaa'}}]", {"{a: {b: {c: 'aaa'}}}", "{a: {b: {c: 'aaa'}}}"})); + + ASSERT_TRUE(compare("[{$match: {a: {$lt: 5, $gt: 5}}}]", {"{_id: 1, a: [4, 6]}"})); + ASSERT_TRUE(compare("[{$match: {a: {$gt: null}}}]", {"{_id: 1, a: 1}"})); + + ASSERT_TRUE(compare("[{$match: {a: {$elemMatch: {$lt: 6, $gt: 4}}}}]", {"{a: [5]}"})); + ASSERT_TRUE(compare("[{$match: {'a.b': {$elemMatch: {$lt: 6, $gt: 4}}}}]", + {"{a: {b: [5]}}", "{a: [{b: 5}]}"})); + + ASSERT_TRUE(compare("[{$match: {a: {$elemMatch: {$elemMatch: {$lt: 6, $gt: 4}}}}}]", + {"{a: [[4, 5, 6], [5]]}", "{a: [4, 5, 6]}"})); +} + +} // namespace +} // namespace mongo::optimizer diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp new file mode 100644 index 00000000000..a8dbb516982 --- /dev/null +++ b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp @@ -0,0 +1,311 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/exec/sbe/abt/abt_lower.h" +#include "mongo/db/exec/sbe/abt/sbe_abt_test_util.h" +#include "mongo/db/pipeline/abt/abt_document_source_visitor.h" +#include "mongo/db/query/optimizer/explain.h" +#include "mongo/db/query/optimizer/opt_phase_manager.h" +#include "mongo/db/query/optimizer/rewrites/const_eval.h" +#include "mongo/db/query/optimizer/rewrites/path_lower.h" +#include "mongo/unittest/unittest.h" + +namespace mongo::optimizer { +namespace { + +TEST_F(ABTSBE, Lower1) { + auto tree = Constant::int64(100); + auto env = VariableEnvironment::build(tree); + SlotVarMap map; + + auto expr = SBEExpressionLowering{env, map}.optimize(tree); + + ASSERT(expr); + + auto compiledExpr = compileExpression(*expr); + auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get()); + + ASSERT_EQ(sbe::value::TypeTags::NumberInt64, resultTag); + ASSERT_EQ(sbe::value::bitcastTo<int64_t>(resultVal), 100); +} + +TEST_F(ABTSBE, Lower2) { + auto tree = + make<Let>("x", + Constant::int64(100), + make<BinaryOp>(Operations::Add, make<Variable>("x"), Constant::int64(100))); + + auto env = VariableEnvironment::build(tree); + SlotVarMap map; + + auto expr = SBEExpressionLowering{env, map}.optimize(tree); + + ASSERT(expr); + + auto compiledExpr = compileExpression(*expr); + auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get()); + + ASSERT_EQ(sbe::value::TypeTags::NumberInt64, resultTag); + ASSERT_EQ(sbe::value::bitcastTo<int64_t>(resultVal), 200); +} + +TEST_F(ABTSBE, Lower3) { + auto tree = make<FunctionCall>("isNumber", makeSeq(Constant::int64(10))); + auto env = VariableEnvironment::build(tree); + SlotVarMap map; + + auto expr = SBEExpressionLowering{env, map}.optimize(tree); + + ASSERT(expr); + + auto compiledExpr = compileExpression(*expr); + auto result = runCompiledExpressionPredicate(compiledExpr.get()); + + ASSERT(result); +} + +TEST_F(ABTSBE, Lower4) { + auto [tagArr, valArr] = sbe::value::makeNewArray(); + auto arr = sbe::value::getArrayView(valArr); + arr->push_back(sbe::value::TypeTags::NumberInt64, 1); + arr->push_back(sbe::value::TypeTags::NumberInt64, 2); + auto [tagArrNest, valArrNest] = sbe::value::makeNewArray(); + auto arrNest = sbe::value::getArrayView(valArrNest); + arrNest->push_back(sbe::value::TypeTags::NumberInt64, 21); + arrNest->push_back(sbe::value::TypeTags::NumberInt64, 22); + arr->push_back(tagArrNest, valArrNest); + arr->push_back(sbe::value::TypeTags::NumberInt64, 3); + + auto tree = make<FunctionCall>( + "traverseP", + makeSeq( + make<Constant>(tagArr, valArr), + make<LambdaAbstraction>( + "x", make<BinaryOp>(Operations::Add, make<Variable>("x"), Constant::int64(10))))); + auto env = VariableEnvironment::build(tree); + SlotVarMap map; + + auto expr = SBEExpressionLowering{env, map}.optimize(tree); + + ASSERT(expr); + + auto compiledExpr = compileExpression(*expr); + auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get()); + sbe::value::ValueGuard guard(resultTag, resultVal); + + ASSERT_EQ(sbe::value::TypeTags::Array, resultTag); +} + +TEST_F(ABTSBE, Lower5) { + auto tree = make<FunctionCall>( + "setField", makeSeq(Constant::nothing(), Constant::str("fieldA"), Constant::int64(10))); + + auto env = VariableEnvironment::build(tree); + SlotVarMap map; + + auto expr = SBEExpressionLowering{env, map}.optimize(tree); + + ASSERT(expr); + + auto compiledExpr = compileExpression(*expr); + auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get()); + sbe::value::ValueGuard guard(resultTag, resultVal); +} + +TEST_F(ABTSBE, Lower6) { + PrefixId prefixId; + + auto [tagObj, valObj] = sbe::value::makeNewObject(); + auto obj = sbe::value::getObjectView(valObj); + + auto [tagObjIn, valObjIn] = sbe::value::makeNewObject(); + auto objIn = sbe::value::getObjectView(valObjIn); + objIn->push_back("fieldB", sbe::value::TypeTags::NumberInt64, 100); + obj->push_back("fieldA", tagObjIn, valObjIn); + + sbe::value::OwnedValueAccessor accessor; + auto slotId = bindAccessor(&accessor); + SlotVarMap map; + map["root"] = slotId; + + accessor.reset(tagObj, valObj); + + auto tree = make<EvalPath>( + make<PathField>("fieldA", + make<PathTraverse>(make<PathComposeM>( + make<PathField>("fieldB", make<PathDefault>(Constant::int64(0))), + make<PathField>("fieldC", make<PathConstant>(Constant::int64(50)))))), + make<Variable>("root")); + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + // std::cout << ExplainGenerator::explain(tree); + + auto expr = SBEExpressionLowering{env, map}.optimize(tree); + + ASSERT(expr); + + auto compiledExpr = compileExpression(*expr); + auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get()); + sbe::value::ValueGuard guard(resultTag, resultVal); + + // std::cout << std::pair{resultTag, resultVal} << "\n"; + + ASSERT_EQ(sbe::value::TypeTags::Object, resultTag); +} + +TEST_F(ABTSBE, Lower7) { + PrefixId prefixId; + + auto [tagArr, valArr] = sbe::value::makeNewArray(); + auto arr = sbe::value::getArrayView(valArr); + arr->push_back(sbe::value::TypeTags::NumberInt64, 1); + arr->push_back(sbe::value::TypeTags::NumberInt64, 2); + arr->push_back(sbe::value::TypeTags::NumberInt64, 3); + + auto [tagObj, valObj] = sbe::value::makeNewObject(); + auto obj = sbe::value::getObjectView(valObj); + obj->push_back("fieldA", tagArr, valArr); + + sbe::value::OwnedValueAccessor accessor; + auto slotId = bindAccessor(&accessor); + SlotVarMap map; + map["root"] = slotId; + + accessor.reset(tagObj, valObj); + auto tree = make<EvalFilter>( + make<PathGet>("fieldA", + make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2)))), + make<Variable>("root")); + + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + auto expr = SBEExpressionLowering{env, map}.optimize(tree); + + ASSERT(expr); + auto compiledExpr = compileExpression(*expr); + auto result = runCompiledExpressionPredicate(compiledExpr.get()); + + ASSERT(result); +} + +TEST_F(NodeSBE, Lower1) { + PrefixId prefixId; + Metadata metadata{{}}; + + auto pipeline = + parsePipeline("[{$project:{'a.b.c.d':{$literal:'abc'}}}]", NamespaceString("test")); + + const auto [tag, val] = sbe::value::makeNewArray(); + { + // Create an array of array with one empty document. + auto outerArrayPtr = sbe::value::getArrayView(val); + + const auto [tag1, val1] = sbe::value::makeNewArray(); + auto innerArrayPtr = sbe::value::getArrayView(val1); + + const auto [tag2, val2] = sbe::value::makeNewObject(); + innerArrayPtr->push_back(tag2, val2); + + outerArrayPtr->push_back(tag1, val1); + } + ABT tree = make<Constant>(tag, val); + + const ProjectionName scanProjName = prefixId.getNextId("scan"); + tree = translatePipelineToABT( + metadata, + *pipeline.get(), + scanProjName, + make<ValueScanNode>(ProjectionNameVector{scanProjName}, std::move(tree)), + prefixId); + + OptPhaseManager phaseManager( + OptPhaseManager::getAllRewritesSet(), prefixId, {{}}, DebugInfo::kDefaultForTests); + + ASSERT_TRUE(phaseManager.optimize(tree)); + auto env = VariableEnvironment::build(tree); + SlotVarMap map; + sbe::value::SlotIdGenerator ids; + + SBENodeLowering g{env, + map, + ids, + phaseManager.getMetadata(), + phaseManager.getNodeToGroupPropsMap(), + phaseManager.getRIDProjections()}; + + auto sbePlan = g.optimize(tree); + + auto opCtx = makeOperationContext(); + + sbe::CompileCtx ctx(std::make_unique<sbe::RuntimeEnvironment>()); + sbePlan->prepare(ctx); + + std::vector<sbe::value::SlotAccessor*> accessors; + for (auto& [name, slot] : map) { + std::cout << name << " "; + accessors.emplace_back(sbePlan->getAccessor(ctx, slot)); + } + std::cout << "\n"; + sbePlan->attachToOperationContext(opCtx.get()); + sbePlan->open(false); + while (sbePlan->getNext() != sbe::PlanState::IS_EOF) { + for (auto acc : accessors) { + std::cout << acc->getViewOfValue() << " "; + } + std::cout << "\n"; + }; + sbePlan->close(); +} + +} // namespace +} // namespace mongo::optimizer diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp b/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp new file mode 100644 index 00000000000..95b9796c324 --- /dev/null +++ b/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp @@ -0,0 +1,63 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/exec/sbe/abt/sbe_abt_test_util.h" +#include "mongo/db/pipeline/aggregate_command_gen.h" +#include "mongo/db/pipeline/expression_context_for_test.h" +#include "mongo/unittest/temp_dir.h" + +namespace mongo::optimizer { + +static std::unique_ptr<mongo::Pipeline, mongo::PipelineDeleter> parsePipelineInternal( + NamespaceString nss, const std::string& inputPipeline, OperationContextNoop& opCtx) { + const BSONObj inputBson = fromjson("{pipeline: " + inputPipeline + "}"); + + std::vector<BSONObj> rawPipeline; + for (auto&& stageElem : inputBson["pipeline"].Array()) { + ASSERT_EQUALS(stageElem.type(), BSONType::Object); + rawPipeline.push_back(stageElem.embeddedObject()); + } + + AggregateCommandRequest request(std::move(nss), rawPipeline); + boost::intrusive_ptr<ExpressionContextForTest> ctx( + new ExpressionContextForTest(&opCtx, request)); + + unittest::TempDir tempDir("ABTPipelineTest"); + ctx->tempDir = tempDir.path(); + + return Pipeline::parse(request.getPipeline(), ctx); +} + +std::unique_ptr<mongo::Pipeline, mongo::PipelineDeleter> parsePipeline( + const std::string& pipelineStr, NamespaceString nss) { + OperationContextNoop opCtx; + return parsePipelineInternal(std::move(nss), pipelineStr, opCtx); +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.h b/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.h new file mode 100644 index 00000000000..c56d035b6ff --- /dev/null +++ b/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.h @@ -0,0 +1,48 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/exec/sbe/expression_test_base.h" +#include "mongo/db/operation_context_noop.h" +#include "mongo/db/pipeline/pipeline.h" +#include "mongo/db/query/optimizer/node.h" +#include "mongo/db/query/optimizer/utils/utils.h" +#include "mongo/db/service_context_test_fixture.h" + +#pragma once + +namespace mongo::optimizer { + +class NodeSBE : public ServiceContextTest {}; + +std::unique_ptr<mongo::Pipeline, mongo::PipelineDeleter> parsePipeline( + const std::string& pipelineStr, NamespaceString nss); + +using ABTSBE = sbe::EExpressionTestFixture; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/exec/sbe/stages/exchange.cpp b/src/mongo/db/exec/sbe/stages/exchange.cpp index 95b143ebec5..8cd7b065559 100644 --- a/src/mongo/db/exec/sbe/stages/exchange.cpp +++ b/src/mongo/db/exec/sbe/stages/exchange.cpp @@ -33,6 +33,7 @@ #include "mongo/base/init.h" #include "mongo/db/client.h" +#include "mongo/db/concurrency/d_concurrency.h" #include "mongo/db/exec/sbe/size_estimator.h" namespace mongo::sbe { @@ -389,7 +390,10 @@ void ExchangeConsumer::close() { std::unique_ptr<PlanStageStats> ExchangeConsumer::getStats(bool includeDebugInfo) const { auto ret = std::make_unique<PlanStageStats>(_commonStats); - ret->children.emplace_back(_children[0]->getStats(includeDebugInfo)); + if (!_children.empty()) { + // TODO: handle empty _children. + ret->children.emplace_back(_children[0]->getStats(includeDebugInfo)); + } return ret; } @@ -428,8 +432,11 @@ std::vector<DebugPrinter::Block> ExchangeConsumer::debugPrint() const { uasserted(4822835, "policy not yet implemented"); } - DebugPrinter::addNewLine(ret); - DebugPrinter::addBlocks(ret, _children[0]->debugPrint()); + if (!_children.empty()) { + // TODO: handle empty _children. + DebugPrinter::addNewLine(ret); + DebugPrinter::addBlocks(ret, _children[0]->debugPrint()); + } return ret; } @@ -497,6 +504,9 @@ void ExchangeProducer::start(OperationContext* opCtx, std::unique_ptr<PlanStage> producer) { ExchangeProducer* p = static_cast<ExchangeProducer*>(producer.get()); + // TODO: SERVER-62925. Rationalize this lock. + Lock::GlobalLock lock(opCtx, MODE_IS); + p->attachToOperationContext(opCtx); try { diff --git a/src/mongo/db/exec/sbe/stages/scan.cpp b/src/mongo/db/exec/sbe/stages/scan.cpp index ad3d1b7e06b..cf1e40bf6cd 100644 --- a/src/mongo/db/exec/sbe/stages/scan.cpp +++ b/src/mongo/db/exec/sbe/stages/scan.cpp @@ -55,7 +55,8 @@ ScanStage::ScanStage(UUID collectionUuid, bool forward, PlanYieldPolicy* yieldPolicy, PlanNodeId nodeId, - ScanCallbacks scanCallbacks) + ScanCallbacks scanCallbacks, + bool useRandomCursor) : PlanStage(seekKeySlot ? "seek"_sd : "scan"_sd, yieldPolicy, nodeId), _collUuid(collectionUuid), _recordSlot(recordSlot), @@ -69,7 +70,8 @@ ScanStage::ScanStage(UUID collectionUuid, _vars(std::move(vars)), _seekKeySlot(seekKeySlot), _forward(forward), - _scanCallbacks(std::move(scanCallbacks)) { + _scanCallbacks(std::move(scanCallbacks)), + _useRandomCursor(useRandomCursor) { invariant(_fields.size() == _vars.size()); invariant(!_seekKeySlot || _forward); tassert(5567202, @@ -77,6 +79,8 @@ ScanStage::ScanStage(UUID collectionUuid, !_oplogTsSlot || (std::find(_fields.begin(), _fields.end(), repl::OpTime::kTimestampFieldName) != _fields.end())); + // We cannot use a random cursor if we are seeking or requesting a reverse scan. + invariant(!_useRandomCursor || (!_seekKeySlot && _forward)); } std::unique_ptr<PlanStage> ScanStage::clone() const { @@ -197,12 +201,12 @@ void ScanStage::doSaveState(bool relinquishCursor) { #endif - if (_cursor && relinquishCursor) { - _cursor->save(); + if (auto cursor = getActiveCursor(); cursor != nullptr && relinquishCursor) { + cursor->save(); } - if (_cursor) { - _cursor->setSaveStorageCursorOnDetachFromOperationContext(!relinquishCursor); + if (auto cursor = getActiveCursor()) { + cursor->setSaveStorageCursorOnDetachFromOperationContext(!relinquishCursor); } _coll.reset(); @@ -220,10 +224,10 @@ void ScanStage::doRestoreState(bool relinquishCursor) { tassert(5777408, "Catalog epoch should be initialized", _catalogEpoch); _coll = restoreCollection(_opCtx, *_collName, _collUuid, *_catalogEpoch); - if (_cursor) { + if (auto cursor = getActiveCursor(); cursor != nullptr) { if (relinquishCursor) { const auto tolerateCappedCursorRepositioning = false; - const bool couldRestore = _cursor->restore(tolerateCappedCursorRepositioning); + const bool couldRestore = cursor->restore(tolerateCappedCursorRepositioning); uassert( ErrorCodes::CappedPositionLost, str::stream() @@ -262,14 +266,14 @@ void ScanStage::doRestoreState(bool relinquishCursor) { } void ScanStage::doDetachFromOperationContext() { - if (_cursor) { - _cursor->detachFromOperationContext(); + if (auto cursor = getActiveCursor()) { + cursor->detachFromOperationContext(); } } void ScanStage::doAttachToOperationContext(OperationContext* opCtx) { - if (_cursor) { - _cursor->reattachToOperationContext(opCtx); + if (auto cursor = getActiveCursor()) { + cursor->reattachToOperationContext(opCtx); } } @@ -283,6 +287,10 @@ PlanStage::TrialRunTrackerAttachResultMask ScanStage::doAttachToTrialRunTracker( return childrenAttachResult | TrialRunTrackerAttachResultFlags::AttachedToStreamingStage; } +RecordCursor* ScanStage::getActiveCursor() const { + return _useRandomCursor ? _randomCursor.get() : _cursor.get(); +} + void ScanStage::open(bool reOpen) { auto optTimer(getOptTimer(_opCtx)); @@ -292,13 +300,13 @@ void ScanStage::open(bool reOpen) { if (_open) { tassert(5071001, "reopened ScanStage but reOpen=false", reOpen); tassert(5071002, "ScanStage is open but _coll is not null", _coll); - tassert(5071003, "ScanStage is open but don't have _cursor", _cursor); + tassert(5071003, "ScanStage is open but doesn't have a cursor", getActiveCursor()); } else { tassert(5071004, "first open to ScanStage but reOpen=true", !reOpen); if (!_coll) { // We're being opened after 'close()'. We need to re-acquire '_coll' in this case and // make some validity checks (the collection has not been dropped, renamed, etc.). - tassert(5071005, "ScanStage is not open but have _cursor", !_cursor); + tassert(5071005, "ScanStage is not open but has a cursor", !getActiveCursor()); tassert(5777401, "Collection name should be initialized", _collName); tassert(5777402, "Catalog epoch should be initialized", _catalogEpoch); _coll = restoreCollection(_opCtx, *_collName, _collUuid, *_catalogEpoch); @@ -321,7 +329,11 @@ void ScanStage::open(bool reOpen) { } if (!_cursor || !_seekKeyAccessor) { - _cursor = _coll->getCursor(_opCtx, _forward); + if (_useRandomCursor) { + _randomCursor = _coll->getRecordStore()->getRandomCursor(_opCtx); + } else { + _cursor = _coll->getCursor(_opCtx, _forward); + } } } else { MONGO_UNREACHABLE_TASSERT(5959701); @@ -355,7 +367,8 @@ PlanState ScanStage::getNext() { } auto res = _firstGetNext && _seekKeyAccessor; - auto nextRecord = res ? _cursor->seekExact(_key) : _cursor->next(); + auto nextRecord = _useRandomCursor ? _randomCursor->next() + : (res ? _cursor->seekExact(_key) : _cursor->next()); _firstGetNext = false; if (!nextRecord) { @@ -445,6 +458,7 @@ void ScanStage::close() { trackClose(); _cursor.reset(); + _randomCursor.reset(); _coll.reset(); _open = false; } @@ -532,6 +546,10 @@ std::vector<DebugPrinter::Block> ScanStage::debugPrint() const { DebugPrinter::addIdentifier(ret, DebugPrinter::kNoneKeyword); } + if (_useRandomCursor) { + DebugPrinter::addKeyword(ret, "random"); + } + ret.emplace_back(DebugPrinter::Block("[`")); for (size_t idx = 0; idx < _fields.size(); ++idx) { if (idx) { diff --git a/src/mongo/db/exec/sbe/stages/scan.h b/src/mongo/db/exec/sbe/stages/scan.h index 95982e6eb0c..37462ac5e14 100644 --- a/src/mongo/db/exec/sbe/stages/scan.h +++ b/src/mongo/db/exec/sbe/stages/scan.h @@ -107,7 +107,8 @@ public: bool forward, PlanYieldPolicy* yieldPolicy, PlanNodeId nodeId, - ScanCallbacks scanCallbacks); + ScanCallbacks scanCallbacks, + bool useRandomCursor = false); std::unique_ptr<PlanStage> clone() const final; @@ -132,6 +133,9 @@ protected: TrialRunTracker* tracker, TrialRunTrackerAttachResultMask childrenAttachResult) override; private: + // Returns the primary cursor or the random cursor depending on whether _useRandomCursor is set. + RecordCursor* getActiveCursor() const; + const UUID _collUuid; const boost::optional<value::SlotId> _recordSlot; const boost::optional<value::SlotId> _recordIdSlot; @@ -168,6 +172,9 @@ private: value::SlotAccessor* _indexKeyPatternAccessor{nullptr}; RuntimeEnvironment::Accessor* _oplogTsAccessor{nullptr}; + // Used to return a random sample of the collection. + const bool _useRandomCursor; + value::FieldAccessorMap _fieldAccessors; value::SlotAccessorMap _varAccessors; value::SlotAccessor* _seekKeyAccessor{nullptr}; @@ -177,6 +184,10 @@ private: bool _open{false}; std::unique_ptr<SeekableRecordCursor> _cursor; + + // TODO: SERVER-62647. Consider removing random cursor when no longer needed. + std::unique_ptr<RecordCursor> _randomCursor; + RecordId _key; bool _firstGetNext{false}; diff --git a/src/mongo/db/exec/sbe/values/bson.h b/src/mongo/db/exec/sbe/values/bson.h index 73e0cd272e8..f5c429d6890 100644 --- a/src/mongo/db/exec/sbe/values/bson.h +++ b/src/mongo/db/exec/sbe/values/bson.h @@ -40,6 +40,12 @@ std::pair<value::TypeTags, value::Value> convertFrom(const char* be, const char* end, size_t fieldNameSize); +template <bool View> +std::pair<value::TypeTags, value::Value> convertFrom(const BSONElement& elem) { + return convertFrom<View>( + elem.rawdata(), elem.rawdata() + elem.size(), elem.fieldNameSize() - 1); +} + const char* advance(const char* be, size_t fieldNameSize); inline auto fieldNameView(const char* be) noexcept { diff --git a/src/mongo/db/exec/sbe/vm/vm.cpp b/src/mongo/db/exec/sbe/vm/vm.cpp index 14dbb993def..b4646cefc17 100644 --- a/src/mongo/db/exec/sbe/vm/vm.cpp +++ b/src/mongo/db/exec/sbe/vm/vm.cpp @@ -1387,7 +1387,7 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinDropFields(Arit } // Build the set of fields to drop. - StringDataSet restrictFieldsSet; + StringSet restrictFieldsSet; for (ArityType idx = 1; idx < arity; ++idx) { auto [owned, tag, val] = getFromStack(idx); @@ -1463,7 +1463,7 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinKeepFields(Arit } // Build the set of fields to keep. - StringDataSet keepFieldsSet; + StringSet keepFieldsSet; for (uint8_t idx = 1; idx < arity; ++idx) { auto [owned, tag, val] = getFromStack(idx); diff --git a/src/mongo/db/exec/sbe_cmd.cpp b/src/mongo/db/exec/sbe_cmd.cpp index 5468c9a54e7..c76885bec0d 100644 --- a/src/mongo/db/exec/sbe_cmd.cpp +++ b/src/mongo/db/exec/sbe_cmd.cpp @@ -124,6 +124,7 @@ public: std::move(cq), nullptr, {std::move(root), std::move(data)}, + {}, &CollectionPtr::null, false, /* returnOwnedBson */ nss, diff --git a/src/mongo/db/matcher/match_expression_walker.h b/src/mongo/db/matcher/match_expression_walker.h index ccef9cb5106..3313a2948ea 100644 --- a/src/mongo/db/matcher/match_expression_walker.h +++ b/src/mongo/db/matcher/match_expression_walker.h @@ -44,15 +44,21 @@ public: : _preVisitor{preVisitor}, _inVisitor{inVisitor}, _postVisitor{postVisitor} {} void preVisit(const MatchExpression* expr) { - expr->acceptVisitor(_preVisitor); + if (_preVisitor != nullptr) { + expr->acceptVisitor(_preVisitor); + } } void postVisit(const MatchExpression* expr) { - expr->acceptVisitor(_postVisitor); + if (_postVisitor != nullptr) { + expr->acceptVisitor(_postVisitor); + } } void inVisit(long count, const MatchExpression* expr) { - expr->acceptVisitor(_inVisitor); + if (_inVisitor != nullptr) { + expr->acceptVisitor(_inVisitor); + } } private: diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index fd3b9ecbb49..ffbda680ee3 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -262,6 +262,11 @@ pipelineEnv.InjectThirdParty(libraries=['snappy']) pipelineEnv.Library( target='pipeline', source=[ + 'abt/abt_document_source_visitor.cpp', + 'abt/agg_expression_visitor.cpp', + 'abt/expr_algebrizer_context.cpp', + 'abt/match_expression_visitor.cpp', + 'abt/utils.cpp', 'document_source.cpp', 'document_source_add_fields.cpp', 'document_source_bucket.cpp', @@ -315,6 +320,8 @@ pipelineEnv.Library( 'sequential_document_cache.cpp', 'skip_and_limit.cpp', 'tee_buffer.cpp', + 'visitors/document_source_walker.cpp', + 'visitors/transformer_interface_walker.cpp', 'window_function/partition_iterator.cpp', 'window_function/spillable_cache.cpp', 'window_function/window_function_exec.cpp', @@ -344,6 +351,7 @@ pipelineEnv.Library( '$BUILD_DIR/mongo/db/query/collation/collator_factory_interface', '$BUILD_DIR/mongo/db/query/collation/collator_interface', '$BUILD_DIR/mongo/db/query/datetime/date_time_support', + '$BUILD_DIR/mongo/db/query/optimizer/optimizer', '$BUILD_DIR/mongo/db/query/query_knobs', '$BUILD_DIR/mongo/db/query/sort_pattern', '$BUILD_DIR/mongo/db/repl/oplog_entry', @@ -487,6 +495,7 @@ env.Library( env.CppUnitTest( target='db_pipeline_test', source=[ + 'abt/pipeline_test.cpp', 'accumulator_js_test.cpp' if get_option('js-engine') != 'none' else [], 'accumulator_test.cpp', 'aggregation_request_test.cpp', @@ -598,6 +607,7 @@ env.CppUnitTest( '$BUILD_DIR/mongo/db/exec/document_value/document_value_test_util', '$BUILD_DIR/mongo/db/mongohasher', '$BUILD_DIR/mongo/db/query/collation/collator_interface_mock', + '$BUILD_DIR/mongo/db/query/optimizer/unit_test_utils', '$BUILD_DIR/mongo/db/query/query_test_service_context', '$BUILD_DIR/mongo/db/repl/image_collection_entry', '$BUILD_DIR/mongo/db/repl/oplog_entry', diff --git a/src/mongo/db/pipeline/abt/abt_document_source_visitor.cpp b/src/mongo/db/pipeline/abt/abt_document_source_visitor.cpp new file mode 100644 index 00000000000..913866c7364 --- /dev/null +++ b/src/mongo/db/pipeline/abt/abt_document_source_visitor.cpp @@ -0,0 +1,878 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/pipeline/abt/abt_document_source_visitor.h" +#include "mongo/db/exec/add_fields_projection_executor.h" +#include "mongo/db/exec/exclusion_projection_executor.h" +#include "mongo/db/exec/inclusion_projection_executor.h" +#include "mongo/db/pipeline/abt/agg_expression_visitor.h" +#include "mongo/db/pipeline/abt/match_expression_visitor.h" +#include "mongo/db/pipeline/abt/utils.h" +#include "mongo/db/pipeline/document_source_bucket_auto.h" +#include "mongo/db/pipeline/document_source_coll_stats.h" +#include "mongo/db/pipeline/document_source_current_op.h" +#include "mongo/db/pipeline/document_source_cursor.h" +#include "mongo/db/pipeline/document_source_exchange.h" +#include "mongo/db/pipeline/document_source_facet.h" +#include "mongo/db/pipeline/document_source_geo_near.h" +#include "mongo/db/pipeline/document_source_geo_near_cursor.h" +#include "mongo/db/pipeline/document_source_graph_lookup.h" +#include "mongo/db/pipeline/document_source_group.h" +#include "mongo/db/pipeline/document_source_index_stats.h" +#include "mongo/db/pipeline/document_source_internal_inhibit_optimization.h" +#include "mongo/db/pipeline/document_source_internal_shard_filter.h" +#include "mongo/db/pipeline/document_source_internal_split_pipeline.h" +#include "mongo/db/pipeline/document_source_limit.h" +#include "mongo/db/pipeline/document_source_list_cached_and_active_users.h" +#include "mongo/db/pipeline/document_source_list_local_sessions.h" +#include "mongo/db/pipeline/document_source_list_sessions.h" +#include "mongo/db/pipeline/document_source_lookup.h" +#include "mongo/db/pipeline/document_source_match.h" +#include "mongo/db/pipeline/document_source_merge.h" +#include "mongo/db/pipeline/document_source_operation_metrics.h" +#include "mongo/db/pipeline/document_source_out.h" +#include "mongo/db/pipeline/document_source_plan_cache_stats.h" +#include "mongo/db/pipeline/document_source_queue.h" +#include "mongo/db/pipeline/document_source_redact.h" +#include "mongo/db/pipeline/document_source_replace_root.h" +#include "mongo/db/pipeline/document_source_sample.h" +#include "mongo/db/pipeline/document_source_sample_from_random_cursor.h" +#include "mongo/db/pipeline/document_source_sequential_document_cache.h" +#include "mongo/db/pipeline/document_source_single_document_transformation.h" +#include "mongo/db/pipeline/document_source_skip.h" +#include "mongo/db/pipeline/document_source_sort.h" +#include "mongo/db/pipeline/document_source_tee_consumer.h" +#include "mongo/db/pipeline/document_source_union_with.h" +#include "mongo/db/pipeline/document_source_unwind.h" +#include "mongo/db/pipeline/visitors/document_source_walker.h" +#include "mongo/db/pipeline/visitors/transformer_interface_walker.h" +#include "mongo/s/query/document_source_merge_cursors.h" +#include "mongo/util/string_map.h" + +namespace mongo::optimizer { + +class DSAlgebrizerContext { +public: + struct NodeWithRootProjection { + NodeWithRootProjection(ProjectionName rootProjection, ABT node) + : _rootProjection(std::move(rootProjection)), _node(std::move(node)) {} + + ProjectionName _rootProjection; + ABT _node; + }; + + DSAlgebrizerContext(PrefixId& prefixId, NodeWithRootProjection node) + : _node(std::move(node)), _scanProjName(_node._rootProjection), _prefixId(prefixId) { + assertNodeSort(_node._node); + } + + template <typename T, typename... Args> + inline auto setNode(ProjectionName rootProjection, Args&&... args) { + setNode(std::move(rootProjection), std::move(ABT::make<T>(std::forward<Args>(args)...))); + } + + void setNode(ProjectionName rootProjection, ABT node) { + assertNodeSort(node); + _node._node = std::move(node); + _node._rootProjection = std::move(rootProjection); + } + + NodeWithRootProjection& getNode() { + return _node; + } + + std::string getNextId(const std::string& key) { + return _prefixId.getNextId(key); + } + + PrefixId& getPrefixId() { + return _prefixId; + } + + const ProjectionName& getScanProjName() const { + return _scanProjName; + } + +private: + NodeWithRootProjection _node; + ProjectionName _scanProjName; + + // We don't own this. + PrefixId& _prefixId; +}; + +class ABTTransformerVisitor : public TransformerInterfaceConstVisitor { + static constexpr const char* kRootElement = "$root"; + +public: + ABTTransformerVisitor(DSAlgebrizerContext& ctx) : _ctx(ctx) {} + + void visit(const projection_executor::AddFieldsProjectionExecutor* transformer) override { + visitInclusionNode(transformer->getRoot(), true /*isAddingFields*/); + } + + void visit(const projection_executor::ExclusionProjectionExecutor* transformer) override { + visitExclusionNode(*transformer->getRoot()); + } + + void visit(const projection_executor::InclusionProjectionExecutor* transformer) override { + visitInclusionNode(*transformer->getRoot(), false /*isAddingFields*/); + } + + void visit(const GroupFromFirstDocumentTransformation* transformer) override { + // TODO: Is this internal-only? + unsupportedTransformer(transformer); + } + + void visit(const ReplaceRootTransformation* transformer) override { + auto entry = _ctx.getNode(); + const std::string& projName = _ctx.getNextId("newRoot"); + ABT expr = generateAggExpression( + transformer->getExpression().get(), entry._rootProjection, projName); + + _ctx.setNode<EvaluationNode>(projName, projName, std::move(expr), std::move(entry._node)); + } + + void generateCombinedProjection() const { + auto it = _fieldMap.find(kRootElement); + if (it == _fieldMap.cend()) { + return; + } + + ABT result = generateABTForField(it->second); + auto entry = _ctx.getNode(); + const ProjectionName projName = _ctx.getNextId("combinedProjection"); + _ctx.setNode<EvaluationNode>(projName, projName, std::move(result), std::move(entry._node)); + } + +private: + struct FieldMapEntry { + FieldMapEntry(std::string fieldName) : _fieldName(std::move(fieldName)) { + uassert(6624200, "Empty field name", !_fieldName.empty()); + } + + std::string _fieldName; + bool _hasKeep = false; + bool _hasLeadingObj = false; + bool _hasTrailingDefault = false; + bool _hasDrop = false; + std::string _constVarName; + + std::set<std::string> _childPaths; + }; + + ABT generateABTForField(const FieldMapEntry& entry) const { + const bool isRootEntry = entry._fieldName == kRootElement; + + bool hasLeadingObj = false; + bool hasTrailingDefault = false; + std::set<std::string> keepSet; + std::set<std::string> dropSet; + std::map<std::string, std::string> varMap; + + for (const std::string& childField : entry._childPaths) { + const FieldMapEntry& childEntry = _fieldMap.at(childField); + const std::string& childFieldName = childEntry._fieldName; + + if (childEntry._hasKeep) { + keepSet.insert(childFieldName); + } + if (childEntry._hasDrop) { + dropSet.insert(childFieldName); + } + if (childEntry._hasLeadingObj) { + hasLeadingObj = true; + } + if (childEntry._hasTrailingDefault) { + hasTrailingDefault = true; + } + if (!childEntry._constVarName.empty()) { + varMap.emplace(childFieldName, childEntry._constVarName); + } + } + + const auto& ctxEntry = _ctx.getNode(); + const ProjectionName& rootProjName = ctxEntry._rootProjection; + + ABT result = make<PathIdentity>(); + if (hasLeadingObj && (!isRootEntry || rootProjName != _ctx.getScanProjName())) { + // We do not need a leading Obj if we are using the scan projection directly (scan + // delivers Objects). + maybeComposePath(result, make<PathObj>()); + } + if (!keepSet.empty()) { + maybeComposePath(result, make<PathKeep>(toUnorderedFieldNameSet(std::move(keepSet)))); + } + if (!dropSet.empty()) { + maybeComposePath(result, make<PathDrop>(toUnorderedFieldNameSet(std::move(dropSet)))); + } + + for (const auto& varMapEntry : varMap) { + maybeComposePath( + result, + make<PathField>(varMapEntry.first, + make<PathConstant>(make<Variable>(varMapEntry.second)))); + } + + for (const std::string& childPath : entry._childPaths) { + const FieldMapEntry& childEntry = _fieldMap.at(childPath); + + ABT childResult = generateABTForField(childEntry); + if (!childResult.is<PathIdentity>()) { + maybeComposePath(result, + make<PathField>(childEntry._fieldName, + make<PathTraverse>(std::move(childResult)))); + } + } + + if (hasTrailingDefault) { + maybeComposePath(result, make<PathDefault>(Constant::emptyObject())); + } + if (!isRootEntry) { + return result; + } + return make<EvalPath>(std::move(result), make<Variable>(rootProjName)); + } + + void integrateFieldPath( + const FieldPath& fieldPath, + const std::function<void(const bool isLastElement, FieldMapEntry& entry)>& fn) { + std::string path = kRootElement; + auto it = _fieldMap.emplace(path, kRootElement); + const size_t fieldPathLength = fieldPath.getPathLength(); + + for (size_t i = 0; i < fieldPathLength; i++) { + const std::string& fieldName = fieldPath.getFieldName(i).toString(); + path += '.' + fieldName; + + it.first->second._childPaths.insert(path); + it = _fieldMap.emplace(path, fieldName); + fn(i == fieldPathLength - 1, it.first->second); + } + } + + void unsupportedTransformer(const TransformerInterface* transformer) const { + uasserted(ErrorCodes::InternalErrorNotSupported, + str::stream() << "Transformer is not supported (code: " + << static_cast<int>(transformer->getType()) << ")"); + } + + void processProjectedPaths(const projection_executor::InclusionNode& node) { + std::set<std::string> preservedPaths; + node.reportProjectedPaths(&preservedPaths); + + for (const std::string& preservedPathStr : preservedPaths) { + integrateFieldPath(FieldPath(preservedPathStr), + [](const bool isLastElement, FieldMapEntry& entry) { + entry._hasLeadingObj = true; + entry._hasKeep = true; + }); + } + } + + void processComputedPaths(const projection_executor::InclusionNode& node, + const std::string& rootProjection, + const bool isAddingFields) { + std::set<std::string> computedPaths; + StringMap<std::string> renamedPaths; + node.reportComputedPaths(&computedPaths, &renamedPaths); + + // Handle path renames: essentially single element FieldPath expression. + for (const auto& renamedPathEntry : renamedPaths) { + ABT path = translateFieldPath( + FieldPath(renamedPathEntry.second), + make<PathIdentity>(), + [](const std::string& fieldName, const bool isLastElement, ABT input) { + return make<PathGet>(fieldName, + isLastElement ? std::move(input) + : make<PathTraverse>(std::move(input))); + }); + + auto entry = _ctx.getNode(); + const std::string& renamedProjName = _ctx.getNextId("projRenamedPath"); + _ctx.setNode<EvaluationNode>( + entry._rootProjection, + renamedProjName, + make<EvalPath>(std::move(path), make<Variable>(entry._rootProjection)), + std::move(entry._node)); + + integrateFieldPath(FieldPath(renamedPathEntry.first), + [&renamedProjName, &isAddingFields](const bool isLastElement, + FieldMapEntry& entry) { + if (!isAddingFields) { + entry._hasKeep = true; + } + if (isLastElement) { + entry._constVarName = renamedProjName; + entry._hasTrailingDefault = true; + } + }); + } + + // Handle general expression projection. + for (const std::string& computedPathStr : computedPaths) { + const FieldPath computedPath(computedPathStr); + + auto entry = _ctx.getNode(); + const std::string& getProjName = _ctx.getNextId("projGetPath"); + ABT getExpr = generateAggExpression( + node.getExpressionForPath(computedPath).get(), rootProjection, getProjName); + + _ctx.setNode<EvaluationNode>(std::move(entry._rootProjection), + getProjName, + std::move(getExpr), + std::move(entry._node)); + + integrateFieldPath( + computedPath, + [&getProjName, &isAddingFields](const bool isLastElement, FieldMapEntry& entry) { + if (!isAddingFields) { + entry._hasKeep = true; + } + if (isLastElement) { + entry._constVarName = getProjName; + entry._hasTrailingDefault = true; + } + }); + } + } + + void visitInclusionNode(const projection_executor::InclusionNode& node, + const bool isAddingFields) { + auto entry = _ctx.getNode(); + const std::string rootProjection = entry._rootProjection; + + processProjectedPaths(node); + processComputedPaths(node, rootProjection, isAddingFields); + } + + void visitExclusionNode(const projection_executor::ExclusionNode& node) { + std::set<std::string> preservedPaths; + node.reportProjectedPaths(&preservedPaths); + + for (const std::string& preservedPathStr : preservedPaths) { + integrateFieldPath(FieldPath(preservedPathStr), + [](const bool isLastElement, FieldMapEntry& entry) { + if (isLastElement) { + entry._hasDrop = true; + } + }); + } + } + + DSAlgebrizerContext& _ctx; + + opt::unordered_map<std::string, FieldMapEntry> _fieldMap; +}; + +class ABTDocumentSourceVisitor : public DocumentSourceConstVisitor { +public: + ABTDocumentSourceVisitor(DSAlgebrizerContext& ctx, const Metadata& metadata) + : _ctx(ctx), _metadata(metadata) {} + + void visit(const DocumentSourceBucketAuto* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceCollStats* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceCurrentOp* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceCursor* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceExchange* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceFacet* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceGeoNear* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceGeoNearCursor* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceGraphLookUp* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceGroup* source) override { + const StringMap<boost::intrusive_ptr<Expression>>& idFields = source->getIdFields(); + uassert(6624201, "Empty idFields map", !idFields.empty()); + + std::vector<FieldNameType> groupByFieldNames; + for (const auto& [fieldName, expr] : idFields) { + groupByFieldNames.push_back(fieldName); + } + // Sort in order to generate consistent plans. + std::sort(groupByFieldNames.begin(), groupByFieldNames.end()); + + ProjectionNameVector groupByProjNames; + auto entry = _ctx.getNode(); + for (const FieldNameType& fieldName : groupByFieldNames) { + const ProjectionName groupByProjName = _ctx.getNextId("groupByProj"); + groupByProjNames.push_back(groupByProjName); + + ABT groupByExpr = generateAggExpression( + idFields.at(fieldName).get(), entry._rootProjection, groupByProjName); + + _ctx.setNode<EvaluationNode>(entry._rootProjection, + groupByProjName, + std::move(groupByExpr), + std::move(entry._node)); + entry = _ctx.getNode(); + } + + // Fields corresponding to each accumulator + ProjectionNameVector aggProjFieldNames; + // Projection names corresponding to each high-level accumulator ($avg can be broken down + // into sum and count.). + ProjectionNameVector aggOutputProjNames; + // Projection names corresponding to each low-level accumulator (no $avg). + ProjectionNameVector aggLowLevelOutputProjNames; + + ABTVector aggregationProjections; + const std::vector<AccumulationStatement>& accumulatedFields = + source->getAccumulatedFields(); + + // Used to keep track which $sum and $count projections to use to compute $avg. + struct AvgProjNames { + ProjectionName _output; + ProjectionName _sum; + ProjectionName _count; + }; + std::vector<AvgProjNames> avgProjNames; + + for (const AccumulationStatement& stmt : accumulatedFields) { + const FieldNameType& fieldName = stmt.fieldName; + aggProjFieldNames.push_back(fieldName); + + ProjectionName aggOutputProjName = _ctx.getNextId(fieldName + "_agg"); + + ABT aggInputExpr = generateAggExpression( + stmt.expr.argument.get(), entry._rootProjection, aggOutputProjName); + if (!aggInputExpr.is<Constant>() && !aggInputExpr.is<Variable>()) { + // Generate nodes for complex projections, otherwise inline constants and variables + // into the group. + const ProjectionName aggInputProjName = _ctx.getNextId("groupByInputProj"); + _ctx.setNode<EvaluationNode>(entry._rootProjection, + aggInputProjName, + std::move(aggInputExpr), + std::move(entry._node)); + entry = _ctx.getNode(); + aggInputExpr = make<Variable>(aggInputProjName); + } + + aggOutputProjNames.push_back(aggOutputProjName); + if (stmt.makeAccumulator()->getOpName() == "$avg"_sd) { + // Express $avg as sum / count. + ProjectionName sumProjName = _ctx.getNextId(fieldName + "_sum_agg"); + aggLowLevelOutputProjNames.push_back(sumProjName); + ProjectionName countProjName = _ctx.getNextId(fieldName + "_count_agg"); + aggLowLevelOutputProjNames.push_back(countProjName); + avgProjNames.emplace_back(AvgProjNames{std::move(aggOutputProjName), + std::move(sumProjName), + std::move(countProjName)}); + + aggregationProjections.emplace_back( + make<FunctionCall>("$sum", makeSeq(aggInputExpr))); + aggregationProjections.emplace_back( + make<FunctionCall>("$sum", makeSeq(Constant::int64(1)))); + } else { + aggLowLevelOutputProjNames.push_back(std::move(aggOutputProjName)); + aggregationProjections.emplace_back(make<FunctionCall>( + stmt.makeAccumulator()->getOpName(), makeSeq(std::move(aggInputExpr)))); + } + } + + ABT result = make<GroupByNode>(ProjectionNameVector{groupByProjNames}, + aggLowLevelOutputProjNames, + aggregationProjections, + std::move(entry._node)); + + for (auto&& [outputProjName, sumProjName, countProjName] : avgProjNames) { + result = make<EvaluationNode>( + std::move(outputProjName), + make<If>(make<BinaryOp>( + Operations::Gt, make<Variable>(countProjName), Constant::int64(0)), + make<BinaryOp>(Operations::Div, + make<Variable>(std::move(sumProjName)), + make<Variable>(countProjName)), + Constant::nothing()), + std::move(result)); + } + + ABT integrationPath = make<PathIdentity>(); + for (size_t i = 0; i < groupByFieldNames.size(); i++) { + maybeComposePath(integrationPath, + make<PathField>(std::move(groupByFieldNames.at(i)), + make<PathConstant>(make<Variable>( + std::move(groupByProjNames.at(i)))))); + } + for (size_t i = 0; i < aggProjFieldNames.size(); i++) { + maybeComposePath( + integrationPath, + make<PathField>(aggProjFieldNames.at(i), + make<PathConstant>(make<Variable>(aggOutputProjNames.at(i))))); + } + + entry = _ctx.getNode(); + const std::string& mergeProject = _ctx.getNextId("agg_project"); + _ctx.setNode<EvaluationNode>( + mergeProject, + mergeProject, + make<EvalPath>(std::move(integrationPath), Constant::emptyObject()), + std::move(result)); + } + + void visit(const DocumentSourceIndexStats* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceInternalInhibitOptimization* source) override { + // Can be ignored. + } + + void visit(const DocumentSourceInternalShardFilter* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceInternalSplitPipeline* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceLimit* source) override { + pushLimitSkip(source->getLimit(), 0); + } + + void visit(const DocumentSourceListCachedAndActiveUsers* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceListLocalSessions* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceListSessions* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceLookUp* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceMatch* source) override { + auto entry = _ctx.getNode(); + ABT matchExpr = generateMatchExpression(source->getMatchExpression(), + true /*allowAggExpressions*/, + entry._rootProjection, + _ctx.getNextId("matchExpression")); + + // If we have a top-level composition, create separate filters. + const auto& composition = collectComposed(matchExpr); + for (const auto& path : composition) { + _ctx.setNode<FilterNode>(entry._rootProjection, + make<EvalFilter>(path, make<Variable>(entry._rootProjection)), + std::move(entry._node)); + entry = _ctx.getNode(); + } + } + + void visit(const DocumentSourceMerge* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceMergeCursors* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceOperationMetrics* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceOut* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourcePlanCacheStats* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceQueue* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceRedact* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceSample* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceSampleFromRandomCursor* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceSequentialDocumentCache* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceSingleDocumentTransformation* source) override { + ABTTransformerVisitor visitor(_ctx); + TransformerInterfaceWalker walker(&visitor); + walker.walk(&source->getTransformer()); + visitor.generateCombinedProjection(); + } + + void visit(const DocumentSourceSkip* source) override { + pushLimitSkip(-1, source->getSkip()); + } + + void visit(const DocumentSourceSort* source) override { + ProjectionCollationSpec collationSpec; + const SortPattern& pattern = source->getSortKeyPattern(); + for (size_t i = 0; i < pattern.size(); i++) { + const SortPattern::SortPatternPart& part = pattern[i]; + if (!part.fieldPath.has_value()) { + // TODO: consider metadata expression. + continue; + } + + const std::string& sortProjName = _ctx.getNextId("sort"); + collationSpec.emplace_back( + sortProjName, part.isAscending ? CollationOp::Ascending : CollationOp::Descending); + + const FieldPath& fieldPath = part.fieldPath.value(); + ABT sortPath = make<PathIdentity>(); + for (size_t j = 0; j < fieldPath.getPathLength(); j++) { + sortPath = make<PathGet>(fieldPath.getFieldName(j).toString(), std::move(sortPath)); + } + + auto entry = _ctx.getNode(); + _ctx.setNode<EvaluationNode>( + entry._rootProjection, + sortProjName, + make<EvalPath>(std::move(sortPath), make<Variable>(entry._rootProjection)), + std::move(entry._node)); + } + + if (!collationSpec.empty()) { + auto entry = _ctx.getNode(); + _ctx.setNode<CollationNode>(std::move(entry._rootProjection), + properties::CollationRequirement(std::move(collationSpec)), + std::move(entry._node)); + } + + if (source->getLimit().has_value()) { + // We need to limit the result of the collation. + pushLimitSkip(source->getLimit().value(), 0); + } + } + + void visit(const DocumentSourceTeeConsumer* source) override { + unsupportedStage(source); + } + + void visit(const DocumentSourceUnionWith* source) override { + auto entry = _ctx.getNode(); + ProjectionName unionProjName = entry._rootProjection; + + const Pipeline& pipeline = source->getPipeline(); + + NamespaceString involvedNss = pipeline.getContext()->ns; + std::string scanDefName = involvedNss.coll().toString(); + const ProjectionName& scanProjName = _ctx.getNextId("scan"); + + PrefixId prefixId; + ABT initialNode = _metadata._scanDefs.at(scanDefName).exists() + ? make<ScanNode>(scanProjName, scanDefName) + : make<ValueScanNode>(ProjectionNameVector{scanProjName}); + ABT pipelineABT = translatePipelineToABT( + _metadata, pipeline, scanProjName, std::move(initialNode), prefixId); + + uassert(6624425, "Expected root node for union pipeline", pipelineABT.is<RootNode>()); + ABT pipelineABTWithoutRoot = pipelineABT.cast<RootNode>()->getChild(); + // Pull out the root projection(s) from the inner pipeline. + const ProjectionNameVector& rootProjections = + pipelineABT.cast<RootNode>()->getProperty().getProjections().getVector(); + uassert(6624426, + "Expected a single projection for inner union branch", + rootProjections.size() == 1); + + // Add an evaluation node such that it shares a projection with the outer pipeline. If the + // same projection name is already defined in the inner pipeline then there's no need for + // the extra eval node. + const ProjectionName& innerProjection = rootProjections[0]; + ProjectionName newRootProj = unionProjName; + if (innerProjection != unionProjName) { + ABT evalNodeInner = make<EvaluationNode>( + unionProjName, + make<EvalPath>(make<PathIdentity>(), make<Variable>(innerProjection)), + std::move(pipelineABTWithoutRoot)); + _ctx.setNode( + std::move(newRootProj), + make<UnionNode>(ProjectionNameVector{std::move(unionProjName)}, + makeSeq(std::move(entry._node), std::move(evalNodeInner)))); + } else { + _ctx.setNode(std::move(newRootProj), + make<UnionNode>( + ProjectionNameVector{std::move(unionProjName)}, + makeSeq(std::move(entry._node), std::move(pipelineABTWithoutRoot)))); + } + } + + void visit(const DocumentSourceUnwind* source) override { + const FieldPath& unwindFieldPath = source->getUnwindPath(); + const bool preserveNullAndEmpty = source->preserveNullAndEmptyArrays(); + + const std::string pidProjName = _ctx.getNextId("unwoundPid"); + const std::string unwoundProjName = _ctx.getNextId("unwoundProj"); + + const auto generatePidGteZeroTest = [&pidProjName](ABT thenCond, ABT elseCond) { + return make<If>( + make<BinaryOp>(Operations::Gte, make<Variable>(pidProjName), Constant::int64(0)), + std::move(thenCond), + std::move(elseCond)); + }; + + ABT embedPath = make<Variable>(unwoundProjName); + if (preserveNullAndEmpty) { + const std::string unwindLambdaVarName = _ctx.getNextId("unwoundLambdaVarName"); + embedPath = make<PathLambda>(make<LambdaAbstraction>( + unwindLambdaVarName, + generatePidGteZeroTest(std::move(embedPath), make<Variable>(unwindLambdaVarName)))); + } else { + embedPath = make<PathConstant>(std::move(embedPath)); + } + embedPath = translateFieldPath( + unwindFieldPath, + std::move(embedPath), + [](const std::string& fieldName, const bool isLastElement, ABT input) { + return make<PathField>(fieldName, + isLastElement ? std::move(input) + : make<PathTraverse>(std::move(input))); + }); + + ABT unwoundPath = translateFieldPath( + unwindFieldPath, + make<PathIdentity>(), + [](const std::string& fieldName, const bool isLastElement, ABT input) { + return make<PathGet>(fieldName, std::move(input)); + }); + + auto entry = _ctx.getNode(); + _ctx.setNode<EvaluationNode>( + entry._rootProjection, + unwoundProjName, + make<EvalPath>(std::move(unwoundPath), make<Variable>(entry._rootProjection)), + std::move(entry._node)); + + entry = _ctx.getNode(); + _ctx.setNode<UnwindNode>(std::move(entry._rootProjection), + unwoundProjName, + pidProjName, + preserveNullAndEmpty, + std::move(entry._node)); + + entry = _ctx.getNode(); + const std::string embedProjName = _ctx.getNextId("embedProj"); + _ctx.setNode<EvaluationNode>( + embedProjName, + embedProjName, + make<EvalPath>(std::move(embedPath), make<Variable>(entry._rootProjection)), + std::move(entry._node)); + + if (source->indexPath().has_value()) { + const FieldPath indexFieldPath = source->indexPath().get(); + if (indexFieldPath.getPathLength() > 0) { + ABT indexPath = translateFieldPath( + indexFieldPath, + make<PathConstant>( + generatePidGteZeroTest(make<Variable>(pidProjName), Constant::null())), + [](const std::string& fieldName, const bool isLastElement, ABT input) { + return make<PathField>(fieldName, std::move(input)); + }); + + entry = _ctx.getNode(); + const std::string embedPidProjName = _ctx.getNextId("embedPidProj"); + _ctx.setNode<EvaluationNode>( + embedPidProjName, + embedPidProjName, + make<EvalPath>(std::move(indexPath), make<Variable>(entry._rootProjection)), + std::move(entry._node)); + } + } + } + +private: + void unsupportedStage(const DocumentSource* source) const { + uasserted(ErrorCodes::InternalErrorNotSupported, + str::stream() << "Stage is not supported: " << source->getSourceName()); + } + + void pushLimitSkip(const int64_t limit, const int64_t skip) { + auto entry = _ctx.getNode(); + _ctx.setNode<LimitSkipNode>(std::move(entry._rootProjection), + properties::LimitSkipRequirement(limit, skip), + std::move(entry._node)); + } + + DSAlgebrizerContext& _ctx; + const Metadata& _metadata; +}; + +ABT translatePipelineToABT(const Metadata& metadata, + const Pipeline& pipeline, + ProjectionName scanProjName, + ABT initialNode, + PrefixId& prefixId) { + DSAlgebrizerContext ctx(prefixId, {scanProjName, std::move(initialNode)}); + ABTDocumentSourceVisitor visitor(ctx, metadata); + + DocumentSourceWalker walker(nullptr /*preVisitor*/, &visitor); + walker.walk(pipeline); + + auto entry = ctx.getNode(); + return make<RootNode>( + properties::ProjectionRequirement{ProjectionNameVector{std::move(entry._rootProjection)}}, + std::move(entry._node)); +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/pipeline/abt/abt_document_source_visitor.h b/src/mongo/db/pipeline/abt/abt_document_source_visitor.h new file mode 100644 index 00000000000..d52d1fe8b23 --- /dev/null +++ b/src/mongo/db/pipeline/abt/abt_document_source_visitor.h @@ -0,0 +1,44 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/pipeline/pipeline.h" +#include "mongo/db/query/optimizer/node.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer { + +ABT translatePipelineToABT(const Metadata& metadata, + const Pipeline& pipeline, + ProjectionName scanProjName, + ABT initialNode, + PrefixId& prefixId); + +} // namespace mongo::optimizer diff --git a/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp b/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp new file mode 100644 index 00000000000..ff6dc5e3b45 --- /dev/null +++ b/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp @@ -0,0 +1,840 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <stack> + +#include "mongo/base/error_codes.h" +#include "mongo/db/pipeline/abt/agg_expression_visitor.h" +#include "mongo/db/pipeline/abt/expr_algebrizer_context.h" +#include "mongo/db/pipeline/abt/utils.h" +#include "mongo/db/pipeline/accumulator.h" +#include "mongo/db/pipeline/accumulator_multi.h" +#include "mongo/db/pipeline/expression_walker.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer { + +class ABTAggExpressionVisitor final : public ExpressionConstVisitor { +public: + ABTAggExpressionVisitor(ExpressionAlgebrizerContext& ctx) : _prefixId(), _ctx(ctx){}; + + void visit(const ExpressionConstant* expr) override final { + auto [tag, val] = convertFrom(expr->getValue()); + _ctx.push<Constant>(tag, val); + } + + void visit(const ExpressionAbs* expr) override final { + pushSingleArgFunctionFromTop("abs"); + } + + void visit(const ExpressionAdd* expr) override final { + pushArithmeticBinaryExpr(expr, Operations::Add); + } + + void visit(const ExpressionAllElementsTrue* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionAnd* expr) override final { + visitMultiBranchLogicExpression(expr, Operations::And); + } + + void visit(const ExpressionAnyElementTrue* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionArray* expr) override final { + const size_t childCount = expr->getChildren().size(); + _ctx.ensureArity(childCount); + + // Need to process in reverse order because of the stack. + ABTVector args; + for (size_t i = 0; i < childCount; i++) { + args.emplace_back(_ctx.pop()); + } + + std::reverse(args.begin(), args.end()); + _ctx.push<FunctionCall>("newArray", std::move(args)); + } + + void visit(const ExpressionArrayElemAt* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionFirst* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionLast* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionObjectToArray* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionArrayToObject* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionBsonSize* expr) override final { + pushSingleArgFunctionFromTop("bsonSize"); + } + + void visit(const ExpressionCeil* expr) override final { + pushSingleArgFunctionFromTop("ceil"); + } + + void visit(const ExpressionCoerceToBool* expr) override final { + // Since $coerceToBool is internal-only and there are not yet any input expressions that + // generate an ExpressionCoerceToBool expression, we will leave it as unreachable for now. + MONGO_UNREACHABLE; + } + + void visit(const ExpressionCompare* expr) override final { + _ctx.ensureArity(2); + ABT right = _ctx.pop(); + ABT left = _ctx.pop(); + + switch (expr->getOp()) { + case ExpressionCompare::CmpOp::EQ: + _ctx.push<BinaryOp>(Operations::Eq, std::move(left), std::move(right)); + break; + + case ExpressionCompare::CmpOp::NE: + _ctx.push<BinaryOp>(Operations::Neq, std::move(left), std::move(right)); + break; + + case ExpressionCompare::CmpOp::GT: + _ctx.push<BinaryOp>(Operations::Gt, std::move(left), std::move(right)); + break; + + case ExpressionCompare::CmpOp::GTE: + _ctx.push<BinaryOp>(Operations::Gte, std::move(left), std::move(right)); + break; + + case ExpressionCompare::CmpOp::LT: + _ctx.push<BinaryOp>(Operations::Lt, std::move(left), std::move(right)); + break; + + case ExpressionCompare::CmpOp::LTE: + _ctx.push<BinaryOp>(Operations::Lte, std::move(left), std::move(right)); + break; + + case ExpressionCompare::CmpOp::CMP: + _ctx.push<FunctionCall>("cmp3w", makeSeq(std::move(left), std::move(right))); + break; + + default: + MONGO_UNREACHABLE; + } + } + + void visit(const ExpressionConcat* expr) override final { + pushMultiArgFunctionFromTop("concat", expr->getChildren().size()); + } + + void visit(const ExpressionConcatArrays* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionCond* expr) override final { + _ctx.ensureArity(3); + ABT cond = generateCoerceToBoolPopInput(); + ABT thenCase = _ctx.pop(); + ABT elseCase = _ctx.pop(); + _ctx.push<If>(std::move(cond), std::move(thenCase), std::move(elseCase)); + } + + void visit(const ExpressionDateFromString* expr) override final { + unsupportedExpression("$dateFromString"); + } + + void visit(const ExpressionDateFromParts* expr) override final { + unsupportedExpression("$dateFromParts"); + } + + void visit(const ExpressionDateDiff* expr) override final { + unsupportedExpression("$dateDiff"); + } + + void visit(const ExpressionDateToParts* expr) override final { + unsupportedExpression("$dateToParts"); + } + + void visit(const ExpressionDateToString* expr) override final { + unsupportedExpression("$dateToString"); + } + + void visit(const ExpressionDateTrunc* expr) override final { + unsupportedExpression("$dateTrunc"); + } + + void visit(const ExpressionDivide* expr) override final { + pushArithmeticBinaryExpr(expr, Operations::Div); + } + + void visit(const ExpressionExp* expr) override final { + pushSingleArgFunctionFromTop("exp"); + } + + void visit(const ExpressionFieldPath* expr) override final { + const auto& varId = expr->getVariableId(); + if (Variables::isUserDefinedVariable(varId)) { + _ctx.push<Variable>(generateVariableName(varId)); + return; + } + + const FieldPath& fieldPath = expr->getFieldPath(); + const size_t pathLength = fieldPath.getPathLength(); + if (pathLength < 1) { + return; + } + + const auto& firstFieldName = fieldPath.getFieldName(0); + if (pathLength == 1 && firstFieldName == "ROOT") { + _ctx.push<Variable>(_ctx.getRootProjection()); + return; + } + uassert(6624239, "Unexpected leading path element.", firstFieldName == "CURRENT"); + + // Here we skip over "CURRENT" first path element. This is represented by rootProjection + // variable. + ABT path = translateFieldPath( + fieldPath, + make<PathIdentity>(), + [](const std::string& fieldName, const bool isLastElement, ABT input) { + return make<PathGet>(fieldName, + isLastElement ? std::move(input) + : make<PathTraverse>(std::move(input))); + }, + 1ul); + + _ctx.push<EvalPath>(std::move(path), make<Variable>(_ctx.getRootProjection())); + } + + void visit(const ExpressionFilter* expr) override final { + const auto& varId = expr->getVariableId(); + uassert(6624427, + "Filter variable must be user-defined.", + Variables::isUserDefinedVariable(varId)); + const std::string& varName = generateVariableName(varId); + + _ctx.ensureArity(2); + ABT filter = _ctx.pop(); + ABT input = _ctx.pop(); + + _ctx.push<EvalPath>(make<PathTraverse>(make<PathLambda>(make<LambdaAbstraction>( + varName, + make<If>(generateCoerceToBoolInternal(std::move(filter)), + make<Variable>(varName), + Constant::nothing())))), + std::move(input)); + } + + void visit(const ExpressionFloor* expr) override final { + pushSingleArgFunctionFromTop("floor"); + } + + void visit(const ExpressionIfNull* expr) override final { + pushMultiArgFunctionFromTop("ifNull", 2); + } + + void visit(const ExpressionIn* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionIndexOfArray* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionIndexOfBytes* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionIndexOfCP* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionIsNumber* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionLet* expr) override final { + unsupportedExpression("$let"); + } + + void visit(const ExpressionLn* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionLog* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionLog10* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionMap* expr) override final { + unsupportedExpression("$map"); + } + + void visit(const ExpressionMeta* expr) override final { + unsupportedExpression("$meta"); + } + + void visit(const ExpressionMod* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionMultiply* expr) override final { + pushArithmeticBinaryExpr(expr, Operations::Mult); + } + + void visit(const ExpressionNot* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionObject* expr) override final { + const auto& expressions = expr->getChildExpressions(); + const size_t childCount = expressions.size(); + _ctx.ensureArity(childCount); + + // Need to process in reverse order because of the stack. + ABTVector children; + for (size_t i = 0; i < childCount; i++) { + children.emplace_back(_ctx.pop()); + } + + sbe::value::Object object; + for (size_t i = 0; i < childCount; i++) { + ABT& child = children.at(childCount - i - 1); + uassert( + 6624345, "Only constants are supported as object fields.", child.is<Constant>()); + + auto [tag, val] = child.cast<Constant>()->get(); + // Copy the value before inserting into the object + auto [tagCopy, valCopy] = sbe::value::copyValue(tag, val); + object.push_back(expressions.at(i).first, tagCopy, valCopy); + } + + auto [tag, val] = sbe::value::makeCopyObject(object); + _ctx.push<Constant>(tag, val); + } + + void visit(const ExpressionOr* expr) override final { + visitMultiBranchLogicExpression(expr, Operations::Or); + } + + void visit(const ExpressionPow* expr) override final { + unsupportedExpression("$pow"); + } + + void visit(const ExpressionRange* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionReduce* expr) override final { + unsupportedExpression("$reduce"); + } + + void visit(const ExpressionReplaceOne* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionReplaceAll* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionSetDifference* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionSetEquals* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionSetIntersection* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionSetIsSubset* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionSetUnion* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionSize* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionReverseArray* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionSortArray* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionSlice* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionIsArray* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionRound* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionSplit* expr) override final { + pushMultiArgFunctionFromTop("split", expr->getChildren().size()); + } + + void visit(const ExpressionSqrt* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionStrcasecmp* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionSubstrBytes* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionSubstrCP* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionStrLenBytes* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionBinarySize* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionStrLenCP* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionSubtract* expr) override final { + pushArithmeticBinaryExpr(expr, Operations::Sub); + } + + void visit(const ExpressionSwitch* expr) override final { + const size_t arity = expr->getChildren().size(); + _ctx.ensureArity(arity); + const size_t numCases = (arity - 1) / 2; + + ABTVector children; + for (size_t i = 0; i < numCases; i++) { + children.emplace_back(generateCoerceToBoolPopInput()); + children.emplace_back(_ctx.pop()); + } + + if (expr->getChildren().back() != nullptr) { + children.push_back(_ctx.pop()); + } + + _ctx.push<FunctionCall>("switch", std::move(children)); + } + + void visit(const ExpressionTestApiVersion* expr) override final { + unsupportedExpression("$_testApiVersion"); + } + + void visit(const ExpressionToLower* expr) override final { + pushSingleArgFunctionFromTop("toLower"); + } + + void visit(const ExpressionToUpper* expr) override final { + pushSingleArgFunctionFromTop("toUpper"); + } + + void visit(const ExpressionTrim* expr) override final { + unsupportedExpression("$trim"); + } + + void visit(const ExpressionTrunc* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionType* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionZip* expr) override final { + unsupportedExpression("$zip"); + } + + void visit(const ExpressionConvert* expr) override final { + unsupportedExpression("$convert"); + } + + void visit(const ExpressionRegexFind* expr) override final { + unsupportedExpression("$regexFind"); + } + + void visit(const ExpressionRegexFindAll* expr) override final { + unsupportedExpression("$regexFindAll"); + } + + void visit(const ExpressionRegexMatch* expr) override final { + unsupportedExpression("$regexMatch"); + } + + void visit(const ExpressionCosine* expr) override final { + pushSingleArgFunctionFromTop("cosine"); + } + + void visit(const ExpressionSine* expr) override final { + pushSingleArgFunctionFromTop("sine"); + } + + void visit(const ExpressionTangent* expr) override final { + pushSingleArgFunctionFromTop("tangent"); + } + + void visit(const ExpressionArcCosine* expr) override final { + pushSingleArgFunctionFromTop("arcCosine"); + } + + void visit(const ExpressionArcSine* expr) override final { + pushSingleArgFunctionFromTop("arcSine"); + } + + void visit(const ExpressionArcTangent* expr) override final { + pushSingleArgFunctionFromTop("arcTangent"); + } + + void visit(const ExpressionArcTangent2* expr) override final { + pushSingleArgFunctionFromTop("arcTangent2"); + } + + void visit(const ExpressionHyperbolicArcTangent* expr) override final { + pushSingleArgFunctionFromTop("arcTangentH"); + } + + void visit(const ExpressionHyperbolicArcCosine* expr) override final { + pushSingleArgFunctionFromTop("arcCosineH"); + } + + void visit(const ExpressionHyperbolicArcSine* expr) override final { + pushSingleArgFunctionFromTop("arcSineH"); + } + + void visit(const ExpressionHyperbolicTangent* expr) override final { + pushSingleArgFunctionFromTop("tangentH"); + } + + void visit(const ExpressionHyperbolicCosine* expr) override final { + pushSingleArgFunctionFromTop("cosineH"); + } + + void visit(const ExpressionHyperbolicSine* expr) override final { + pushSingleArgFunctionFromTop("sineH"); + } + + void visit(const ExpressionDegreesToRadians* expr) override final { + pushSingleArgFunctionFromTop("degreesToRadians"); + } + + void visit(const ExpressionRadiansToDegrees* expr) override final { + pushSingleArgFunctionFromTop("radiansToDegrees"); + } + + void visit(const ExpressionDayOfMonth* expr) override final { + unsupportedExpression("$dayOfMonth"); + } + + void visit(const ExpressionDayOfWeek* expr) override final { + unsupportedExpression("$dayOfWeek"); + } + + void visit(const ExpressionDayOfYear* expr) override final { + unsupportedExpression("$dayOfYear"); + } + + void visit(const ExpressionHour* expr) override final { + unsupportedExpression("$hour"); + } + + void visit(const ExpressionMillisecond* expr) override final { + unsupportedExpression("$millisecond"); + } + + void visit(const ExpressionMinute* expr) override final { + unsupportedExpression("$minute"); + } + + void visit(const ExpressionMonth* expr) override final { + unsupportedExpression("$month"); + } + + void visit(const ExpressionSecond* expr) override final { + unsupportedExpression("$second"); + } + + void visit(const ExpressionWeek* expr) override final { + unsupportedExpression("$week"); + } + + void visit(const ExpressionIsoWeekYear* expr) override final { + unsupportedExpression("$isoWeekYear"); + } + + void visit(const ExpressionIsoDayOfWeek* expr) override final { + unsupportedExpression("$isoDayOfWeek"); + } + + void visit(const ExpressionIsoWeek* expr) override final { + unsupportedExpression("$isoWeek"); + } + + void visit(const ExpressionYear* expr) override final { + unsupportedExpression("$year"); + } + + void visit(const ExpressionFromAccumulator<AccumulatorAvg>* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionFromAccumulatorN<AccumulatorFirstN>* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionFromAccumulatorN<AccumulatorLastN>* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionFromAccumulator<AccumulatorMax>* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionFromAccumulator<AccumulatorMin>* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionFromAccumulatorN<AccumulatorMaxN>* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionFromAccumulatorN<AccumulatorMinN>* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionFromAccumulator<AccumulatorStdDevPop>* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionFromAccumulator<AccumulatorStdDevSamp>* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionFromAccumulator<AccumulatorSum>* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionFromAccumulator<AccumulatorMergeObjects>* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionTests::Testable* expr) override final { + unsupportedExpression("$test"); + } + + void visit(const ExpressionInternalJsEmit* expr) override final { + unsupportedExpression("$internalJsEmit"); + } + + void visit(const ExpressionInternalFindSlice* expr) override final { + unsupportedExpression("$internalFindSlice"); + } + + void visit(const ExpressionInternalFindPositional* expr) override final { + unsupportedExpression("$internalFindPositional"); + } + + void visit(const ExpressionInternalFindElemMatch* expr) override final { + unsupportedExpression("$internalFindElemMatch"); + } + + void visit(const ExpressionFunction* expr) override final { + unsupportedExpression("$function"); + } + + void visit(const ExpressionRandom* expr) override final { + unsupportedExpression(expr->getOpName()); + } + + void visit(const ExpressionToHashedIndexKey* expr) override final { + unsupportedExpression("$toHashedIndexKey"); + } + + void visit(const ExpressionDateAdd* expr) override final { + unsupportedExpression("dateAdd"); + } + + void visit(const ExpressionDateSubtract* expr) override final { + unsupportedExpression("dateSubtract"); + } + + void visit(const ExpressionSetField* expr) override final { + unsupportedExpression("$setField"); + } + + void visit(const ExpressionGetField* expr) override final { + unsupportedExpression("$getField"); + } + + void visit(const ExpressionTsSecond* expr) override final { + unsupportedExpression("tsSecond"); + } + + void visit(const ExpressionTsIncrement* expr) override final { + unsupportedExpression("tsIncrement"); + } + +private: + /** + * Shared logic for $and, $or. Converts each child into an EExpression that evaluates to Boolean + * true or false, based on MQL rules for $and and $or branches, and then chains the branches + * together using binary and/or EExpressions so that the result has MQL's short-circuit + * semantics. + */ + void visitMultiBranchLogicExpression(const Expression* expr, Operations logicOp) { + invariant(logicOp == Operations::And || logicOp == Operations::Or); + const size_t arity = expr->getChildren().size(); + _ctx.ensureArity(arity); + + if (arity == 0) { + // Empty $and and $or always evaluate to their logical operator's identity value: true + // and false, respectively. + const bool logicIdentityVal = (logicOp == Operations::And); + _ctx.push<Constant>(sbe::value::TypeTags::Boolean, + sbe::value::bitcastFrom<bool>(logicIdentityVal)); + return; + } + + ABT current = generateCoerceToBoolPopInput(); + for (size_t i = 0; i < arity - 1; i++) { + current = make<BinaryOp>(logicOp, std::move(current), generateCoerceToBoolPopInput()); + } + _ctx.push(std::move(current)); + } + + void pushMultiArgFunctionFromTop(const std::string& functionName, const size_t argCount) { + _ctx.ensureArity(argCount); + + ABTVector children; + for (size_t i = 0; i < argCount; i++) { + children.emplace_back(_ctx.pop()); + } + _ctx.push<FunctionCall>(functionName, children); + } + + void pushSingleArgFunctionFromTop(const std::string& functionName) { + pushMultiArgFunctionFromTop(functionName, 1); + } + + void pushArithmeticBinaryExpr(const Expression* expr, const Operations op) { + const size_t arity = expr->getChildren().size(); + _ctx.ensureArity(arity); + if (arity < 2) { + // Nothing to do for arity 0 and 1. + return; + } + + ABT current = _ctx.pop(); + for (size_t i = 0; i < arity - 1; i++) { + current = make<BinaryOp>(op, std::move(current), _ctx.pop()); + } + _ctx.push(std::move(current)); + } + + ABT generateCoerceToBoolInternal(ABT input) { + return generateCoerceToBool(std::move(input), getNextId("coerceToBool")); + } + + ABT generateCoerceToBoolPopInput() { + return generateCoerceToBoolInternal(_ctx.pop()); + } + + std::string generateVariableName(const Variables::Id varId) { + std::ostringstream os; + os << _ctx.getUniqueIdPrefix() << "_var_" << varId; + return os.str(); + } + + std::string getNextId(const std::string& key) { + return _ctx.getUniqueIdPrefix() + "_" + _prefixId.getNextId(key); + } + + void unsupportedExpression(const char* op) const { + uasserted(ErrorCodes::InternalErrorNotSupported, + str::stream() << "Expression is not supported: " << op); + } + + PrefixId _prefixId; + + // We don't own this. + ExpressionAlgebrizerContext& _ctx; +}; + +class AggExpressionWalker final { +public: + AggExpressionWalker(ABTAggExpressionVisitor* visitor) : _visitor{visitor} {} + + void postVisit(const Expression* expr) { + expr->acceptVisitor(_visitor); + } + +private: + ABTAggExpressionVisitor* _visitor; +}; + +ABT generateAggExpression(const Expression* expr, + const std::string& rootProjection, + const std::string& uniqueIdPrefix) { + ExpressionAlgebrizerContext ctx( + true /*assertExprSort*/, false /*assertPathSort*/, rootProjection, uniqueIdPrefix); + ABTAggExpressionVisitor visitor(ctx); + + AggExpressionWalker walker(&visitor); + expression_walker::walk<const Expression>(expr, &walker); + return ctx.pop(); +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/pipeline/abt/agg_expression_visitor.h b/src/mongo/db/pipeline/abt/agg_expression_visitor.h new file mode 100644 index 00000000000..b0300dceb63 --- /dev/null +++ b/src/mongo/db/pipeline/abt/agg_expression_visitor.h @@ -0,0 +1,42 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/operation_context.h" +#include "mongo/db/pipeline/expression.h" +#include "mongo/db/query/optimizer/node.h" + +namespace mongo::optimizer { + +ABT generateAggExpression(const Expression* expr, + const std::string& rootProjection, + const std::string& uniqueIdPrefix); + +} // namespace mongo::optimizer diff --git a/src/mongo/db/pipeline/abt/expr_algebrizer_context.cpp b/src/mongo/db/pipeline/abt/expr_algebrizer_context.cpp new file mode 100644 index 00000000000..af838a74b06 --- /dev/null +++ b/src/mongo/db/pipeline/abt/expr_algebrizer_context.cpp @@ -0,0 +1,73 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/pipeline/abt/expr_algebrizer_context.h" + +namespace mongo::optimizer { + +ExpressionAlgebrizerContext::ExpressionAlgebrizerContext(const bool assertExprSort, + const bool assertPathSort, + const std::string& rootProjection, + const std::string& uniqueIdPrefix) + : _assertExprSort(assertExprSort), + _assertPathSort(assertPathSort), + _rootProjection(rootProjection), + _uniqueIdPrefix(uniqueIdPrefix) {} + +void ExpressionAlgebrizerContext::push(ABT node) { + if (_assertExprSort) { + assertExprSort(node); + } else if (_assertPathSort) { + assertPathSort(node); + } + + _stack.emplace(node); +} + +ABT ExpressionAlgebrizerContext::pop() { + uassert(6624428, "Arity violation", !_stack.empty()); + + ABT node = _stack.top(); + _stack.pop(); + return node; +} + +void ExpressionAlgebrizerContext::ensureArity(const size_t arity) { + uassert(6624429, "Arity violation", _stack.size() >= arity); +} + +const std::string& ExpressionAlgebrizerContext::getRootProjection() const { + return _rootProjection; +} + +const std::string& ExpressionAlgebrizerContext::getUniqueIdPrefix() const { + return _uniqueIdPrefix; +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/pipeline/abt/expr_algebrizer_context.h b/src/mongo/db/pipeline/abt/expr_algebrizer_context.h new file mode 100644 index 00000000000..7a04da91da1 --- /dev/null +++ b/src/mongo/db/pipeline/abt/expr_algebrizer_context.h @@ -0,0 +1,70 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <stack> + +#include "mongo/db/query/optimizer/node.h" + +namespace mongo::optimizer { + +class ExpressionAlgebrizerContext { +public: + ExpressionAlgebrizerContext(bool assertExprSort, + bool assertPathSort, + const std::string& rootProjection, + const std::string& uniqueIdPrefix); + + template <typename T, typename... Args> + inline auto push(Args&&... args) { + push(std::move(ABT::make<T>(std::forward<Args>(args)...))); + } + + void push(ABT node); + + ABT pop(); + + void ensureArity(size_t arity); + + const std::string& getRootProjection() const; + + const std::string& getUniqueIdPrefix() const; + +private: + const bool _assertExprSort; + const bool _assertPathSort; + + const std::string _rootProjection; + const std::string _uniqueIdPrefix; + + std::stack<ABT> _stack; +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/pipeline/abt/match_expression_visitor.cpp b/src/mongo/db/pipeline/abt/match_expression_visitor.cpp new file mode 100644 index 00000000000..445c2e719ad --- /dev/null +++ b/src/mongo/db/pipeline/abt/match_expression_visitor.cpp @@ -0,0 +1,479 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/pipeline/abt/match_expression_visitor.h" +#include "mongo/db/matcher/expression_always_boolean.h" +#include "mongo/db/matcher/expression_array.h" +#include "mongo/db/matcher/expression_expr.h" +#include "mongo/db/matcher/expression_geo.h" +#include "mongo/db/matcher/expression_internal_bucket_geo_within.h" +#include "mongo/db/matcher/expression_internal_expr_comparison.h" +#include "mongo/db/matcher/expression_leaf.h" +#include "mongo/db/matcher/expression_text.h" +#include "mongo/db/matcher/expression_text_noop.h" +#include "mongo/db/matcher/expression_tree.h" +#include "mongo/db/matcher/expression_type.h" +#include "mongo/db/matcher/expression_visitor.h" +#include "mongo/db/matcher/expression_where.h" +#include "mongo/db/matcher/expression_where_noop.h" +#include "mongo/db/matcher/match_expression_walker.h" +#include "mongo/db/matcher/schema/expression_internal_schema_all_elem_match_from_index.h" +#include "mongo/db/matcher/schema/expression_internal_schema_allowed_properties.h" +#include "mongo/db/matcher/schema/expression_internal_schema_cond.h" +#include "mongo/db/matcher/schema/expression_internal_schema_eq.h" +#include "mongo/db/matcher/schema/expression_internal_schema_fmod.h" +#include "mongo/db/matcher/schema/expression_internal_schema_match_array_index.h" +#include "mongo/db/matcher/schema/expression_internal_schema_max_items.h" +#include "mongo/db/matcher/schema/expression_internal_schema_max_length.h" +#include "mongo/db/matcher/schema/expression_internal_schema_max_properties.h" +#include "mongo/db/matcher/schema/expression_internal_schema_min_items.h" +#include "mongo/db/matcher/schema/expression_internal_schema_min_length.h" +#include "mongo/db/matcher/schema/expression_internal_schema_min_properties.h" +#include "mongo/db/matcher/schema/expression_internal_schema_object_match.h" +#include "mongo/db/matcher/schema/expression_internal_schema_root_doc_eq.h" +#include "mongo/db/matcher/schema/expression_internal_schema_unique_items.h" +#include "mongo/db/matcher/schema/expression_internal_schema_xor.h" +#include "mongo/db/pipeline/abt/agg_expression_visitor.h" +#include "mongo/db/pipeline/abt/expr_algebrizer_context.h" +#include "mongo/db/pipeline/abt/utils.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer { + +class ABTMatchExpressionVisitor : public MatchExpressionConstVisitor { +public: + ABTMatchExpressionVisitor(ExpressionAlgebrizerContext& ctx, const bool allowAggExpressions) + : _prefixId(), _allowAggExpressions(allowAggExpressions), _ctx(ctx) {} + + void visit(const AlwaysFalseMatchExpression* expr) override { + generateBoolConstant(false); + } + + void visit(const AlwaysTrueMatchExpression* expr) override { + generateBoolConstant(true); + } + + void visit(const AndMatchExpression* expr) override { + visitAndOrExpression<PathComposeM, true>(expr); + } + + void visit(const BitsAllClearMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const BitsAllSetMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const BitsAnyClearMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const BitsAnySetMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const ElemMatchObjectMatchExpression* expr) override { + generateElemMatch<false /*isValueElemMatch*/>(expr); + } + + void visit(const ElemMatchValueMatchExpression* expr) override { + generateElemMatch<true /*isValueElemMatch*/>(expr); + } + + void visit(const EqualityMatchExpression* expr) override { + generateSimpleComparison(expr, Operations::Eq); + } + + void visit(const ExistsMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const ExprMatchExpression* expr) override { + uassert(6624246, "Cannot generate an agg expression in this context", _allowAggExpressions); + + ABT result = generateAggExpression( + expr->getExpression().get(), _ctx.getRootProjection(), _ctx.getUniqueIdPrefix()); + _ctx.push<PathConstant>(generateCoerceToBool(std::move(result), getNextId("coerceToBool"))); + } + + void visit(const GTEMatchExpression* expr) override { + generateSimpleComparison(expr, Operations::Gte); + } + + void visit(const GTMatchExpression* expr) override { + generateSimpleComparison(expr, Operations::Gt); + } + + void visit(const GeoMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const GeoNearMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalBucketGeoWithinMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalExprEqMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalExprGTMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalExprGTEMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalExprLTMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalExprLTEMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaAllElemMatchFromIndexMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaAllowedPropertiesMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaBinDataEncryptedTypeExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaBinDataSubTypeExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaCondMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaEqMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaFmodMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaMatchArrayIndexMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaMaxItemsMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaMaxLengthMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaMaxPropertiesMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaMinItemsMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaMinLengthMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaMinPropertiesMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaObjectMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaRootDocEqMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaTypeExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaUniqueItemsMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const InternalSchemaXorMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const LTEMatchExpression* expr) override { + generateSimpleComparison(expr, Operations::Lte); + } + + void visit(const LTMatchExpression* expr) override { + generateSimpleComparison(expr, Operations::Lt); + } + + void visit(const ModMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const NorMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const NotMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const OrMatchExpression* expr) override { + visitAndOrExpression<PathComposeA, false>(expr); + } + + void visit(const RegexMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const SizeMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const TextMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const TextNoOpMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const TwoDPtInAnnulusExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const TypeMatchExpression* expr) override { + const std::string lambdaProjName = _prefixId.getNextId("lambda_typeMatch"); + ABT result = make<PathLambda>(make<LambdaAbstraction>( + lambdaProjName, + make<FunctionCall>("typeMatch", + makeSeq(make<Variable>(lambdaProjName), + Constant::int64(expr->typeSet().getBSONTypeMask()))))); + + if (!expr->path().empty()) { + result = generateFieldPath(FieldPath(expr->path().toString()), std::move(result)); + } + _ctx.push(std::move(result)); + } + + void visit(const WhereMatchExpression* expr) override { + unsupportedExpression(expr); + } + + void visit(const WhereNoOpMatchExpression* expr) override { + unsupportedExpression(expr); + } + +private: + void generateBoolConstant(const bool value) { + _ctx.push<PathConstant>(Constant::boolean(value)); + } + + template <bool isValueElemMatch> + void generateElemMatch(const ArrayMatchingMatchExpression* expr) { + // Returns true if at least one sub-objects matches the condition. + + const size_t childCount = expr->numChildren(); + if (childCount == 0) { + _ctx.push(Constant::boolean(true)); + } + + _ctx.ensureArity(childCount); + ABT result = _ctx.pop(); + for (size_t i = 1; i < childCount; i++) { + maybeComposePath(result, _ctx.pop()); + } + if constexpr (!isValueElemMatch) { + // Make sure we consider only objects as elements of the array. + maybeComposePath(result, make<PathObj>()); + } + result = make<PathTraverse>(std::move(result)); + + // Make sure we consider only arrays fields on the path. + maybeComposePath(result, make<PathArr>()); + + if (!expr->path().empty()) { + result = translateFieldPath( + FieldPath{expr->path().toString()}, + std::move(result), + [&](const std::string& fieldName, const bool isLastElement, ABT input) { + if (!isLastElement) { + input = make<PathTraverse>(std::move(input)); + } + return make<PathGet>(fieldName, std::move(input)); + }); + } + + _ctx.push(std::move(result)); + } + + /** + * Return the minimum or maximum value for the "class" of values represented by the input + * constant. Used to support type bracketing. + */ + template <bool isMin> + std::pair<boost::optional<ABT>, bool> getMinMaxBoundForType(const sbe::value::TypeTags& tag) { + if (isNumber(tag)) { + if constexpr (isMin) { + return {Constant::fromDouble(std::numeric_limits<double>::quiet_NaN()), true}; + } else { + return {Constant::str(""), false}; + }; + } else if (isStringOrSymbol(tag)) { + if constexpr (isMin) { + return {Constant::str(""), true}; + } else { + // TODO: we need limit string from above. + return {boost::none, false}; + } + } else if (tag == sbe::value::TypeTags::Null) { + // Same bound above and below. + return {Constant::null(), true}; + } else { + // TODO: compute bounds for other types based on bsonobjbuilder.cpp. + return {boost::none, false}; + } + + MONGO_UNREACHABLE; + } + + ABT generateFieldPath(const FieldPath& fieldPath, ABT initial) { + return translateFieldPath( + fieldPath, + std::move(initial), + [](const std::string& fieldName, const bool /*isLastElement*/, ABT input) { + return make<PathGet>(fieldName, make<PathTraverse>(std::move(input))); + }); + } + + void generateSimpleComparison(const ComparisonMatchExpressionBase* expr, const Operations op) { + auto [tag, val] = convertFrom(Value(expr->getData())); + ABT result = make<PathCompare>(op, make<Constant>(tag, val)); + + switch (op) { + case Operations::Lt: + case Operations::Lte: { + auto&& [constant, inclusive] = getMinMaxBoundForType<true /*isMin*/>(tag); + if (constant) { + maybeComposePath(result, + make<PathCompare>(inclusive ? Operations::Gte : Operations::Gt, + std::move(constant.get()))); + } + break; + } + + case Operations::Gt: + case Operations::Gte: { + auto&& [constant, inclusive] = getMinMaxBoundForType<false /*isMin*/>(tag); + if (constant) { + maybeComposePath(result, + make<PathCompare>(inclusive ? Operations::Lte : Operations::Lt, + std::move(constant.get()))); + } + break; + } + + default: + break; + } + + if (!expr->path().empty()) { + result = generateFieldPath(FieldPath(expr->path().toString()), std::move(result)); + } + _ctx.push(std::move(result)); + } + + template <class Composition, bool defaultResult> + void visitAndOrExpression(const ListOfMatchExpression* expr) { + const size_t childCount = expr->numChildren(); + if (childCount == 0) { + generateBoolConstant(defaultResult); + return; + } + if (childCount == 1) { + return; + } + + ABT node = _ctx.pop(); + for (size_t i = 0; i < childCount - 1; i++) { + node = make<Composition>(_ctx.pop(), std::move(node)); + } + _ctx.push(std::move(node)); + } + + std::string getNextId(const std::string& key) { + return _ctx.getUniqueIdPrefix() + "_" + _prefixId.getNextId(key); + } + + void unsupportedExpression(const MatchExpression* expr) const { + uasserted(ErrorCodes::InternalErrorNotSupported, + str::stream() << "Match expression is not supported: " << expr->matchType()); + } + + PrefixId _prefixId; + + // If we are parsing a partial index filter, we don't allow match expressions. + const bool _allowAggExpressions; + + // We don't own this + ExpressionAlgebrizerContext& _ctx; +}; + +ABT generateMatchExpression(const MatchExpression* expr, + const bool allowAggExpressions, + const std::string& rootProjection, + const std::string& uniqueIdPrefix) { + ExpressionAlgebrizerContext ctx( + false /*assertExprSort*/, true /*assertPathSort*/, rootProjection, uniqueIdPrefix); + ABTMatchExpressionVisitor visitor(ctx, allowAggExpressions); + MatchExpressionWalker walker(nullptr /*preVisitor*/, nullptr /*inVisitor*/, &visitor); + tree_walker::walk<true, MatchExpression>(expr, &walker); + return ctx.pop(); +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/pipeline/abt/match_expression_visitor.h b/src/mongo/db/pipeline/abt/match_expression_visitor.h new file mode 100644 index 00000000000..d971ffd5ec6 --- /dev/null +++ b/src/mongo/db/pipeline/abt/match_expression_visitor.h @@ -0,0 +1,45 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/matcher/expression.h" +#include "mongo/db/query/optimizer/node.h" + +namespace mongo::optimizer { + +/** + * Returns a path encoding the match expression. + */ +ABT generateMatchExpression(const MatchExpression* expr, + bool allowAggExpressions, + const std::string& rootProjection, + const std::string& uniqueIdPrefix); + +} // namespace mongo::optimizer diff --git a/src/mongo/db/pipeline/abt/pipeline_test.cpp b/src/mongo/db/pipeline/abt/pipeline_test.cpp new file mode 100644 index 00000000000..36dc2235db8 --- /dev/null +++ b/src/mongo/db/pipeline/abt/pipeline_test.cpp @@ -0,0 +1,2508 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <boost/intrusive_ptr.hpp> + +#include "mongo/db/pipeline/aggregate_command_gen.h" +#include "mongo/db/pipeline/pipeline.h" +#include "mongo/db/query/optimizer/explain.h" +#include "mongo/db/query/optimizer/opt_phase_manager.h" +#include "mongo/db/query/optimizer/utils/unit_test_utils.h" +#include "mongo/unittest/unittest.h" + +namespace mongo { +namespace { + +using namespace optimizer; + +TEST(ABTTranslate, SortLimitSkip) { + ABT translated = translatePipeline( + "[{$limit: 5}, " + "{$skip: 3}, " + "{$sort: {a: 1, b: -1}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Collation []\n" + "| | collation: \n" + "| | sort_0: Ascending\n" + "| | sort_1: Descending\n" + "| RefBlock: \n" + "| Variable [sort_0]\n" + "| Variable [sort_1]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [sort_1]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [b]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [sort_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathIdentity []\n" + "LimitSkip []\n" + "| limitSkip:\n" + "| limit: (none)\n" + "| skip: 3\n" + "LimitSkip []\n" + "| limitSkip:\n" + "| limit: 5\n" + "| skip: 0\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, ProjectRetain) { + PrefixId prefixId; + std::string scanDefName = "collection"; + Metadata metadata = {{{scanDefName, ScanDefinition{{}, {}}}}}; + ABT translated = translatePipeline( + metadata, "[{$project: {a: 1, b: 1}}, {$match: {a: 2}}]", scanDefName, prefixId); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre, + OptPhaseManager::OptPhase::PathFuse, + OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); + + ABT optimized = translated; + ASSERT_TRUE(phaseManager.optimize(optimized)); + + // Observe the Filter can be reordered against the Eval node. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Const [{}]\n" + "| PathComposeM []\n" + "| | PathField [b]\n" + "| | PathConstant []\n" + "| | Variable [fieldProj_2]\n" + "| PathComposeM []\n" + "| | PathField [a]\n" + "| | PathConstant []\n" + "| | Variable [fieldProj_1]\n" + "| PathField [_id]\n" + "| PathConstant []\n" + "| Variable [fieldProj_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [fieldProj_1]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "PhysicalScan [{'_id': fieldProj_0, 'a': fieldProj_1, 'b': fieldProj_2}, collection]\n" + " BindBlock:\n" + " [fieldProj_0]\n" + " Source []\n" + " [fieldProj_1]\n" + " Source []\n" + " [fieldProj_2]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, ProjectRetain1) { + ABT translated = translatePipeline("[{$project: {a1: 1, a2: 1, a3: 1, a4: 1, a5: 1, a6: 1}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathKeep [_id, a1, a2, a3, a4, a5, a6]\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, AddFields) { + // Since '$z' is a single element, it will be considered a renamed path. + ABT translated = translatePipeline("[{$addFields: {a: '$z'}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathComposeM []\n" + "| | PathDefault []\n" + "| | Const [{}]\n" + "| PathField [a]\n" + "| PathConstant []\n" + "| Variable [projRenamedPath_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [projRenamedPath_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [z]\n" + "| PathIdentity []\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, ProjectRenames) { + // Since '$c' is a single element, it will be considered a renamed path. + ABT translated = translatePipeline("[{$project: {'a.b': '$c'}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathComposeM []\n" + "| | PathField [a]\n" + "| | PathTraverse []\n" + "| | PathComposeM []\n" + "| | | PathDefault []\n" + "| | | Const [{}]\n" + "| | PathComposeM []\n" + "| | | PathField [b]\n" + "| | | PathConstant []\n" + "| | | Variable [projRenamedPath_0]\n" + "| | PathKeep [b]\n" + "| PathKeep [_id, a]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [projRenamedPath_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [c]\n" + "| PathIdentity []\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, ProjectPaths) { + ABT translated = translatePipeline("[{$project: {'a.b.c': '$x.y.z'}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathComposeM []\n" + "| | PathField [a]\n" + "| | PathTraverse []\n" + "| | PathComposeM []\n" + "| | | PathField [b]\n" + "| | | PathTraverse []\n" + "| | | PathComposeM []\n" + "| | | | PathDefault []\n" + "| | | | Const [{}]\n" + "| | | PathComposeM []\n" + "| | | | PathField [c]\n" + "| | | | PathConstant []\n" + "| | | | Variable [projGetPath_0]\n" + "| | | PathKeep [c]\n" + "| | PathKeep [b]\n" + "| PathKeep [_id, a]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [projGetPath_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [x]\n" + "| PathTraverse []\n" + "| PathGet [y]\n" + "| PathTraverse []\n" + "| PathGet [z]\n" + "| PathIdentity []\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, ProjectPaths1) { + ABT translated = translatePipeline("[{$project: {'a.b':1, 'a.c':1, 'b':1}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathComposeM []\n" + "| | PathField [a]\n" + "| | PathTraverse []\n" + "| | PathComposeM []\n" + "| | | PathKeep [b, c]\n" + "| | PathObj []\n" + "| PathKeep [_id, a, b]\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, ProjectInclusion) { + ABT translated = translatePipeline("[{$project: {a: {$add: ['$c.d', 2]}, b: 1}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathComposeM []\n" + "| | PathDefault []\n" + "| | Const [{}]\n" + "| PathComposeM []\n" + "| | PathField [a]\n" + "| | PathConstant []\n" + "| | Variable [projGetPath_0]\n" + "| PathKeep [_id, a, b]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [projGetPath_0]\n" + "| BinaryOp [Add]\n" + "| | EvalPath []\n" + "| | | Variable [scan_0]\n" + "| | PathGet [c]\n" + "| | PathTraverse []\n" + "| | PathGet [d]\n" + "| | PathIdentity []\n" + "| Const [2]\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, ProjectExclusion) { + ABT translated = translatePipeline("[{$project: {a: 0, b: 0}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathDrop [a, b]\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, ProjectReplaceRoot) { + ABT translated = translatePipeline("[{$replaceRoot: {newRoot: '$a'}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | newRoot_0\n" + "| RefBlock: \n" + "| Variable [newRoot_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [newRoot_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathIdentity []\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, MatchBasic) { + PrefixId prefixId; + std::string scanDefName = "collection"; + + ABT translated = translatePipeline("[{$match: {a: 1, b: 2}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [scan_0]\n" + "| PathGet [b]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{scanDefName, ScanDefinition{{}, {}}}}}, + DebugInfo::kDefaultForTests); + + ABT optimized = translated; + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_1]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_0]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "PhysicalScan [{'<root>': scan_0, 'a': evalTemp_0, 'b': evalTemp_1}, collection]\n" + " BindBlock:\n" + " [evalTemp_0]\n" + " Source []\n" + " [evalTemp_1]\n" + " Source []\n" + " [scan_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, MatchPath) { + ABT translated = translatePipeline("[{$match: {$expr: {$eq: ['$a.b', 1]}}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [scan_0]\n" + "| PathConstant []\n" + "| Let [matchExpression_0_coerceToBool_0]\n" + "| | BinaryOp [And]\n" + "| | | BinaryOp [And]\n" + "| | | | BinaryOp [And]\n" + "| | | | | FunctionCall [exists]\n" + "| | | | | BinaryOp [Eq]\n" + "| | | | | | Const [1]\n" + "| | | | | EvalPath []\n" + "| | | | | | Variable [scan_0]\n" + "| | | | | PathGet [a]\n" + "| | | | | PathTraverse []\n" + "| | | | | PathGet [b]\n" + "| | | | | PathIdentity []\n" + "| | | | UnaryOp [Not]\n" + "| | | | FunctionCall [isNull]\n" + "| | | | BinaryOp [Eq]\n" + "| | | | | Const [1]\n" + "| | | | EvalPath []\n" + "| | | | | Variable [scan_0]\n" + "| | | | PathGet [a]\n" + "| | | | PathTraverse []\n" + "| | | | PathGet [b]\n" + "| | | | PathIdentity []\n" + "| | | BinaryOp [Neq]\n" + "| | | | Const [0]\n" + "| | | BinaryOp [Cmp3w]\n" + "| | | | Const [false]\n" + "| | | Variable [matchExpression_0_coerceToBool_0]\n" + "| | BinaryOp [Neq]\n" + "| | | Const [0]\n" + "| | BinaryOp [Cmp3w]\n" + "| | | Const [0]\n" + "| | Variable [matchExpression_0_coerceToBool_0]\n" + "| BinaryOp [Eq]\n" + "| | Const [1]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathGet [b]\n" + "| PathIdentity []\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, ElemMatchPath) { + ABT translated = translatePipeline( + "[{$project: {a: {$literal: [1, 2, 3, 4]}}}, {$match: {a: {$elemMatch: {$gte: 2, $lte: " + "3}}}}]"); + + // Observe type bracketing in the filter. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [combinedProjection_0]\n" + "| PathGet [a]\n" + "| PathComposeM []\n" + "| | PathArr []\n" + "| PathTraverse []\n" + "| PathComposeM []\n" + "| | PathComposeM []\n" + "| | | PathCompare [Lt]\n" + "| | | Const [\"\"]\n" + "| | PathCompare [Gte]\n" + "| | Const [2]\n" + "| PathComposeM []\n" + "| | PathCompare [Gte]\n" + "| | Const [nan]\n" + "| PathCompare [Lte]\n" + "| Const [3]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathComposeM []\n" + "| | PathDefault []\n" + "| | Const [{}]\n" + "| PathComposeM []\n" + "| | PathField [a]\n" + "| | PathConstant []\n" + "| | Variable [projGetPath_0]\n" + "| PathKeep [_id, a]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [projGetPath_0]\n" + "| Const [[1, 2, 3, 4]]\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, MatchProject) { + ABT translated = translatePipeline( + "[{$project: {s: {$add: ['$a', '$b']}, c: 1}}, " + "{$match: {$or: [{c: 2}, {s: {$gte: 10}}]}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [combinedProjection_0]\n" + "| PathComposeA []\n" + "| | PathGet [s]\n" + "| | PathTraverse []\n" + "| | PathComposeM []\n" + "| | | PathCompare [Lt]\n" + "| | | Const [\"\"]\n" + "| | PathCompare [Gte]\n" + "| | Const [10]\n" + "| PathGet [c]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathComposeM []\n" + "| | PathDefault []\n" + "| | Const [{}]\n" + "| PathComposeM []\n" + "| | PathField [s]\n" + "| | PathConstant []\n" + "| | Variable [projGetPath_0]\n" + "| PathKeep [_id, c, s]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [projGetPath_0]\n" + "| BinaryOp [Add]\n" + "| | EvalPath []\n" + "| | | Variable [scan_0]\n" + "| | PathGet [a]\n" + "| | PathIdentity []\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [b]\n" + "| PathIdentity []\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, ProjectComplex) { + ABT translated = translatePipeline("[{$project: {'a1.b.c':1, 'a.b.c.d.e':'str'}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathComposeM []\n" + "| | PathField [a1]\n" + "| | PathTraverse []\n" + "| | PathComposeM []\n" + "| | | PathField [b]\n" + "| | | PathTraverse []\n" + "| | | PathComposeM []\n" + "| | | | PathKeep [c]\n" + "| | | PathObj []\n" + "| | PathComposeM []\n" + "| | | PathKeep [b]\n" + "| | PathObj []\n" + "| PathComposeM []\n" + "| | PathField [a]\n" + "| | PathTraverse []\n" + "| | PathComposeM []\n" + "| | | PathField [b]\n" + "| | | PathTraverse []\n" + "| | | PathComposeM []\n" + "| | | | PathField [c]\n" + "| | | | PathTraverse []\n" + "| | | | PathComposeM []\n" + "| | | | | PathField [d]\n" + "| | | | | PathTraverse []\n" + "| | | | | PathComposeM []\n" + "| | | | | | PathDefault []\n" + "| | | | | | Const [{}]\n" + "| | | | | PathComposeM []\n" + "| | | | | | PathField [e]\n" + "| | | | | | PathConstant []\n" + "| | | | | | Variable [projGetPath_0]\n" + "| | | | | PathKeep [e]\n" + "| | | | PathKeep [d]\n" + "| | | PathKeep [c]\n" + "| | PathKeep [b]\n" + "| PathKeep [_id, a, a1]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [projGetPath_0]\n" + "| Const [\"str\"]\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, ExprFilter) { + ABT translated = translatePipeline( + "[{$project: {a: {$filter: {input: [1, 2, 'str', {a: 2.0, b:'s'}, 3, 4], as: 'num', cond: " + "{$and: [{$gte: ['$$num', 2]}, {$lte: ['$$num', 3]}]}}}}}]"); + + PrefixId prefixId; + std::string scanDefName = "collection"; + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre}, + prefixId, + {{{scanDefName, ScanDefinition{{}, {}}}}}, + DebugInfo::kDefaultForTests); + + ABT optimized = translated; + ASSERT_TRUE(phaseManager.optimize(optimized)); + + // Make sure we have a single array constant for (1, 2, 'str', ...). + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathComposeM []\n" + "| | PathDefault []\n" + "| | Const [{}]\n" + "| PathComposeM []\n" + "| | PathField [a]\n" + "| | PathConstant []\n" + "| | EvalPath []\n" + "| | | Const [[1, 2, \"str\", {\"a\" : 2, \"b\" : \"s\"}, 3, 4]]\n" + "| | PathTraverse []\n" + "| | PathLambda []\n" + "| | LambdaAbstraction [projGetPath_0_var_1]\n" + "| | If []\n" + "| | | | Const [Nothing]\n" + "| | | Variable [projGetPath_0_var_1]\n" + "| | Let [projGetPath_0_coerceToBool_2]\n" + "| | | BinaryOp [And]\n" + "| | | | BinaryOp [And]\n" + "| | | | | BinaryOp [And]\n" + "| | | | | | FunctionCall [exists]\n" + "| | | | | | BinaryOp [And]\n" + "| | | | | | | Let [projGetPath_0_coerceToBool_1]\n" + "| | | | | | | | BinaryOp [And]\n" + "| | | | | | | | | BinaryOp [And]\n" + "| | | | | | | | | | BinaryOp [And]\n" + "| | | | | | | | | | | FunctionCall [exists]\n" + "| | | | | | | | | | | BinaryOp [Gte]\n" + "| | | | | | | | | | | | Const [2]\n" + "| | | | | | | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | | | | | | UnaryOp [Not]\n" + "| | | | | | | | | | FunctionCall [isNull]\n" + "| | | | | | | | | | BinaryOp [Gte]\n" + "| | | | | | | | | | | Const [2]\n" + "| | | | | | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | | | | | BinaryOp [Neq]\n" + "| | | | | | | | | | Const [0]\n" + "| | | | | | | | | BinaryOp [Cmp3w]\n" + "| | | | | | | | | | Const [false]\n" + "| | | | | | | | | Variable [projGetPath_0_coerceToBool_1]\n" + "| | | | | | | | BinaryOp [Neq]\n" + "| | | | | | | | | Const [0]\n" + "| | | | | | | | BinaryOp [Cmp3w]\n" + "| | | | | | | | | Const [0]\n" + "| | | | | | | | Variable [projGetPath_0_coerceToBool_1]\n" + "| | | | | | | BinaryOp [Gte]\n" + "| | | | | | | | Const [2]\n" + "| | | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | | Let [projGetPath_0_coerceToBool_0]\n" + "| | | | | | | BinaryOp [And]\n" + "| | | | | | | | BinaryOp [And]\n" + "| | | | | | | | | BinaryOp [And]\n" + "| | | | | | | | | | FunctionCall [exists]\n" + "| | | | | | | | | | BinaryOp [Lte]\n" + "| | | | | | | | | | | Const [3]\n" + "| | | | | | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | | | | | UnaryOp [Not]\n" + "| | | | | | | | | FunctionCall [isNull]\n" + "| | | | | | | | | BinaryOp [Lte]\n" + "| | | | | | | | | | Const [3]\n" + "| | | | | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | | | | BinaryOp [Neq]\n" + "| | | | | | | | | Const [0]\n" + "| | | | | | | | BinaryOp [Cmp3w]\n" + "| | | | | | | | | Const [false]\n" + "| | | | | | | | Variable [projGetPath_0_coerceToBool_0]\n" + "| | | | | | | BinaryOp [Neq]\n" + "| | | | | | | | Const [0]\n" + "| | | | | | | BinaryOp [Cmp3w]\n" + "| | | | | | | | Const [0]\n" + "| | | | | | | Variable [projGetPath_0_coerceToBool_0]\n" + "| | | | | | BinaryOp [Lte]\n" + "| | | | | | | Const [3]\n" + "| | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | UnaryOp [Not]\n" + "| | | | | FunctionCall [isNull]\n" + "| | | | | BinaryOp [And]\n" + "| | | | | | Let [projGetPath_0_coerceToBool_1]\n" + "| | | | | | | BinaryOp [And]\n" + "| | | | | | | | BinaryOp [And]\n" + "| | | | | | | | | BinaryOp [And]\n" + "| | | | | | | | | | FunctionCall [exists]\n" + "| | | | | | | | | | BinaryOp [Gte]\n" + "| | | | | | | | | | | Const [2]\n" + "| | | | | | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | | | | | UnaryOp [Not]\n" + "| | | | | | | | | FunctionCall [isNull]\n" + "| | | | | | | | | BinaryOp [Gte]\n" + "| | | | | | | | | | Const [2]\n" + "| | | | | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | | | | BinaryOp [Neq]\n" + "| | | | | | | | | Const [0]\n" + "| | | | | | | | BinaryOp [Cmp3w]\n" + "| | | | | | | | | Const [false]\n" + "| | | | | | | | Variable [projGetPath_0_coerceToBool_1]\n" + "| | | | | | | BinaryOp [Neq]\n" + "| | | | | | | | Const [0]\n" + "| | | | | | | BinaryOp [Cmp3w]\n" + "| | | | | | | | Const [0]\n" + "| | | | | | | Variable [projGetPath_0_coerceToBool_1]\n" + "| | | | | | BinaryOp [Gte]\n" + "| | | | | | | Const [2]\n" + "| | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | Let [projGetPath_0_coerceToBool_0]\n" + "| | | | | | BinaryOp [And]\n" + "| | | | | | | BinaryOp [And]\n" + "| | | | | | | | BinaryOp [And]\n" + "| | | | | | | | | FunctionCall [exists]\n" + "| | | | | | | | | BinaryOp [Lte]\n" + "| | | | | | | | | | Const [3]\n" + "| | | | | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | | | | UnaryOp [Not]\n" + "| | | | | | | | FunctionCall [isNull]\n" + "| | | | | | | | BinaryOp [Lte]\n" + "| | | | | | | | | Const [3]\n" + "| | | | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | | | BinaryOp [Neq]\n" + "| | | | | | | | Const [0]\n" + "| | | | | | | BinaryOp [Cmp3w]\n" + "| | | | | | | | Const [false]\n" + "| | | | | | | Variable [projGetPath_0_coerceToBool_0]\n" + "| | | | | | BinaryOp [Neq]\n" + "| | | | | | | Const [0]\n" + "| | | | | | BinaryOp [Cmp3w]\n" + "| | | | | | | Const [0]\n" + "| | | | | | Variable [projGetPath_0_coerceToBool_0]\n" + "| | | | | BinaryOp [Lte]\n" + "| | | | | | Const [3]\n" + "| | | | | Variable [projGetPath_0_var_1]\n" + "| | | | BinaryOp [Neq]\n" + "| | | | | Const [0]\n" + "| | | | BinaryOp [Cmp3w]\n" + "| | | | | Const [false]\n" + "| | | | Variable [projGetPath_0_coerceToBool_2]\n" + "| | | BinaryOp [Neq]\n" + "| | | | Const [0]\n" + "| | | BinaryOp [Cmp3w]\n" + "| | | | Const [0]\n" + "| | | Variable [projGetPath_0_coerceToBool_2]\n" + "| | BinaryOp [And]\n" + "| | | Let [projGetPath_0_coerceToBool_1]\n" + "| | | | BinaryOp [And]\n" + "| | | | | BinaryOp [And]\n" + "| | | | | | BinaryOp [And]\n" + "| | | | | | | FunctionCall [exists]\n" + "| | | | | | | BinaryOp [Gte]\n" + "| | | | | | | | Const [2]\n" + "| | | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | | UnaryOp [Not]\n" + "| | | | | | FunctionCall [isNull]\n" + "| | | | | | BinaryOp [Gte]\n" + "| | | | | | | Const [2]\n" + "| | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | BinaryOp [Neq]\n" + "| | | | | | Const [0]\n" + "| | | | | BinaryOp [Cmp3w]\n" + "| | | | | | Const [false]\n" + "| | | | | Variable [projGetPath_0_coerceToBool_1]\n" + "| | | | BinaryOp [Neq]\n" + "| | | | | Const [0]\n" + "| | | | BinaryOp [Cmp3w]\n" + "| | | | | Const [0]\n" + "| | | | Variable [projGetPath_0_coerceToBool_1]\n" + "| | | BinaryOp [Gte]\n" + "| | | | Const [2]\n" + "| | | Variable [projGetPath_0_var_1]\n" + "| | Let [projGetPath_0_coerceToBool_0]\n" + "| | | BinaryOp [And]\n" + "| | | | BinaryOp [And]\n" + "| | | | | BinaryOp [And]\n" + "| | | | | | FunctionCall [exists]\n" + "| | | | | | BinaryOp [Lte]\n" + "| | | | | | | Const [3]\n" + "| | | | | | Variable [projGetPath_0_var_1]\n" + "| | | | | UnaryOp [Not]\n" + "| | | | | FunctionCall [isNull]\n" + "| | | | | BinaryOp [Lte]\n" + "| | | | | | Const [3]\n" + "| | | | | Variable [projGetPath_0_var_1]\n" + "| | | | BinaryOp [Neq]\n" + "| | | | | Const [0]\n" + "| | | | BinaryOp [Cmp3w]\n" + "| | | | | Const [false]\n" + "| | | | Variable [projGetPath_0_coerceToBool_0]\n" + "| | | BinaryOp [Neq]\n" + "| | | | Const [0]\n" + "| | | BinaryOp [Cmp3w]\n" + "| | | | Const [0]\n" + "| | | Variable [projGetPath_0_coerceToBool_0]\n" + "| | BinaryOp [Lte]\n" + "| | | Const [3]\n" + "| | Variable [projGetPath_0_var_1]\n" + "| PathKeep [_id, a]\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, GroupBasic) { + ABT translated = + translatePipeline("[{$group: {_id: '$a.b', s: {$sum: {$multiply: ['$b', '$c']}}}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | agg_project_0\n" + "| RefBlock: \n" + "| Variable [agg_project_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [agg_project_0]\n" + "| EvalPath []\n" + "| | Const [{}]\n" + "| PathComposeM []\n" + "| | PathField [s]\n" + "| | PathConstant []\n" + "| | Variable [s_agg_0]\n" + "| PathField [_id]\n" + "| PathConstant []\n" + "| Variable [groupByProj_0]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [groupByProj_0]\n" + "| aggregations: \n" + "| [s_agg_0]\n" + "| FunctionCall [$sum]\n" + "| Variable [groupByInputProj_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [groupByInputProj_0]\n" + "| BinaryOp [Mult]\n" + "| | EvalPath []\n" + "| | | Variable [scan_0]\n" + "| | PathGet [b]\n" + "| | PathIdentity []\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [c]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [groupByProj_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathGet [b]\n" + "| PathIdentity []\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, GroupLocalGlobal) { + ABT translated = translatePipeline("[{$group: {_id: '$a', c: {$sum: '$b'}}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | agg_project_0\n" + "| RefBlock: \n" + "| Variable [agg_project_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [agg_project_0]\n" + "| EvalPath []\n" + "| | Const [{}]\n" + "| PathComposeM []\n" + "| | PathField [c]\n" + "| | PathConstant []\n" + "| | Variable [c_agg_0]\n" + "| PathField [_id]\n" + "| PathConstant []\n" + "| Variable [groupByProj_0]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [groupByProj_0]\n" + "| aggregations: \n" + "| [c_agg_0]\n" + "| FunctionCall [$sum]\n" + "| Variable [groupByInputProj_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [groupByInputProj_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [b]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [groupByProj_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathIdentity []\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); + + PrefixId prefixId; + std::string scanDefName = "collection"; + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{scanDefName, ScanDefinition{{}, {}, {DistributionType::UnknownPartitioning}}}}, + 5 /*numberOfPartitions*/}, + DebugInfo::kDefaultForTests); + + ABT optimized = std::move(translated); + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | agg_project_0\n" + "| RefBlock: \n" + "| Variable [agg_project_0]\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: Centralized\n" + "| RefBlock: \n" + "Evaluation []\n" + "| BindBlock:\n" + "| [agg_project_0]\n" + "| EvalPath []\n" + "| | Const [{}]\n" + "| PathComposeM []\n" + "| | PathField [c]\n" + "| | PathConstant []\n" + "| | Variable [c_agg_0]\n" + "| PathField [_id]\n" + "| PathConstant []\n" + "| Variable [groupByProj_0]\n" + "GroupBy [Global]\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [groupByProj_0]\n" + "| aggregations: \n" + "| [c_agg_0]\n" + "| FunctionCall [$sum]\n" + "| Variable [preagg_0]\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: HashPartitioning\n" + "| | projections: \n" + "| | groupByProj_0\n" + "| RefBlock: \n" + "| Variable [groupByProj_0]\n" + "GroupBy [Local]\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [groupByProj_0]\n" + "| aggregations: \n" + "| [preagg_0]\n" + "| FunctionCall [$sum]\n" + "| Variable [groupByInputProj_0]\n" + "PhysicalScan [{'a': groupByProj_0, 'b': groupByInputProj_0}, collection, parallel]\n" + " BindBlock:\n" + " [groupByInputProj_0]\n" + " Source []\n" + " [groupByProj_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, UnwindBasic) { + ABT translated = translatePipeline("[{$unwind: {path: '$a.b.c'}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | embedProj_0\n" + "| RefBlock: \n" + "| Variable [embedProj_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [embedProj_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathField [a]\n" + "| PathTraverse []\n" + "| PathField [b]\n" + "| PathTraverse []\n" + "| PathField [c]\n" + "| PathConstant []\n" + "| Variable [unwoundProj_0]\n" + "Unwind []\n" + "| BindBlock:\n" + "| [unwoundPid_0]\n" + "| Source []\n" + "| [unwoundProj_0]\n" + "| Source []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [unwoundProj_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathGet [b]\n" + "| PathGet [c]\n" + "| PathIdentity []\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, UnwindComplex) { + ABT translated = translatePipeline( + "[{$unwind: {path: '$a.b.c', includeArrayIndex: 'p1.pid', preserveNullAndEmptyArrays: " + "true}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | embedPidProj_0\n" + "| RefBlock: \n" + "| Variable [embedPidProj_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [embedPidProj_0]\n" + "| EvalPath []\n" + "| | Variable [embedProj_0]\n" + "| PathField [p1]\n" + "| PathField [pid]\n" + "| PathConstant []\n" + "| If []\n" + "| | | Const [null]\n" + "| | Variable [unwoundPid_0]\n" + "| BinaryOp [Gte]\n" + "| | Const [0]\n" + "| Variable [unwoundPid_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [embedProj_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathField [a]\n" + "| PathTraverse []\n" + "| PathField [b]\n" + "| PathTraverse []\n" + "| PathField [c]\n" + "| PathLambda []\n" + "| LambdaAbstraction [unwoundLambdaVarName_0]\n" + "| If []\n" + "| | | Variable [unwoundLambdaVarName_0]\n" + "| | Variable [unwoundProj_0]\n" + "| BinaryOp [Gte]\n" + "| | Const [0]\n" + "| Variable [unwoundPid_0]\n" + "Unwind [retainNonArrays]\n" + "| BindBlock:\n" + "| [unwoundPid_0]\n" + "| Source []\n" + "| [unwoundProj_0]\n" + "| Source []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [unwoundProj_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathGet [b]\n" + "| PathGet [c]\n" + "| PathIdentity []\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, UnwindAndGroup) { + ABT translated = translatePipeline( + "[{$unwind:{path: '$a.b', preserveNullAndEmptyArrays: true}}, " + "{$group:{_id: '$a.b'}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | agg_project_0\n" + "| RefBlock: \n" + "| Variable [agg_project_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [agg_project_0]\n" + "| EvalPath []\n" + "| | Const [{}]\n" + "| PathField [_id]\n" + "| PathConstant []\n" + "| Variable [groupByProj_0]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [groupByProj_0]\n" + "| aggregations: \n" + "Evaluation []\n" + "| BindBlock:\n" + "| [groupByProj_0]\n" + "| EvalPath []\n" + "| | Variable [embedProj_0]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathGet [b]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [embedProj_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathField [a]\n" + "| PathTraverse []\n" + "| PathField [b]\n" + "| PathLambda []\n" + "| LambdaAbstraction [unwoundLambdaVarName_0]\n" + "| If []\n" + "| | | Variable [unwoundLambdaVarName_0]\n" + "| | Variable [unwoundProj_0]\n" + "| BinaryOp [Gte]\n" + "| | Const [0]\n" + "| Variable [unwoundPid_0]\n" + "Unwind [retainNonArrays]\n" + "| BindBlock:\n" + "| [unwoundPid_0]\n" + "| Source []\n" + "| [unwoundProj_0]\n" + "| Source []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [unwoundProj_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathGet [b]\n" + "| PathIdentity []\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, UnwindSort) { + ABT translated = translatePipeline("[{$unwind: '$x'}, {$sort: {'x': 1}}]"); + + PrefixId prefixId; + std::string scanDefName = "collection"; + OptPhaseManager phaseManager(OptPhaseManager::getAllRewritesSet(), + prefixId, + {{{scanDefName, ScanDefinition{{}, {}}}}}, + DebugInfo::kDefaultForTests); + + ABT optimized = translated; + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | embedProj_0\n" + "| RefBlock: \n" + "| Variable [embedProj_0]\n" + "Collation []\n" + "| | collation: \n" + "| | sort_0: Ascending\n" + "| RefBlock: \n" + "| Variable [sort_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [sort_0]\n" + "| Variable [unwoundProj_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [embedProj_0]\n" + "| If []\n" + "| | | Variable [scan_0]\n" + "| | FunctionCall [setField]\n" + "| | | | Variable [unwoundProj_0]\n" + "| | | Const [\"x\"]\n" + "| | Variable [scan_0]\n" + "| BinaryOp [Or]\n" + "| | FunctionCall [isObject]\n" + "| | Variable [scan_0]\n" + "| FunctionCall [exists]\n" + "| Variable [unwoundProj_0]\n" + "Unwind []\n" + "| BindBlock:\n" + "| [unwoundPid_0]\n" + "| Source []\n" + "| [unwoundProj_0]\n" + "| Source []\n" + "PhysicalScan [{'<root>': scan_0, 'x': unwoundProj_0}, collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n" + " [unwoundProj_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, MatchIndex) { + PrefixId prefixId; + std::string scanDefName = "collection"; + + Metadata metadata = { + {{scanDefName, + ScanDefinition{{}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}}; + ABT translated = translatePipeline(metadata, "[{$match: {'a': 10}}]", scanDefName, prefixId); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [10]\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); + + ABT optimized = std::move(translated); + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': scan_0}, collection]\n" + "| | BindBlock:\n" + "| | [scan_0]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: " + "{[Const [10], Const [10]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, MatchIndexCovered) { + PrefixId prefixId; + std::string scanDefName = "collection"; + + Metadata metadata = { + {{scanDefName, + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}}}, + false /*multiKey*/}}}}}}}; + ABT translated = translatePipeline( + metadata, "[{$project: {_id: 0, a: 1}}, {$match: {'a': 10}}]", scanDefName, prefixId); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre, + OptPhaseManager::OptPhase::PathFuse, + OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); + + ABT optimized = std::move(translated); + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Const [{}]\n" + "| PathField [a]\n" + "| PathConstant []\n" + "| Variable [fieldProj_0]\n" + "IndexScan [{'<indexKey> 0': fieldProj_0}, scanDefName: collection, indexDefName: index1, " + "interval: {[Const [10], Const [10]]}]\n" + " BindBlock:\n" + " [fieldProj_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, MatchIndexCovered1) { + PrefixId prefixId; + std::string scanDefName = "collection"; + + Metadata metadata = { + {{scanDefName, + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}}}, + false /*multiKey*/}}}}}}}; + ABT translated = translatePipeline( + metadata, "[{$match: {'a': 10}}, {$project: {_id: 0, a: 1}}]", scanDefName, prefixId); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre, + OptPhaseManager::OptPhase::PathFuse, + OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); + + ABT optimized = std::move(translated); + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Const [{}]\n" + "| PathField [a]\n" + "| PathConstant []\n" + "| Variable [fieldProj_0]\n" + "IndexScan [{'<indexKey> 0': fieldProj_0}, scanDefName: collection, indexDefName: index1, " + "interval: {[Const [10], Const [10]]}]\n" + " BindBlock:\n" + " [fieldProj_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, MatchIndexCovered2) { + PrefixId prefixId; + std::string scanDefName = "collection"; + + Metadata metadata = { + {{scanDefName, + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("b"), CollationOp::Ascending}}}, + false /*multiKey*/}}}}}}}; + ABT translated = translatePipeline(metadata, + "[{$match: {'a': 10, 'b': 20}}, {$project: {_id: 0, a: 1}}]", + scanDefName, + prefixId); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre, + OptPhaseManager::OptPhase::PathFuse, + OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); + + ABT optimized = std::move(translated); + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Const [{}]\n" + "| PathField [a]\n" + "| PathConstant []\n" + "| Variable [fieldProj_0]\n" + "IndexScan [{'<indexKey> 0': fieldProj_0}, scanDefName: collection, indexDefName: index1, " + "interval: {[Const [10], Const [10]], [Const [20], Const [20]]}]\n" + " BindBlock:\n" + " [fieldProj_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, MatchIndexCovered3) { + PrefixId prefixId; + std::string scanDefName = "collection"; + + Metadata metadata = { + {{scanDefName, + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("b"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("c"), CollationOp::Ascending}}}, + false /*multiKey*/}}}}}}}; + ABT translated = translatePipeline( + metadata, + "[{$match: {'a': 10, 'b': 20, 'c': 30}}, {$project: {_id: 0, a: 1, b: 1, c: 1}}]", + scanDefName, + prefixId); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre, + OptPhaseManager::OptPhase::PathFuse, + OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); + + ABT optimized = std::move(translated); + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Const [{}]\n" + "| PathComposeM []\n" + "| | PathField [c]\n" + "| | PathConstant []\n" + "| | Variable [fieldProj_2]\n" + "| PathComposeM []\n" + "| | PathField [b]\n" + "| | PathConstant []\n" + "| | Variable [fieldProj_1]\n" + "| PathField [a]\n" + "| PathConstant []\n" + "| Variable [fieldProj_0]\n" + "IndexScan [{'<indexKey> 0': fieldProj_0, '<indexKey> 1': fieldProj_1, '<indexKey> 2': " + "fieldProj_2}, scanDefName: collection, indexDefName: index1, interval: {[Const [10], " + "Const [10]], [Const [20], Const [20]], [Const [30], Const [30]]}]\n" + " BindBlock:\n" + " [fieldProj_0]\n" + " Source []\n" + " [fieldProj_1]\n" + " Source []\n" + " [fieldProj_2]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, MatchIndexCovered4) { + PrefixId prefixId; + std::string scanDefName = "collection"; + + Metadata metadata = { + {{scanDefName, + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("b"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("c"), CollationOp::Ascending}}}, + false /*multiKey*/}}}}}}}; + ABT translated = translatePipeline( + metadata, + "[{$project: {_id: 0, a: 1, b: 1, c: 1}}, {$match: {'a': 10, 'b': 20, 'c': 30}}]", + scanDefName, + prefixId); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::ConstEvalPre, + OptPhaseManager::OptPhase::PathFuse, + OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); + + ABT optimized = std::move(translated); + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Const [{}]\n" + "| PathComposeM []\n" + "| | PathField [c]\n" + "| | PathConstant []\n" + "| | Variable [fieldProj_2]\n" + "| PathComposeM []\n" + "| | PathField [b]\n" + "| | PathConstant []\n" + "| | Variable [fieldProj_1]\n" + "| PathField [a]\n" + "| PathConstant []\n" + "| Variable [fieldProj_0]\n" + "IndexScan [{'<indexKey> 0': fieldProj_0, '<indexKey> 1': fieldProj_1, '<indexKey> 2': " + "fieldProj_2}, scanDefName: collection, indexDefName: index1, interval: {[Const [10], " + "Const [10]], [Const [20], Const [20]], [Const [30], Const [30]]}]\n" + " BindBlock:\n" + " [fieldProj_0]\n" + " Source []\n" + " [fieldProj_1]\n" + " Source []\n" + " [fieldProj_2]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, MatchSortIndex) { + PrefixId prefixId; + std::string scanDefName = "collection"; + + Metadata metadata = { + {{scanDefName, + ScanDefinition{{}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}}; + ABT translated = translatePipeline( + metadata, "[{$match: {'a': 10}}, {$sort: {'a': 1}}]", scanDefName, prefixId); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Collation []\n" + "| | collation: \n" + "| | sort_0: Ascending\n" + "| RefBlock: \n" + "| Variable [sort_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [sort_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathIdentity []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [10]\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); + + ABT optimized = std::move(translated); + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Collation []\n" + "| | collation: \n" + "| | sort_0: Ascending\n" + "| RefBlock: \n" + "| Variable [sort_0]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': scan_0, 'a': sort_0}, collection]\n" + "| | BindBlock:\n" + "| | [scan_0]\n" + "| | Source []\n" + "| | [sort_0]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: " + "{[Const [10], Const [10]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, RangeIndex) { + PrefixId prefixId; + std::string scanDefName = "collection"; + Metadata metadata = { + {{scanDefName, + ScanDefinition{{}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}}; + ABT translated = + translatePipeline(metadata, "[{$match: {'a': {$gt: 70, $lt: 90}}}]", scanDefName, prefixId); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathComposeM []\n" + "| | PathCompare [Gte]\n" + "| | Const [nan]\n" + "| PathCompare [Lt]\n" + "| Const [90]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathComposeM []\n" + "| | PathCompare [Lt]\n" + "| | Const [\"\"]\n" + "| PathCompare [Gt]\n" + "| Const [70]\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); + + // Demonstrate we can get an intersection plan, even though it might not be the best one under + // the heuristic CE. + phaseManager.getHints()._disableScan = true; + + ABT optimized = std::move(translated); + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': scan_0}, collection]\n" + "| | BindBlock:\n" + "| | [scan_0]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | FunctionCall [getArraySize]\n" + "| | Variable [sides_0]\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [rid_0]\n" + "| aggregations: \n" + "| [sides_0]\n" + "| FunctionCall [$addToSet]\n" + "| Variable [sideId_0]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [rid_0]\n" + "| | Source []\n" + "| | [sideId_0]\n" + "| | Source []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [sideId_0]\n" + "| | Const [1]\n" + "| IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: " + "{[Const [nan], Const [90])}]\n" + "| BindBlock:\n" + "| [rid_0]\n" + "| Source []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [sideId_0]\n" + "| Const [0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: " + "{(Const [70], Const [\"\"])}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, Index1) { + { + PrefixId prefixId; + std::string scanDefName = "collection"; + Metadata metadata = { + {{scanDefName, + ScanDefinition{{}, + {{"index1", + IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}, + {makeIndexPath("b"), CollationOp::Ascending}}, + true /*multiKey*/}}}}}}}; + + ABT translated = + translatePipeline(metadata, "[{$match: {'a': 2, 'b': 2}}]", scanDefName, prefixId); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [scan_0]\n" + "| PathGet [b]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); + + ABT optimized = translated; + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': scan_0}, collection]\n" + "| | BindBlock:\n" + "| | [scan_0]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: " + "{[Const [2], Const [2]], [Const [2], Const [2]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); + } + + { + PrefixId prefixId; + std::string scanDefName = "collection"; + Metadata metadata = { + {{scanDefName, + ScanDefinition{{}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}}; + + ABT translated = + translatePipeline(metadata, "[{$match: {'a': 2, 'b': 2}}]", scanDefName, prefixId); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [scan_0]\n" + "| PathGet [b]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [scan_0]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); + + // Demonstrate we can use an index over only one field. + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase, + OptPhaseManager::OptPhase::ConstEvalPost}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); + + ABT optimized = translated; + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| Filter []\n" + "| | EvalFilter []\n" + "| | | Variable [evalTemp_2]\n" + "| | PathTraverse []\n" + "| | PathCompare [Eq]\n" + "| | Const [2]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': scan_0, 'b': evalTemp_2}, collection]\n" + "| | BindBlock:\n" + "| | [evalTemp_2]\n" + "| | Source []\n" + "| | [scan_0]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: " + "{[Const [2], Const [2]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); + } +} + +TEST(ABTTranslate, GroupMultiKey) { + ABT translated = translatePipeline( + "[{$group: {_id: {'isin': '$isin', 'year': '$year'}, 'count': {$sum: 1}, 'open': {$first: " + "'$$ROOT'}}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | agg_project_0\n" + "| RefBlock: \n" + "| Variable [agg_project_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [agg_project_0]\n" + "| EvalPath []\n" + "| | Const [{}]\n" + "| PathComposeM []\n" + "| | PathField [open]\n" + "| | PathConstant []\n" + "| | Variable [open_agg_0]\n" + "| PathComposeM []\n" + "| | PathField [count]\n" + "| | PathConstant []\n" + "| | Variable [count_agg_0]\n" + "| PathComposeM []\n" + "| | PathField [_id.year]\n" + "| | PathConstant []\n" + "| | Variable [groupByProj_1]\n" + "| PathField [_id.isin]\n" + "| PathConstant []\n" + "| Variable [groupByProj_0]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [groupByProj_0]\n" + "| | Variable [groupByProj_1]\n" + "| aggregations: \n" + "| [count_agg_0]\n" + "| FunctionCall [$sum]\n" + "| Const [1]\n" + "| [open_agg_0]\n" + "| FunctionCall [$first]\n" + "| Variable [scan_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [groupByProj_1]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [year]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [groupByProj_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [isin]\n" + "| PathIdentity []\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, GroupEvalNoInline) { + ABT translated = translatePipeline("[{$group: {_id: null, a: {$first: '$b'}}}]"); + + PrefixId prefixId; + std::string scanDefName = "collection"; + OptPhaseManager phaseManager(OptPhaseManager::getAllRewritesSet(), + prefixId, + {{{scanDefName, ScanDefinition{{}, {}}}}}, + DebugInfo::kDefaultForTests); + + ABT optimized = translated; + ASSERT_TRUE(phaseManager.optimize(optimized)); + + // Verify that "b" is not inlined in the group expression, but is coming from the physical scan. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | agg_project_0\n" + "| RefBlock: \n" + "| Variable [agg_project_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [agg_project_0]\n" + "| Let [inputField_1]\n" + "| | If []\n" + "| | | | Variable [inputField_1]\n" + "| | | FunctionCall [setField]\n" + "| | | | | Variable [a_agg_0]\n" + "| | | | Const [\"a\"]\n" + "| | | Variable [inputField_1]\n" + "| | BinaryOp [Or]\n" + "| | | FunctionCall [isObject]\n" + "| | | Variable [inputField_1]\n" + "| | FunctionCall [exists]\n" + "| | Variable [a_agg_0]\n" + "| If []\n" + "| | | Const [{}]\n" + "| | FunctionCall [setField]\n" + "| | | | Variable [groupByProj_0]\n" + "| | | Const [\"_id\"]\n" + "| | Const [{}]\n" + "| BinaryOp [Or]\n" + "| | FunctionCall [isObject]\n" + "| | Const [{}]\n" + "| FunctionCall [exists]\n" + "| Variable [groupByProj_0]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [groupByProj_0]\n" + "| aggregations: \n" + "| [a_agg_0]\n" + "| FunctionCall [$first]\n" + "| Variable [groupByInputProj_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [groupByProj_0]\n" + "| Const [null]\n" + "PhysicalScan [{'b': groupByInputProj_0}, collection]\n" + " BindBlock:\n" + " [groupByInputProj_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, ArrayExpr) { + ABT translated = translatePipeline("[{$project: {a: ['$b', '$c']}}]"); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathComposeM []\n" + "| | PathDefault []\n" + "| | Const [{}]\n" + "| PathComposeM []\n" + "| | PathField [a]\n" + "| | PathConstant []\n" + "| | Variable [projGetPath_0]\n" + "| PathKeep [_id, a]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [projGetPath_0]\n" + "| FunctionCall [newArray]\n" + "| | EvalPath []\n" + "| | | Variable [scan_0]\n" + "| | PathGet [c]\n" + "| | PathIdentity []\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathGet [b]\n" + "| PathIdentity []\n" + "Scan [collection]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); +} + +TEST(ABTTranslate, Union) { + PrefixId prefixId; + std::string scanDefA = "collA"; + std::string scanDefB = "collB"; + + Metadata metadata{{{scanDefA, {}}, {scanDefB, {}}}}; + ABT translated = translatePipeline(metadata, + "[{$unionWith: 'collB'}, {$match: {_id: 1}}]", + prefixId.getNextId("scan"), + scanDefA, + prefixId, + {{NamespaceString("a." + scanDefB), {}}}); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [scan_0]\n" + "| PathGet [_id]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [scan_0]\n" + "| | Source []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [scan_0]\n" + "| | EvalPath []\n" + "| | | Variable [scan_1]\n" + "| | PathIdentity []\n" + "| Scan [collB]\n" + "| BindBlock:\n" + "| [scan_1]\n" + "| Source []\n" + "Scan [collA]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + translated); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{scanDefA, ScanDefinition{{}, {}}}, {scanDefB, ScanDefinition{{}, {}}}}}, + DebugInfo::kDefaultForTests); + + ABT optimized = translated; + ASSERT_TRUE(phaseManager.optimize(optimized)); + + // Note that the optimized ABT will show the filter push-down. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [scan_0]\n" + "| | Source []\n" + "| Filter []\n" + "| | EvalFilter []\n" + "| | | Variable [scan_0]\n" + "| | PathGet [_id]\n" + "| | PathTraverse []\n" + "| | PathCompare [Eq]\n" + "| | Const [1]\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [scan_0]\n" + "| | EvalPath []\n" + "| | | Variable [scan_1]\n" + "| | PathIdentity []\n" + "| PhysicalScan [{'<root>': scan_1}, collB]\n" + "| BindBlock:\n" + "| [scan_1]\n" + "| Source []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_0]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "PhysicalScan [{'<root>': scan_0, '_id': evalTemp_0}, collA]\n" + " BindBlock:\n" + " [evalTemp_0]\n" + " Source []\n" + " [scan_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, PartialIndex) { + PrefixId prefixId; + std::string scanDefName = "collection"; + ProjectionName scanProjName = prefixId.getNextId("scan"); + + // The expression matches the pipeline. + // By default the constant is translated as "int32". + auto conversionResult = convertExprToPartialSchemaReq(make<EvalFilter>( + make<PathGet>("b", + make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int32(2)))), + make<Variable>(scanProjName))); + ASSERT_TRUE(conversionResult._success); + ASSERT_FALSE(conversionResult._hasEmptyInterval); + + Metadata metadata = { + {{scanDefName, + ScanDefinition{{}, + {{"index1", + IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}}, + true /*multiKey*/, + {DistributionType::Centralized}, + std::move(conversionResult._reqMap)}}}}}}}; + + ABT translated = translatePipeline( + metadata, "[{$match: {'a': 3, 'b': 2}}]", scanProjName, scanDefName, prefixId); + + OptPhaseManager phaseManager( + OptPhaseManager::getAllRewritesSet(), prefixId, metadata, DebugInfo::kDefaultForTests); + + ABT optimized = translated; + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| Filter []\n" + "| | FunctionCall [traverseF]\n" + "| | | | Const [false]\n" + "| | | LambdaAbstraction [valCmp_0]\n" + "| | | BinaryOp [Eq]\n" + "| | | | Const [2]\n" + "| | | Variable [valCmp_0]\n" + "| | Variable [evalTemp_2]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': scan_0, 'b': evalTemp_2}, collection]\n" + "| | BindBlock:\n" + "| | [evalTemp_2]\n" + "| | Source []\n" + "| | [scan_0]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: collection, indexDefName: index1, interval: " + "{[Const [3], Const [3]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, PartialIndexNegative) { + PrefixId prefixId; + std::string scanDefName = "collection"; + ProjectionName scanProjName = prefixId.getNextId("scan"); + + // The expression does not match the pipeline. + auto conversionResult = convertExprToPartialSchemaReq(make<EvalFilter>( + make<PathGet>("b", + make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int32(2)))), + make<Variable>(scanProjName))); + ASSERT_TRUE(conversionResult._success); + ASSERT_FALSE(conversionResult._hasEmptyInterval); + + Metadata metadata = { + {{scanDefName, + ScanDefinition{{}, + {{"index1", + IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}}, + true /*multiKey*/, + {DistributionType::Centralized}, + std::move(conversionResult._reqMap)}}}}}}}; + + ABT translated = translatePipeline( + metadata, "[{$match: {'a': 3, 'b': 3}}]", scanProjName, scanDefName, prefixId); + + OptPhaseManager phaseManager( + OptPhaseManager::getAllRewritesSet(), prefixId, metadata, DebugInfo::kDefaultForTests); + + ABT optimized = translated; + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Filter []\n" + "| FunctionCall [traverseF]\n" + "| | | Const [false]\n" + "| | LambdaAbstraction [valCmp_1]\n" + "| | BinaryOp [Eq]\n" + "| | | Const [3]\n" + "| | Variable [valCmp_1]\n" + "| Variable [evalTemp_1]\n" + "Filter []\n" + "| FunctionCall [traverseF]\n" + "| | | Const [false]\n" + "| | LambdaAbstraction [valCmp_0]\n" + "| | BinaryOp [Eq]\n" + "| | | Const [3]\n" + "| | Variable [valCmp_0]\n" + "| Variable [evalTemp_0]\n" + "PhysicalScan [{'<root>': scan_0, 'a': evalTemp_0, 'b': evalTemp_1}, collection]\n" + " BindBlock:\n" + " [evalTemp_0]\n" + " Source []\n" + " [evalTemp_1]\n" + " Source []\n" + " [scan_0]\n" + " Source []\n", + optimized); +} + +TEST(ABTTranslate, CommonExpressionElimination) { + PrefixId prefixId; + Metadata metadata = {{{"test", {{}, {}}}}}; + + auto rootNode = + translatePipeline(metadata, + "[{$project: {foo: {$add: ['$b', 1]}, bar: {$add: ['$b', 1]}}}]", + "test", + prefixId); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::ConstEvalPre}, prefixId, metadata, DebugInfo::kDefaultForTests); + + ASSERT_TRUE(phaseManager.optimize(rootNode)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | combinedProjection_0\n" + "| RefBlock: \n" + "| Variable [combinedProjection_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [combinedProjection_0]\n" + "| EvalPath []\n" + "| | Variable [scan_0]\n" + "| PathComposeM []\n" + "| | PathDefault []\n" + "| | Const [{}]\n" + "| PathComposeM []\n" + "| | PathField [foo]\n" + "| | PathConstant []\n" + "| | Variable [projGetPath_0]\n" + "| PathComposeM []\n" + "| | PathField [bar]\n" + "| | PathConstant []\n" + "| | Variable [projGetPath_0]\n" + "| PathKeep [_id, bar, foo]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [projGetPath_0]\n" + "| BinaryOp [Add]\n" + "| | EvalPath []\n" + "| | | Variable [scan_0]\n" + "| | PathGet [b]\n" + "| | PathIdentity []\n" + "| Const [1]\n" + "Scan [test]\n" + " BindBlock:\n" + " [scan_0]\n" + " Source []\n", + rootNode); +} + +} // namespace +} // namespace mongo diff --git a/src/mongo/db/pipeline/abt/utils.cpp b/src/mongo/db/pipeline/abt/utils.cpp new file mode 100644 index 00000000000..18e22eb265d --- /dev/null +++ b/src/mongo/db/pipeline/abt/utils.cpp @@ -0,0 +1,104 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/pipeline/abt/utils.h" + +namespace mongo::optimizer { + +template <bool isConjunction, typename... Args> +ABT generateConjunctionOrDisjunction(Args&... args) { + ABTVector elements; + (elements.emplace_back(args), ...); + + if (elements.size() == 0) { + return Constant::boolean(isConjunction); + } + + ABT result = std::move(elements.at(0)); + for (size_t i = 1; i < elements.size(); i++) { + result = make<BinaryOp>(isConjunction ? Operations::And : Operations::Or, + std::move(elements.at(i)), + std::move(result)); + } + return result; +} + +ABT generateCoerceToBool(ABT input, const std::string& varName) { + // Adapted from sbe_stage_builder_expression.cpp::generateCoerceToBoolExpression. + + const auto makeNeqCheckFn = [&varName](ABT valExpr) { + return make<BinaryOp>( + Operations::Neq, + make<BinaryOp>(Operations::Cmp3w, make<Variable>(varName), std::move(valExpr)), + Constant::int64(0)); + }; + + // If any of these are false, the branch is considered false for the purposes of the + // any logical expression. + ABT checkExists = make<FunctionCall>("exists", makeSeq(input)); + ABT checkNotNull = make<UnaryOp>(Operations::Not, make<FunctionCall>("isNull", makeSeq(input))); + ABT checkNotFalse = makeNeqCheckFn(Constant::boolean(false)); + ABT checkNotZero = makeNeqCheckFn(Constant::int64(0)); + + return make<Let>(varName, + std::move(input), + generateConjunctionOrDisjunction<true /*isConjunction*/>( + checkExists, checkNotNull, checkNotFalse, checkNotZero)); +}; + +std::pair<sbe::value::TypeTags, sbe::value::Value> convertFrom(const Value val) { + // TODO: Either make this conversion unnecessary by changing the value representation in + // ExpressionConstant, or provide a nicer way to convert directly from Document/Value to + // sbe::Value. + BSONObjBuilder bob; + val.addToBsonObj(&bob, ""_sd); + auto obj = bob.done(); + auto be = obj.objdata(); + auto end = be + ConstDataView(be).read<LittleEndian<uint32_t>>(); + return sbe::bson::convertFrom<false>(be + 4, end, 0); +} + +ABT translateFieldPath(const FieldPath& fieldPath, + ABT initial, + const ABTFieldNameFn& fieldNameFn, + const size_t skipFromStart) { + ABT result = std::move(initial); + + const size_t fieldPathLength = fieldPath.getPathLength(); + bool isLastElement = true; + for (size_t i = fieldPathLength; i-- > skipFromStart;) { + result = + fieldNameFn(fieldPath.getFieldName(i).toString(), isLastElement, std::move(result)); + isLastElement = false; + } + + return result; +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/pipeline/abt/utils.h b/src/mongo/db/pipeline/abt/utils.h new file mode 100644 index 00000000000..85c1a643764 --- /dev/null +++ b/src/mongo/db/pipeline/abt/utils.h @@ -0,0 +1,54 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/exec/document_value/value.h" +#include "mongo/db/exec/sbe/values/bson.h" +#include "mongo/db/pipeline/field_path.h" +#include "mongo/db/query/optimizer/node.h" + +namespace mongo::optimizer { + +/** + * Generate an AST to compute "coerceToBool" on the input. Expects a variable name to use as a let + * variable. + */ +ABT generateCoerceToBool(ABT input, const std::string& varName); + +std::pair<sbe::value::TypeTags, sbe::value::Value> convertFrom(Value val); + +using ABTFieldNameFn = + std::function<ABT(const std::string& fieldName, const bool isLastElement, ABT input)>; +ABT translateFieldPath(const FieldPath& fieldPath, + ABT initial, + const ABTFieldNameFn& fieldNameFn, + size_t skipFromStart = 0); + +} // namespace mongo::optimizer diff --git a/src/mongo/db/pipeline/document_source_union_with.h b/src/mongo/db/pipeline/document_source_union_with.h index 1d9068a9f4c..239a41c9235 100644 --- a/src/mongo/db/pipeline/document_source_union_with.h +++ b/src/mongo/db/pipeline/document_source_union_with.h @@ -134,6 +134,10 @@ public: return &_stats; } + const Pipeline& getPipeline() const { + return *_pipeline; + } + boost::intrusive_ptr<DocumentSource> clone() const final; protected: diff --git a/src/mongo/db/pipeline/visitors/document_source_visitor.h b/src/mongo/db/pipeline/visitors/document_source_visitor.h new file mode 100644 index 00000000000..a0158147e38 --- /dev/null +++ b/src/mongo/db/pipeline/visitors/document_source_visitor.h @@ -0,0 +1,135 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/tree_walker.h" + +namespace mongo { + +class DocumentSourceBucketAuto; +class DocumentSourceCollStats; +class DocumentSourceCurrentOp; +class DocumentSourceCursor; +class DocumentSourceExchange; +class DocumentSourceFacet; +class DocumentSourceGeoNear; +class DocumentSourceGeoNearCursor; +class DocumentSourceGraphLookUp; +class DocumentSourceGroup; +class DocumentSourceIndexStats; +class DocumentSourceInternalInhibitOptimization; +class DocumentSourceInternalShardFilter; +class DocumentSourceInternalSplitPipeline; +class DocumentSourceLimit; +class DocumentSourceListCachedAndActiveUsers; +class DocumentSourceListLocalSessions; +class DocumentSourceListSessions; +class DocumentSourceLookUp; +class DocumentSourceMatch; +class DocumentSourceMerge; +class DocumentSourceMergeCursors; +class DocumentSourceOperationMetrics; +class DocumentSourceOut; +class DocumentSourcePlanCacheStats; +class DocumentSourceQueue; +class DocumentSourceRedact; +class DocumentSourceSample; +class DocumentSourceSampleFromRandomCursor; +class DocumentSourceSequentialDocumentCache; +class DocumentSourceSingleDocumentTransformation; +class DocumentSourceSkip; +class DocumentSourceSort; +class DocumentSourceTeeConsumer; +class DocumentSourceUnionWith; +class DocumentSourceUnwind; + +/** + * Visitor pattern for pipeline document sources. + * + * This code is not responsible for traversing the tree, only for performing the double-dispatch. + * + * If the visitor doesn't intend to modify the tree, then the template argument 'IsConst' should be + * set to 'true'. In this case all 'visit()' methods will take a const pointer to a visiting node. + */ +template <bool IsConst = false> +class DocumentSourceVisitor { +public: + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceBucketAuto> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceCollStats> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceCurrentOp> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceCursor> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceExchange> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceFacet> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceGeoNear> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceGeoNearCursor> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceGraphLookUp> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceGroup> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceIndexStats> source) = 0; + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, DocumentSourceInternalInhibitOptimization> source) = 0; + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, DocumentSourceInternalShardFilter> source) = 0; + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, DocumentSourceInternalSplitPipeline> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceLimit> source) = 0; + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, DocumentSourceListCachedAndActiveUsers> source) = 0; + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, DocumentSourceListLocalSessions> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceListSessions> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceLookUp> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceMatch> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceMerge> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceMergeCursors> source) = 0; + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, DocumentSourceOperationMetrics> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceOut> source) = 0; + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, DocumentSourcePlanCacheStats> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceQueue> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceRedact> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceSample> source) = 0; + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, DocumentSourceSampleFromRandomCursor> source) = 0; + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, DocumentSourceSequentialDocumentCache> source) = 0; + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, DocumentSourceSingleDocumentTransformation> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceSkip> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceSort> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceTeeConsumer> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceUnionWith> source) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, DocumentSourceUnwind> source) = 0; +}; + +using DocumentSourceMutableVisitor = DocumentSourceVisitor<false>; +using DocumentSourceConstVisitor = DocumentSourceVisitor<true>; + +} // namespace mongo diff --git a/src/mongo/db/pipeline/visitors/document_source_walker.cpp b/src/mongo/db/pipeline/visitors/document_source_walker.cpp new file mode 100644 index 00000000000..b0ea004cae9 --- /dev/null +++ b/src/mongo/db/pipeline/visitors/document_source_walker.cpp @@ -0,0 +1,144 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/pipeline/visitors/document_source_walker.h" + +#include "mongo/base/error_codes.h" +#include "mongo/db/pipeline/document_source_bucket_auto.h" +#include "mongo/db/pipeline/document_source_coll_stats.h" +#include "mongo/db/pipeline/document_source_current_op.h" +#include "mongo/db/pipeline/document_source_cursor.h" +#include "mongo/db/pipeline/document_source_exchange.h" +#include "mongo/db/pipeline/document_source_facet.h" +#include "mongo/db/pipeline/document_source_geo_near.h" +#include "mongo/db/pipeline/document_source_geo_near_cursor.h" +#include "mongo/db/pipeline/document_source_graph_lookup.h" +#include "mongo/db/pipeline/document_source_group.h" +#include "mongo/db/pipeline/document_source_index_stats.h" +#include "mongo/db/pipeline/document_source_internal_inhibit_optimization.h" +#include "mongo/db/pipeline/document_source_internal_shard_filter.h" +#include "mongo/db/pipeline/document_source_internal_split_pipeline.h" +#include "mongo/db/pipeline/document_source_limit.h" +#include "mongo/db/pipeline/document_source_list_cached_and_active_users.h" +#include "mongo/db/pipeline/document_source_list_local_sessions.h" +#include "mongo/db/pipeline/document_source_list_sessions.h" +#include "mongo/db/pipeline/document_source_lookup.h" +#include "mongo/db/pipeline/document_source_match.h" +#include "mongo/db/pipeline/document_source_merge.h" +#include "mongo/db/pipeline/document_source_operation_metrics.h" +#include "mongo/db/pipeline/document_source_out.h" +#include "mongo/db/pipeline/document_source_plan_cache_stats.h" +#include "mongo/db/pipeline/document_source_queue.h" +#include "mongo/db/pipeline/document_source_redact.h" +#include "mongo/db/pipeline/document_source_sample.h" +#include "mongo/db/pipeline/document_source_sample_from_random_cursor.h" +#include "mongo/db/pipeline/document_source_sequential_document_cache.h" +#include "mongo/db/pipeline/document_source_single_document_transformation.h" +#include "mongo/db/pipeline/document_source_skip.h" +#include "mongo/db/pipeline/document_source_sort.h" +#include "mongo/db/pipeline/document_source_tee_consumer.h" +#include "mongo/db/pipeline/document_source_union_with.h" +#include "mongo/db/pipeline/document_source_unwind.h" +#include "mongo/s/query/document_source_merge_cursors.h" + +namespace mongo { + +template <class T> +bool DocumentSourceWalker::visitHelper(const DocumentSource* source) { + const T* concrete = dynamic_cast<const T*>(source); + if (concrete == nullptr) { + return false; + } + + _postVisitor->visit(concrete); + return true; +} + +void DocumentSourceWalker::walk(const Pipeline& pipeline) { + const Pipeline::SourceContainer& sources = pipeline.getSources(); + + if (_postVisitor != nullptr) { + for (auto it = sources.begin(); it != sources.end(); it++) { + // TODO: use acceptVisitor method when DocumentSources get ability to visit. + // source->acceptVisitor(*_preVisitor); + // + // For now, however, we use a crutch walker which performs a series of dynamic casts. + // Some types are commented out because of dependency issues (e.g. not in pipeline + // target but in query_exec target) + const DocumentSource* ds = it->get(); + const bool visited = visitHelper<DocumentSourceBucketAuto>(ds) || + visitHelper<DocumentSourceBucketAuto>(ds) || + visitHelper<DocumentSourceCollStats>(ds) || + visitHelper<DocumentSourceCurrentOp>(ds) || + // TODO: uncomment after fixing dependency + // visitHelper<DocumentSourceCursor>(ds) || + visitHelper<DocumentSourceExchange>(ds) || visitHelper<DocumentSourceFacet>(ds) || + visitHelper<DocumentSourceGeoNear>(ds) || + + // TODO: uncomment after fixing dependency + //! visitHelper<DocumentSourceGeoNearCursor>(ds) || + visitHelper<DocumentSourceGraphLookUp>(ds) || + visitHelper<DocumentSourceGroup>(ds) || visitHelper<DocumentSourceIndexStats>(ds) || + visitHelper<DocumentSourceInternalInhibitOptimization>(ds) || + visitHelper<DocumentSourceInternalShardFilter>(ds) || + visitHelper<DocumentSourceInternalSplitPipeline>(ds) || + visitHelper<DocumentSourceLimit>(ds) || + visitHelper<DocumentSourceListCachedAndActiveUsers>(ds) || + visitHelper<DocumentSourceListLocalSessions>(ds) || + visitHelper<DocumentSourceListSessions>(ds) || + visitHelper<DocumentSourceLookUp>(ds) || visitHelper<DocumentSourceMatch>(ds) || + visitHelper<DocumentSourceMerge>(ds) || + // TODO: uncomment after fixing dependency + // visitHelper<DocumentSourceMergeCursors>(ds) || + visitHelper<DocumentSourceOperationMetrics>(ds) || + visitHelper<DocumentSourceOut>(ds) || + visitHelper<DocumentSourcePlanCacheStats>(ds) || + visitHelper<DocumentSourceQueue>(ds) || visitHelper<DocumentSourceRedact>(ds) || + visitHelper<DocumentSourceSample>(ds) || + visitHelper<DocumentSourceSampleFromRandomCursor>(ds) || + visitHelper<DocumentSourceSequentialDocumentCache>(ds) || + visitHelper<DocumentSourceSingleDocumentTransformation>(ds) || + visitHelper<DocumentSourceSkip>(ds) || visitHelper<DocumentSourceSort>(ds) || + visitHelper<DocumentSourceTeeConsumer>(ds) || + visitHelper<DocumentSourceUnionWith>(ds) || visitHelper<DocumentSourceUnwind>(ds) + // TODO: uncomment after fixing dependency + //&& visitHelper<DocumentSourceUpdateOnAddShard>(ds) + ; + + if (!visited) { + uasserted(ErrorCodes::InternalErrorNotSupported, + str::stream() << "Stage is not supported: " << ds->getSourceName()); + } + } + } + + // TODO: reverse for pre-visitor +} + +} // namespace mongo diff --git a/src/mongo/db/pipeline/visitors/document_source_walker.h b/src/mongo/db/pipeline/visitors/document_source_walker.h new file mode 100644 index 00000000000..cb4a844de4d --- /dev/null +++ b/src/mongo/db/pipeline/visitors/document_source_walker.h @@ -0,0 +1,58 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/pipeline/document_source.h" +#include "mongo/db/pipeline/pipeline.h" +#include "mongo/db/pipeline/visitors/document_source_visitor.h" + +namespace mongo { + +/** + * A document source walker. + * TODO: SERVER-62657. Implement a hash-table based resolution instead of sequential dynamic casts. + */ +class DocumentSourceWalker final { +public: + DocumentSourceWalker(DocumentSourceConstVisitor* preVisitor, + DocumentSourceConstVisitor* postVisitor) + : _preVisitor{preVisitor}, _postVisitor{postVisitor} {} + + void walk(const Pipeline& pipeline); + +private: + template <class T> + bool visitHelper(const DocumentSource* source); + + DocumentSourceConstVisitor* _preVisitor; + DocumentSourceConstVisitor* _postVisitor; +}; + +} // namespace mongo diff --git a/src/mongo/db/pipeline/visitors/transformer_interface_visitor.h b/src/mongo/db/pipeline/visitors/transformer_interface_visitor.h new file mode 100644 index 00000000000..e81075e8e9f --- /dev/null +++ b/src/mongo/db/pipeline/visitors/transformer_interface_visitor.h @@ -0,0 +1,66 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/tree_walker.h" + +namespace mongo { + +namespace projection_executor { +class AddFieldsProjectionExecutor; +class ExclusionProjectionExecutor; +class InclusionProjectionExecutor; +} // namespace projection_executor +class GroupFromFirstDocumentTransformation; +class ReplaceRootTransformation; + +/** + * Visitor pattern for Transformer Interface instances + */ +template <bool IsConst = false> +class TransformerInterfaceVisitor { +public: + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, projection_executor::AddFieldsProjectionExecutor> + visitor) = 0; + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, projection_executor::ExclusionProjectionExecutor> + visitor) = 0; + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, projection_executor::InclusionProjectionExecutor> + visitor) = 0; + virtual void visit( + tree_walker::MaybeConstPtr<IsConst, GroupFromFirstDocumentTransformation> visitor) = 0; + virtual void visit(tree_walker::MaybeConstPtr<IsConst, ReplaceRootTransformation> visitor) = 0; +}; + +using TransformerInterfaceMutableVisitor = TransformerInterfaceVisitor<false>; +using TransformerInterfaceConstVisitor = TransformerInterfaceVisitor<true>; +} // namespace mongo diff --git a/src/mongo/db/pipeline/visitors/transformer_interface_walker.cpp b/src/mongo/db/pipeline/visitors/transformer_interface_walker.cpp new file mode 100644 index 00000000000..11555f764b9 --- /dev/null +++ b/src/mongo/db/pipeline/visitors/transformer_interface_walker.cpp @@ -0,0 +1,72 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/pipeline/visitors/transformer_interface_walker.h" +#include "mongo/db/exec/add_fields_projection_executor.h" +#include "mongo/db/exec/exclusion_projection_executor.h" +#include "mongo/db/exec/inclusion_projection_executor.h" +#include "mongo/db/pipeline/document_source_group.h" +#include "mongo/db/pipeline/document_source_replace_root.h" + +namespace mongo { + +TransformerInterfaceWalker::TransformerInterfaceWalker(TransformerInterfaceConstVisitor* visitor) + : _visitor(visitor) {} + +void TransformerInterfaceWalker::walk(const TransformerInterface* transformer) { + switch (transformer->getType()) { + case TransformerInterface::TransformerType::kExclusionProjection: + _visitor->visit( + static_cast<const projection_executor::ExclusionProjectionExecutor*>(transformer)); + break; + + case TransformerInterface::TransformerType::kInclusionProjection: + _visitor->visit( + static_cast<const projection_executor::InclusionProjectionExecutor*>(transformer)); + break; + + case TransformerInterface::TransformerType::kComputedProjection: + _visitor->visit( + static_cast<const projection_executor::AddFieldsProjectionExecutor*>(transformer)); + break; + + case TransformerInterface::TransformerType::kReplaceRoot: + _visitor->visit(static_cast<const ReplaceRootTransformation*>(transformer)); + break; + + case TransformerInterface::TransformerType::kGroupFromFirstDocument: + _visitor->visit(static_cast<const GroupFromFirstDocumentTransformation*>(transformer)); + break; + + default: + MONGO_UNREACHABLE; + } +} + +} // namespace mongo diff --git a/src/mongo/db/pipeline/visitors/transformer_interface_walker.h b/src/mongo/db/pipeline/visitors/transformer_interface_walker.h new file mode 100644 index 00000000000..32c8248e00e --- /dev/null +++ b/src/mongo/db/pipeline/visitors/transformer_interface_walker.h @@ -0,0 +1,48 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + + +#include "mongo/db/pipeline/transformer_interface.h" +#include "mongo/db/pipeline/visitors/transformer_interface_visitor.h" + +namespace mongo { + +class TransformerInterfaceWalker final { +public: + TransformerInterfaceWalker(TransformerInterfaceConstVisitor* visitor); + + void walk(const TransformerInterface* transformer); + +private: + TransformerInterfaceConstVisitor* _visitor; +}; + +} // namespace mongo diff --git a/src/mongo/db/query/SConscript b/src/mongo/db/query/SConscript index 67b992fe02a..77d2fabc948 100644 --- a/src/mongo/db/query/SConscript +++ b/src/mongo/db/query/SConscript @@ -6,8 +6,10 @@ env = env.Clone() env.SConscript( dirs=[ + "ce", "collation", "datetime", + 'optimizer', ], exports=[ 'env' diff --git a/src/mongo/db/query/ce/SConscript b/src/mongo/db/query/ce/SConscript new file mode 100644 index 00000000000..8ab2ca62f51 --- /dev/null +++ b/src/mongo/db/query/ce/SConscript @@ -0,0 +1,16 @@ +# -*- mode: python -*- + +Import("env") + +env = env.Clone() + +env.Library( + target="query_ce", + source=[ + 'ce_sampling.cpp', + ], + LIBDEPS_PRIVATE=[ + '$BUILD_DIR/mongo/db/exec/sbe/query_sbe_abt', + '$BUILD_DIR/mongo/db/query/optimizer/optimizer', + ] +)
\ No newline at end of file diff --git a/src/mongo/db/query/ce/ce_sampling.cpp b/src/mongo/db/query/ce/ce_sampling.cpp new file mode 100644 index 00000000000..e95b836bd57 --- /dev/null +++ b/src/mongo/db/query/ce/ce_sampling.cpp @@ -0,0 +1,290 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/ce/ce_sampling.h" + +#include "mongo/db/exec/sbe/abt/abt_lower.h" +#include "mongo/db/query/optimizer/cascades/ce_heuristic.h" +#include "mongo/db/query/optimizer/explain.h" +#include "mongo/db/query/optimizer/utils/abt_hash.h" +#include "mongo/db/query/optimizer/utils/memo_utils.h" + +namespace mongo::optimizer::cascades { + +using namespace properties; + +class SamplingPlanExtractor { +public: + SamplingPlanExtractor(const Memo& memo, const size_t sampleSize) + : _memo(memo), _sampleSize(sampleSize) {} + + void transport(ABT& n, const MemoLogicalDelegatorNode& node) { + n = extract(_memo.getGroup(node.getGroupId())._logicalNodes.at(0)); + } + + void transport(ABT& n, const ScanNode& /*node*/, ABT& /*binder*/) { + // We will lower the scan node in a sampling context here. + // TODO: for now just return the documents in random order. + n = make<LimitSkipNode>(LimitSkipRequirement(_sampleSize, 0), std::move(n)); + } + + void transport(ABT& n, const FilterNode& /*node*/, ABT& childResult, ABT& /*exprResult*/) { + // Skip over filters. + n = childResult; + } + + void transport(ABT& /*n*/, + const EvaluationNode& /*node*/, + ABT& /*childResult*/, + ABT& /*exprResult*/) { + // Keep Eval nodes. + } + + void transport(ABT& n, const SargableNode& /*node*/, ABT& childResult, ABT& refs, ABT& binds) { + // Skip over sargable nodes. + n = childResult; + } + + template <typename T, typename... Ts> + void transport(ABT& /*n*/, const T& /*node*/, Ts&&...) { + if constexpr (std::is_base_of_v<Node, T>) { + uasserted(6624242, "Should not be seeing other types of nodes here."); + } + } + + ABT extract(ABT node) { + algebra::transport<true>(node, *this); + return node; + } + +private: + const Memo& _memo; + const size_t _sampleSize; +}; + +class CESamplingTransportImpl { + static constexpr size_t kMaxSampleSize = 1000; + +public: + CESamplingTransportImpl(OperationContext* opCtx, + OptPhaseManager& phaseManager, + const int64_t numRecords) + : _opCtx(opCtx), + _phaseManager(phaseManager), + _heuristicCE(), + _sampleSize(std::min<int64_t>(numRecords, kMaxSampleSize)) {} + + CEType transport(const ABT& n, + const FilterNode& node, + const Memo& memo, + const LogicalProps& logicalProps, + CEType childResult, + CEType /*exprResult*/) { + if (!hasProperty<IndexingAvailability>(logicalProps)) { + return _heuristicCE.deriveCE(memo, logicalProps, n.ref()); + } + + SamplingPlanExtractor planExtractor(memo, _sampleSize); + // Create a plan with all eval nodes so far and the filter last. + ABT abtTree = make<FilterNode>(node.getFilter(), planExtractor.extract(n)); + + return estimateFilterCE(memo, logicalProps, n, std::move(abtTree), childResult); + } + + CEType transport(const ABT& n, + const SargableNode& node, + const Memo& memo, + const LogicalProps& logicalProps, + CEType childResult, + CEType /*bindResult*/, + CEType /*refsResult*/) { + if (!hasProperty<IndexingAvailability>(logicalProps)) { + return _heuristicCE.deriveCE(memo, logicalProps, n.ref()); + } + + SamplingPlanExtractor planExtractor(memo, _sampleSize); + ABT extracted = planExtractor.extract(n); + + // Estimate individual requirements separately by potentially re-using cached results. + // Here we assume that each requirement is independent. + // TODO: consider estimating together the entire set of requirements (but caching!) + CEType result = childResult; + for (const auto& [key, req] : node.getReqMap()) { + if (!isIntervalReqFullyOpenDNF(req.getIntervals())) { + ABT lowered = extracted; + lowerPartialSchemaRequirement(key, req, lowered); + uassert(6624243, "Expected a filter node", lowered.is<FilterNode>()); + result = estimateFilterCE(memo, logicalProps, n, std::move(lowered), result); + } + } + + return result; + } + + /** + * Other ABT types. + */ + template <typename T, typename... Ts> + CEType transport(const ABT& n, + const T& /*node*/, + const Memo& memo, + const LogicalProps& logicalProps, + Ts&&...) { + if (canBeLogicalNode<T>()) { + return _heuristicCE.deriveCE(memo, logicalProps, n.ref()); + } + return 0.0; + } + + CEType derive(const Memo& memo, + const properties::LogicalProps& logicalProps, + const ABT::reference_type logicalNodeRef) { + return algebra::transport<true>(logicalNodeRef, *this, memo, logicalProps); + } + +private: + CEType estimateFilterCE(const Memo& memo, + const LogicalProps& logicalProps, + const ABT& n, + ABT abtTree, + CEType childResult) { + auto it = _selectivityCacheMap.find(abtTree); + if (it != _selectivityCacheMap.cend()) { + // Cache hit. + return it->second * childResult; + } + + const auto [success, selectivity] = estimateSelectivity(abtTree); + if (!success) { + return _heuristicCE.deriveCE(memo, logicalProps, n.ref()); + } + + _selectivityCacheMap.emplace(std::move(abtTree), selectivity); + std::cerr << "Sampling sel.: " << selectivity << "\n"; + return selectivity * childResult; + } + + std::pair<bool, SelectivityType> estimateSelectivity(ABT abtTree) { + // Add a group by to count number of documents. + const ProjectionName sampleSumProjection = "sum"; + abtTree = + make<GroupByNode>(ProjectionNameVector{}, + ProjectionNameVector{sampleSumProjection}, + makeSeq(make<FunctionCall>("$sum", makeSeq(Constant::int64(1)))), + std::move(abtTree)); + abtTree = make<RootNode>( + properties::ProjectionRequirement{ProjectionNameVector{sampleSumProjection}}, + std::move(abtTree)); + + std::cerr << "********* Sampling ABT *********\n"; + std::cerr << ExplainGenerator::explainV2(abtTree); + std::cerr << "********* Sampling ABT *********\n"; + + if (!_phaseManager.optimize(abtTree)) { + return {false, {}}; + } + + auto env = VariableEnvironment::build(abtTree); + SlotVarMap slotMap; + sbe::value::SlotIdGenerator ids; + SBENodeLowering g{env, + slotMap, + ids, + _phaseManager.getMetadata(), + _phaseManager.getNodeToGroupPropsMap(), + _phaseManager.getRIDProjections(), + true /*randomScan*/}; + auto sbePlan = g.optimize(abtTree); + + // TODO: return errors instead of exceptions? + uassert(6624244, "Lowering failed", sbePlan != nullptr); + uassert(6624245, "Invalid slot map size", slotMap.size() == 1); + + sbePlan->attachToOperationContext(_opCtx); + sbe::CompileCtx ctx(std::make_unique<sbe::RuntimeEnvironment>()); + sbePlan->prepare(ctx); + + std::vector<sbe::value::SlotAccessor*> accessors; + for (auto& [name, slot] : slotMap) { + accessors.emplace_back(sbePlan->getAccessor(ctx, slot)); + } + + sbePlan->open(false); + ON_BLOCK_EXIT([&] { sbePlan->close(); }); + + while (sbePlan->getNext() != sbe::PlanState::IS_EOF) { + const auto [tag, value] = accessors.at(0)->getViewOfValue(); + if (tag == sbe::value::TypeTags::NumberInt64) { + // TODO: check if we get exactly one result from the groupby? + return {true, static_cast<double>(value) / _sampleSize}; + } + return {false, {}}; + }; + + // If nothing passes the filter, estimate 0.0 selectivity. HashGroup will return 0 results. + return {true, 0.0}; + } + + struct NodeRefHash { + size_t operator()(const ABT& node) const { + return ABTHashGenerator::generate(node); + } + }; + + struct NodeRefCompare { + bool operator()(const ABT& left, const ABT& right) const { + return left == right; + } + }; + + // Cache a logical node reference to computed selectivity. Used for Filter and Sargable nodes. + opt::unordered_map<ABT, SelectivityType, NodeRefHash, NodeRefCompare> _selectivityCacheMap; + + // We don't own those. + OperationContext* _opCtx; + OptPhaseManager& _phaseManager; + + HeuristicCE _heuristicCE; + const int64_t _sampleSize; +}; + +CESamplingTransport::CESamplingTransport(OperationContext* opCtx, + OptPhaseManager& phaseManager, + const int64_t numRecords) + : _impl(std::make_unique<CESamplingTransportImpl>(opCtx, phaseManager, numRecords)) {} + +CESamplingTransport::~CESamplingTransport() {} + +CEType CESamplingTransport::deriveCE(const Memo& memo, + const LogicalProps& logicalProps, + const ABT::reference_type logicalNodeRef) const { + return _impl->derive(memo, logicalProps, logicalNodeRef); +} + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/ce/ce_sampling.h b/src/mongo/db/query/ce/ce_sampling.h new file mode 100644 index 00000000000..68eada866a1 --- /dev/null +++ b/src/mongo/db/query/ce/ce_sampling.h @@ -0,0 +1,52 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/cascades/interfaces.h" +#include "mongo/db/query/optimizer/opt_phase_manager.h" + +namespace mongo::optimizer::cascades { + +class CESamplingTransportImpl; + +class CESamplingTransport : public CEInterface { +public: + CESamplingTransport(OperationContext* opCtx, OptPhaseManager& phaseManager, int64_t numRecords); + ~CESamplingTransport(); + + CEType deriveCE(const Memo& memo, + const properties::LogicalProps& logicalProps, + ABT::reference_type logicalNodeRef) const final; + +private: + std::unique_ptr<CESamplingTransportImpl> _impl; +}; + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/get_executor.cpp b/src/mongo/db/query/get_executor.cpp index a1a0716d29e..cd61dfefe7c 100644 --- a/src/mongo/db/query/get_executor.cpp +++ b/src/mongo/db/query/get_executor.cpp @@ -1318,6 +1318,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getSlotBasedExe std::move(cq), std::move(solutions[0]), std::move(roots[0]), + {}, mainColl, plannerOptions, std::move(nss), diff --git a/src/mongo/db/query/optimizer/SConscript b/src/mongo/db/query/optimizer/SConscript new file mode 100644 index 00000000000..8843d549d8e --- /dev/null +++ b/src/mongo/db/query/optimizer/SConscript @@ -0,0 +1,76 @@ +# -*- mode: python -*- + +Import("env") + +env = env.Clone() + +env.SConscript( + dirs=[ + "algebra", + ], + exports=[ + 'env' + ], +) + +env.Library( + target="optimizer", + source=[ + "cascades/ce_heuristic.cpp", + "cascades/ce_hinted.cpp", + "cascades/cost_derivation.cpp", + "cascades/enforcers.cpp", + "cascades/implementers.cpp", + "cascades/logical_props_derivation.cpp", + "cascades/logical_rewriter.cpp", + "cascades/memo.cpp", + "cascades/physical_rewriter.cpp", + "cascades/rewrite_queues.cpp", + "defs.cpp", + "explain.cpp", + "index_bounds.cpp", + "metadata.cpp", + "node.cpp", + "opt_phase_manager.cpp", + "props.cpp", + "reference_tracker.cpp", + "rewrites/const_eval.cpp", + "rewrites/path.cpp", + "rewrites/path_lower.cpp", + "syntax/expr.cpp", + "utils/abt_hash.cpp", + "utils/interval_utils.cpp", + "utils/memo_utils.cpp", + "utils/utils.cpp" + ], + LIBDEPS=[ + "$BUILD_DIR/mongo/db/exec/sbe/query_sbe_values", + ], +) + +env.Library( + target="unit_test_utils", + source=[ + "utils/unit_test_utils.cpp", + ], + LIBDEPS=[ + "$BUILD_DIR/mongo/db/pipeline/pipeline", + "$BUILD_DIR/mongo/db/query/query_test_service_context", + "$BUILD_DIR/mongo/unittest/unittest", + ], +) + +env.CppUnitTest( + target='optimizer_test', + source=[ + "logical_rewriter_optimizer_test.cpp", + "optimizer_test.cpp", + "physical_rewriter_optimizer_test.cpp", + "rewrites/path_optimizer_test.cpp", + "interval_intersection_test.cpp", + ], + LIBDEPS=[ + "optimizer", + "unit_test_utils", + ] +) diff --git a/src/mongo/db/query/optimizer/algebra/SConscript b/src/mongo/db/query/optimizer/algebra/SConscript new file mode 100644 index 00000000000..0d2a48c24d3 --- /dev/null +++ b/src/mongo/db/query/optimizer/algebra/SConscript @@ -0,0 +1,15 @@ +# -*- mode: python -*- + +Import("env") + +env = env.Clone() + +env.CppUnitTest( + target='algebra_test', + source=[ + 'algebra_test.cpp', + ], + LIBDEPS=[ + + ] +) diff --git a/src/mongo/db/query/optimizer/algebra/algebra_test.cpp b/src/mongo/db/query/optimizer/algebra/algebra_test.cpp new file mode 100644 index 00000000000..48e668a6e32 --- /dev/null +++ b/src/mongo/db/query/optimizer/algebra/algebra_test.cpp @@ -0,0 +1,570 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/algebra/operator.h" +#include "mongo/db/query/optimizer/algebra/polyvalue.h" +#include "mongo/unittest/unittest.h" + +namespace mongo::optimizer::algebra { + +namespace { + +class Leaf; +class BinaryNode; +class NaryNode; +class AtLeastBinaryNode; +using Tree = PolyValue<Leaf, BinaryNode, NaryNode, AtLeastBinaryNode>; + +/** + * A leaf in the tree. Just contains data - in this case a double. + */ +class Leaf : public OpSpecificArity<Tree, Leaf, 0> { +public: + Leaf(double x) : x(x) {} + double x; +}; + +/** + * An inner node in the tree with exactly two children. + */ +class BinaryNode : public OpSpecificArity<Tree, BinaryNode, 2> { +public: + BinaryNode(Tree left, Tree right) + : OpSpecificArity<Tree, BinaryNode, 2>(std::move(left), std::move(right)) {} +}; + +/** + * An inner node in the tree with any number of children, zero or greater. + */ +class NaryNode : public OpSpecificDynamicArity<Tree, NaryNode, 0> { +public: + NaryNode(std::vector<Tree> children) + : OpSpecificDynamicArity<Tree, NaryNode, 0>(std::move(children)) {} +}; + +/** + * An inner node in the tree with 2 or more nodes. + */ +class AtLeastBinaryNode : public OpSpecificDynamicArity<Tree, AtLeastBinaryNode, 2> { +public: + /** + * Notice the required number of nodes are given as separate arguments from the vector. + */ + AtLeastBinaryNode(std::vector<Tree> children, Tree left, Tree right) + : OpSpecificDynamicArity<Tree, AtLeastBinaryNode, 2>( + std::move(children), std::move(left), std::move(right)) {} +}; + +/** + * A visitor of the tree with methods to visit each kind of node. + * + * This is a very basic visitor to just demonstrate the transport() API - all it does is sum up + * doubles in the leaf nodes of the tree. + * + * Notice that each kind of node did not need to fill out some boilerplate "visit()" method or + * anything like that. The PolyValue templating magic took care of all the boilerplate for us, and + * the operator classes (e.g. OpSpecificArity) exposes the tree structure and children. + */ +class NodeTransporter { +public: + double transport(Leaf& leaf) { + return leaf.x; + } + double transport(BinaryNode& node, double child0, double child1) { + return child0 + child1; + } + double transport(NaryNode& node, std::vector<double> children) { + return std::accumulate(children.begin(), children.end(), 0.0); + } + double transport(AtLeastBinaryNode& node, + std::vector<double> children, + double child0, + double child1) { + return child0 + child1 + std::accumulate(children.begin(), children.end(), 0.0); + } +}; + +/** + * A visitor of the tree with methods to visit each kind of node. This visitor also takes a + * reference to the Tree itself. Unused here, this reference can be used to mutate or replace the + * node itself while the walking takes place. + */ +class TreeTransporter { +public: + double transport(Tree& tree, Leaf& leaf) { + return leaf.x; + } + double transport(Tree& tree, BinaryNode& node, double child0, double child1) { + return child0 + child1; + } + double transport(Tree& tree, NaryNode& node, std::vector<double> children) { + return std::accumulate(children.begin(), children.end(), 0.0); + } + double transport(Tree& tree, + AtLeastBinaryNode& node, + std::vector<double> children, + double child0, + double child1) { + return child0 + child1 + std::accumulate(children.begin(), children.end(), 0.0); + } +}; + +TEST(PolyValueTest, SumTransportFixedArity) { + NodeTransporter nodeTransporter; + TreeTransporter treeTransporter; + { + Tree simple = Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), Tree::make<Leaf>(1.0)); + // Notice the template parameter true or false matches whether the walker expects to have a + // Tree& parameter first in the transport implementations. + double result = transport<false>(simple, nodeTransporter); + ASSERT_EQ(result, 3.0); + // This 'true' template means we expect the 'Tree&' argument to come first in all the + // 'transport()' implementations. + result = transport<true>(simple, treeTransporter); + ASSERT_EQ(result, 3.0); + } + + { + Tree deeper = Tree::make<BinaryNode>( + Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), Tree::make<Leaf>(1.0)), + Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), Tree::make<Leaf>(1.0))); + double result = transport<false>(deeper, nodeTransporter); + ASSERT_EQ(result, 6.0); + result = transport<true>(deeper, treeTransporter); + ASSERT_EQ(result, 6.0); + } +} + +/** + * Prove out that the walking/visiting can hit the variadic NaryNode. + */ +TEST(PolyValueTest, SumTransportVariadic) { + NodeTransporter nodeTransporter; + TreeTransporter treeTransporter; + Tree naryDemoTree = Tree::make<NaryNode>( + std::vector<Tree>{Tree::make<Leaf>(6.0), + Tree::make<Leaf>(5.0), + Tree::make<NaryNode>(std::vector<Tree>{ + Tree::make<Leaf>(4.0), Tree::make<Leaf>(3.0), Tree::make<Leaf>(2.0)}), + Tree::make<Leaf>(1.0)}); + + double result = transport<false>(naryDemoTree, nodeTransporter); + ASSERT_EQ(result, 21.0); + result = transport<true>(naryDemoTree, treeTransporter); + ASSERT_EQ(result, 21.0); +} + +TEST(PolyValueTest, SumTransportAtLeast2Children) { + NodeTransporter nodeTransporter; + TreeTransporter treeTransporter; + Tree demoTree = Tree::make<AtLeastBinaryNode>( + std::vector<Tree>{Tree::make<Leaf>(7.0), Tree::make<Leaf>(6.0)}, + Tree::make<Leaf>(5.0), + Tree::make<AtLeastBinaryNode>( + std::vector<Tree>{Tree::make<Leaf>(4.0), Tree::make<Leaf>(3.0)}, + Tree::make<Leaf>(2.0), + Tree::make<Leaf>(1.0))); + double result = transport<false>(demoTree, nodeTransporter); + ASSERT_EQ(result, 28.0); + result = transport<true>(demoTree, treeTransporter); + ASSERT_EQ(result, 28.0); +} + +/** + * A visitor of the tree like those above but which takes const references so is forbidden from + * modifying the tree or nodes. + * + * This visitor creates a copy of the tree but with the values at the leaves doubled. + */ +class ConstTransporterCopyAndDouble { +public: + Tree transport(const Leaf& leaf) { + return Tree::make<Leaf>(2 * leaf.x); + } + Tree transport(const BinaryNode& node, Tree child0, Tree child1) { + return Tree::make<BinaryNode>(std::move(child0), std::move(child1)); + } + Tree transport(const NaryNode& node, std::vector<Tree> children) { + return Tree::make<NaryNode>(std::move(children)); + } + Tree transport(const AtLeastBinaryNode& node, + std::vector<Tree> children, + Tree child0, + Tree child1) { + return Tree::make<AtLeastBinaryNode>( + std::move(children), std::move(child0), std::move(child1)); + } + + // Add all the same walkers with the optional 'tree' argument. Note this is also const. + Tree transport(const Tree& tree, const Leaf& leaf) { + return Tree::make<Leaf>(2 * leaf.x); + } + Tree transport(const Tree& tree, const BinaryNode& node, Tree child0, Tree child1) { + return Tree::make<BinaryNode>(std::move(child0), std::move(child1)); + } + Tree transport(const Tree& tree, const NaryNode& node, std::vector<Tree> children) { + return Tree::make<NaryNode>(std::move(children)); + } + Tree transport(const Tree& tree, + const AtLeastBinaryNode& node, + std::vector<Tree> children, + Tree child0, + Tree child1) { + return Tree::make<AtLeastBinaryNode>( + std::move(children), std::move(child0), std::move(child1)); + } +}; + +TEST(PolyValueTest, CopyAndDoubleTreeConst) { + // Test that we can create a copy of a tree and walk with a const transporter to provide extra + // proof that it's actually a deep copy. + ConstTransporterCopyAndDouble transporter; + { + const Tree simple = Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), Tree::make<Leaf>(1.0)); + // Notice 'simple' is const. + Tree result = transport<false>(simple, transporter); + BinaryNode* newRoot = result.cast<BinaryNode>(); + ASSERT(newRoot); + Leaf* newLeafLeft = newRoot->get<0>().cast<Leaf>(); + ASSERT(newLeafLeft); + ASSERT_EQ(newLeafLeft->x, 4.0); + + Leaf* newLeafRight = newRoot->get<1>().cast<Leaf>(); + ASSERT(newLeafRight); + ASSERT_EQ(newLeafRight->x, 2.0); + } + { + // Do the same test but walk with the tree reference (pass 'true' to transport). + const Tree simple = Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), Tree::make<Leaf>(1.0)); + // Notice 'simple' is const. + Tree result = transport<true>(simple, transporter); + BinaryNode* newRoot = result.cast<BinaryNode>(); + ASSERT(newRoot); + Leaf* newLeafLeft = newRoot->get<0>().cast<Leaf>(); + ASSERT(newLeafLeft); + ASSERT_EQ(newLeafLeft->x, 4.0); + + Leaf* newLeafRight = newRoot->get<1>().cast<Leaf>(); + ASSERT(newLeafRight); + ASSERT_EQ(newLeafRight->x, 2.0); + } +} + +/** + * A walker which accumulates all nodes into a std::set to demonstrate which nodes are visited. + * + * The order of the visitation is not guaranteed, except that we visit "bottom-up" so leaves must + * happen before parents. This much must be true since the API to visit a node depends on the + * results of its children being pre-computed. + */ +class AccumulateToSetTransporter { +public: + std::set<double> transport(Leaf& leaf) { + return {leaf.x}; + } + + std::set<double> transport(BinaryNode& node, + std::set<double> visitedChild0, + std::set<double> visitedChild1) { + // 'visistedChild0' and 'visitedChild1' represent the accumulated results of their visited + // numbers. Here we just merge the two. + std::set<double> merged; + std::merge(visitedChild0.begin(), + visitedChild0.end(), + visitedChild1.begin(), + visitedChild1.end(), + std::inserter(merged, merged.begin())); + return merged; + } + + std::set<double> transport(NaryNode& node, std::vector<std::set<double>> childrenVisitedSets) { + return std::accumulate(childrenVisitedSets.begin(), + childrenVisitedSets.end(), + std::set<double>{}, + [](auto&& visited1, auto&& visited2) { + std::set<double> merged; + std::merge(visited1.begin(), + visited1.end(), + visited2.begin(), + visited2.end(), + std::inserter(merged, merged.begin())); + return merged; + }); + } + + std::set<double> transport(AtLeastBinaryNode& node, + std::vector<std::set<double>> childrenVisitedSets, + std::set<double> visitedChild0, + std::set<double> visitedChild1) { + std::set<double> merged; + std::merge(visitedChild0.begin(), + visitedChild0.end(), + visitedChild1.begin(), + visitedChild1.end(), + std::inserter(merged, merged.begin())); + + return std::accumulate(childrenVisitedSets.begin(), + childrenVisitedSets.end(), + merged, + [](auto&& visited1, auto&& visited2) { + std::set<double> merged; + std::merge(visited1.begin(), + visited1.end(), + visited2.begin(), + visited2.end(), + std::inserter(merged, merged.begin())); + return merged; + }); + } +}; + +/** + * Here we see a test which walks all the various types of nodes at once, and in this case + * accumulates into a std::set any visited leaves. + */ +TEST(PolyValueTest, AccumulateAllDoubles) { + AccumulateToSetTransporter nodeTransporter; + + { + Tree simple = Tree::make<AtLeastBinaryNode>( + std::vector<Tree>{Tree::make<Leaf>(4.0)}, + Tree::make<Leaf>(3.0), + Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), + Tree::make<NaryNode>(std::vector<Tree>{Tree::make<Leaf>(1.0)}))); + std::set<double> result = transport<false>(simple, nodeTransporter); + ASSERT_EQ(result.size(), 4UL); + ASSERT_EQ(result.count(1.0), 1UL); + ASSERT_EQ(result.count(2.0), 1UL); + ASSERT_EQ(result.count(3.0), 1UL); + ASSERT_EQ(result.count(4.0), 1UL); + } + { + Tree complex = Tree::make<AtLeastBinaryNode>( + std::vector<Tree>{Tree::make<Leaf>(1.0), Tree::make<Leaf>(2.0)}, + Tree::make<Leaf>(3.0), + Tree::make<BinaryNode>( + Tree::make<Leaf>(4.0), + Tree::make<NaryNode>(std::vector<Tree>{ + Tree::make<Leaf>(5.0), + Tree::make<BinaryNode>(Tree::make<Leaf>(6.0), Tree::make<Leaf>(7.0))}))); + std::set<double> result = transport<false>(complex, nodeTransporter); + ASSERT_EQ(result.size(), 7UL); + ASSERT_EQ(result.count(1.0), 1UL); + ASSERT_EQ(result.count(2.0), 1UL); + ASSERT_EQ(result.count(3.0), 1UL); + ASSERT_EQ(result.count(4.0), 1UL); + ASSERT_EQ(result.count(5.0), 1UL); + ASSERT_EQ(result.count(6.0), 1UL); + ASSERT_EQ(result.count(7.0), 1UL); + } +} + + +/** + * A walker which accepts an extra 'multiplier' argument to each transport call. + */ +class NodeTransporterWithExtraArg { +public: + double transport(Leaf& leaf, double multiplier) { + return leaf.x * multiplier; + } + double transport(BinaryNode& node, double multiplier, double child0, double child1) { + return child0 + + child1; // No need to apply multiplier here, would be applied in the children already. + } + double transport(NaryNode& node, double multiplier, std::vector<double> children) { + return std::accumulate(children.begin(), children.end(), 0); + } + double transport(AtLeastBinaryNode& node, + double multiplier, + std::vector<double> children, + double child0, + double child1) { + return child0 + child1 + std::accumulate(children.begin(), children.end(), 0); + } +}; + +TEST(PolyValueTest, TransporterWithAnExtrArgument) { + NodeTransporterWithExtraArg nodeTransporter; + + Tree simple = Tree::make<AtLeastBinaryNode>( + std::vector<Tree>{Tree::make<Leaf>(4.0)}, + Tree::make<Leaf>(3.0), + Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), + Tree::make<NaryNode>(std::vector<Tree>{Tree::make<Leaf>(1.0)}))); + double result = transport<false>(simple, nodeTransporter, 2.0); + ASSERT_EQ(result, 20.0); +} + +/** + * A simple walker which trackes whether it has seen a zero. While the task is simple, this walker + * demosntrates: + * - A walker with state attached ('iHaveSeenAZero'). Note it could be done without tracking state + * also. + * - The capability of 'transport' to return void. + * - You can add a templated 'transport()' to avoid needing to fill in each and every instantiation + * for the PolyValue. + */ +class TemplatedNodeTransporterWithContext { +public: + bool iHaveSeenAZero = false; + + void transport(Leaf& leaf) { + if (leaf.x == 0.0) { + iHaveSeenAZero = true; + } + } + + /** + * Template to handle all other cases - we don't care or need to do anything here, so we knock + * out all the other required implementations at once with this template. + */ + template <typename T, typename... Args> + void transport(T&& node, Args&&... args) { + return; + } +}; + +TEST(PolyValueTest, TransporterTrackingState) { + TemplatedNodeTransporterWithContext templatedNodeTransporter; + + Tree noZero = Tree::make<AtLeastBinaryNode>( + std::vector<Tree>{Tree::make<Leaf>(4.0)}, + Tree::make<Leaf>(3.0), + Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), + Tree::make<NaryNode>(std::vector<Tree>{Tree::make<Leaf>(1.0)}))); + transport<false>(noZero, templatedNodeTransporter); + ASSERT_EQ(templatedNodeTransporter.iHaveSeenAZero, false); + + Tree yesZero = Tree::make<AtLeastBinaryNode>( + std::vector<Tree>{Tree::make<Leaf>(3.0)}, + Tree::make<Leaf>(2.0), + Tree::make<BinaryNode>(Tree::make<Leaf>(1.0), + Tree::make<NaryNode>(std::vector<Tree>{Tree::make<Leaf>(0.0)}))); + transport<false>(yesZero, templatedNodeTransporter); + ASSERT_EQ(templatedNodeTransporter.iHaveSeenAZero, true); +} + +/** + * A walker demonstrating the 'prepare()' API which tracks the depth and weights things deeper in + * the tree at factors of 10 higher. So the top level is worth 1x, second level 10x, third level + * 100x, etc. + */ +class NodeTransporterTrackingDepth { + int _depthMultiplier = 1; + +public: + double transport(Leaf& leaf) { + return leaf.x * _depthMultiplier; + } + + void prepare(Leaf&) { + // Noop. Just here to prevent from yet another 10x multiplication if we were to fall into + // the generic 'prepare()'. + } + + /** + * 'prepare()' is called as we descend the tree before we walk/visit the children. + */ + template <typename T, typename... Args> + void prepare(T&& node, Args&&... args) { + _depthMultiplier *= 10; + } + + double transport(BinaryNode& node, double child0, double child1) { + _depthMultiplier /= 10; + return child0 + child1; + } + double transport(NaryNode& node, std::vector<double> children) { + _depthMultiplier /= 10; + return std::accumulate(children.begin(), children.end(), 0); + } + double transport(AtLeastBinaryNode& node, + std::vector<double> children, + double child0, + double child1) { + _depthMultiplier /= 10; + return child0 + child1 + std::accumulate(children.begin(), children.end(), 0); + } +}; + +TEST(PolyValueTest, TransporterUsingPrepare) { + NodeTransporterTrackingDepth nodeTransporter; + + Tree demoTree = Tree::make<AtLeastBinaryNode>( + std::vector<Tree>{Tree::make<Leaf>(4.0)}, + Tree::make<Leaf>(3.0), + Tree::make<BinaryNode>(Tree::make<Leaf>(2.0), + Tree::make<NaryNode>(std::vector<Tree>{Tree::make<Leaf>(1.0)}))); + const double result = transport<false>(demoTree, nodeTransporter); + /* + demoTree + 1x level: root + / | \ + 10x level: 4 3 binary + / \ + 100x level: 2 nary + \ + 1000x level: 1 + */ + ASSERT_EQ(result, 1270.0); +} + +class NodeWalkerIsLeaf { +public: + bool walk(Leaf& leaf) { + return true; + } + + bool walk(BinaryNode& node, Tree& leftChild, Tree& rightChild) { + return false; + } + + bool walk(AtLeastBinaryNode& node, + std::vector<Tree>& extraChildren, + Tree& leftChild, + Tree& rightChild) { + return false; + } + + bool walk(NaryNode& node, std::vector<Tree>& children) { + return false; + } +}; + +TEST(PolyValueTest, WalkerBasic) { + NodeWalkerIsLeaf walker; + auto tree = Tree::make<BinaryNode>(Tree::make<Leaf>(1.0), Tree::make<Leaf>(2.0)); + ASSERT(!walk<false>(tree, walker)); + ASSERT(walk<false>(tree.cast<BinaryNode>()->get<0>(), walker)); + ASSERT(walk<false>(tree.cast<BinaryNode>()->get<1>(), walker)); +} + +} // namespace +} // namespace mongo::optimizer::algebra diff --git a/src/mongo/db/query/optimizer/algebra/operator.h b/src/mongo/db/query/optimizer/algebra/operator.h new file mode 100644 index 00000000000..fb6dbc4d474 --- /dev/null +++ b/src/mongo/db/query/optimizer/algebra/operator.h @@ -0,0 +1,341 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <vector> + +#include "mongo/db/query/optimizer/algebra/polyvalue.h" + +namespace mongo::optimizer { +namespace algebra { + +template <typename T, int S> +struct OpNodeStorage { + T _nodes[S]; + + template <typename... Ts> + OpNodeStorage(Ts&&... vals) : _nodes{std::forward<Ts>(vals)...} {} +}; + +template <typename T> +struct OpNodeStorage<T, 0> {}; + +/*=====----- + * + * Arity of operator can be: + * 1. statically known - A, A, A, ... + * 2. dynamic prefix with optional statically know - vector<A>, A, A, A, ... + * + * Denotations map A to some B. + * So static arity <A,A,A> is mapped to <B,B,B>. + * Similarly, arity <vector<A>,A> is mapped to <vector<B>,B> + * + * There is a wrinkle when B is a reference (if allowed at all) + * Arity <vector<A>, A, A> is mapped to <vector<B>&, B&, B&> - note that the reference is lifted + * outside of the vector. + * + */ +template <typename Slot, typename Derived, int Arity> +class OpSpecificArity : public OpNodeStorage<Slot, Arity> { + using Base = OpNodeStorage<Slot, Arity>; + +public: + template <typename... Ts> + OpSpecificArity(Ts&&... vals) : Base({std::forward<Ts>(vals)...}) { + static_assert(sizeof...(Ts) == Arity, "constructor paramaters do not match"); + } + + template <int I, std::enable_if_t<(I >= 0 && I < Arity), int> = 0> + auto& get() noexcept { + return this->_nodes[I]; + } + + template <int I, std::enable_if_t<(I >= 0 && I < Arity), int> = 0> + const auto& get() const noexcept { + return this->_nodes[I]; + } +}; +/*=====----- + * + * Operator with dynamic arity + * + */ +template <typename Slot, typename Derived, int Arity> +class OpSpecificDynamicArity : public OpSpecificArity<Slot, Derived, Arity> { + using Base = OpSpecificArity<Slot, Derived, Arity>; + + std::vector<Slot> _dyNodes; + +public: + template <typename... Ts> + OpSpecificDynamicArity(std::vector<Slot>&& nodes, Ts&&... vals) + : Base({std::forward<Ts>(vals)...}), _dyNodes(std::move(nodes)) {} + + auto& nodes() { + return _dyNodes; + } + const auto& nodes() const { + return _dyNodes; + } +}; + +/*=====----- + * + * Semantic transport interface + * + */ +namespace detail { +template <typename D, typename T, typename... Args> +using call_prepare_t = + decltype(std::declval<D>().prepare(std::declval<T&>(), std::declval<Args>()...)); + +template <typename N, typename D, typename T, typename... Args> +using call_prepare_slot_t = decltype( + std::declval<D>().prepare(std::declval<N&>(), std::declval<T&>(), std::declval<Args>()...)); + +template <typename Void, template <class...> class Op, class... Args> +struct has_prepare : std::false_type {}; + +template <template <class...> class Op, class... Args> +struct has_prepare<std::void_t<Op<Args...>>, Op, Args...> : std::true_type {}; + +template <bool withSlot, typename N, typename D, typename T, typename... Args> +inline constexpr auto has_prepare_v = + std::conditional_t<withSlot, + has_prepare<void, call_prepare_slot_t, N, D, T, Args...>, + has_prepare<void, call_prepare_t, D, T, Args...>>::value; + +template <typename Slot, typename Derived, int Arity> +inline constexpr int get_arity(const OpSpecificArity<Slot, Derived, Arity>*) { + return Arity; +} + +template <typename Slot, typename Derived, int Arity> +inline constexpr bool is_dynamic(const OpSpecificArity<Slot, Derived, Arity>*) { + return false; +} + +template <typename Slot, typename Derived, int Arity> +inline constexpr bool is_dynamic(const OpSpecificDynamicArity<Slot, Derived, Arity>*) { + return true; +} + +template <typename T> +using OpConcreteType = typename std::remove_reference_t<T>::template get_t<0>; +} // namespace detail + +template <typename D, bool withSlot> +class OpTransporter { + D& _domain; + + template <typename T, bool B, typename... Args> + struct Deducer {}; + template <typename T, typename... Args> + struct Deducer<T, true, Args...> { + using type = + decltype(std::declval<D>().transport(std::declval<T>(), + std::declval<detail::OpConcreteType<T>&>(), + std::declval<Args>()...)); + }; + template <typename T, typename... Args> + struct Deducer<T, false, Args...> { + using type = decltype(std::declval<D>().transport( + std::declval<detail::OpConcreteType<T>&>(), std::declval<Args>()...)); + }; + template <typename T, typename... Args> + using deduced_t = typename Deducer<T, withSlot, Args...>::type; + + template <typename N, typename T, typename... Ts> + auto transformStep(N&& slot, T&& op, Ts&&... args) { + if constexpr (withSlot) { + return _domain.transport( + std::forward<N>(slot), std::forward<T>(op), std::forward<Ts>(args)...); + } else { + return _domain.transport(std::forward<T>(op), std::forward<Ts>(args)...); + } + } + + template <typename N, typename T, typename... Args, size_t... I> + auto transportUnpack(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) { + return transformStep(std::forward<N>(slot), + std::forward<T>(op), + std::forward<Args>(args)..., + op.template get<I>().visit(*this, std::forward<Args>(args)...)...); + } + template <typename N, typename T, typename... Args, size_t... I> + auto transportDynamicUnpack(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) { + std::vector<decltype(slot.visit(*this, std::forward<Args>(args)...))> v; + for (auto& node : op.nodes()) { + v.emplace_back(node.visit(*this, std::forward<Args>(args)...)); + } + return transformStep(std::forward<N>(slot), + std::forward<T>(op), + std::forward<Args>(args)..., + std::move(v), + op.template get<I>().visit(*this, std::forward<Args>(args)...)...); + } + template <typename N, typename T, typename... Args, size_t... I> + void transportUnpackVoid(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) { + (op.template get<I>().visit(*this, std::forward<Args>(args)...), ...); + return transformStep(std::forward<N>(slot), + std::forward<T>(op), + std::forward<Args>(args)..., + op.template get<I>()...); + } + template <typename N, typename T, typename... Args, size_t... I> + void transportDynamicUnpackVoid(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) { + for (auto& node : op.nodes()) { + node.visit(*this, std::forward<Args>(args)...); + } + (op.template get<I>().visit(*this, std::forward<Args>(args)...), ...); + return transformStep(std::forward<N>(slot), + std::forward<T>(op), + std::forward<Args>(args)..., + op.nodes(), + op.template get<I>()...); + } + +public: + OpTransporter(D& domain) : _domain(domain) {} + + template <typename N, typename T, typename... Args, typename R = deduced_t<N, Args...>> + R operator()(N&& slot, T&& op, Args&&... args) { + // N is either `PolyValue<Ts...>&` or `const PolyValue<Ts...>&` i.e. reference + // T is either `A&` or `const A&` where A is one of Ts + using type = std::remove_reference_t<T>; + + constexpr int arity = detail::get_arity(static_cast<type*>(nullptr)); + constexpr bool is_dynamic = detail::is_dynamic(static_cast<type*>(nullptr)); + + if constexpr (detail::has_prepare_v<withSlot, N, D, type, Args...>) { + if constexpr (withSlot) { + _domain.prepare( + std::forward<N>(slot), std::forward<T>(op), std::forward<Args>(args)...); + } else { + _domain.prepare(std::forward<T>(op), std::forward<Args>(args)...); + } + } + + if constexpr (is_dynamic) { + if constexpr (std::is_same_v<R, void>) { + return transportDynamicUnpackVoid(std::forward<N>(slot), + std::forward<T>(op), + std::make_index_sequence<arity>{}, + std::forward<Args>(args)...); + } else { + return transportDynamicUnpack(std::forward<N>(slot), + std::forward<T>(op), + std::make_index_sequence<arity>{}, + std::forward<Args>(args)...); + } + } else { + if constexpr (std::is_same_v<R, void>) { + return transportUnpackVoid(std::forward<N>(slot), + std::forward<T>(op), + std::make_index_sequence<arity>{}, + std::forward<Args>(args)...); + } else { + return transportUnpack(std::forward<N>(slot), + std::forward<T>(op), + std::make_index_sequence<arity>{}, + std::forward<Args>(args)...); + } + } + } +}; + +template <typename D, bool withSlot> +class OpWalker { + D& _domain; + + template <typename N, typename T, typename... Ts> + auto walkStep(N&& slot, T&& op, Ts&&... args) { + if constexpr (withSlot) { + return _domain.walk( + std::forward<N>(slot), std::forward<T>(op), std::forward<Ts>(args)...); + } else { + return _domain.walk(std::forward<T>(op), std::forward<Ts>(args)...); + } + } + + template <typename N, typename T, typename... Args, size_t... I> + auto walkUnpack(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) { + return walkStep(std::forward<N>(slot), + std::forward<T>(op), + std::forward<Args>(args)..., + op.template get<I>()...); + } + template <typename N, typename T, typename... Args, size_t... I> + auto walkDynamicUnpack(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) { + return walkStep(std::forward<N>(slot), + std::forward<T>(op), + std::forward<Args>(args)..., + op.nodes(), + op.template get<I>()...); + } + +public: + OpWalker(D& domain) : _domain(domain) {} + + template <typename N, typename T, typename... Args> + auto operator()(N&& slot, T&& op, Args&&... args) { + // N is either `PolyValue<Ts...>&` or `const PolyValue<Ts...>&` i.e. reference + // T is either `A&` or `const A&` where A is one of Ts + using type = std::remove_reference_t<T>; + + constexpr int arity = detail::get_arity(static_cast<type*>(nullptr)); + constexpr bool is_dynamic = detail::is_dynamic(static_cast<type*>(nullptr)); + + if constexpr (is_dynamic) { + return walkDynamicUnpack(std::forward<N>(slot), + std::forward<T>(op), + std::make_index_sequence<arity>{}, + std::forward<Args>(args)...); + } else { + return walkUnpack(std::forward<N>(slot), + std::forward<T>(op), + std::make_index_sequence<arity>{}, + std::forward<Args>(args)...); + } + } +}; + +template <bool withSlot = false, typename D, typename N, typename... Args> +auto transport(N&& node, D& domain, Args&&... args) { + return node.visit(OpTransporter<D, withSlot>{domain}, std::forward<Args>(args)...); +} + +template <bool withSlot = false, typename D, typename N, typename... Args> +auto walk(N&& node, D& domain, Args&&... args) { + return node.visit(OpWalker<D, withSlot>{domain}, std::forward<Args>(args)...); +} + +} // namespace algebra +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/algebra/polyvalue.h b/src/mongo/db/query/optimizer/algebra/polyvalue.h new file mode 100644 index 00000000000..63f5965c50c --- /dev/null +++ b/src/mongo/db/query/optimizer/algebra/polyvalue.h @@ -0,0 +1,541 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <array> +#include <stdexcept> +#include <type_traits> + +namespace mongo::optimizer { +namespace algebra { +namespace detail { + +template <typename T, typename... Args> +inline constexpr bool is_one_of_v = std::disjunction_v<std::is_same<T, Args>...>; + +template <typename T, typename... Args> +inline constexpr bool is_one_of_f() { + return is_one_of_v<T, Args...>; +} + +template <typename... Args> +struct is_unique_t : std::true_type {}; + +template <typename H, typename... T> +struct is_unique_t<H, T...> + : std::bool_constant<!is_one_of_f<H, T...>() && is_unique_t<T...>::value> {}; + +template <typename... Args> +inline constexpr bool is_unique_v = is_unique_t<Args...>::value; + +// Given the type T find its index in Ts +template <typename T, typename... Ts> +static inline constexpr int find_index() { + static_assert(detail::is_unique_v<Ts...>, "Types must be unique"); + constexpr bool matchVector[] = {std::is_same<T, Ts>::value...}; + + for (int index = 0; index < static_cast<int>(sizeof...(Ts)); ++index) { + if (matchVector[index]) { + return index; + } + } + + return -1; +} + +template <int N, typename T, typename... Ts> +struct get_type_by_index_impl { + using type = typename get_type_by_index_impl<N - 1, Ts...>::type; +}; +template <typename T, typename... Ts> +struct get_type_by_index_impl<0, T, Ts...> { + using type = T; +}; + +// Given the index I return the type from Ts +template <int I, typename... Ts> +using get_type_by_index = typename get_type_by_index_impl<I, Ts...>::type; + +} // namespace detail + +/*=====----- + * + * The overload trick to construct visitors from lambdas. + * + */ +template <class... Ts> +struct overload : Ts... { + using Ts::operator()...; +}; +template <class... Ts> +overload(Ts...)->overload<Ts...>; + +/*=====----- + * + * Forward declarations + * + */ +template <typename... Ts> +class PolyValue; + +template <typename T, typename... Ts> +class ControlBlockVTable; + +/*=====----- + * + * The base control block that PolyValue holds. + * + * It does not contain anything else by the runtime tag. + * + */ +template <typename... Ts> +class ControlBlock { + const int _tag; + +protected: + ControlBlock(int tag) noexcept : _tag(tag) {} + +public: + auto getRuntimeTag() const noexcept { + return _tag; + } +}; + +/*=====----- + * + * The concrete control block VTable generator. + * + * It must be empty ad PolyValue derives from the generators + * and we want EBO to kick in. + * + */ +template <typename T, typename... Ts> +class ControlBlockVTable { +protected: + static constexpr int _staticTag = detail::find_index<T, Ts...>(); + static_assert(_staticTag != -1, "Type must be on the list"); + + using AbstractType = ControlBlock<Ts...>; + + /*=====----- + * + * The concrete control block for every type T of Ts. + * + * It derives from the ControlBlock. All methods are private and only + * the friend class ControlBlockVTable can call them. + * + */ + class ConcreteType : public AbstractType { + T _t; + + public: + template <typename... Args> + ConcreteType(Args&&... args) : AbstractType(_staticTag), _t(std::forward<Args>(args)...) {} + + const T* getPtr() const noexcept { + return &_t; + } + + T* getPtr() noexcept { + return &_t; + } + }; + + static constexpr auto concrete(AbstractType* block) noexcept { + return static_cast<ConcreteType*>(block); + } + + static constexpr auto concrete(const AbstractType* block) noexcept { + return static_cast<const ConcreteType*>(block); + } + +public: + template <typename... Args> + static AbstractType* make(Args&&... args) { + return new ConcreteType(std::forward<Args>(args)...); + } + + static AbstractType* clone(const AbstractType* block) { + return new ConcreteType(*concrete(block)); + } + + static void destroy(AbstractType* block) noexcept { + delete concrete(block); + } + + static bool compareEq(AbstractType* blockLhs, AbstractType* blockRhs) noexcept { + if (blockLhs->getRuntimeTag() == blockRhs->getRuntimeTag()) { + return *castConst<T>(blockLhs) == *castConst<T>(blockRhs); + } + return false; + } + + template <typename U> + static constexpr bool is_v = std::is_base_of_v<U, T>; + + template <typename U> + static U* cast(AbstractType* block) noexcept { + if constexpr (is_v<U>) { + return static_cast<U*>(concrete(block)->getPtr()); + } else { + // gcc bug 81676 + (void)block; + return nullptr; + } + } + + template <typename U> + static const U* castConst(const AbstractType* block) noexcept { + if constexpr (is_v<U>) { + return static_cast<const U*>(concrete(block)->getPtr()); + } else { + // gcc bug 81676 + (void)block; + return nullptr; + } + } + + template <typename V, typename N, typename... Args> + static auto visit(V&& v, N& holder, AbstractType* block, Args&&... args) { + return v(holder, *cast<T>(block), std::forward<Args>(args)...); + } + + template <typename V, typename N, typename... Args> + static auto visitConst(V&& v, const N& holder, const AbstractType* block, Args&&... args) { + return v(holder, *castConst<T>(block), std::forward<Args>(args)...); + } +}; + +/*=====----- + * + * This is a variation on variant and polymorphic value theme. + * + * A tag based dispatch + * + * Supported operations: + * - construction + * - destruction + * - clone a = b; + * - cast a.cast<T>() + * - multi-method cast to common base a.cast<B>() + * - multi-method visit + */ +template <typename... Ts> +class PolyValue : private ControlBlockVTable<Ts, Ts...>... { +public: + using key_type = int; + +private: + static_assert(detail::is_unique_v<Ts...>, "Types must be unique"); + static_assert(std::conjunction_v<std::is_empty<ControlBlockVTable<Ts, Ts...>>...>, + "VTable base classes must be empty"); + + ControlBlock<Ts...>* _object{nullptr}; + + PolyValue(ControlBlock<Ts...>* object) noexcept : _object(object) {} + + auto tag() const noexcept { + return _object->getRuntimeTag(); + } + + static void check(const ControlBlock<Ts...>* object) { + if (!object) { + throw std::logic_error("PolyValue is empty"); + } + } + + static void destroy(ControlBlock<Ts...>* object) noexcept { + static constexpr std::array destroyTbl = {&ControlBlockVTable<Ts, Ts...>::destroy...}; + + destroyTbl[object->getRuntimeTag()](object); + } + + template <typename T> + static T* cast(ControlBlock<Ts...>* object) { + check(object); + static constexpr std::array castTbl = {&ControlBlockVTable<Ts, Ts...>::template cast<T>...}; + return castTbl[object->getRuntimeTag()](object); + } + + template <typename T> + static const T* castConst(ControlBlock<Ts...>* object) { + check(object); + static constexpr std::array castTbl = { + &ControlBlockVTable<Ts, Ts...>::template castConst<T>...}; + return castTbl[object->getRuntimeTag()](object); + } + + template <typename T> + static bool is(ControlBlock<Ts...>* object) { + check(object); + static constexpr std::array isTbl = {ControlBlockVTable<Ts, Ts...>::template is_v<T>...}; + return isTbl[object->getRuntimeTag()]; + } + + class CompareHelper { + ControlBlock<Ts...>* _object{nullptr}; + + auto tag() const noexcept { + return _object->getRuntimeTag(); + } + + public: + CompareHelper() = default; + CompareHelper(ControlBlock<Ts...>* object) : _object(object) {} + + bool operator==(const CompareHelper& rhs) const noexcept { + static constexpr std::array cmp = {ControlBlockVTable<Ts, Ts...>::compareEq...}; + return cmp[tag()](_object, rhs._object); + } + }; + + class Reference { + ControlBlock<Ts...>* _object{nullptr}; + + auto tag() const noexcept { + return _object->getRuntimeTag(); + } + + public: + Reference() = default; + Reference(ControlBlock<Ts...>* object) : _object(object) {} + + template <int I> + using get_t = detail::get_type_by_index<I, Ts...>; + + key_type tagOf() const { + check(_object); + + return tag(); + } + + + template <typename V, typename... Args> + auto visit(V&& v, Args&&... args) { + // unfortunately gcc rejects much nicer code, clang and msvc accept + // static constexpr std::array visitTbl = { &ControlBlockVTable<Ts, Ts...>::template + // visit<V>... }; + + using FunPtrType = decltype( + &ControlBlockVTable<get_t<0>, Ts...>::template visit<V, Reference, Args...>); + static constexpr FunPtrType visitTbl[] = { + &ControlBlockVTable<Ts, Ts...>::template visit<V, Reference, Args...>...}; + + check(_object); + return visitTbl[tag()](std::forward<V>(v), *this, _object, std::forward<Args>(args)...); + } + + template <typename V, typename... Args> + auto visit(V&& v, Args&&... args) const { + // unfortunately gcc rejects much nicer code, clang and msvc accept + // static constexpr std::array visitTbl = { &ControlBlockVTable<Ts, Ts...>::template + // visitConst<V>... }; + + using FunPtrType = decltype( + &ControlBlockVTable<get_t<0>, Ts...>::template visitConst<V, Reference, Args...>); + static constexpr FunPtrType visitTbl[] = { + &ControlBlockVTable<Ts, Ts...>::template visitConst<V, Reference, Args...>...}; + + check(_object); + return visitTbl[tag()](std::forward<V>(v), *this, _object, std::forward<Args>(args)...); + } + + template <typename T> + T* cast() { + return PolyValue<Ts...>::template cast<T>(_object); + } + + template <typename T> + const T* cast() const { + return PolyValue<Ts...>::template castConst<T>(_object); + } + + template <typename T> + bool is() const { + return PolyValue<Ts...>::template is<T>(_object); + } + + bool empty() const noexcept { + return !_object; + } + + void swap(Reference& other) noexcept { + std::swap(other._object, _object); + } + + // Compare references, not the objects themselves. + bool operator==(const Reference& rhs) const noexcept { + return _object == rhs._object; + } + + bool operator==(const PolyValue& rhs) const noexcept { + return rhs == (*this); + } + + auto hash() const noexcept { + return std::hash<const void*>{}(_object); + } + + auto follow() const { + return CompareHelper(_object); + } + + friend class PolyValue; + }; + +public: + using reference_type = Reference; + + template <typename T> + static constexpr key_type tagOf() { + return ControlBlockVTable<T, Ts...>::_staticTag; + } + + key_type tagOf() const { + check(_object); + + return tag(); + } + + PolyValue() = delete; + + PolyValue(const PolyValue& other) { + static constexpr std::array cloneTbl = {&ControlBlockVTable<Ts, Ts...>::clone...}; + if (other._object) { + _object = cloneTbl[other.tag()](other._object); + } + } + + PolyValue(const Reference& other) { + static constexpr std::array cloneTbl = {&ControlBlockVTable<Ts, Ts...>::clone...}; + if (other._object) { + _object = cloneTbl[other.tag()](other._object); + } + } + + PolyValue(PolyValue&& other) noexcept { + swap(other); + } + + ~PolyValue() noexcept { + if (_object) { + destroy(_object); + } + } + + PolyValue& operator=(PolyValue other) noexcept { + swap(other); + return *this; + } + + template <typename T, typename... Args> + static PolyValue make(Args&&... args) { + return PolyValue{ControlBlockVTable<T, Ts...>::make(std::forward<Args>(args)...)}; + } + + template <int I> + using get_t = detail::get_type_by_index<I, Ts...>; + + template <typename V, typename... Args> + auto visit(V&& v, Args&&... args) { + // unfortunately gcc rejects much nicer code, clang and msvc accept + // static constexpr std::array visitTbl = { &ControlBlockVTable<Ts, Ts...>::template + // visit<V>... }; + + using FunPtrType = + decltype(&ControlBlockVTable<get_t<0>, Ts...>::template visit<V, PolyValue, Args...>); + static constexpr FunPtrType visitTbl[] = { + &ControlBlockVTable<Ts, Ts...>::template visit<V, PolyValue, Args...>...}; + + check(_object); + return visitTbl[tag()](std::forward<V>(v), *this, _object, std::forward<Args>(args)...); + } + + template <typename V, typename... Args> + auto visit(V&& v, Args&&... args) const { + // unfortunately gcc rejects much nicer code, clang and msvc accept + // static constexpr std::array visitTbl = { &ControlBlockVTable<Ts, Ts...>::template + // visitConst<V>... }; + + using FunPtrType = decltype( + &ControlBlockVTable<get_t<0>, Ts...>::template visitConst<V, PolyValue, Args...>); + static constexpr FunPtrType visitTbl[] = { + &ControlBlockVTable<Ts, Ts...>::template visitConst<V, PolyValue, Args...>...}; + + check(_object); + return visitTbl[tag()](std::forward<V>(v), *this, _object, std::forward<Args>(args)...); + } + + template <typename T> + T* cast() { + return cast<T>(_object); + } + + template <typename T> + const T* cast() const { + return castConst<T>(_object); + } + + template <typename T> + bool is() const { + return is<T>(_object); + } + + bool empty() const noexcept { + return !_object; + } + + void swap(PolyValue& other) noexcept { + std::swap(other._object, _object); + } + + bool operator==(const PolyValue& rhs) const noexcept { + static constexpr std::array cmp = {ControlBlockVTable<Ts, Ts...>::compareEq...}; + return cmp[tag()](_object, rhs._object); + } + + bool operator==(const Reference& rhs) const noexcept { + static constexpr std::array cmp = {ControlBlockVTable<Ts, Ts...>::compareEq...}; + return cmp[tag()](_object, rhs._object); + } + + auto ref() { + check(_object); + return Reference(_object); + } + + auto ref() const { + check(_object); + return Reference(_object); + } +}; + +} // namespace algebra +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/bool_expression.h b/src/mongo/db/query/optimizer/bool_expression.h new file mode 100644 index 00000000000..bf00f907504 --- /dev/null +++ b/src/mongo/db/query/optimizer/bool_expression.h @@ -0,0 +1,140 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + + +#pragma once + +#include "boost/optional.hpp" +#include <vector> + +#include "mongo/db/query/optimizer/algebra/operator.h" +#include "mongo/db/query/optimizer/algebra/polyvalue.h" +#include "mongo/util/assert_util.h" + +namespace mongo::optimizer { + +/** + * Represents a generic boolean expression with arbitrarily nested conjunctions and disjunction + * elements. + */ +template <class T> +struct BoolExpr { + class Atom; + class Conjunction; + class Disjunction; + + using Node = algebra::PolyValue<Atom, Conjunction, Disjunction>; + using NodeVector = std::vector<Node>; + + + class Atom final : public algebra::OpSpecificArity<Node, Atom, 0> { + using Base = algebra::OpSpecificArity<Node, Atom, 0>; + + public: + Atom(T expr) : Base(), _expr(std::move(expr)) {} + + bool operator==(const Atom& other) const { + return _expr == other._expr; + } + + const T& getExpr() const { + return _expr; + } + T& getExpr() { + return _expr; + } + + private: + T _expr; + }; + + class Conjunction final : public algebra::OpSpecificDynamicArity<Node, Conjunction, 0> { + using Base = algebra::OpSpecificDynamicArity<Node, Conjunction, 0>; + + public: + Conjunction(NodeVector children) : Base(std::move(children)) { + uassert(6624351, "Must have at least one child", !Base::nodes().empty()); + } + + bool operator==(const Conjunction& other) const { + return Base::nodes() == other.nodes(); + } + }; + + class Disjunction final : public algebra::OpSpecificDynamicArity<Node, Disjunction, 0> { + using Base = algebra::OpSpecificDynamicArity<Node, Disjunction, 0>; + + public: + Disjunction(NodeVector children) : Base(std::move(children)) { + uassert(6624301, "Must have at least one child", !Base::nodes().empty()); + } + + bool operator==(const Disjunction& other) const { + return Base::nodes() == other.nodes(); + } + }; + + + /** + * Utility functions. + */ + template <typename T1, typename... Args> + static auto make(Args&&... args) { + return Node::template make<T1>(std::forward<Args>(args)...); + } + + template <typename... Args> + static auto makeSeq(Args&&... args) { + NodeVector seq; + (seq.emplace_back(std::forward<Args>(args)), ...); + return seq; + } + + template <typename... Args> + static Node makeSingularDNF(Args&&... args) { + return make<Disjunction>( + makeSeq(make<Conjunction>(makeSeq(make<Atom>(T{std::forward<Args>(args)...}))))); + } + + static boost::optional<const T&> getSingularDNF(const Node& n) { + if (auto disjunction = n.template cast<Disjunction>(); + disjunction != nullptr && disjunction->nodes().size() == 1) { + if (auto conjunction = disjunction->nodes().front().template cast<Conjunction>(); + conjunction != nullptr && conjunction->nodes().size() == 1) { + if (auto atom = conjunction->nodes().front().template cast<Atom>(); + atom != nullptr) { + return {atom->getExpr()}; + } + } + } + return {}; + } +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/cascades/ce_heuristic.cpp b/src/mongo/db/query/optimizer/cascades/ce_heuristic.cpp new file mode 100644 index 00000000000..263fab109a5 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/ce_heuristic.cpp @@ -0,0 +1,192 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/cascades/ce_heuristic.h" +#include "mongo/db/query/optimizer/utils/memo_utils.h" + +namespace mongo::optimizer::cascades { + +using namespace properties; + +class CEHeuristicTransport { +public: + CEType transport(const ScanNode& node, CEType /*bindResult*/) { + // Default cardinality estimate. + const CEType metadataCE = _memo.getMetadata()._scanDefs.at(node.getScanDefName()).getCE(); + return (metadataCE < 0.0) ? 1000.00 : metadataCE; + } + + CEType transport(const ValueScanNode& node, CEType /*bindResult*/) { + return node.getArraySize(); + } + + CEType transport(const MemoLogicalDelegatorNode& node) { + return getPropertyConst<CardinalityEstimate>( + _memo.getGroup(node.getGroupId())._logicalProperties) + .getEstimate(); + } + + CEType transport(const FilterNode& node, CEType childResult, CEType /*exprResult*/) { + if (node.getFilter() == Constant::boolean(true)) { + // Trivially true filter. + return childResult; + } else if (node.getFilter() == Constant::boolean(false)) { + // Trivially false filter. + return 0.0; + } else { + // Estimate filter selectivity at 0.1. + return 0.1 * childResult; + } + } + + CEType transport(const EvaluationNode& node, CEType childResult, CEType /*exprResult*/) { + // Evaluations do not change cardinality. + return childResult; + } + + CEType transport(const SargableNode& node, + CEType /*childResult*/, + CEType /*bindsResult*/, + CEType /*refsResult*/) { + ABT lowered = node.getChild(); + for (const auto& [key, req] : node.getReqMap()) { + // TODO: consider issuing one filter node per interval. + lowerPartialSchemaRequirement(key, req, lowered); + } + return algebra::transport<false>(lowered, *this); + } + + CEType transport(const RIDIntersectNode& node, + CEType /*leftChildResult*/, + CEType /*rightChildResult*/) { + // CE for the group should already be derived via the underlying Filter or Evaluation + // logical nodes. + uasserted(6624038, "Should not be necessary to derive CE for RIDIntersectNode"); + } + + CEType transport(const BinaryJoinNode& node, + CEType /*leftChildResult*/, + CEType /*rightChildResult*/, + CEType /*exprResult*/) { + uasserted(6624039, "CE derivation not implemented."); + } + + CEType transport(const UnionNode& node, + std::vector<CEType> childResults, + CEType /*bindResult*/, + CEType /*refsResult*/) { + // Combine the CE of each child. + CEType result = 0; + for (auto&& child : childResults) { + result += child; + } + return result; + } + + CEType transport(const GroupByNode& node, + CEType childResult, + CEType /*bindAggResult*/, + CEType /*refsAggResult*/, + CEType /*bindGbResult*/, + CEType /*refsGbResult*/) { + // TODO: estimate number of groups. + switch (node.getType()) { + case GroupNodeType::Complete: + return 0.01 * childResult; + + // Global and Local selectivity should multiply to Complete selectivity. + case GroupNodeType::Global: + return 0.5 * childResult; + case GroupNodeType::Local: + return 0.02 * childResult; + + default: + MONGO_UNREACHABLE; + } + } + + CEType transport(const UnwindNode& node, + CEType childResult, + CEType /*bindResult*/, + CEType /*refsResult*/) { + // Estimate unwind selectivity at 10.0 + return 10.0 * childResult; + } + + CEType transport(const CollationNode& node, CEType childResult, CEType /*refsResult*/) { + // Collations do not change cardinality. + return childResult; + } + + CEType transport(const LimitSkipNode& node, CEType childResult) { + const auto limit = node.getProperty().getLimit(); + if (limit < childResult) { + return limit; + } + return childResult; + } + + CEType transport(const ExchangeNode& node, CEType childResult, CEType /*refsResult*/) { + // Exchanges do not change cardinality. + return childResult; + } + + CEType transport(const RootNode& node, CEType childResult, CEType /*refsResult*/) { + // Root node does not change cardinality. + return childResult; + } + + /** + * Other ABT types. + */ + template <typename T, typename... Ts> + CEType transport(const T& /*node*/, Ts&&...) { + static_assert(!canBeLogicalNode<T>(), "Logical node must implement its CE derivation."); + return 0.0; + } + + static CEType derive(const Memo& memo, const ABT::reference_type logicalNodeRef) { + CEHeuristicTransport instance(memo); + return algebra::transport<false>(logicalNodeRef, instance); + } + +private: + CEHeuristicTransport(const Memo& memo) : _memo(memo) {} + + // We don't own this. + const Memo& _memo; +}; + +CEType HeuristicCE::deriveCE(const Memo& memo, + const LogicalProps& /*logicalProps*/, + const ABT::reference_type logicalNodeRef) const { + return CEHeuristicTransport::derive(memo, logicalNodeRef); +} + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/ce_heuristic.h b/src/mongo/db/query/optimizer/cascades/ce_heuristic.h new file mode 100644 index 00000000000..40547b90867 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/ce_heuristic.h @@ -0,0 +1,48 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/cascades/interfaces.h" + +namespace mongo::optimizer::cascades { + +/** + * Default cardinality estimation in the absence of statistics. + * Relies purely on heuristics. + * We currently do not use logical properties for heuristic ce. + */ +class HeuristicCE : public CEInterface { +public: + CEType deriveCE(const Memo& memo, + const properties::LogicalProps& /*logicalProps*/, + ABT::reference_type logicalNodeRef) const override final; +}; + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/ce_hinted.cpp b/src/mongo/db/query/optimizer/cascades/ce_hinted.cpp new file mode 100644 index 00000000000..ed3f4ce9335 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/ce_hinted.cpp @@ -0,0 +1,96 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/cascades/ce_hinted.h" +#include "mongo/db/query/optimizer/cascades/ce_heuristic.h" + +namespace mongo::optimizer::cascades { + +using namespace properties; + +class CEHintedTransport { +public: + CEType transport(const ABT& n, + const SargableNode& node, + const Memo& memo, + const LogicalProps& logicalProps, + CEType childResult, + CEType /*bindsResult*/, + CEType /*refsResult*/) { + CEType result = childResult; + for (const auto& [key, req] : node.getReqMap()) { + if (!isIntervalReqFullyOpenDNF(req.getIntervals())) { + auto it = _hints.find(key); + if (it != _hints.cend()) { + // Assume independence. + result *= it->second; + } + } + } + + return result; + } + + template <typename T, typename... Ts> + CEType transport(const ABT& n, + const T& /*node*/, + const Memo& memo, + const LogicalProps& logicalProps, + Ts&&...) { + if (canBeLogicalNode<T>()) { + return _heuristicCE.deriveCE(memo, logicalProps, n.ref()); + } + return 0.0; + } + + static CEType derive(const Memo& memo, + const PartialSchemaSelHints& hints, + const LogicalProps& logicalProps, + const ABT::reference_type logicalNodeRef) { + CEHintedTransport instance(memo, hints); + return algebra::transport<true>(logicalNodeRef, instance, memo, logicalProps); + } + +private: + CEHintedTransport(const Memo& memo, const PartialSchemaKeyCE& hints) + : _heuristicCE(), _hints(hints) {} + + HeuristicCE _heuristicCE; + + // We don't own this. + const PartialSchemaSelHints& _hints; +}; + +CEType HintedCE::deriveCE(const Memo& memo, + const LogicalProps& logicalProps, + const ABT::reference_type logicalNodeRef) const { + return CEHintedTransport::derive(memo, _hints, logicalProps, logicalNodeRef); +} + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/ce_hinted.h b/src/mongo/db/query/optimizer/cascades/ce_hinted.h new file mode 100644 index 00000000000..422bfcdfa73 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/ce_hinted.h @@ -0,0 +1,56 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/cascades/interfaces.h" + +namespace mongo::optimizer::cascades { + +using PartialSchemaSelHints = + std::map<PartialSchemaKey, SelectivityType, PartialSchemaKeyLessComparator>; + +/** + * Estimation based on hints. The hints are organized in a PartialSchemaKeyCE structure. + * SargableNodes are estimated based on the matching PartialSchemaKeys. + */ +class HintedCE : public CEInterface { +public: + HintedCE(PartialSchemaSelHints hints) : _hints(std::move(hints)) {} + + CEType deriveCE(const Memo& memo, + const properties::LogicalProps& logicalProps, + ABT::reference_type logicalNodeRef) const override final; + +private: + // Selectivity hints per PartialSchemaKey. + PartialSchemaSelHints _hints; +}; + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/cost_derivation.cpp b/src/mongo/db/query/optimizer/cascades/cost_derivation.cpp new file mode 100644 index 00000000000..20da31e0d80 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/cost_derivation.cpp @@ -0,0 +1,429 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/cascades/cost_derivation.h" +#include "mongo/db/query/optimizer/defs.h" + +namespace mongo::optimizer::cascades { + +using namespace properties; + +struct CostAndCEInternal { + CostAndCEInternal(double cost, CEType ce) : _cost(cost), _ce(ce) { + uassert(8423334, "Invalid cost.", !std::isnan(cost) && cost >= 0.0); + uassert(8423332, "Invalid cardinality", std::isfinite(ce) && ce >= 0.0); + } + double _cost; + CEType _ce; +}; + +class CostDerivation { + // These cost should reflect estimated aggregated execution time in milliseconds. + static constexpr double ms = 1.0e-3; + + // Startup cost of an operator. This is the minimal cost of an operator since it is + // present even if it doesn't process any input. + // TODO: calibrate the cost individually for each operator + static constexpr double kStartupCost = 0.000001; + + // TODO: collection scan should depend on the width of the doc. + // TODO: the actual measured cost is (0.4 * ms), however we increase it here because currently + // it is not possible to estimate the cost of a collection scan vs a full index scan. + static constexpr double kScanIncrementalCost = 0.6 * ms; + + // TODO: cost(N fields) ~ (0.55 + 0.025 * N) + static constexpr double kIndexScanIncrementalCost = 0.5 * ms; + + // TODO: cost(N fields) ~ 0.7 + 0.19 * N + static constexpr double kSeekCost = 2.0 * ms; + + // TODO: take the expression into account. + // cost(N conditions) = 0.2 + N * ??? + static constexpr double kFilterIncrementalCost = 0.2 * ms; + // TODO: the cost of projection depends on number of fields: cost(N fields) ~ 0.1 + 0.2 * N + static constexpr double kEvalIncrementalCost = 2.0 * ms; + + // TODO: cost(N fields) ~ 0.04 + 0.03*(N^2) + static constexpr double kGroupByIncrementalCost = 0.07 * ms; + static constexpr double kUnwindIncrementalCost = 0.03 * ms; // TODO: not yet calibrated + // TODO: not yet calibrated, should be at least as expensive as a filter + static constexpr double kBinaryJoinIncrementalCost = 0.2 * ms; + static constexpr double kHashJoinIncrementalCost = 0.05 * ms; // TODO: not yet calibrated + static constexpr double kMergeJoinIncrementalCost = 0.02 * ms; // TODO: not yet calibrated + + static constexpr double kUniqueIncrementalCost = 0.7 * ms; + + // TODO: implement collation cost that depends on number and size of sorted fields + // Based on a mix of int and str(64) fields: + // 1 sort field: sort_cost(N) = 1.0/10 * N * log(N) + // 5 sort fields: sort_cost(N) = 2.5/10 * N * log(N) + // 10 sort fields: sort_cost(N) = 3.0/10 * N * log(N) + // field_cost_coeff(F) ~ 0.75 + 0.2 * F + static constexpr double kCollationIncrementalCost = 2.5 * ms; // 5 fields avg + static constexpr double kCollationWithLimitIncrementalCost = + 1.0 * ms; // TODO: not yet calibrated + + static constexpr double kUnionIncrementalCost = 0.02 * ms; + + static constexpr double kExchangeIncrementalCost = 0.1 * ms; // TODO: not yet calibrated + +public: + CostAndCEInternal operator()(const ABT& /*n*/, const PhysicalScanNode& /*node*/) { + // Default estimate for scan. + const double collectionScanCost = + kStartupCost + kScanIncrementalCost * _cardinalityEstimate; + return {collectionScanCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const CoScanNode& /*node*/) { + // Assumed to be free. + return {kStartupCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const IndexScanNode& node) { + const double indexScanCost = + kStartupCost + kIndexScanIncrementalCost * _cardinalityEstimate; + return {indexScanCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const SeekNode& /*node*/) { + // SeekNode should deliver one result via cardinality estimate override. + // TODO: consider using node.getProjectionMap()._fieldProjections.size() to make the cost + // dependent on the size of the projection + const double seekCost = kStartupCost + kSeekCost * _cardinalityEstimate; + return {seekCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const MemoLogicalDelegatorNode& node) { + const LogicalProps& childLogicalProps = + _memo.getGroup(node.getGroupId())._logicalProperties; + // Notice that unlike all physical nodes, this logical node takes it cardinality directly + // from the memo group logical property, igrnoring _cardinalityEstimate. + CEType baseCE = getPropertyConst<CardinalityEstimate>(childLogicalProps).getEstimate(); + + if (hasProperty<IndexingRequirement>(_physProps)) { + const auto& indexingReq = getPropertyConst<IndexingRequirement>(_physProps); + if (indexingReq.getIndexReqTarget() == IndexReqTarget::Seek) { + // If we are performing a seek, normalize against the scan group cardinality. + const GroupIdType scanGroupId = + getPropertyConst<IndexingAvailability>(childLogicalProps).getScanGroupId(); + if (scanGroupId == node.getGroupId()) { + baseCE = 1.0; + } else { + const CEType scanGroupCE = getPropertyConst<CardinalityEstimate>( + _memo.getGroup(scanGroupId)._logicalProperties) + .getEstimate(); + if (scanGroupCE > 0.0) { + baseCE /= scanGroupCE; + } + } + } + } + + return {0.0, getAdjustedCE(baseCE, _physProps)}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const MemoPhysicalDelegatorNode& /*node*/) { + uasserted(6624040, "Should not be costing physical delegator nodes."); + } + + CostAndCEInternal operator()(const ABT& /*n*/, const FilterNode& node) { + CostAndCEInternal childResult = deriveChild(node.getChild(), 0); + double filterCost = childResult._cost; + if (!node.getFilter().is<Constant>() && !node.getFilter().is<Variable>()) { + // Non-trivial filter. + filterCost += kStartupCost + kFilterIncrementalCost * childResult._ce; + } + return {filterCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const EvaluationNode& node) { + CostAndCEInternal childResult = deriveChild(node.getChild(), 0); + double evalCost = childResult._cost; + if (!node.getProjection().is<Constant>() && !node.getProjection().is<Variable>()) { + // Non-trivial projection. + evalCost += kStartupCost + kEvalIncrementalCost * childResult._ce; + } + return {evalCost, childResult._ce}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const BinaryJoinNode& node) { + CostAndCEInternal leftChildResult = deriveChild(node.getLeftChild(), 0); + CostAndCEInternal rightChildResult = deriveChild(node.getRightChild(), 1); + const double joinCost = kStartupCost + + kBinaryJoinIncrementalCost * (leftChildResult._ce + rightChildResult._ce) + + leftChildResult._cost + rightChildResult._cost; + return {joinCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const HashJoinNode& node) { + CostAndCEInternal leftChildResult = deriveChild(node.getLeftChild(), 0); + CostAndCEInternal rightChildResult = deriveChild(node.getRightChild(), 1); + + // TODO: distinguish build side and probe side. + const double hashJoinCost = kStartupCost + + kHashJoinIncrementalCost * (leftChildResult._ce + rightChildResult._ce) + + leftChildResult._cost + rightChildResult._cost; + return {hashJoinCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const MergeJoinNode& node) { + CostAndCEInternal leftChildResult = deriveChild(node.getLeftChild(), 0); + CostAndCEInternal rightChildResult = deriveChild(node.getRightChild(), 1); + + const double mergeJoinCost = kStartupCost + + kMergeJoinIncrementalCost * (leftChildResult._ce + rightChildResult._ce) + + leftChildResult._cost + rightChildResult._cost; + + return {mergeJoinCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const UnionNode& node) { + const ABTVector& children = node.nodes(); + // UnionNode with one child is optimized away before lowering, therefore + // its cost is the cost of its child. + if (children.size() == 1) { + CostAndCEInternal childResult = deriveChild(children[0], 0); + return {childResult._cost, _cardinalityEstimate}; + } + + double totalCost = kStartupCost; + // The cost is the sum of the costs of its children and the cost to union each child. + for (size_t childIdx = 0; childIdx < children.size(); childIdx++) { + CostAndCEInternal childResult = deriveChild(children[childIdx], childIdx); + const double childCost = + childResult._cost + (childIdx > 0 ? kUnionIncrementalCost * childResult._ce : 0); + totalCost += childCost; + } + return {totalCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const GroupByNode& node) { + CostAndCEInternal childResult = deriveChild(node.getChild(), 0); + double groupByCost = kStartupCost; + + // TODO: for now pretend global group by is free. + if (node.getType() == GroupNodeType::Global) { + groupByCost += childResult._cost; + } else { + // TODO: consider RepetitionEstimate since this is a stateful operation. + groupByCost += kGroupByIncrementalCost * childResult._ce + childResult._cost; + } + return {groupByCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const UnwindNode& node) { + CostAndCEInternal childResult = deriveChild(node.getChild(), 0); + // Unwind probably depends mostly on its output size. + const double unwindCost = kUnwindIncrementalCost * _cardinalityEstimate + childResult._cost; + return {unwindCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const UniqueNode& node) { + CostAndCEInternal childResult = deriveChild(node.getChild(), 0); + const double uniqueCost = + kStartupCost + kUniqueIncrementalCost * childResult._ce + childResult._cost; + return {uniqueCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const CollationNode& node) { + CostAndCEInternal childResult = deriveChild(node.getChild(), 0); + // TODO: consider RepetitionEstimate since this is a stateful operation. + + double logFactor = childResult._ce; + double incrConst = kCollationIncrementalCost; + if (hasProperty<LimitSkipRequirement>(_physProps)) { + if (auto limit = getPropertyConst<LimitSkipRequirement>(_physProps).getAbsoluteLimit(); + limit < logFactor) { + logFactor = limit; + incrConst = kCollationWithLimitIncrementalCost; + } + } + + // Notice that log2(x) < 0 for any x < 1, and log2(1) = 0. Generally it makes sense that + // there is no cost to sort 1 document, so the only cost left is the startup cost. + const double sortCost = kStartupCost + childResult._cost + + ((logFactor <= 1.0) + ? 0.0 + // TODO: The cost formula below is based on 1 field, mix of int and str. Instead we + // have to take into account the number and size of sorted fields. + : incrConst * childResult._ce * std::log2(logFactor)); + return {sortCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const LimitSkipNode& node) { + // Assumed to be free. + CostAndCEInternal childResult = deriveChild(node.getChild(), 0); + const double limitCost = kStartupCost + childResult._cost; + return {limitCost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const ExchangeNode& node) { + CostAndCEInternal childResult = deriveChild(node.getChild(), 0); + double localCost = kStartupCost + kExchangeIncrementalCost * _cardinalityEstimate; + + switch (node.getProperty().getDistributionAndProjections()._type) { + case DistributionType::Replicated: + localCost *= 2.0; + break; + + case DistributionType::HashPartitioning: + case DistributionType::RangePartitioning: + localCost *= 1.1; + break; + + default: + break; + } + + return {localCost + childResult._cost, _cardinalityEstimate}; + } + + CostAndCEInternal operator()(const ABT& /*n*/, const RootNode& node) { + return deriveChild(node.getChild(), 0); + } + + /** + * Other ABT types. + */ + template <typename T, typename... Ts> + CostAndCEInternal operator()(const ABT& /*n*/, const T& /*node*/, Ts&&...) { + static_assert(!canBePhysicalNode<T>(), "Physical node must implement its cost derivation."); + return {0.0, 0.0}; + } + + static CostAndCEInternal derive(const Memo& memo, + const PhysProps& physProps, + const ABT::reference_type physNodeRef, + const ChildPropsType& childProps, + const NodeCEMap& nodeCEMap) { + CostAndCEInternal result = + deriveInternal(memo, physProps, physNodeRef, childProps, nodeCEMap); + + switch (getPropertyConst<DistributionRequirement>(physProps) + .getDistributionAndProjections() + ._type) { + case DistributionType::Centralized: + case DistributionType::Replicated: + break; + + case DistributionType::RoundRobin: + case DistributionType::HashPartitioning: + case DistributionType::RangePartitioning: + case DistributionType::UnknownPartitioning: + result._cost /= memo.getMetadata()._numberOfPartitions; + break; + + default: + MONGO_UNREACHABLE; + } + + return result; + } + +private: + CostDerivation(const Memo& memo, + const CEType ce, + const PhysProps& physProps, + const ChildPropsType& childProps, + const NodeCEMap& nodeCEMap) + : _memo(memo), + _physProps(physProps), + _cardinalityEstimate(getAdjustedCE(ce, _physProps)), + _childProps(childProps), + _nodeCEMap(nodeCEMap) {} + + static CostAndCEInternal deriveInternal(const Memo& memo, + const PhysProps& physProps, + const ABT::reference_type physNodeRef, + const ChildPropsType& childProps, + const NodeCEMap& nodeCEMap) { + auto it = nodeCEMap.find(physNodeRef.cast<Node>()); + bool found = (it != nodeCEMap.cend()); + uassert(8423330, + "Only MemoLogicalDelegatorNode can be missing from nodeCEMap.", + found || physNodeRef.is<MemoLogicalDelegatorNode>()); + const CEType ce = (found ? it->second : 0.0); + + CostDerivation instance(memo, ce, physProps, childProps, nodeCEMap); + CostAndCEInternal costCEestimates = physNodeRef.visit(instance); + return costCEestimates; + } + + CostAndCEInternal deriveChild(const ABT& child, const size_t childIndex) { + PhysProps physProps = _childProps.empty() ? _physProps : _childProps.at(childIndex).second; + return deriveInternal(_memo, physProps, child.ref(), {}, _nodeCEMap); + } + + static CEType getAdjustedCE(CEType baseCE, const PhysProps& physProps) { + CEType result = baseCE; + + // First: correct for un-enforced limit. + if (hasProperty<LimitSkipRequirement>(physProps)) { + const auto limit = getPropertyConst<LimitSkipRequirement>(physProps).getAbsoluteLimit(); + if (result > limit) { + result = limit; + } + } + + // Second: correct for enforced limit. + if (hasProperty<LimitEstimate>(physProps)) { + const auto limit = getPropertyConst<LimitEstimate>(physProps).getEstimate(); + if (result > limit) { + result = limit; + } + } + + // Third: correct for repetition. + if (hasProperty<RepetitionEstimate>(physProps)) { + result *= getPropertyConst<RepetitionEstimate>(physProps).getEstimate(); + } + + return result; + } + + // We don't own this. + const Memo& _memo; + const PhysProps& _physProps; + const CEType _cardinalityEstimate; + const ChildPropsType& _childProps; + const NodeCEMap& _nodeCEMap; +}; + +CostAndCE DefaultCosting::deriveCost(const Memo& memo, + const PhysProps& physProps, + const ABT::reference_type physNodeRef, + const ChildPropsType& childProps, + const NodeCEMap& nodeCEMap) const { + const CostAndCEInternal result = + CostDerivation::derive(memo, physProps, physNodeRef, childProps, nodeCEMap); + return {CostType::fromDouble(result._cost), result._ce}; +} + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/cost_derivation.h b/src/mongo/db/query/optimizer/cascades/cost_derivation.h new file mode 100644 index 00000000000..2d85db5f63f --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/cost_derivation.h @@ -0,0 +1,49 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/cascades/interfaces.h" +#include "mongo/db/query/optimizer/cascades/memo.h" + +namespace mongo::optimizer::cascades { + +/** + * Default costing for physical nodes with logical delegator (not-yet-optimized) inputs. + */ +class DefaultCosting : public CostingInterface { +public: + CostAndCE deriveCost(const Memo& memo, + const properties::PhysProps& physProps, + ABT::reference_type physNodeRef, + const ChildPropsType& childProps, + const NodeCEMap& nodeCEMap) const override final; +}; + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/enforcers.cpp b/src/mongo/db/query/optimizer/cascades/enforcers.cpp new file mode 100644 index 00000000000..ba5413126bf --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/enforcers.cpp @@ -0,0 +1,269 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/cascades/enforcers.h" + +#include "mongo/db/query/optimizer/utils/memo_utils.h" + +namespace mongo::optimizer::cascades { + +using namespace properties; + +// Maximum Limit to consider to implement in the sort stage (via a min-heap internally). +static constexpr int64_t kMaxLimitForSort = 100; + +static bool isDistributionCentralizedOrReplicated(const PhysProps& physProps) { + switch (getPropertyConst<DistributionRequirement>(physProps) + .getDistributionAndProjections() + ._type) { + case DistributionType::Centralized: + case DistributionType::Replicated: + return true; + + default: + return false; + } +} + +/** + * Checks if we are not trying to satisfy using the entire collection. We are either aiming for a + * covered index, or for a seek. + */ +static bool hasIncompleteScanIndexingRequirement(const PhysProps& physProps) { + return hasProperty<IndexingRequirement>(physProps) && + getPropertyConst<IndexingRequirement>(physProps).getIndexReqTarget() != + IndexReqTarget::Complete; +} + +class PropEnforcerVisitor { +public: + PropEnforcerVisitor(const GroupIdType groupId, + const Metadata& metadata, + PrefixId& prefixId, + PhysRewriteQueue& queue, + const PhysProps& physProps, + const LogicalProps& logicalProps) + : _groupId(groupId), + _metadata(metadata), + _prefixId(prefixId), + _queue(queue), + _physProps(physProps), + _logicalProps(logicalProps) {} + + void operator()(const PhysProperty&, const CollationRequirement& prop) { + if (hasIncompleteScanIndexingRequirement(_physProps)) { + // If we have indexing requirements, we do not enforce collation separately. + // It will be satisfied as part of the index collation. + return; + } + + PhysProps childProps = _physProps; + removeProperty<CollationRequirement>(childProps); + addProjectionsToProperties(childProps, prop.getAffectedProjectionNames()); + + // TODO: also remove RepetitionEstimate if the subtree does not use bound variables. + removeProperty<LimitEstimate>(childProps); + + if (hasProperty<LimitSkipRequirement>(_physProps)) { + const auto& limitSkipReq = getPropertyConst<LimitSkipRequirement>(_physProps); + if (prop.hasClusteredOp() || limitSkipReq.getSkip() != 0 || + limitSkipReq.getLimit() > kMaxLimitForSort) { + // We cannot enforce collation+skip or collation+large limit. + return; + } + + // We can satisfy both collation and limit-skip requirement. During lowering, physical + // properties will indicate presence of limit skip, and thus we set the limit on the sbe + // stage. + removeProperty<LimitSkipRequirement>(childProps); + } + + ABT enforcer = make<CollationNode>(prop, make<MemoLogicalDelegatorNode>(_groupId)); + optimizeChild<CollationNode>( + _queue, kDefaultPriority, std::move(enforcer), std::move(childProps)); + } + + void operator()(const PhysProperty&, const LimitSkipRequirement& prop) { + if (hasIncompleteScanIndexingRequirement(_physProps)) { + // If we have indexing requirements, we do not enforce limit skip. + return; + } + if (!isDistributionCentralizedOrReplicated(_physProps)) { + // Can only enforce limit-skip under centralized or replicated distribution. + return; + } + + PhysProps childProps = _physProps; + removeProperty<LimitSkipRequirement>(childProps); + setPropertyOverwrite<LimitEstimate>( + childProps, LimitEstimate{static_cast<CEType>(prop.getAbsoluteLimit())}); + + ABT enforcer = make<LimitSkipNode>(prop, make<MemoLogicalDelegatorNode>(_groupId)); + optimizeChild<LimitSkipNode>( + _queue, kDefaultPriority, std::move(enforcer), std::move(childProps)); + } + + void operator()(const PhysProperty&, const DistributionRequirement& prop) { + if (!_metadata.isParallelExecution()) { + // We're running in serial mode. + return; + } + if (prop.getDisableExchanges()) { + // We cannot change distributions. + return; + } + if (hasProperty<IndexingRequirement>(_physProps) && + getPropertyConst<IndexingRequirement>(_physProps).getIndexReqTarget() == + IndexReqTarget::Seek) { + // Cannot change distributions while under Seek requirement. + return; + } + + if (prop.getDistributionAndProjections()._type == DistributionType::UnknownPartitioning) { + // Cannot exchange into unknown partitioning. + return; + } + + // TODO: consider hash partition on RID if under IndexingAvailability. + + const bool hasCollation = hasProperty<CollationRequirement>(_physProps); + if (hasCollation) { + // For now we cannot enforce if we have collation requirement. + // TODO: try enforcing into partitioning distributions which form prefixes over the + // collation, with ordered exchange. + return; + } + + const auto& distributions = + getPropertyConst<DistributionAvailability>(_logicalProps).getDistributionSet(); + for (const auto& distribution : distributions) { + if (distribution == prop.getDistributionAndProjections()) { + // Same distribution. + continue; + } + if (distribution._type == DistributionType::Replicated) { + // Cannot switch "away" from replicated distribution. + continue; + } + + PhysProps childProps = _physProps; + setPropertyOverwrite<DistributionRequirement>(childProps, distribution); + + addProjectionsToProperties(childProps, distribution._projectionNames); + getProperty<DistributionRequirement>(childProps).setDisableExchanges(true); + + ABT enforcer = make<ExchangeNode>(prop, make<MemoLogicalDelegatorNode>(_groupId)); + optimizeChild<ExchangeNode>( + _queue, kDefaultPriority, std::move(enforcer), std::move(childProps)); + } + } + + void operator()(const PhysProperty&, const ProjectionRequirement& prop) { + const ProjectionNameSet& availableProjections = + getPropertyConst<ProjectionAvailability>(_logicalProps).getProjections(); + + // Verify we can satisfy the required projections using the logical projections. + for (const ProjectionName& projectionName : prop.getProjections().getVector()) { + if (availableProjections.find(projectionName) == availableProjections.cend()) { + uasserted(6624100, "Cannot satisfy all projections"); + } + } + } + + void operator()(const PhysProperty&, const IndexingRequirement& prop) { + if (prop.getIndexReqTarget() != IndexReqTarget::Complete) { + return; + } + + uassert(6624101, + "IndexingRequirement without indexing availability", + hasProperty<IndexingAvailability>(_logicalProps)); + const IndexingAvailability& indexingAvailability = + getPropertyConst<IndexingAvailability>(_logicalProps); + + // TODO: consider left outer joins. We can propagate rid from the outer side. + if (_metadata._scanDefs.at(indexingAvailability.getScanDefName()).getIndexDefs().empty()) { + // No indexes on the collection. + return; + } + + const ProjectionNameOrderPreservingSet& requiredProjections = + getPropertyConst<ProjectionRequirement>(_physProps).getProjections(); + const ProjectionName& scanProjection = indexingAvailability.getScanProjection(); + const bool requiresScanProjection = requiredProjections.find(scanProjection).second; + + if (!requiresScanProjection) { + // Try indexScanOnly (covered index) if we do not require scan projection. + PhysProps newProps = _physProps; + setPropertyOverwrite<IndexingRequirement>(newProps, + {IndexReqTarget::Index, + prop.getNeedsRID(), + prop.getDedupRID(), + prop.getSatisfiedPartialIndexesGroupId()}); + + optimizeUnderNewProperties(_queue, + kDefaultPriority, + make<MemoLogicalDelegatorNode>(_groupId), + std::move(newProps)); + } + } + + void operator()(const PhysProperty&, const RepetitionEstimate& prop) { + // Noop. We do not currently enforce this property. It only affects costing. + // TODO: consider materializing the subtree if we estimate a lot of repetitions. + } + + void operator()(const PhysProperty&, const LimitEstimate& prop) { + // Noop. We do not currently enforce this property. It only affects costing. + } + +private: + const GroupIdType _groupId; + + // We don't own any of those. + const Metadata& _metadata; + PrefixId& _prefixId; + PhysRewriteQueue& _queue; + const PhysProps& _physProps; + const LogicalProps& _logicalProps; +}; + +void addEnforcers(const GroupIdType groupId, + const Metadata& metadata, + PrefixId& prefixId, + PhysRewriteQueue& queue, + const PhysProps& physProps, + const LogicalProps& logicalProps) { + PropEnforcerVisitor visitor(groupId, metadata, prefixId, queue, physProps, logicalProps); + for (const auto& entry : physProps) { + entry.second.visit(visitor); + } +} + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/enforcers.h b/src/mongo/db/query/optimizer/cascades/enforcers.h new file mode 100644 index 00000000000..6ad1ace8bce --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/enforcers.h @@ -0,0 +1,47 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/cascades/rewrite_queues.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer::cascades { + +/** + * Adds property enforcement rules for particular group and physical properties. + */ +void addEnforcers(GroupIdType groupId, + const Metadata& metadata, + PrefixId& prefixId, + PhysRewriteQueue& queue, + const properties::PhysProps& physProps, + const properties::LogicalProps& logicalProps); + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/implementers.cpp b/src/mongo/db/query/optimizer/cascades/implementers.cpp new file mode 100644 index 00000000000..54dc07eca53 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/implementers.cpp @@ -0,0 +1,1441 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/cascades/implementers.h" + +#include "mongo/db/query/optimizer/utils/memo_utils.h" + +namespace mongo::optimizer::cascades { +using namespace properties; + +template <class P, class T> +static bool propertyAffectsProjections(const PhysProps& props, const T& projections) { + if (!hasProperty<P>(props)) { + return false; + } + + const ProjectionNameSet& propProjections = + getPropertyConst<P>(props).getAffectedProjectionNames(); + for (const ProjectionName& projectionName : projections) { + if (propProjections.find(projectionName) != propProjections.cend()) { + return true; + } + } + + return false; +} + +template <class P> +static bool propertyAffectsProjection(const PhysProps& props, + const ProjectionName& projectionName) { + return propertyAffectsProjections<P>(props, ProjectionNameVector{projectionName}); +} + +/** + * Implement physical nodes based on existing logical nodes. + */ +class ImplementationVisitor { +public: + void operator()(const ABT& /*n*/, const ScanNode& node) { + if (hasProperty<LimitSkipRequirement>(_physProps)) { + // Cannot satisfy limit-skip. + return; + } + if (hasProperty<CollationRequirement>(_physProps)) { + // Regular scan cannot satisfy any collation requirement. + // TODO: consider rid? + return; + } + + const auto& indexReq = getPropertyConst<IndexingRequirement>(_physProps); + const IndexReqTarget indexReqTarget = indexReq.getIndexReqTarget(); + switch (indexReqTarget) { + case IndexReqTarget::Index: + // At this point cannot only satisfy index-only. + return; + + case IndexReqTarget::Seek: + if (_hints._disableIndexes == DisableIndexOptions::DisableAll) { + return; + } + uassert(6624102, "RID projection is required for seek", indexReq.getNeedsRID()); + // Fall through to code below. + break; + + case IndexReqTarget::Complete: + if (_hints._disableScan) { + return; + } + // Fall through to code below. + break; + + default: + MONGO_UNREACHABLE; + } + + // Handle complete indexing requirement. + bool canUseParallelScan = false; + if (!distributionsCompatible( + indexReqTarget, + _memo.getMetadata()._scanDefs.at(node.getScanDefName()).getDistributionAndPaths(), + node.getProjectionName(), + _logicalProps, + {}, + canUseParallelScan)) { + return; + } + + FieldProjectionMap fieldProjectionMap; + for (const ProjectionName& required : + getPropertyConst<ProjectionRequirement>(_physProps).getAffectedProjectionNames()) { + if (required == node.getProjectionName()) { + fieldProjectionMap._rootProjection = node.getProjectionName(); + } else { + // Regular scan node can satisfy only using its root projection (not fields). + return; + } + } + + const ProjectionName& ridProjName = _ridProjections.at( + getPropertyConst<IndexingAvailability>(_logicalProps).getScanDefName()); + if (indexReqTarget == IndexReqTarget::Seek) { + NodeCEMap nodeCEMap; + + ABT physicalSeek = + make<SeekNode>(ridProjName, std::move(fieldProjectionMap), node.getScanDefName()); + // If optimizing a Seek, override CE to 1.0. + nodeCEMap.emplace(physicalSeek.cast<Node>(), 1.0); + + ABT limitSkip = + make<LimitSkipNode>(LimitSkipRequirement{1, 0}, std::move(physicalSeek)); + nodeCEMap.emplace(limitSkip.cast<Node>(), 1.0); + + optimizeChildrenNoAssert( + _queue, kDefaultPriority, std::move(limitSkip), {}, std::move(nodeCEMap)); + } else { + if (indexReq.getNeedsRID()) { + fieldProjectionMap._ridProjection = ridProjName; + } + ABT physicalScan = make<PhysicalScanNode>( + std::move(fieldProjectionMap), node.getScanDefName(), canUseParallelScan); + optimizeChild<PhysicalScanNode>(_queue, kDefaultPriority, std::move(physicalScan)); + } + } + + void operator()(const ABT& n, const ValueScanNode& node) { + if (hasProperty<LimitSkipRequirement>(_physProps)) { + // Cannot satisfy limit-skip. + return; + } + if (hasProperty<CollationRequirement>(_physProps)) { + // Cannot satisfy any collation requirement. + return; + } + + NodeCEMap nodeCEMap; + ABT physNode = make<CoScanNode>(); + const auto& requiredProjections = + getPropertyConst<ProjectionRequirement>(_physProps).getProjections(); + + if (node.getArraySize() == 0) { + nodeCEMap.emplace(physNode.cast<Node>(), 0.0); + + physNode = + make<LimitSkipNode>(properties::LimitSkipRequirement{0, 0}, std::move(physNode)); + nodeCEMap.emplace(physNode.cast<Node>(), 0.0); + + for (const ProjectionName& projectionName : requiredProjections.getVector()) { + physNode = + make<EvaluationNode>(projectionName, Constant::nothing(), std::move(physNode)); + nodeCEMap.emplace(physNode.cast<Node>(), 0.0); + } + } else { + nodeCEMap.emplace(physNode.cast<Node>(), 1.0); + + physNode = + make<LimitSkipNode>(properties::LimitSkipRequirement{1, 0}, std::move(physNode)); + nodeCEMap.emplace(physNode.cast<Node>(), 1.0); + + const ProjectionName valueScanProj = _prefixId.getNextId("valueScan"); + physNode = + make<EvaluationNode>(valueScanProj, node.getValueArray(), std::move(physNode)); + nodeCEMap.emplace(physNode.cast<Node>(), 1.0); + + // Unwind the combined array constant and pick an element for each required projection + // in sequence. + physNode = make<UnwindNode>(valueScanProj, + _prefixId.getNextId("valueScanPid"), + false /*retainNonArrays*/, + std::move(physNode)); + nodeCEMap.emplace(physNode.cast<Node>(), node.getArraySize()); + + /** + * Iterate over the bound projections here as opposed to the required projections, since + * the array elements are ordered accordingly. + */ + const ProjectionNameVector& boundProjNames = node.binder().names(); + for (size_t i = 0; i < boundProjNames.size(); i++) { + const ProjectionName& boundProjName = boundProjNames.at(i); + if (requiredProjections.find(boundProjName).second) { + physNode = make<EvaluationNode>( + boundProjName, + make<FunctionCall>( + "getElement", + makeSeq(make<Variable>(valueScanProj), Constant::int32(i))), + std::move(physNode)); + nodeCEMap.emplace(physNode.cast<Node>(), node.getArraySize()); + } + } + } + + optimizeChildrenNoAssert( + _queue, kDefaultPriority, std::move(physNode), {}, std::move(nodeCEMap)); + } + + void operator()(const ABT& /*n*/, const MemoLogicalDelegatorNode& /*node*/) { + uasserted(6624041, + "Must not have logical delegator nodes in the list of the logical nodes"); + } + + void operator()(const ABT& n, const FilterNode& node) { + if (hasProperty<LimitSkipRequirement>(_physProps)) { + // We cannot satisfy here. + return; + } + + VariableNameSetType references = collectVariableReferences(n); + if (checkIntroducesScanProjectionUnderIndexOnly(references)) { + // Reject if under indexing requirements and now we introduce dependence on scan + // projection. + return; + } + + PhysProps newProps = _physProps; + // Add projections we depend on to the requirement. + addProjectionsToProperties(newProps, std::move(references)); + getProperty<DistributionRequirement>(newProps).setDisableExchanges(true); + + ABT physicalFilter = n; + optimizeChild<FilterNode>( + _queue, kDefaultPriority, std::move(physicalFilter), std::move(newProps)); + } + + void operator()(const ABT& n, const EvaluationNode& node) { + const ProjectionName& projectionName = node.getProjectionName(); + + if (const auto* varPtr = node.getProjection().cast<Variable>(); varPtr != nullptr) { + // Special case of evaluation node: rebinds to a different variable. + const ProjectionName& newProjName = varPtr->name(); + PhysProps newProps = _physProps; + + { + // Update required projections. + auto& reqProjections = + getProperty<ProjectionRequirement>(newProps).getProjections(); + reqProjections.erase(projectionName); + reqProjections.emplace_back(newProjName); + } + + if (hasProperty<CollationRequirement>(newProps)) { + // Update the collation specification to use the input variable. + auto& collationReq = getProperty<CollationRequirement>(newProps); + for (auto& [projName, op] : collationReq.getCollationSpec()) { + if (projName == projectionName) { + projName = newProjName; + } + } + } + + { + // Update the distribution specification to use the input variable; + auto& distribReq = getProperty<DistributionRequirement>(newProps); + for (auto& projName : distribReq.getDistributionAndProjections()._projectionNames) { + if (projName == projectionName) { + projName = newProjName; + } + } + } + + ABT physicalEval = n; + optimizeChild<EvaluationNode>( + _queue, kDefaultPriority, std::move(physicalEval), std::move(newProps)); + return; + } + + if (propertyAffectsProjection<DistributionRequirement>(_physProps, projectionName)) { + // We cannot satisfy distribution on the projection we output. + return; + } + if (propertyAffectsProjection<CollationRequirement>(_physProps, projectionName)) { + // In general, we cannot satisfy collation on the projection we output. + // TODO consider x = y+1, we can propagate the collation requirement from x to y. + return; + } + if (!propertyAffectsProjection<ProjectionRequirement>(_physProps, projectionName)) { + // We do not require the projection. Do not place a physical evaluation node and + // continue optimizing the child. + optimizeUnderNewProperties(_queue, kDefaultPriority, node.getChild(), _physProps); + return; + } + + // Remove our projection from requirement, and add projections we depend on to the + // requirement. + PhysProps newProps = _physProps; + + VariableNameSetType references = collectVariableReferences(n); + if (checkIntroducesScanProjectionUnderIndexOnly(references)) { + // Reject if under indexing requirements and now we introduce dependence on scan + // projection. + return; + } + + addRemoveProjectionsToProperties( + newProps, std::move(references), ProjectionNameVector{projectionName}); + getProperty<DistributionRequirement>(newProps).setDisableExchanges(true); + + ABT physicalEval = n; + optimizeChild<EvaluationNode>( + _queue, kDefaultPriority, std::move(physicalEval), std::move(newProps)); + } + + void operator()(const ABT& n, const SargableNode& node) { + if (hasProperty<LimitSkipRequirement>(_physProps)) { + // Cannot satisfy limit-skip. + return; + } + + const IndexingAvailability& indexingAvailability = + getPropertyConst<IndexingAvailability>(_logicalProps); + if (node.getChild().cast<MemoLogicalDelegatorNode>()->getGroupId() != + indexingAvailability.getScanGroupId()) { + // To optimize a sargable predicate, we must have the scan group as a child. + return; + } + + const std::string& scanDefName = indexingAvailability.getScanDefName(); + const auto& scanDef = _memo.getMetadata()._scanDefs.at(scanDefName); + + + // We do not check indexDefs to be empty here. We want to allow evaluations to be covered + // via a physical scan even in the absence of indexes. + + const IndexingRequirement& requirements = getPropertyConst<IndexingRequirement>(_physProps); + const IndexReqTarget indexReqTarget = requirements.getIndexReqTarget(); + switch (indexReqTarget) { + case IndexReqTarget::Complete: + if (_hints._disableScan) { + return; + } + break; + + case IndexReqTarget::Index: + case IndexReqTarget::Seek: + if (_hints._disableIndexes == DisableIndexOptions::DisableAll) { + return; + } + break; + + default: + MONGO_UNREACHABLE; + } + const auto& satisfiedPartialIndexes = + getPropertyConst<IndexingAvailability>( + _memo.getGroup(requirements.getSatisfiedPartialIndexesGroupId())._logicalProperties) + .getSatisfiedPartialIndexes(); + + const bool needsRID = requirements.getNeedsRID(); + const ProjectionName& ridProjName = _ridProjections.at(scanDefName); + + const ProjectionName& scanProjectionName = indexingAvailability.getScanProjection(); + const GroupIdType scanGroupId = indexingAvailability.getScanGroupId(); + const LogicalProps& scanLogicalProps = _memo.getGroup(scanGroupId)._logicalProperties; + const CEType scanGroupCE = + getPropertyConst<CardinalityEstimate>(scanLogicalProps).getEstimate(); + + const PartialSchemaRequirements& reqMap = node.getReqMap(); + + if (indexReqTarget != IndexReqTarget::Index && + hasProperty<CollationRequirement>(_physProps)) { + // PhysicalScan or Seek cannot satisfy any collation requirement. + // TODO: consider rid? + return; + } + + for (const auto& [key, req] : reqMap) { + if (key.emptyPath()) { + // We cannot satisfy without a field. + return; + } + if (key._projectionName != scanProjectionName) { + // We can only satisfy partial schema requirements using our root projection. + return; + } + } + + const auto& requiredProjections = + getPropertyConst<ProjectionRequirement>(_physProps).getProjections(); + bool requiresRootProjection = false; + { + auto projectionsLeftToSatisfy = requiredProjections; + if (indexReqTarget != IndexReqTarget::Index) { + // Deliver root projection if required. + requiresRootProjection = projectionsLeftToSatisfy.erase(scanProjectionName); + } + + for (const auto& entry : reqMap) { + if (entry.second.hasBoundProjectionName()) { + // Project field only if it required. + const ProjectionName& projectionName = entry.second.getBoundProjectionName(); + projectionsLeftToSatisfy.erase(projectionName); + } + } + if (!projectionsLeftToSatisfy.getVector().empty()) { + // Unknown projections remain. Reject. + return; + } + } + + const auto& ceProperty = getPropertyConst<CardinalityEstimate>(_logicalProps); + const CEType currentCE = ceProperty.getEstimate(); + const PartialSchemaKeyCE& partialSchemaKeyCEMap = ceProperty.getPartialSchemaKeyCEMap(); + + ProjectionRenames projectionRenames; + if (indexReqTarget == IndexReqTarget::Index) { + ProjectionCollationSpec requiredCollation; + if (hasProperty<CollationRequirement>(_physProps)) { + requiredCollation = + getPropertyConst<CollationRequirement>(_physProps).getCollationSpec(); + } + + // Consider all candidate indexes, and check if they satisfy the collation and + // distribution requirements. + for (const auto& [indexDefName, candidateIndexEntry] : node.getCandidateIndexMap()) { + const auto& indexDef = scanDef.getIndexDefs().at(indexDefName); + if (!indexDef.getPartialReqMap().empty() && + (_hints._disableIndexes == DisableIndexOptions::DisablePartialOnly || + satisfiedPartialIndexes.count(indexDefName) == 0)) { + // Consider only indexes for which we satisfy partial requirements. + continue; + } + + { + bool canUseParallelScanUnused = false; + if (!distributionsCompatible(IndexReqTarget::Index, + indexDef.getDistributionAndPaths(), + scanProjectionName, + scanLogicalProps, + reqMap, + canUseParallelScanUnused)) { + return; + } + } + + const auto availableDirections = indexSatisfiesCollation( + indexDef.getCollationSpec(), candidateIndexEntry, requiredCollation); + if (!availableDirections._forward && !availableDirections._backward) { + // Failed to satisfy collation. + continue; + } + + uassert(6624103, + "Either forward or backward direction must be available.", + availableDirections._forward || availableDirections._backward); + + auto indexProjectionMap = candidateIndexEntry._fieldProjectionMap; + indexProjectionMap._ridProjection = needsRID ? ridProjName : ""; + + { + // Remove unused projections from the field projection map. + auto& fieldProjMap = indexProjectionMap._fieldProjections; + for (auto it = fieldProjMap.begin(); it != fieldProjMap.end();) { + const ProjectionName& projName = it->second; + if (!requiredProjections.find(projName).second && + candidateIndexEntry._residualRequirementsTempProjections.count( + projName) == 0) { + fieldProjMap.erase(it++); + } else { + it++; + } + } + } + + CEType indexCE = currentCE; + ResidualRequirements residualRequirements; + if (!candidateIndexEntry._residualRequirements.empty()) { + SelectivityType residualSelectivity = 1.0; + SelectivityType currentSelectivity = currentCE / scanGroupCE; + + for (const auto& [residualKey, residualReq] : + candidateIndexEntry._residualRequirements) { + const auto& queryKey = candidateIndexEntry._residualKeyMap.at(residualKey); + const CEType ce = partialSchemaKeyCEMap.at(queryKey); + if (scanGroupCE > 0.0) { + residualSelectivity *= ce / scanGroupCE; + } + residualRequirements.emplace_back(residualKey, residualReq, ce); + } + if (residualSelectivity > 0.0) { + indexCE = scanGroupCE * (currentSelectivity / residualSelectivity); + } + } + + const auto& intervals = candidateIndexEntry._intervals; + ABT physNode = make<Blackhole>(); + NodeCEMap nodeCEMap; + + // TODO: consider pre-computing as part of the candidateIndexes structure. + const auto singularInterval = MultiKeyIntervalReqExpr::getSingularDNF(intervals); + const bool needsUniqueStage = singularInterval && + !areMultiKeyIntervalsEqualities(*singularInterval) && indexDef.isMultiKey() && + requirements.getDedupRID(); + + if (singularInterval) { + physNode = + make<IndexScanNode>(std::move(indexProjectionMap), + IndexSpecification{scanDefName, + indexDefName, + *singularInterval, + !availableDirections._forward}); + nodeCEMap.emplace(physNode.cast<Node>(), indexCE); + } else { + physNode = lowerIntervals(_prefixId, + ridProjName, + std::move(indexProjectionMap), + scanDefName, + indexDefName, + intervals, + !availableDirections._forward, + indexCE, + scanGroupCE, + nodeCEMap); + } + + applyProjectionRenames(projectionRenames, physNode, [&](const ABT& node) { + nodeCEMap.emplace(node.cast<Node>(), indexCE); + }); + + lowerPartialSchemaRequirements( + indexCE, scanGroupCE, residualRequirements, physNode, nodeCEMap); + + if (needsUniqueStage) { + // Insert unique stage if we need to, after the residual requirements. + physNode = + make<UniqueNode>(ProjectionNameVector{ridProjName}, std::move(physNode)); + nodeCEMap.emplace(physNode.cast<Node>(), currentCE); + } + + optimizeChildrenNoAssert( + _queue, kDefaultPriority, std::move(physNode), {}, std::move(nodeCEMap)); + } + } else { + bool canUseParallelScan = false; + if (!distributionsCompatible(indexReqTarget, + scanDef.getDistributionAndPaths(), + scanProjectionName, + scanLogicalProps, + reqMap, + canUseParallelScan)) { + return; + } + + FieldProjectionMap fieldProjectionMap; + if (indexReqTarget == IndexReqTarget::Complete && needsRID) { + fieldProjectionMap._ridProjection = ridProjName; + } + + ResidualRequirements residualRequirements; + computePhysicalScanParams(_prefixId, + reqMap, + partialSchemaKeyCEMap, + requiredProjections, + residualRequirements, + projectionRenames, + fieldProjectionMap, + requiresRootProjection); + if (requiresRootProjection) { + fieldProjectionMap._rootProjection = scanProjectionName; + } + + NodeCEMap nodeCEMap; + ABT physNode = make<Blackhole>(); + CEType baseCE = 0.0; + + if (indexReqTarget == IndexReqTarget::Complete) { + baseCE = scanGroupCE; + + // Return a physical scan with field map. + physNode = make<PhysicalScanNode>( + std::move(fieldProjectionMap), scanDefName, canUseParallelScan); + nodeCEMap.emplace(physNode.cast<Node>(), baseCE); + } else { + baseCE = 1.0; + + // Try Seek with Limit 1. + physNode = make<SeekNode>(ridProjName, std::move(fieldProjectionMap), scanDefName); + nodeCEMap.emplace(physNode.cast<Node>(), baseCE); + + physNode = make<LimitSkipNode>(LimitSkipRequirement{1, 0}, std::move(physNode)); + nodeCEMap.emplace(physNode.cast<Node>(), baseCE); + } + + applyProjectionRenames(std::move(projectionRenames), physNode, [&](const ABT& node) { + nodeCEMap.emplace(node.cast<Node>(), baseCE); + }); + + lowerPartialSchemaRequirements( + baseCE, scanGroupCE, residualRequirements, physNode, nodeCEMap); + optimizeChildrenNoAssert( + _queue, kDefaultPriority, std::move(physNode), {}, std::move(nodeCEMap)); + } + } + + void operator()(const ABT& /*n*/, const RIDIntersectNode& node) { + const auto& indexingAvailability = getPropertyConst<IndexingAvailability>(_logicalProps); + const std::string& scanDefName = indexingAvailability.getScanDefName(); + { + const auto& scanDef = _memo.getMetadata()._scanDefs.at(scanDefName); + if (scanDef.getIndexDefs().empty()) { + // Reject if we do not have any indexes. + return; + } + } + + if (hasProperty<LimitSkipRequirement>(_physProps)) { + // Cannot satisfy limit-skip. + return; + } + + const IndexingRequirement& requirements = getPropertyConst<IndexingRequirement>(_physProps); + const bool dedupRID = requirements.getDedupRID(); + const IndexReqTarget indexReqTarget = requirements.getIndexReqTarget(); + if (indexReqTarget == IndexReqTarget::Seek) { + return; + } + const bool isIndex = indexReqTarget == IndexReqTarget::Index; + if (isIndex && (!node.hasLeftIntervals() || !node.hasRightIntervals())) { + // We need to have proper intervals on both sides. + return; + } + + const auto& distribRequirement = getPropertyConst<DistributionRequirement>(_physProps); + const auto& distrAndProjections = distribRequirement.getDistributionAndProjections(); + if (isIndex) { + switch (distrAndProjections._type) { + case DistributionType::UnknownPartitioning: + case DistributionType::RoundRobin: + // Cannot satisfy unknown or round-robin distributions. + return; + + default: + break; + } + } + + const GroupIdType leftGroupId = + node.getLeftChild().cast<MemoLogicalDelegatorNode>()->getGroupId(); + const GroupIdType rightGroupId = + node.getRightChild().cast<MemoLogicalDelegatorNode>()->getGroupId(); + + const LogicalProps& leftLogicalProps = _memo.getGroup(leftGroupId)._logicalProperties; + const LogicalProps& rightLogicalProps = _memo.getGroup(rightGroupId)._logicalProperties; + + const CEType intersectedCE = + getPropertyConst<CardinalityEstimate>(_logicalProps).getEstimate(); + const CEType leftCE = getPropertyConst<CardinalityEstimate>(leftLogicalProps).getEstimate(); + const CEType rightCE = + getPropertyConst<CardinalityEstimate>(rightLogicalProps).getEstimate(); + const ProjectionNameSet& leftProjections = + getPropertyConst<ProjectionAvailability>(leftLogicalProps).getProjections(); + const ProjectionNameSet& rightProjections = + getPropertyConst<ProjectionAvailability>(rightLogicalProps).getProjections(); + + // Split required projections between inner and outer side. + ProjectionNameOrderPreservingSet leftChildProjections; + ProjectionNameOrderPreservingSet rightChildProjections; + for (const ProjectionName& projectionName : + getPropertyConst<ProjectionRequirement>(_physProps).getProjections().getVector()) { + if (projectionName != node.getScanProjectionName() && + leftProjections.count(projectionName) > 0) { + leftChildProjections.emplace_back(projectionName); + } else if (rightProjections.count(projectionName) > 0) { + if (isIndex && projectionName == node.getScanProjectionName()) { + return; + } + rightChildProjections.emplace_back(projectionName); + } else { + uasserted(6624104, + "Required projection must appear in either the left or the right child " + "projections"); + return; + } + } + + ProjectionCollationSpec collationSpec; + if (hasProperty<CollationRequirement>(_physProps)) { + collationSpec = getPropertyConst<CollationRequirement>(_physProps).getCollationSpec(); + } + + // Split collation between inner and outer side. + const CollationSplitResult& collationLeftRightSplit = + splitCollationSpec(collationSpec, leftProjections, rightProjections); + const CollationSplitResult& collationRightLeftSplit = + splitCollationSpec(collationSpec, rightProjections, leftProjections); + + // We are propagating the distribution requirements to both sides. + PhysProps leftPhysProps = _physProps; + PhysProps rightPhysProps = _physProps; + + // Specifically do not propagate limit-skip. + // TODO: handle similarly to physical join. + removeProperty<LimitSkipRequirement>(leftPhysProps); + removeProperty<LimitSkipRequirement>(rightPhysProps); + + getProperty<DistributionRequirement>(leftPhysProps).setDisableExchanges(false); + getProperty<DistributionRequirement>(rightPhysProps).setDisableExchanges(false); + + const ProjectionName& ridProjName = _ridProjections.at( + getPropertyConst<IndexingAvailability>(_logicalProps).getScanDefName()); + setPropertyOverwrite<IndexingRequirement>( + leftPhysProps, + {IndexReqTarget::Index, + true /*needRID*/, + !isIndex && dedupRID, + requirements.getSatisfiedPartialIndexesGroupId()}); + setPropertyOverwrite<IndexingRequirement>( + rightPhysProps, + {isIndex ? IndexReqTarget::Index : IndexReqTarget::Seek, + true /*needRID*/, + !isIndex && dedupRID, + requirements.getSatisfiedPartialIndexesGroupId()}); + + setPropertyOverwrite<ProjectionRequirement>(leftPhysProps, std::move(leftChildProjections)); + setPropertyOverwrite<ProjectionRequirement>(rightPhysProps, + std::move(rightChildProjections)); + + if (!isIndex) { + // Add repeated execution property to inner side. + CEType estimatedRepetitions = hasProperty<RepetitionEstimate>(_physProps) + ? getPropertyConst<RepetitionEstimate>(_physProps).getEstimate() + : 1.0; + estimatedRepetitions *= + getPropertyConst<CardinalityEstimate>(leftLogicalProps).getEstimate(); + setPropertyOverwrite<RepetitionEstimate>(rightPhysProps, + RepetitionEstimate{estimatedRepetitions}); + } + + const auto& optimizeFn = std::bind(&ImplementationVisitor::optimizeRIDIntersect, + this, + isIndex, + dedupRID, + indexingAvailability.getPossiblyEqPredsOnly(), + std::cref(ridProjName), + std::cref(collationLeftRightSplit), + std::cref(collationRightLeftSplit), + intersectedCE, + leftCE, + rightCE, + std::cref(leftPhysProps), + std::cref(rightPhysProps), + std::cref(node.getLeftChild()), + std::cref(node.getRightChild())); + + // Always optimize under same distributions on left and on right. + optimizeFn(); + + if (isIndex) { + switch (distrAndProjections._type) { + case DistributionType::HashPartitioning: + case DistributionType::RangePartitioning: { + // Specifically for index intersection, try propagating the requirement on one + // side and replicating the other. + + const auto& leftDistributions = + getPropertyConst<DistributionAvailability>(leftLogicalProps) + .getDistributionSet(); + const auto& rightDistributions = + getPropertyConst<DistributionAvailability>(rightLogicalProps) + .getDistributionSet(); + + if (leftDistributions.count(distrAndProjections) > 0) { + setPropertyOverwrite<DistributionRequirement>(leftPhysProps, + distribRequirement); + setPropertyOverwrite<DistributionRequirement>( + rightPhysProps, DistributionRequirement{DistributionType::Replicated}); + optimizeFn(); + } + + if (rightDistributions.count(distrAndProjections) > 0) { + setPropertyOverwrite<DistributionRequirement>( + leftPhysProps, DistributionRequirement{DistributionType::Replicated}); + setPropertyOverwrite<DistributionRequirement>(rightPhysProps, + distribRequirement); + optimizeFn(); + } + break; + } + + default: + break; + } + } + } + + void operator()(const ABT& /*n*/, const BinaryJoinNode& node) { + // TODO: optimize binary joins + uasserted(6624105, "not implemented"); + } + + void operator()(const ABT& n, const UnionNode& node) { + if (hasProperty<LimitSkipRequirement>(_physProps)) { + // We cannot satisfy limit-skip requirements. + return; + } + if (hasProperty<CollationRequirement>(_physProps)) { + // In general we cannot satisfy collation requirements. + // TODO: This may be possible with a merge sort type of node. + return; + } + + // Only need to propagate the required projection set. + ABT physicalUnion = make<UnionNode>( + getPropertyConst<ProjectionRequirement>(_physProps).getProjections().getVector(), + node.nodes()); + + // Optimize each child under the same physical properties. + ChildPropsType childProps; + for (auto& child : physicalUnion.cast<UnionNode>()->nodes()) { + PhysProps newProps = _physProps; + childProps.emplace_back(&child, std::move(newProps)); + } + + optimizeChildren<UnionNode>( + _queue, kDefaultPriority, std::move(physicalUnion), std::move(childProps)); + } + + void operator()(const ABT& n, const GroupByNode& node) { + if (hasProperty<LimitSkipRequirement>(_physProps)) { + // We cannot satisfy limit-skip requirements. + // TODO: consider an optimization where we keep track of at most "limit" groups. + return; + } + if (hasProperty<CollationRequirement>(_physProps)) { + // In general we cannot satisfy collation requirements. + // TODO: consider stream group-by. + return; + } + + if (propertyAffectsProjections<DistributionRequirement>( + _physProps, node.getAggregationProjectionNames())) { + // We cannot satisfy distribution on the aggregations. + return; + } + + const ProjectionNameVector& groupByProjections = node.getGroupByProjectionNames(); + + const bool isLocal = node.getType() == GroupNodeType::Local; + if (!isLocal) { + // We are constrained in terms of distribution only if we are a global or complete agg. + + const auto& distribAndProjections = + getPropertyConst<DistributionRequirement>(_physProps) + .getDistributionAndProjections(); + switch (distribAndProjections._type) { + case DistributionType::UnknownPartitioning: + case DistributionType::RoundRobin: + // Cannot satisfy unknown or round-robin partitioning. + return; + + case DistributionType::HashPartitioning: { + ProjectionNameSet groupByProjectionSet; + for (const ProjectionName& projectionName : groupByProjections) { + groupByProjectionSet.insert(projectionName); + } + for (const ProjectionName& projectionName : + distribAndProjections._projectionNames) { + if (groupByProjectionSet.count(projectionName) == 0) { + // We can only be partitioned on projections on which we group. + return; + } + } + break; + } + + case DistributionType::RangePartitioning: + if (distribAndProjections._projectionNames != groupByProjections) { + // For range partitioning we need to be partitioned exactly in the same + // order as our group-by projections. + return; + } + break; + + default: + break; + } + } + + PhysProps newProps = _physProps; + + // TODO: remove RepetitionEstimate if the subtree does not use bound variables. + // TODO: this is not the case for stream group-by. + + // Specifically do not propagate limit-skip. + removeProperty<LimitSkipRequirement>(newProps); + + getProperty<DistributionRequirement>(newProps).setDisableExchanges(isLocal); + + // Iterate over the aggregation expressions and only add those required. + ABTVector aggregationProjections; + ProjectionNameVector aggregationProjectionNames; + VariableNameSetType projectionsToAdd; + for (const ProjectionName& groupByProjName : groupByProjections) { + projectionsToAdd.insert(groupByProjName); + } + + const auto& requiredProjections = + getPropertyConst<ProjectionRequirement>(_physProps).getProjections(); + for (size_t aggIndex = 0; aggIndex < node.getAggregationExpressions().size(); aggIndex++) { + const ProjectionName& aggProjectionName = + node.getAggregationProjectionNames().at(aggIndex); + + if (requiredProjections.find(aggProjectionName).second) { + // We require this agg expression. + aggregationProjectionNames.push_back(aggProjectionName); + const ABT& aggExpr = node.getAggregationExpressions().at(aggIndex); + aggregationProjections.push_back(aggExpr); + + for (const auto& var : VariableEnvironment::getVariables(aggExpr)._variables) { + // Add all references this expression requires. + projectionsToAdd.insert(var->name()); + } + } + } + + addRemoveProjectionsToProperties( + newProps, projectionsToAdd, node.getAggregationProjectionNames()); + + ABT physicalGroupBy = make<GroupByNode>(groupByProjections, + std::move(aggregationProjectionNames), + std::move(aggregationProjections), + node.getType(), + node.getChild()); + optimizeChild<GroupByNode>( + _queue, kDefaultPriority, std::move(physicalGroupBy), std::move(newProps)); + } + + void operator()(const ABT& n, const UnwindNode& node) { + const ProjectionName& pidProjectionName = node.getPIDProjectionName(); + const ProjectionNameVector& projectionNames = {(node.getProjectionName()), + pidProjectionName}; + + if (propertyAffectsProjections<DistributionRequirement>(_physProps, projectionNames)) { + // We cannot satisfy distribution on the unwound output, or pid. + return; + } + if (propertyAffectsProjections<CollationRequirement>(_physProps, projectionNames)) { + // We cannot satisfy collation on the output. + return; + } + if (hasProperty<LimitSkipRequirement>(_physProps)) { + // Cannot satisfy limit-skip. + return; + } + + PhysProps newProps = _physProps; + addRemoveProjectionsToProperties( + newProps, collectVariableReferences(n), ProjectionNameVector{pidProjectionName}); + + // Specifically do not propagate limit-skip. + removeProperty<LimitSkipRequirement>(newProps); + // Keep collation property if given it does not affect output. + + getProperty<DistributionRequirement>(newProps).setDisableExchanges(false); + + ABT physicalUnwind = n; + optimizeChild<UnwindNode>( + _queue, kDefaultPriority, std::move(physicalUnwind), std::move(newProps)); + } + + void operator()(const ABT& /*n*/, const CollationNode& node) { + if (getPropertyConst<DistributionRequirement>(_physProps) + .getDistributionAndProjections() + ._type != DistributionType::Centralized) { + // We can only pick up collation under centralized (but we can enforce under any + // distribution). + return; + } + + optimizeSimplePropertyNode<CollationNode, CollationRequirement>(node); + } + + void operator()(const ABT& /*n*/, const LimitSkipNode& node) { + // We can pick-up limit-skip under any distribution (but enforce under centralized or + // replicated). + + PhysProps newProps = _physProps; + LimitSkipRequirement newProp = node.getProperty(); + + removeProperty<LimitEstimate>(newProps); + + if (hasProperty<LimitSkipRequirement>(_physProps)) { + const LimitSkipRequirement& required = + getPropertyConst<LimitSkipRequirement>(_physProps); + LimitSkipRequirement merged(required.getLimit(), required.getSkip()); + + combineLimitSkipProperties(merged, newProp); + // Continue with new unenforced requirement. + newProp = std::move(merged); + } + + setPropertyOverwrite<LimitSkipRequirement>(newProps, std::move(newProp)); + getProperty<DistributionRequirement>(newProps).setDisableExchanges(false); + + optimizeUnderNewProperties(_queue, kDefaultPriority, node.getChild(), std::move(newProps)); + } + + void operator()(const ABT& /*n*/, const ExchangeNode& node) { + optimizeSimplePropertyNode<ExchangeNode, DistributionRequirement>(node); + } + + void operator()(const ABT& n, const RootNode& node) { + PhysProps newProps = _physProps; + setPropertyOverwrite<ProjectionRequirement>(newProps, node.getProperty()); + getProperty<DistributionRequirement>(newProps).setDisableExchanges(false); + + ABT rootNode = n; + optimizeChild<RootNode>(_queue, kDefaultPriority, std::move(rootNode), std::move(newProps)); + } + + template <typename T> + void operator()(const ABT& /*n*/, const T& /*node*/) { + static_assert(!canBeLogicalNode<T>(), "Logical node must implement its visitor."); + } + + ImplementationVisitor(const Memo& memo, + const QueryHints& hints, + const opt::unordered_map<std::string, ProjectionName>& ridProjections, + PrefixId& prefixId, + PhysRewriteQueue& queue, + const PhysProps& physProps, + const LogicalProps& logicalProps) + : _memo(memo), + _hints(hints), + _ridProjections(ridProjections), + _prefixId(prefixId), + _queue(queue), + _physProps(physProps), + _logicalProps(logicalProps) {} + +private: + template <class NodeType, class PropType> + void optimizeSimplePropertyNode(const NodeType& node) { + const PropType& nodeProp = node.getProperty(); + PhysProps newProps = _physProps; + setPropertyOverwrite<PropType>(newProps, nodeProp); + + getProperty<DistributionRequirement>(newProps).setDisableExchanges(false); + optimizeUnderNewProperties(_queue, kDefaultPriority, node.getChild(), std::move(newProps)); + } + + struct IndexAvailableDirections { + // Keep track if we can match against forward or backward direction. + bool _forward = true; + bool _backward = true; + }; + + IndexAvailableDirections indexSatisfiesCollation( + const IndexCollationSpec& indexCollationSpec, + const CandidateIndexEntry& candidateIndexEntry, + const ProjectionCollationSpec& requiredCollationSpec) { + if (requiredCollationSpec.empty()) { + return {true, true}; + } + + IndexAvailableDirections result; + size_t collationSpecIndex = 0; + bool indexSuitable = true; + const auto& fieldProjections = candidateIndexEntry._fieldProjectionMap._fieldProjections; + + // Verify the index is compatible with our collation requirement, and can deliver the right + // order of paths. + for (size_t indexField = 0; indexField < indexCollationSpec.size(); indexField++) { + const bool needsCollation = candidateIndexEntry._fieldsToCollate.count(indexField) > 0; + + auto it = fieldProjections.find(encodeIndexKeyName(indexField)); + if (it == fieldProjections.cend()) { + // No bound projection for this index field. + if (needsCollation) { + // We cannot satisfy the rest of the collation requirements. + indexSuitable = false; + break; + } + continue; + } + const ProjectionName& projName = it->second; + + if (!needsCollation) { + // We do not need to collate this field because of equality. + if (requiredCollationSpec.at(collationSpecIndex).first == projName) { + // We can satisfy the next collation requirement independent of collation op. + if (++collationSpecIndex >= requiredCollationSpec.size()) { + break; + } + } + continue; + } + + // Check if we can satisfy the next collation requirement. + const auto& collationEntry = requiredCollationSpec.at(collationSpecIndex); + if (collationEntry.first != projName) { + indexSuitable = false; + break; + } + + const auto& indexCollationEntry = indexCollationSpec.at(indexField); + if (result._forward && + !collationOpsCompatible(indexCollationEntry._op, collationEntry.second)) { + result._forward = false; + } + if (result._backward && + !collationOpsCompatible(reverseCollationOp(indexCollationEntry._op), + collationEntry.second)) { + result._backward = false; + } + if (!result._forward && !result._backward) { + indexSuitable = false; + break; + } + if (++collationSpecIndex >= requiredCollationSpec.size()) { + break; + } + } + + if (!indexSuitable || collationSpecIndex < requiredCollationSpec.size()) { + return {false, false}; + } + return result; + } + + /** + * Check if we are under index-only requirements and expression introduces dependency on scan + * projection. + */ + bool checkIntroducesScanProjectionUnderIndexOnly(const VariableNameSetType& references) { + return hasProperty<IndexingAvailability>(_logicalProps) && + getPropertyConst<IndexingRequirement>(_physProps).getIndexReqTarget() == + IndexReqTarget::Index && + references.find( + getPropertyConst<IndexingAvailability>(_logicalProps).getScanProjection()) != + references.cend(); + } + + bool distributionsCompatible(const IndexReqTarget target, + const DistributionAndPaths& distributionAndPaths, + const ProjectionName& scanProjection, + const LogicalProps& scanLogicalProps, + const PartialSchemaRequirements& reqMap, + bool& canUseParallelScan) { + const DistributionRequirement& required = + getPropertyConst<DistributionRequirement>(_physProps); + const auto& distribAndProjections = required.getDistributionAndProjections(); + + const auto& scanDistributions = + getPropertyConst<DistributionAvailability>(scanLogicalProps).getDistributionSet(); + + switch (distribAndProjections._type) { + case DistributionType::Centralized: + return scanDistributions.count({DistributionType::Centralized}) > 0 || + scanDistributions.count({DistributionType::Replicated}) > 0; + + case DistributionType::Replicated: + return scanDistributions.count({DistributionType::Replicated}) > 0; + + case DistributionType::RoundRobin: + if (target == IndexReqTarget::Seek) { + // We can satisfy Seek with RoundRobin if we can scan the collection in + // parallel. + return scanDistributions.count({DistributionType::UnknownPartitioning}) > 0; + } + + // TODO: Are two round robin distributions compatible? + return false; + + case DistributionType::UnknownPartitioning: + if (target == IndexReqTarget::Index) { + // We cannot satisfy unknown partitioning with an index as (unlike parallel + // collection scan) we currently cannot perform a parallel index scan. + return false; + } + + if (scanDistributions.count({DistributionType::UnknownPartitioning}) > 0) { + canUseParallelScan = true; + return true; + } + return false; + + case DistributionType::HashPartitioning: + case DistributionType::RangePartitioning: { + if (distribAndProjections._type != distributionAndPaths._type) { + return false; + } + + size_t distributionPartitionIndex = 0; + const ProjectionNameVector& requiredProjections = + distribAndProjections._projectionNames; + + for (const ABT& partitioningPath : distributionAndPaths._paths) { + auto it = reqMap.find(PartialSchemaKey{scanProjection, partitioningPath}); + if (it == reqMap.cend()) { + return false; + } + + if (it->second.getBoundProjectionName() != + requiredProjections.at(distributionPartitionIndex)) { + return false; + } + distributionPartitionIndex++; + } + + return distributionPartitionIndex == requiredProjections.size(); + } + + default: + MONGO_UNREACHABLE; + } + } + + void setCollationForRIDIntersect(const CollationSplitResult& collationSplit, + PhysProps& leftPhysProps, + PhysProps& rightPhysProps) { + if (collationSplit._leftCollation.empty()) { + removeProperty<CollationRequirement>(leftPhysProps); + } else { + setPropertyOverwrite<CollationRequirement>(leftPhysProps, + collationSplit._leftCollation); + } + + if (collationSplit._rightCollation.empty()) { + removeProperty<CollationRequirement>(rightPhysProps); + } else { + setPropertyOverwrite<CollationRequirement>(rightPhysProps, + collationSplit._rightCollation); + } + } + + void optimizeRIDIntersect(const bool isIndex, + const bool dedupRID, + const bool useMergeJoin, + const ProjectionName& ridProjectionName, + const CollationSplitResult& collationLeftRightSplit, + const CollationSplitResult& collationRightLeftSplit, + const CEType intersectedCE, + const CEType leftCE, + const CEType rightCE, + const PhysProps& leftPhysProps, + const PhysProps& rightPhysProps, + const ABT& leftChild, + const ABT& rightChild) { + if (isIndex && collationRightLeftSplit._validSplit && + (!collationLeftRightSplit._validSplit || leftCE > rightCE)) { + // Need to reverse the left and right side as the left collation split is not valid, or + // to use the larger CE as the other side. + optimizeRIDIntersect(true /*isIndex*/, + dedupRID, + useMergeJoin, + ridProjectionName, + collationRightLeftSplit, + {}, + intersectedCE, + rightCE, + leftCE, + rightPhysProps, + leftPhysProps, + rightChild, + leftChild); + return; + } + if (!collationLeftRightSplit._validSplit) { + return; + } + + if (isIndex) { + if (useMergeJoin && !_hints._disableMergeJoinRIDIntersect) { + // Try a merge join on RID since both of our children only have equality + // predicates. + + NodeCEMap nodeCEMap; + ChildPropsType childProps; + PhysProps leftPhysPropsLocal = leftPhysProps; + PhysProps rightPhysPropsLocal = rightPhysProps; + + setCollationForRIDIntersect( + collationLeftRightSplit, leftPhysPropsLocal, rightPhysPropsLocal); + if (dedupRID) { + getProperty<IndexingRequirement>(leftPhysPropsLocal) + .setDedupRID(true /*dedupRID*/); + getProperty<IndexingRequirement>(rightPhysPropsLocal) + .setDedupRID(true /*dedupRID*/); + } + + ABT physNode = lowerRIDIntersectMergeJoin(_prefixId, + ridProjectionName, + intersectedCE, + leftCE, + rightCE, + leftPhysPropsLocal, + rightPhysPropsLocal, + leftChild, + rightChild, + nodeCEMap, + childProps); + optimizeChildrenNoAssert(_queue, + kDefaultPriority, + std::move(physNode), + std::move(childProps), + std::move(nodeCEMap)); + } else { + if (!_hints._disableHashJoinRIDIntersect) { + // Try a HashJoin. Propagate dedupRID on left and right indexing + // requirements. + + NodeCEMap nodeCEMap; + ChildPropsType childProps; + PhysProps leftPhysPropsLocal = leftPhysProps; + PhysProps rightPhysPropsLocal = rightPhysProps; + + setCollationForRIDIntersect( + collationLeftRightSplit, leftPhysPropsLocal, rightPhysPropsLocal); + if (dedupRID) { + getProperty<IndexingRequirement>(leftPhysPropsLocal) + .setDedupRID(true /*dedupRID*/); + getProperty<IndexingRequirement>(rightPhysPropsLocal) + .setDedupRID(true /*dedupRID*/); + } + + ABT physNode = lowerRIDIntersectHashJoin(_prefixId, + ridProjectionName, + intersectedCE, + leftCE, + rightCE, + leftPhysPropsLocal, + rightPhysPropsLocal, + leftChild, + rightChild, + nodeCEMap, + childProps); + optimizeChildrenNoAssert(_queue, + kDefaultPriority, + std::move(physNode), + std::move(childProps), + std::move(nodeCEMap)); + } + + // We can only attempt this strategy if we have no collation requirements. + if (!_hints._disableGroupByAndUnionRIDIntersect && dedupRID && + collationLeftRightSplit._leftCollation.empty() && + collationLeftRightSplit._rightCollation.empty()) { + // Try a Union+GroupBy. left and right indexing requirements are already + // initialized to not dedup. + + NodeCEMap nodeCEMap; + ChildPropsType childProps; + PhysProps leftPhysPropsLocal = leftPhysProps; + PhysProps rightPhysPropsLocal = rightPhysProps; + + setCollationForRIDIntersect( + collationLeftRightSplit, leftPhysPropsLocal, rightPhysPropsLocal); + + ABT physNode = lowerRIDIntersectGroupBy(_prefixId, + ridProjectionName, + intersectedCE, + leftCE, + rightCE, + _physProps, + leftPhysPropsLocal, + rightPhysPropsLocal, + leftChild, + rightChild, + nodeCEMap, + childProps); + optimizeChildrenNoAssert(_queue, + kDefaultPriority, + std::move(physNode), + std::move(childProps), + std::move(nodeCEMap)); + } + } + } else { + ABT physicalJoin = make<BinaryJoinNode>(JoinType::Inner, + ProjectionNameSet{ridProjectionName}, + Constant::boolean(true), + leftChild, + rightChild); + + PhysProps leftPhysPropsLocal = leftPhysProps; + PhysProps rightPhysPropsLocal = rightPhysProps; + setCollationForRIDIntersect( + collationLeftRightSplit, leftPhysPropsLocal, rightPhysPropsLocal); + + optimizeChildren<BinaryJoinNode>(_queue, + kDefaultPriority, + std::move(physicalJoin), + std::move(leftPhysPropsLocal), + std::move(rightPhysPropsLocal)); + } + } + + // We don't own any of those; + const Memo& _memo; + const QueryHints& _hints; + const opt::unordered_map<std::string, ProjectionName>& _ridProjections; + PrefixId& _prefixId; + PhysRewriteQueue& _queue; + const PhysProps& _physProps; + const LogicalProps& _logicalProps; +}; + +void addImplementers(const Memo& memo, + const QueryHints& hints, + const opt::unordered_map<std::string, ProjectionName>& ridProjections, + PrefixId& prefixId, + PhysOptimizationResult& bestResult, + const properties::LogicalProps& logicalProps, + const OrderPreservingABTSet& logicalNodes) { + ImplementationVisitor visitor(memo, + hints, + ridProjections, + prefixId, + bestResult._queue, + bestResult._physProps, + logicalProps); + while (bestResult._lastImplementedNodePos < logicalNodes.size()) { + logicalNodes.at(bestResult._lastImplementedNodePos++).visit(visitor); + } +} + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/implementers.h b/src/mongo/db/query/optimizer/cascades/implementers.h new file mode 100644 index 00000000000..c26bbbbdf69 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/implementers.h @@ -0,0 +1,50 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/cascades/memo.h" +#include "mongo/db/query/optimizer/cascades/rewrite_queues.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer::cascades { + +/** + * Adds operator implementation rules for particular group and physical properties. Only add + * rewrites for newly added logical nodes + */ +void addImplementers(const Memo& memo, + const QueryHints& hints, + const opt::unordered_map<std::string, ProjectionName>& ridProjections, + PrefixId& prefixId, + PhysOptimizationResult& bestResult, + const properties::LogicalProps& logicalProps, + const OrderPreservingABTSet& logicalNodes); + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/interfaces.h b/src/mongo/db/query/optimizer/cascades/interfaces.h new file mode 100644 index 00000000000..2a09720bed5 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/interfaces.h @@ -0,0 +1,81 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/node_defs.h" +#include "mongo/db/query/optimizer/props.h" + +namespace mongo::optimizer::cascades { + +class Memo; + +/** + * Interface for deriving properties. Typically we supply memo, but may derive without a memo as + * long as metadata is provided. + */ +class LogicalPropsInterface { +public: + virtual ~LogicalPropsInterface() = default; + + using NodePropsMap = opt::unordered_map<const Node*, properties::LogicalProps>; + + virtual properties::LogicalProps deriveProps(const Metadata& metadata, + ABT::reference_type nodeRef, + NodePropsMap* nodePropsMap = nullptr, + const Memo* memo = nullptr, + GroupIdType groupId = -1) const = 0; +}; + +/** + * Interface for deriving CE for a newly added logical node in a new memo group. + */ +class CEInterface { +public: + virtual ~CEInterface() = default; + + virtual CEType deriveCE(const Memo& memo, + const properties::LogicalProps& logicalProps, + ABT::reference_type logicalNodeRef) const = 0; +}; + +/** + * Interface for deriving costs and adjusted CE (based on physical props) for a physical node. + */ +class CostingInterface { +public: + virtual ~CostingInterface() = default; + virtual CostAndCE deriveCost(const Memo& memo, + const properties::PhysProps& physProps, + ABT::reference_type physNodeRef, + const ChildPropsType& childProps, + const NodeCEMap& nodeCEMap) const = 0; +}; + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp b/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp new file mode 100644 index 00000000000..4f2a655a363 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp @@ -0,0 +1,494 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/cascades/logical_props_derivation.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer::cascades { + +using namespace properties; + +static void populateInitialDistributions(const DistributionAndPaths& distributionAndPaths, + const bool isMultiPartition, + DistributionSet& distributions) { + switch (distributionAndPaths._type) { + case DistributionType::Centralized: + distributions.insert({DistributionType::Centralized}); + break; + + case DistributionType::Replicated: + uassert(6624106, "Invalid distribution specification", isMultiPartition); + + distributions.insert({DistributionType::Centralized}); + distributions.insert({DistributionType::Replicated}); + break; + + case DistributionType::HashPartitioning: + case DistributionType::RangePartitioning: + case DistributionType::UnknownPartitioning: + uassert(6624107, "Invalid distribution specification", isMultiPartition); + + distributions.insert({DistributionType::UnknownPartitioning}); + break; + + default: + uasserted(6624108, "Invalid collection distribution"); + } +} + +static void populateDistributionPaths(const PartialSchemaRequirements& req, + const ProjectionName& scanProjectionName, + const DistributionAndPaths& distributionAndPaths, + DistributionSet& distributions) { + switch (distributionAndPaths._type) { + case DistributionType::HashPartitioning: + case DistributionType::RangePartitioning: { + ProjectionNameVector distributionProjections; + + for (const ABT& path : distributionAndPaths._paths) { + auto it = req.find(PartialSchemaKey{scanProjectionName, path}); + if (it == req.cend()) { + break; + } + if (it->second.hasBoundProjectionName()) { + distributionProjections.push_back(it->second.getBoundProjectionName()); + } + } + + if (distributionProjections.size() == distributionAndPaths._paths.size()) { + distributions.emplace(distributionAndPaths._type, + std::move(distributionProjections)); + } + } + + default: + break; + } +} + +static bool computePossiblyEqPredsOnly(const PartialSchemaRequirements& reqMap) { + PartialSchemaRequirements equalitiesReqMap; + PartialSchemaRequirements fullyOpenReqMap; + + for (const auto& [key, req] : reqMap) { + const auto& intervals = req.getIntervals(); + if (auto singularInterval = IntervalReqExpr::getSingularDNF(intervals)) { + if (singularInterval->isFullyOpen()) { + fullyOpenReqMap.emplace(key, req); + } else if (singularInterval->isEquality()) { + equalitiesReqMap.emplace(key, req); + } else { + // Encountered a non-equality and not-fully-open interval. + return false; + } + } else { + // Encountered a non-trivial interval. + return false; + } + } + + PartialSchemaKeySet resultKeySet; + PartialSchemaRequirement req_unused; + for (const auto& [key, req] : fullyOpenReqMap) { + findMatchingSchemaRequirement( + key, equalitiesReqMap, resultKeySet, req_unused, false /*setIntervalsAndBoundProj*/); + if (resultKeySet.empty()) { + // No possible match for fully open requirement. + return false; + } + } + + return true; +} + +class DeriveLogicalProperties { +public: + LogicalProps transport(const ScanNode& node, LogicalProps /*bindResult*/) { + DistributionSet distributions; + + const auto& scanDef = _metadata._scanDefs.at(node.getScanDefName()); + populateInitialDistributions( + scanDef.getDistributionAndPaths(), _metadata.isParallelExecution(), distributions); + for (const auto& entry : scanDef.getIndexDefs()) { + populateInitialDistributions(entry.second.getDistributionAndPaths(), + _metadata.isParallelExecution(), + distributions); + } + + return maybeUpdateNodePropsMap( + node, + makeLogicalProps(IndexingAvailability(_groupId, + node.getProjectionName(), + node.getScanDefName(), + true /*possiblyEqPredsOnly*/, + {} /*satisfiedPartialIndexes*/), + CollectionAvailability({node.getScanDefName()}), + DistributionAvailability(std::move(distributions)))); + } + + LogicalProps transport(const ValueScanNode& node, LogicalProps /*bindResult*/) { + // We do not originate indexing availability, and have empty collection availability with + // Centralized + Replicated distribution availability. During physical optimization we + // accept optimization under any distribution. + LogicalProps result = + makeLogicalProps(CollectionAvailability{{}}, DistributionAvailability{{}}); + addCentralizedAndRoundRobinDistributions(result); + return maybeUpdateNodePropsMap(node, std::move(result)); + } + + LogicalProps transport(const MemoLogicalDelegatorNode& node) { + uassert(6624109, "Uninitialized memo", _memo != nullptr); + return maybeUpdateNodePropsMap(node, _memo->getGroup(node.getGroupId())._logicalProperties); + } + + LogicalProps transport(const FilterNode& node, + LogicalProps childResult, + LogicalProps /*exprResult*/) { + // Propagate indexing, collection, and distribution availabilities. + LogicalProps result = std::move(childResult); + if (hasProperty<IndexingAvailability>(result)) { + getProperty<IndexingAvailability>(result).setPossiblyEqPredsOnly(false); + } + addCentralizedAndRoundRobinDistributions(result); + return maybeUpdateNodePropsMap(node, std::move(result)); + } + + LogicalProps transport(const EvaluationNode& node, + LogicalProps childResult, + LogicalProps /*exprResult*/) { + // We are specifically not adding the node's projection to ProjectionAvailability here. + // The logical properties already contains projection availability which is derived first + // when the memo group is created. + LogicalProps result = std::move(childResult); + if (hasProperty<IndexingAvailability>(result)) { + getProperty<IndexingAvailability>(result).setPossiblyEqPredsOnly(false); + } + addCentralizedAndRoundRobinDistributions(result); + return maybeUpdateNodePropsMap(node, std::move(result)); + } + + LogicalProps transport(const SargableNode& node, + LogicalProps childResult, + LogicalProps /*bindsResult*/, + LogicalProps /*refsResult*/) { + LogicalProps result = std::move(childResult); + + auto& indexingAvailability = getProperty<IndexingAvailability>(result); + const ProjectionName& scanProjectionName = indexingAvailability.getScanProjection(); + const std::string& scanDefName = indexingAvailability.getScanDefName(); + const auto& scanDef = _metadata._scanDefs.at(scanDefName); + + auto& distributions = getProperty<DistributionAvailability>(result).getDistributionSet(); + addCentralizedAndRoundRobinDistributions(result); + + populateDistributionPaths( + node.getReqMap(), scanProjectionName, scanDef.getDistributionAndPaths(), distributions); + for (const auto& entry : scanDef.getIndexDefs()) { + populateDistributionPaths(node.getReqMap(), + scanProjectionName, + entry.second.getDistributionAndPaths(), + distributions); + } + + if (indexingAvailability.getPossiblyEqPredsOnly()) { + indexingAvailability.setPossiblyEqPredsOnly( + computePossiblyEqPredsOnly(node.getReqMap())); + } + + auto& satisfiedPartialIndexes = + getProperty<IndexingAvailability>(result).getSatisfiedPartialIndexes(); + for (const auto& [indexDefName, indexDef] : scanDef.getIndexDefs()) { + if (!indexDef.getPartialReqMap().empty()) { + auto intersection = node.getReqMap(); + // We specifically ignore projectionRenames here. + ProjectionRenames projectionRenames_unused; + if (intersectPartialSchemaReq( + intersection, indexDef.getPartialReqMap(), projectionRenames_unused) && + intersection == node.getReqMap()) { + satisfiedPartialIndexes.insert(indexDefName); + } + } + } + + return maybeUpdateNodePropsMap(node, std::move(result)); + } + + LogicalProps transport(const RIDIntersectNode& node, + LogicalProps /*leftChildResult*/, + LogicalProps /*rightChildResult*/) { + // Properties for the group should already be derived via the underlying Filter or + // Evaluation logical nodes. + uasserted(6624042, "Should not be necessary to derive properties for RIDIntersectNode"); + } + + LogicalProps transport(const BinaryJoinNode& node, + LogicalProps /*leftChildResult*/, + LogicalProps /*rightChildResult*/, + LogicalProps /*exprResult*/) { + // TODO: remove indexing availability property when implemented. + // TODO: combine scan defs from all children for CollectionAvailability. + uasserted(6624043, "Logical property derivation not implemented."); + } + + LogicalProps transport(const UnionNode& node, + std::vector<LogicalProps> childResults, + LogicalProps bindResult, + LogicalProps refsResult) { + uassert(6624044, "Unexpected empty child results for union node", !childResults.empty()); + + // We are specifically not adding the node's projection to ProjectionAvailability here. + // The logical properties already contains projection availability which is derived first + // when the memo group is created. + LogicalProps result = std::move(childResults[0]); + auto& mergedScanDefs = getProperty<CollectionAvailability>(result).getScanDefSet(); + auto& mergedDistributionSet = + getProperty<DistributionAvailability>(result).getDistributionSet(); + for (size_t childIdx = 1; childIdx < childResults.size(); childIdx++) { + auto childScanDefs = + getProperty<CollectionAvailability>(childResults[childIdx]).getScanDefSet(); + mergedScanDefs.merge(std::move(childScanDefs)); + + // Only keep the distribution properties which are common across all children + // distributions. + const auto& childDistributionSet = + getProperty<DistributionAvailability>(childResults[childIdx]).getDistributionSet(); + + for (auto it = mergedDistributionSet.begin(); it != mergedDistributionSet.end(); it++) { + if (childDistributionSet.find(*it) == childDistributionSet.end()) { + mergedDistributionSet.erase(it); + } + } + } + + // Verify that there is at least one common distribution available. + uassert(6624045, "No common distributions for union", !mergedDistributionSet.empty()); + + removeProperty<IndexingAvailability>(result); + return maybeUpdateNodePropsMap(node, std::move(result)); + } + + LogicalProps transport(const GroupByNode& node, + LogicalProps childResult, + LogicalProps /*bindAggResult*/, + LogicalProps /*refsAggResult*/, + LogicalProps /*bindGbResult*/, + LogicalProps /*refsGbResult*/) { + LogicalProps result = std::move(childResult); + removeProperty<IndexingAvailability>(result); + + auto& distributions = getProperty<DistributionAvailability>(result).getDistributionSet(); + addCentralizedAndRoundRobinDistributions<false /*addRoundRobin*/>(distributions); + + if (_metadata.isParallelExecution() && node.getType() != GroupNodeType::Local) { + distributions.erase({DistributionType::UnknownPartitioning}); + distributions.erase({DistributionType::RoundRobin}); + + // We propagate hash and range partitioning only if we are global agg. + const ProjectionNameVector& groupByProjections = node.getGroupByProjectionNames(); + if (!groupByProjections.empty()) { + DistributionRequirement allowedRangePartitioning{ + {DistributionType::RangePartitioning, groupByProjections}}; + for (auto it = distributions.begin(); it != distributions.end();) { + switch (it->_type) { + case DistributionType::HashPartitioning: + // Erase all hash partition distributions. New ones will be generated + // after. + distributions.erase(it++); + break; + + case DistributionType::RangePartitioning: + // Retain only the range partition which contains the group by + // projections in the node order. + if (*it == allowedRangePartitioning.getDistributionAndProjections()) { + it++; + } else { + distributions.erase(it++); + } + + default: + it++; + break; + } + } + + // Generate hash distributions using the power set of group-by projections. + for (size_t mask = 1; mask < (1ull << groupByProjections.size()); mask++) { + ProjectionNameVector projectionNames; + for (size_t index = 0; index < groupByProjections.size(); index++) { + if ((mask & (1ull << index)) != 0) { + projectionNames.push_back(groupByProjections.at(index)); + } + } + distributions.emplace(DistributionType::HashPartitioning, + std::move(projectionNames)); + } + } + } + + return maybeUpdateNodePropsMap(node, std::move(result)); + } + + LogicalProps transport(const UnwindNode& node, + LogicalProps childResult, + LogicalProps /*bindResult*/, + LogicalProps /*refsResult*/) { + LogicalProps result = std::move(childResult); + removeProperty<IndexingAvailability>(result); + + const ProjectionName& unwoundProjectionName = node.getProjectionName(); + auto& distributions = getProperty<DistributionAvailability>(result).getDistributionSet(); + addCentralizedAndRoundRobinDistributions(distributions); + + if (_metadata.isParallelExecution()) { + for (auto it = distributions.begin(); it != distributions.end();) { + switch (it->_type) { + case DistributionType::HashPartitioning: + case DistributionType::RangePartitioning: { + // Erase partitioned distributions which contain the projection to unwind. + bool containsProjection = false; + for (const ProjectionName& projectionName : it->_projectionNames) { + if (projectionName == unwoundProjectionName) { + containsProjection = true; + break; + } + } + if (containsProjection) { + distributions.erase(it); + } + it++; + break; + } + + default: + it++; + break; + } + } + } + + return maybeUpdateNodePropsMap(node, std::move(result)); + } + + LogicalProps transport(const CollationNode& node, + LogicalProps childResult, + LogicalProps /*refsResult*/) { + LogicalProps result = std::move(childResult); + // We propagate indexing availability. + + addCentralizedAndRoundRobinDistributions<false /*addRoundRobin*/>(result); + return maybeUpdateNodePropsMap(node, std::move(result)); + } + + LogicalProps transport(const LimitSkipNode& node, LogicalProps childResult) { + LogicalProps result = std::move(childResult); + removeProperty<IndexingAvailability>(result); + addCentralizedAndRoundRobinDistributions<false /*addRoundRobin*/>(result); + return maybeUpdateNodePropsMap(node, std::move(result)); + } + + LogicalProps transport(const ExchangeNode& node, + LogicalProps childResult, + LogicalProps /*refsResult*/) { + LogicalProps result = std::move(childResult); + removeProperty<IndexingAvailability>(result); + addCentralizedAndRoundRobinDistributions<false /*addRoundRobin*/>(result); + return maybeUpdateNodePropsMap(node, std::move(result)); + } + + LogicalProps transport(const RootNode& node, + LogicalProps childResult, + LogicalProps /*refsResult*/) { + return maybeUpdateNodePropsMap(node, std::move(childResult)); + } + + /** + * Other ABT types. + */ + template <typename T, typename... Ts> + LogicalProps transport(const T& /*node*/, Ts&&...) { + static_assert(!canBeLogicalNode<T>(), + "Logical node must implement its logical property derivation."); + return {}; + } + + static LogicalProps derive(const Metadata& metadata, + const ABT::reference_type nodeRef, + LogicalPropsInterface::NodePropsMap* nodePropsMap, + const Memo* memo, + const GroupIdType groupId) { + DeriveLogicalProperties instance(memo, metadata, groupId, nodePropsMap); + return algebra::transport<false>(nodeRef, instance); + } + +private: + DeriveLogicalProperties(const Memo* memo, + const Metadata& metadata, + const GroupIdType groupId, + LogicalPropsInterface::NodePropsMap* nodePropsMap) + : _groupId(groupId), _memo(memo), _metadata(metadata), _nodePropsMap(nodePropsMap) {} + + template <bool addRoundRobin = true> + void addCentralizedAndRoundRobinDistributions(DistributionSet& distributions) { + distributions.emplace(DistributionType::Centralized); + if (addRoundRobin && _metadata.isParallelExecution()) { + distributions.emplace(DistributionType::RoundRobin); + } + } + + template <bool addRoundRobin = true> + void addCentralizedAndRoundRobinDistributions(LogicalProps& properties) { + addCentralizedAndRoundRobinDistributions<addRoundRobin>( + getProperty<DistributionAvailability>(properties).getDistributionSet()); + } + + LogicalProps maybeUpdateNodePropsMap(const Node& node, LogicalProps props) { + if (_nodePropsMap != nullptr) { + _nodePropsMap->emplace(&node, props); + } + return props; + } + + const GroupIdType _groupId; + + // We don't own any of those. + const Memo* _memo; + const Metadata& _metadata; + LogicalPropsInterface::NodePropsMap* _nodePropsMap; +}; + +properties::LogicalProps DefaultLogicalPropsDerivation::deriveProps( + const Metadata& metadata, + const ABT::reference_type nodeRef, + LogicalPropsInterface::NodePropsMap* nodePropsMap, + const Memo* memo, + const GroupIdType groupId) const { + return DeriveLogicalProperties::derive(metadata, nodeRef, nodePropsMap, memo, groupId); +} + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/logical_props_derivation.h b/src/mongo/db/query/optimizer/cascades/logical_props_derivation.h new file mode 100644 index 00000000000..d8b4d8ec349 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/logical_props_derivation.h @@ -0,0 +1,52 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/cascades/interfaces.h" +#include "mongo/db/query/optimizer/cascades/memo.h" + +namespace mongo::optimizer::cascades { + + +/** + * Logical property derivation framework. + * If nodePropsMap is supplied, populate per-node properties. + */ +class DefaultLogicalPropsDerivation : public LogicalPropsInterface { +public: + properties::LogicalProps deriveProps(const Metadata& metadata, + ABT::reference_type nodeRef, + NodePropsMap* nodePropsMap = nullptr, + const Memo* memo = nullptr, + GroupIdType groupId = -1) const override final; +}; + + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp b/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp new file mode 100644 index 00000000000..b4a52153736 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp @@ -0,0 +1,1288 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/cascades/logical_rewriter.h" + + +namespace mongo::optimizer::cascades { + +LogicalRewriter::RewriteSet LogicalRewriter::_explorationSet = { + {LogicalRewriteType::GroupByExplore, 1}, + {LogicalRewriteType::SargableSplit, 2}, + {LogicalRewriteType::FilterRIDIntersectReorder, 2}, + {LogicalRewriteType::EvaluationRIDIntersectReorder, 2}}; + +LogicalRewriter::RewriteSet LogicalRewriter::_substitutionSet = { + {LogicalRewriteType::FilterEvaluationReorder, 1}, + {LogicalRewriteType::FilterCollationReorder, 1}, + {LogicalRewriteType::EvaluationCollationReorder, 1}, + {LogicalRewriteType::EvaluationLimitSkipReorder, 1}, + + {LogicalRewriteType::FilterGroupByReorder, 1}, + {LogicalRewriteType::GroupCollationReorder, 1}, + + {LogicalRewriteType::FilterUnwindReorder, 1}, + {LogicalRewriteType::EvaluationUnwindReorder, 1}, + {LogicalRewriteType::UnwindCollationReorder, 1}, + + {LogicalRewriteType::FilterExchangeReorder, 1}, + {LogicalRewriteType::ExchangeEvaluationReorder, 1}, + + {LogicalRewriteType::FilterUnionReorder, 1}, + + {LogicalRewriteType::CollationMerge, 1}, + {LogicalRewriteType::LimitSkipMerge, 1}, + + {LogicalRewriteType::SargableFilterReorder, 1}, + {LogicalRewriteType::SargableEvaluationReorder, 1}, + + {LogicalRewriteType::FilterValueScanPropagate, 1}, + {LogicalRewriteType::EvaluationValueScanPropagate, 1}, + {LogicalRewriteType::SargableValueScanPropagate, 1}, + {LogicalRewriteType::CollationValueScanPropagate, 1}, + {LogicalRewriteType::LimitSkipValueScanPropagate, 1}, + {LogicalRewriteType::ExchangeValueScanPropagate, 1}, + + {LogicalRewriteType::LimitSkipSubstitute, 1}, + + {LogicalRewriteType::FilterSubstitute, 2}, + {LogicalRewriteType::EvaluationSubstitute, 2}, + {LogicalRewriteType::SargableMerge, 2}}; + +LogicalRewriter::LogicalRewriter(Memo& memo, PrefixId& prefixId, RewriteSet rewriteSet) + : _activeRewriteSet(std::move(rewriteSet)), _groupsPending(), _memo(memo), _prefixId(prefixId) { + initializeRewrites(); + + if (_activeRewriteSet.count(LogicalRewriteType::SargableSplit) > 0) { + // If we are performing SargableSplit exploration rewrite, populate helper map. + for (const auto& [scanDefName, scanDef] : _memo.getMetadata()._scanDefs) { + for (const auto& [indexDefName, indexDef] : scanDef.getIndexDefs()) { + for (const IndexCollationEntry& entry : indexDef.getCollationSpec()) { + if (auto pathPtr = entry._path.cast<PathGet>(); pathPtr != nullptr) { + _indexFieldPrefixMap[scanDefName].insert(pathPtr->name()); + } + } + } + } + } +} + +GroupIdType LogicalRewriter::addRootNode(const ABT& node) { + return addNode(node, -1, false /*addExistingNodeWithNewChild*/).first; +} + +std::pair<GroupIdType, NodeIdSet> LogicalRewriter::addNode(const ABT& node, + const GroupIdType targetGroupId, + const bool addExistingNodeWithNewChild) { + NodeIdSet insertNodeIds; + + Memo::NodeTargetGroupMap targetGroupMap; + if (targetGroupId >= 0) { + targetGroupMap = {{node.ref(), targetGroupId}}; + } + + const GroupIdType resultGroupId = _memo.integrate( + node, std::move(targetGroupMap), insertNodeIds, addExistingNodeWithNewChild); + + uassert(6624046, + "Result group is not the same as target group", + targetGroupId < 0 || targetGroupId == resultGroupId); + + for (const MemoLogicalNodeId& nodeMemoId : insertNodeIds) { + if (addExistingNodeWithNewChild && nodeMemoId._groupId == targetGroupId) { + continue; + } + + for (const auto [type, priority] : _activeRewriteSet) { + auto& groupQueue = _memo.getGroup(nodeMemoId._groupId)._logicalRewriteQueue; + groupQueue.push(std::make_unique<LogicalRewriteEntry>(priority, type, nodeMemoId)); + + _groupsPending.insert(nodeMemoId._groupId); + } + } + + return {resultGroupId, std::move(insertNodeIds)}; +} + +void LogicalRewriter::clearGroup(const GroupIdType groupId) { + _memo.clearLogicalNodes(groupId); +} + +class RewriteContext { +public: + RewriteContext(LogicalRewriter& rewriter, + const MemoLogicalNodeId aboveNodeId, + const MemoLogicalNodeId belowNodeId) + : RewriteContext(rewriter, aboveNodeId, true /*hasBelowNodeId*/, belowNodeId){}; + + RewriteContext(LogicalRewriter& rewriter, const MemoLogicalNodeId aboveNodeId) + : RewriteContext(rewriter, aboveNodeId, false /*hasBelowNodeId*/, {}){}; + + std::pair<GroupIdType, NodeIdSet> addNode(const ABT& node, + const bool substitute, + const bool addExistingNodeWithNewChild = false) { + if (substitute) { + uassert(6624110, "Cannot substitute twice", !_hasSubstituted); + _hasSubstituted = true; + + _rewriter.clearGroup(_aboveNodeId._groupId); + if (_hasBelowNodeId) { + _rewriter.clearGroup(_belowNodeId._groupId); + } + } + return _rewriter.addNode(node, _aboveNodeId._groupId, addExistingNodeWithNewChild); + } + + Memo& getMemo() const { + return _rewriter._memo; + } + + const Metadata& getMetadata() const { + return _rewriter._memo.getMetadata(); + } + + PrefixId& getPrefixId() const { + return _rewriter._prefixId; + } + + auto& getIndexFieldPrefixMap() const { + return _rewriter._indexFieldPrefixMap; + } + + const properties::LogicalProps& getAboveLogicalProps() const { + return getMemo().getGroup(_aboveNodeId._groupId)._logicalProperties; + } + + bool hasSubstituted() const { + return _hasSubstituted; + } + + MemoLogicalNodeId getAboveNodeId() const { + return _aboveNodeId; + } + + auto& getSargableSplitCountMap() const { + return _rewriter._sargableSplitCountMap; + } + +private: + RewriteContext(LogicalRewriter& rewriter, + const MemoLogicalNodeId aboveNodeId, + const bool hasBelowNodeId, + const MemoLogicalNodeId belowNodeId) + : _aboveNodeId(aboveNodeId), + _hasBelowNodeId(hasBelowNodeId), + _belowNodeId(belowNodeId), + _rewriter(rewriter), + _hasSubstituted(false){}; + + const MemoLogicalNodeId _aboveNodeId; + const bool _hasBelowNodeId; + const MemoLogicalNodeId _belowNodeId; + + // We don't own this. + LogicalRewriter& _rewriter; + + bool _hasSubstituted; +}; + +struct ReorderDependencies { + bool _hasNodeRef = false; + bool _hasChildRef = false; + bool _hasNodeAndChildRef = false; +}; + +template <class NodeType> +struct DefaultChildAccessor { + const ABT& operator()(const ABT& node) const { + return node.cast<NodeType>()->getChild(); + } + + ABT& operator()(ABT& node) const { + return node.cast<NodeType>()->getChild(); + } +}; + +template <class NodeType> +struct LeftChildAccessor { + const ABT& operator()(const ABT& node) const { + return node.cast<NodeType>()->getLeftChild(); + } + + ABT& operator()(ABT& node) const { + return node.cast<NodeType>()->getLeftChild(); + } +}; + +template <class NodeType> +struct RightChildAccessor { + const ABT& operator()(const ABT& node) const { + return node.cast<NodeType>()->getRightChild(); + } + + ABT& operator()(ABT& node) const { + return node.cast<NodeType>()->getRightChild(); + } +}; + +template <class AboveType, + class BelowType, + template <class> class BelowChildAccessor = DefaultChildAccessor> +ReorderDependencies computeDependencies(ABT::reference_type aboveNodeRef, + ABT::reference_type belowNodeRef, + RewriteContext& ctx) { + // Get variables from above node and check if they are bound at below node, or at below node's + // child. + const auto aboveNodeVarNames = collectVariableReferences(aboveNodeRef); + + ABT belowNode = belowNodeRef; + VariableEnvironment env = VariableEnvironment::build(belowNode, &ctx.getMemo()); + const DefinitionsMap belowNodeDefs = env.hasDefinitions(belowNode.ref()) + ? env.getDefinitions(belowNode.ref()) + : DefinitionsMap{}; + ABT::reference_type belowChild = BelowChildAccessor<BelowType>()(belowNode).ref(); + const DefinitionsMap belowChildNodeDefs = + env.hasDefinitions(belowChild) ? env.getDefinitions(belowChild) : DefinitionsMap{}; + + ReorderDependencies dependencies; + for (const std::string& varName : aboveNodeVarNames) { + auto it = belowNodeDefs.find(varName); + // Variable is exclusively defined in the below node. + const bool refersToNode = it != belowNodeDefs.cend() && it->second.definedBy == belowNode; + // Variable is defined in the belowNode's child subtree. + const bool refersToChild = belowChildNodeDefs.find(varName) != belowChildNodeDefs.cend(); + + if (refersToNode) { + if (refersToChild) { + dependencies._hasNodeAndChildRef = true; + } else { + dependencies._hasNodeRef = true; + } + } else if (refersToChild) { + dependencies._hasChildRef = true; + } else { + // Lambda variable. Ignore. + } + } + + return dependencies; +} + +static ABT createEmptyValueScanNode(const RewriteContext& ctx) { + using namespace properties; + + const ProjectionNameSet& projNameSet = + getPropertyConst<ProjectionAvailability>(ctx.getAboveLogicalProps()).getProjections(); + ProjectionNameVector projNameVector; + projNameVector.insert(projNameVector.begin(), projNameSet.cbegin(), projNameSet.cend()); + return make<ValueScanNode>(std::move(projNameVector)); +} + +static void addEmptyValueScanNode(RewriteContext& ctx) { + ABT newNode = createEmptyValueScanNode(ctx); + ctx.addNode(newNode, true /*substitute*/); +} + +static void defaultPropagateEmptyValueScanNode(const ABT& n, RewriteContext& ctx) { + if (n.cast<ValueScanNode>()->getArraySize() == 0) { + addEmptyValueScanNode(ctx); + } +} + +template <class AboveType, + class BelowType, + template <class> class AboveChildAccessor = DefaultChildAccessor, + template <class> class BelowChildAccessor = DefaultChildAccessor, + bool substitute = true> +void defaultReorder(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) { + ABT newParent = belowNode; + ABT newChild = aboveNode; + + std::swap(BelowChildAccessor<BelowType>()(newParent), + AboveChildAccessor<AboveType>()(newChild)); + BelowChildAccessor<BelowType>()(newParent) = std::move(newChild); + + ctx.addNode(newParent, substitute); +} + +template <class AboveType, class BelowType> +void defaultReorderWithDependenceCheck(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) { + const ReorderDependencies dependencies = + computeDependencies<AboveType, BelowType>(aboveNode, belowNode, ctx); + if (dependencies._hasNodeRef) { + // Above node refers to a variable bound by below node. + return; + } + + defaultReorder<AboveType, BelowType>(aboveNode, belowNode, ctx); +} + +template <class AboveType, class BelowType> +struct SubstituteReorder { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + defaultReorderWithDependenceCheck<AboveType, BelowType>(aboveNode, belowNode, ctx); + } +}; + +template <> +struct SubstituteReorder<FilterNode, FilterNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + defaultReorder<FilterNode, FilterNode>(aboveNode, belowNode, ctx); + } +}; + +template <> +struct SubstituteReorder<FilterNode, UnionNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + ABT newParent = belowNode; + + for (auto& childOfChild : newParent.cast<UnionNode>()->nodes()) { + ABT aboveCopy = aboveNode; + std::swap(aboveCopy.cast<FilterNode>()->getChild(), childOfChild); + std::swap(childOfChild, aboveCopy); + } + + ctx.addNode(newParent, true /*substitute*/); + } +}; + +template <class AboveType> +void unwindBelowReorder(ABT::reference_type aboveNode, + ABT::reference_type unwindNode, + RewriteContext& ctx) { + const ReorderDependencies dependencies = + computeDependencies<AboveType, UnwindNode>(aboveNode, unwindNode, ctx); + if (dependencies._hasNodeRef || dependencies._hasNodeAndChildRef) { + // Above node refers to projection being unwound. Reject rewrite. + return; + } + + defaultReorder<AboveType, UnwindNode>(aboveNode, unwindNode, ctx); +} + +template <> +struct SubstituteReorder<FilterNode, UnwindNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + unwindBelowReorder<FilterNode>(aboveNode, belowNode, ctx); + } +}; + +template <> +struct SubstituteReorder<EvaluationNode, UnwindNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + unwindBelowReorder<EvaluationNode>(aboveNode, belowNode, ctx); + } +}; + +template <> +struct SubstituteReorder<UnwindNode, CollationNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + const ProjectionNameSet& collationProjections = + belowNode.cast<CollationNode>()->getProperty().getAffectedProjectionNames(); + if (collationProjections.find(aboveNode.cast<UnwindNode>()->getProjectionName()) != + collationProjections.cend()) { + // A projection being affected by the collation is being unwound. Reject rewrite. + return; + } + + defaultReorder<UnwindNode, CollationNode>(aboveNode, belowNode, ctx); + } +}; + +template <> +struct SubstituteReorder<FilterNode, ValueScanNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + defaultPropagateEmptyValueScanNode(belowNode, ctx); + } +}; + +template <> +struct SubstituteReorder<EvaluationNode, ValueScanNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + defaultPropagateEmptyValueScanNode(belowNode, ctx); + } +}; + +template <> +struct SubstituteReorder<SargableNode, ValueScanNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + defaultPropagateEmptyValueScanNode(belowNode, ctx); + } +}; + +template <> +struct SubstituteReorder<CollationNode, ValueScanNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + defaultPropagateEmptyValueScanNode(belowNode, ctx); + } +}; + +template <> +struct SubstituteReorder<LimitSkipNode, ValueScanNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + defaultPropagateEmptyValueScanNode(belowNode, ctx); + } +}; + +template <> +struct SubstituteReorder<ExchangeNode, ValueScanNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + defaultPropagateEmptyValueScanNode(belowNode, ctx); + } +}; + +template <class AboveType, class BelowType> +struct SubstituteMerge { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) = delete; +}; + +template <> +struct SubstituteMerge<CollationNode, CollationNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + ABT newRoot = aboveNode; + // Retain above property. + newRoot.cast<CollationNode>()->getChild() = belowNode.cast<CollationNode>()->getChild(); + + ctx.addNode(newRoot, true /*substitute*/); + } +}; + +template <> +struct SubstituteMerge<LimitSkipNode, LimitSkipNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + using namespace properties; + + ABT newRoot = aboveNode; + LimitSkipNode& aboveCollationNode = *newRoot.cast<LimitSkipNode>(); + const LimitSkipNode& belowCollationNode = *belowNode.cast<LimitSkipNode>(); + + aboveCollationNode.getChild() = belowCollationNode.getChild(); + combineLimitSkipProperties(aboveCollationNode.getProperty(), + belowCollationNode.getProperty()); + + ctx.addNode(newRoot, true /*substitute*/); + } +}; + +static boost::optional<ABT> mergeSargableNodes( + const properties::IndexingAvailability& indexingAvailability, + const SargableNode& aboveNode, + const SargableNode& belowNode, + RewriteContext& ctx) { + if (indexingAvailability.getScanGroupId() != + belowNode.getChild().cast<MemoLogicalDelegatorNode>()->getGroupId()) { + // Do not merge if child is not another Sargable node, or the child's child is not a + // ScanNode. + return {}; + } + + PartialSchemaRequirements mergedReqs = belowNode.getReqMap(); + ProjectionRenames projectionRenames; + if (!intersectPartialSchemaReq(mergedReqs, aboveNode.getReqMap(), projectionRenames)) { + return {}; + } + if (mergedReqs.size() > LogicalRewriter::kMaxPartialSchemaReqCount) { + return {}; + } + + const ScanDefinition& scanDef = + ctx.getMetadata()._scanDefs.at(indexingAvailability.getScanDefName()); + bool hasEmptyInterval = false; + auto candidateIndexMap = computeCandidateIndexMap(ctx.getPrefixId(), + indexingAvailability.getScanProjection(), + mergedReqs, + scanDef, + hasEmptyInterval); + + if (hasEmptyInterval) { + return createEmptyValueScanNode(ctx); + } + + ABT result = make<SargableNode>(std::move(mergedReqs), + std::move(candidateIndexMap), + IndexReqTarget::Complete, + belowNode.getChild()); + applyProjectionRenames(std::move(projectionRenames), result); + return result; +} + +template <> +struct SubstituteMerge<SargableNode, SargableNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + using namespace properties; + const auto& result = + mergeSargableNodes(getPropertyConst<IndexingAvailability>(ctx.getAboveLogicalProps()), + *aboveNode.cast<SargableNode>(), + *belowNode.cast<SargableNode>(), + ctx); + if (result) { + ctx.addNode(*result, true /*substitute*/); + } + } +}; + +template <class Type> +struct SubstituteConvert { + void operator()(ABT::reference_type nodeRef, RewriteContext& ctx) = delete; +}; + +template <> +struct SubstituteConvert<LimitSkipNode> { + void operator()(ABT::reference_type node, RewriteContext& ctx) { + if (node.cast<LimitSkipNode>()->getProperty().getLimit() == 0) { + addEmptyValueScanNode(ctx); + } + } +}; + +static void addElemMatchAndSargableNode(const ABT& node, ABT sargableNode, RewriteContext& ctx) { + ABT newNode = node; + newNode.cast<FilterNode>()->getChild() = std::move(sargableNode); + ctx.addNode(newNode, false /*substitute*/, true /*addExistingNodeWithNewChild*/); +} + +void convertFilterToSargableNode(ABT::reference_type node, + const FilterNode& filterNode, + RewriteContext& ctx) { + using namespace properties; + + const LogicalProps& props = ctx.getAboveLogicalProps(); + if (!hasProperty<IndexingAvailability>(props)) { + // Can only convert to sargable node if we have indexing availability. + return; + } + + const auto& indexingAvailability = getPropertyConst<IndexingAvailability>(props); + const ScanDefinition& scanDef = + ctx.getMetadata()._scanDefs.at(indexingAvailability.getScanDefName()); + if (!scanDef.exists()) { + // Do not attempt to optimize for non-existing collections. + return; + } + + PartialSchemaReqConversion conversion = convertExprToPartialSchemaReq(filterNode.getFilter()); + if (!conversion._success) { + return; + } + if (conversion._hasEmptyInterval) { + addEmptyValueScanNode(ctx); + return; + } + + for (const auto& entry : conversion._reqMap) { + uassert(6624111, + "Filter partial schema requirement must contain a variable name.", + !entry.first._projectionName.empty()); + uassert(6624112, + "Filter partial schema requirement cannot bind.", + !entry.second.hasBoundProjectionName()); + uassert(6624113, + "Filter partial schema requirement must have a range.", + !isIntervalReqFullyOpenDNF(entry.second.getIntervals())); + } + + bool hasEmptyInterval = false; + auto candidateIndexMap = computeCandidateIndexMap(ctx.getPrefixId(), + indexingAvailability.getScanProjection(), + conversion._reqMap, + scanDef, + hasEmptyInterval); + + if (hasEmptyInterval) { + addEmptyValueScanNode(ctx); + } else { + ABT sargableNode = make<SargableNode>(std::move(conversion._reqMap), + std::move(candidateIndexMap), + IndexReqTarget::Complete, + filterNode.getChild()); + + ctx.addNode(sargableNode, true /*substitute*/); + } +} + +static ABT appendFieldPath(const FieldPathType& fieldPath, ABT input) { + for (size_t index = fieldPath.size(); index-- > 0;) { + input = make<PathGet>(fieldPath.at(index), std::move(input)); + } + return input; +} + +template <> +struct SubstituteConvert<FilterNode> { + void operator()(ABT::reference_type node, RewriteContext& ctx) { + const FilterNode& filterNode = *node.cast<FilterNode>(); + + // Sub-rewrite: attempt to de-compose filter. If we have a path with a prefix of PathGet's + // followed by a PathComposeM, then split into two filter nodes at the composition and + // retain the prefix for each. + // TODO: consider using a standalone rewrite. + if (auto evalFilter = filterNode.getFilter().cast<EvalFilter>(); evalFilter != nullptr) { + ABT::reference_type pathRef = evalFilter->getPath().ref(); + FieldPathType fieldPath; + for (;;) { + if (auto newPath = pathRef.cast<PathGet>(); newPath != nullptr) { + fieldPath.push_back(newPath->name()); + pathRef = newPath->getPath().ref(); + } else { + break; + } + } + + if (auto composition = pathRef.cast<PathComposeM>(); composition != nullptr) { + // Remove the path composition and insert two filter nodes. + ABT filterNode1 = make<FilterNode>( + make<EvalFilter>(appendFieldPath(fieldPath, composition->getPath1()), + evalFilter->getInput()), + filterNode.getChild()); + ABT filterNode2 = make<FilterNode>( + make<EvalFilter>(appendFieldPath(fieldPath, composition->getPath2()), + evalFilter->getInput()), + std::move(filterNode1)); + + ctx.addNode(filterNode2, true /*substitute*/); + return; + } + } + + convertFilterToSargableNode(node, filterNode, ctx); + } +}; + +template <> +struct SubstituteConvert<EvaluationNode> { + void operator()(ABT::reference_type node, RewriteContext& ctx) { + using namespace properties; + + const LogicalProps props = ctx.getAboveLogicalProps(); + if (!hasProperty<IndexingAvailability>(props)) { + // Can only convert to sargable node if we have indexing availability. + return; + } + + const auto& indexingAvailability = getPropertyConst<IndexingAvailability>(props); + const ProjectionName& scanProjName = indexingAvailability.getScanProjection(); + + const ScanDefinition& scanDef = + ctx.getMetadata()._scanDefs.at(indexingAvailability.getScanDefName()); + if (!scanDef.exists()) { + // Do not attempt to optimize for non-existing collections. + return; + } + + const EvaluationNode& evalNode = *node.cast<EvaluationNode>(); + + // Sub-rewrite: attempt to convert Keep to a chain of individual evaluations. + // TODO: consider using a standalone rewrite. + if (auto evalPathPtr = evalNode.getProjection().cast<EvalPath>(); evalPathPtr != nullptr) { + if (auto inputPtr = evalPathPtr->getInput().cast<Variable>(); + inputPtr != nullptr && inputPtr->name() == scanProjName) { + if (auto pathKeepPtr = evalPathPtr->getPath().cast<PathKeep>(); + pathKeepPtr != nullptr) { + // Optimization. If we are retaining fields on the root level, generate + // EvalNodes with the intention of converting later to a SargableNode after + // reordering, in order to be able to cover the fields using a physical scan or + // index. + + ABT result = evalNode.getChild(); + ABT keepPath = make<PathIdentity>(); + + std::set<std::string> orderedSet; + for (const std::string& fieldName : pathKeepPtr->getNames()) { + orderedSet.insert(fieldName); + } + for (const std::string& fieldName : orderedSet) { + ProjectionName projName = ctx.getPrefixId().getNextId("fieldProj"); + result = make<EvaluationNode>( + projName, + make<EvalPath>(make<PathGet>(fieldName, make<PathIdentity>()), + evalPathPtr->getInput()), + std::move(result)); + + maybeComposePath(keepPath, + make<PathField>(fieldName, + make<PathConstant>( + make<Variable>(std::move(projName))))); + } + + result = make<EvaluationNode>( + evalNode.getProjectionName(), + make<EvalPath>(std::move(keepPath), Constant::emptyObject()), + std::move(result)); + ctx.addNode(result, true /*substitute*/); + return; + } + } + } + + // We still want to extract sargable nodes from EvalNode to use for PhysicalScans. + PartialSchemaReqConversion conversion = + convertExprToPartialSchemaReq(evalNode.getProjection()); + if (!conversion._success || conversion._reqMap.size() != 1) { + // For evaluation nodes we expect to create a single entry. + return; + } + if (conversion._hasEmptyInterval) { + addEmptyValueScanNode(ctx); + return; + } + + for (auto& entry : conversion._reqMap) { + PartialSchemaRequirement& req = entry.second; + req.setBoundProjectionName(evalNode.getProjectionName()); + + uassert(6624114, + "Eval partial schema requirement must contain a variable name.", + !entry.first._projectionName.empty()); + uassert(6624115, + "Eval partial schema requirement cannot have a range", + isIntervalReqFullyOpenDNF(req.getIntervals())); + } + + bool hasEmptyInterval = false; + auto candidateIndexMap = computeCandidateIndexMap( + ctx.getPrefixId(), scanProjName, conversion._reqMap, scanDef, hasEmptyInterval); + + if (hasEmptyInterval) { + addEmptyValueScanNode(ctx); + } else { + ABT newNode = make<SargableNode>(std::move(conversion._reqMap), + std::move(candidateIndexMap), + IndexReqTarget::Complete, + evalNode.getChild()); + ctx.addNode(newNode, true /*substitute*/); + } + } +}; + +static void lowerSargableNode(const SargableNode& node, RewriteContext& ctx) { + ABT n = node.getChild(); + const auto reqMap = node.getReqMap(); + for (const auto& [key, req] : reqMap) { + lowerPartialSchemaRequirement(key, req, n); + } + ctx.addNode(n, true /*clear*/); +} + +template <class Type> +struct ExploreConvert { + void operator()(ABT::reference_type nodeRef, RewriteContext& ctx) = delete; +}; + +template <> +struct ExploreConvert<SargableNode> { + void operator()(ABT::reference_type node, RewriteContext& ctx) { + using namespace properties; + + const SargableNode& sargableNode = *node.cast<SargableNode>(); + const IndexReqTarget target = sargableNode.getTarget(); + if (target == IndexReqTarget::Seek) { + return; + } + + const LogicalProps& props = ctx.getAboveLogicalProps(); + const auto& indexingAvailability = getPropertyConst<IndexingAvailability>(props); + const GroupIdType scanGroupId = indexingAvailability.getScanGroupId(); + if (sargableNode.getChild().cast<MemoLogicalDelegatorNode>()->getGroupId() != scanGroupId) { + lowerSargableNode(sargableNode, ctx); + return; + } + + const std::string& scanDefName = indexingAvailability.getScanDefName(); + const ScanDefinition& scanDef = ctx.getMetadata()._scanDefs.at(scanDefName); + const size_t indexCount = scanDef.getIndexDefs().size(); + if (indexCount == 0) { + // Do not insert RIDIntersect if we do not have indexes available. + return; + } + + const auto aboveNodeId = ctx.getAboveNodeId(); + auto& sargableSplitCountMap = ctx.getSargableSplitCountMap(); + const size_t splitCount = sargableSplitCountMap[aboveNodeId]; + if ((1ull << splitCount) > + roundUpToNextPow2(indexCount, LogicalRewriter::kMaxSargableNodeSplitCount)) { + // We cannot split this node further. + return; + } + + const ProjectionName& scanProjectionName = indexingAvailability.getScanProjection(); + if (collectVariableReferences(node) != VariableNameSetType{scanProjectionName}) { + // Rewrite not applicable if we refer projections other than the scan projection. + return; + } + + const bool isIndex = target == IndexReqTarget::Index; + + const auto& indexFieldPrefixMap = ctx.getIndexFieldPrefixMap(); + const auto indexFieldPrefixMapIt = + isIndex ? indexFieldPrefixMap.cend() : indexFieldPrefixMap.find(scanDefName); + const bool indexFieldMapHasScanDef = indexFieldPrefixMapIt != indexFieldPrefixMap.cend(); + + const auto& reqMap = sargableNode.getReqMap(); + const size_t reqSize = reqMap.size(); + const size_t highMask = isIndex ? (1ull << (reqSize - 1)) : (1ull << reqSize); + for (size_t mask = 1; mask < highMask; mask++) { + PartialSchemaRequirements leftReqs; + PartialSchemaRequirements rightReqs; + bool hasFieldCoverage = true; + bool hasLeftIntervals = false; + bool hasRightIntervals = false; + + size_t index = 0; + for (const auto& [key, req] : reqMap) { + const bool fullyOpenInterval = isIntervalReqFullyOpenDNF(req.getIntervals()); + + if (((1ull << index) & mask) != 0) { + leftReqs.emplace(key, req); + + if (!fullyOpenInterval) { + hasLeftIntervals = true; + } + if (indexFieldMapHasScanDef) { + if (auto pathPtr = key._path.cast<PathGet>(); pathPtr != nullptr && + indexFieldPrefixMapIt->second.count(pathPtr->name()) == 0) { + // We have found a left requirement which cannot be covered with an + // index. + hasFieldCoverage = false; + break; + } + } + } else { + rightReqs.emplace(key, req); + + if (!fullyOpenInterval) { + hasRightIntervals = true; + } + } + index++; + } + + if (isIndex && (!hasLeftIntervals || !hasRightIntervals)) { + // Reject. Must have at least one proper interval on either side. + continue; + } + if (!hasFieldCoverage) { + // Reject rewrite. No suitable indexes. + continue; + } + + bool hasEmptyLeftInterval = false; + auto leftCandidateIndexMap = computeCandidateIndexMap( + ctx.getPrefixId(), scanProjectionName, leftReqs, scanDef, hasEmptyLeftInterval); + if (isIndex && leftCandidateIndexMap.empty()) { + // Reject rewrite. + continue; + } + + bool hasEmptyRightInterval = false; + auto rightCandidateIndexMap = computeCandidateIndexMap( + ctx.getPrefixId(), scanProjectionName, rightReqs, scanDef, hasEmptyRightInterval); + if (isIndex && rightCandidateIndexMap.empty()) { + // With empty candidate map, reject only if we cannot implement as Seek. + continue; + } + uassert(6624116, + "Empty intervals should already be rewritten to empty ValueScan nodes", + !hasEmptyLeftInterval && !hasEmptyRightInterval); + + ABT scanDelegator = make<MemoLogicalDelegatorNode>(scanGroupId); + ABT leftChild = make<SargableNode>(std::move(leftReqs), + std::move(leftCandidateIndexMap), + IndexReqTarget::Index, + scanDelegator); + ABT rightChild = rightReqs.empty() + ? scanDelegator + : make<SargableNode>(std::move(rightReqs), + std::move(rightCandidateIndexMap), + isIndex ? IndexReqTarget::Index : IndexReqTarget::Seek, + scanDelegator); + + ABT newRoot = make<RIDIntersectNode>(scanProjectionName, + hasLeftIntervals, + hasRightIntervals, + std::move(leftChild), + std::move(rightChild)); + + const auto& result = ctx.addNode(newRoot, false /*substitute*/); + for (const MemoLogicalNodeId nodeId : result.second) { + if (!(nodeId == aboveNodeId)) { + sargableSplitCountMap[nodeId] = splitCount + 1; + } + } + } + } +}; + +template <> +struct ExploreConvert<GroupByNode> { + void operator()(ABT::reference_type node, RewriteContext& ctx) { + const GroupByNode& groupByNode = *node.cast<GroupByNode>(); + if (groupByNode.getType() != GroupNodeType::Complete) { + return; + } + + ProjectionNameVector preaggVariableNames; + ABTVector preaggExpressions; + + const ABTVector& aggExpressions = groupByNode.getAggregationExpressions(); + for (const ABT& expr : aggExpressions) { + const FunctionCall* aggPtr = expr.cast<FunctionCall>(); + if (aggPtr == nullptr) { + return; + } + + // In order to be able to pre-aggregate for now we expect a simple aggregate like + // SUM(x). + const auto& aggFnName = aggPtr->name(); + if (aggFnName != "$sum" && aggFnName != "$min" && aggFnName != "$max") { + // TODO: allow more functions. + return; + } + uassert(6624117, "Invalid argument count", aggPtr->nodes().size() == 1); + + preaggVariableNames.push_back(ctx.getPrefixId().getNextId("preagg")); + preaggExpressions.emplace_back( + make<FunctionCall>(aggFnName, makeSeq(make<Variable>(preaggVariableNames.back())))); + } + + ABT localGroupBy = make<GroupByNode>(groupByNode.getGroupByProjectionNames(), + std::move(preaggVariableNames), + aggExpressions, + GroupNodeType::Local, + groupByNode.getChild()); + + ABT newRoot = make<GroupByNode>(groupByNode.getGroupByProjectionNames(), + groupByNode.getAggregationProjectionNames(), + std::move(preaggExpressions), + GroupNodeType::Global, + std::move(localGroupBy)); + + ctx.addNode(newRoot, false /*substitute*/); + } +}; + +template <class AboveType, class BelowType> +struct ExploreReorder { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const = delete; +}; + +template <class AboveNode> +void reorderAgainstRIDIntersectNode(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) { + const ReorderDependencies leftDeps = + computeDependencies<AboveNode, RIDIntersectNode, LeftChildAccessor>( + aboveNode, belowNode, ctx); + uassert(6624118, "RIDIntersect cannot bind projections", !leftDeps._hasNodeRef); + const bool hasLeftRef = leftDeps._hasChildRef; + + const ReorderDependencies rightDeps = + computeDependencies<AboveNode, RIDIntersectNode, RightChildAccessor>( + aboveNode, belowNode, ctx); + uassert(6624119, "RIDIntersect cannot bind projections", !rightDeps._hasNodeRef); + const bool hasRightRef = rightDeps._hasChildRef; + + if (hasLeftRef == hasRightRef) { + // Both left and right reorderings available means that we refer to both left and right + // sides. + return; + } + + const RIDIntersectNode& node = *belowNode.cast<RIDIntersectNode>(); + if (node.hasLeftIntervals() && hasLeftRef) { + defaultReorder<AboveNode, + RIDIntersectNode, + DefaultChildAccessor, + LeftChildAccessor, + false /*substitute*/>(aboveNode, belowNode, ctx); + } + if (node.hasRightIntervals() && hasRightRef) { + defaultReorder<AboveNode, + RIDIntersectNode, + DefaultChildAccessor, + RightChildAccessor, + false /*substitute*/>(aboveNode, belowNode, ctx); + } +}; + +template <> +struct ExploreReorder<FilterNode, RIDIntersectNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + reorderAgainstRIDIntersectNode<FilterNode>(aboveNode, belowNode, ctx); + } +}; + +template <> +struct ExploreReorder<EvaluationNode, RIDIntersectNode> { + void operator()(ABT::reference_type aboveNode, + ABT::reference_type belowNode, + RewriteContext& ctx) const { + reorderAgainstRIDIntersectNode<EvaluationNode>(aboveNode, belowNode, ctx); + } +}; + +void LogicalRewriter::registerRewrite(const LogicalRewriteType rewriteType, RewriteFn fn) { + if (_activeRewriteSet.find(rewriteType) != _activeRewriteSet.cend()) { + _rewriteMap.emplace(rewriteType, fn); + } +} + +void LogicalRewriter::initializeRewrites() { + registerRewrite( + LogicalRewriteType::FilterEvaluationReorder, + &LogicalRewriter::bindAboveBelow<FilterNode, EvaluationNode, SubstituteReorder>); + registerRewrite(LogicalRewriteType::FilterCollationReorder, + &LogicalRewriter::bindAboveBelow<FilterNode, CollationNode, SubstituteReorder>); + registerRewrite( + LogicalRewriteType::EvaluationCollationReorder, + &LogicalRewriter::bindAboveBelow<EvaluationNode, CollationNode, SubstituteReorder>); + registerRewrite( + LogicalRewriteType::EvaluationLimitSkipReorder, + &LogicalRewriter::bindAboveBelow<EvaluationNode, LimitSkipNode, SubstituteReorder>); + registerRewrite(LogicalRewriteType::FilterGroupByReorder, + &LogicalRewriter::bindAboveBelow<FilterNode, GroupByNode, SubstituteReorder>); + registerRewrite( + LogicalRewriteType::GroupCollationReorder, + &LogicalRewriter::bindAboveBelow<GroupByNode, CollationNode, SubstituteReorder>); + registerRewrite(LogicalRewriteType::FilterUnwindReorder, + &LogicalRewriter::bindAboveBelow<FilterNode, UnwindNode, SubstituteReorder>); + registerRewrite( + LogicalRewriteType::EvaluationUnwindReorder, + &LogicalRewriter::bindAboveBelow<EvaluationNode, UnwindNode, SubstituteReorder>); + registerRewrite(LogicalRewriteType::UnwindCollationReorder, + &LogicalRewriter::bindAboveBelow<UnwindNode, CollationNode, SubstituteReorder>); + + registerRewrite(LogicalRewriteType::FilterExchangeReorder, + &LogicalRewriter::bindAboveBelow<FilterNode, ExchangeNode, SubstituteReorder>); + registerRewrite( + LogicalRewriteType::ExchangeEvaluationReorder, + &LogicalRewriter::bindAboveBelow<ExchangeNode, EvaluationNode, SubstituteReorder>); + + registerRewrite(LogicalRewriteType::FilterUnionReorder, + &LogicalRewriter::bindAboveBelow<FilterNode, UnionNode, SubstituteReorder>); + + registerRewrite( + LogicalRewriteType::CollationMerge, + &LogicalRewriter::bindAboveBelow<CollationNode, CollationNode, SubstituteMerge>); + registerRewrite( + LogicalRewriteType::LimitSkipMerge, + &LogicalRewriter::bindAboveBelow<LimitSkipNode, LimitSkipNode, SubstituteMerge>); + + registerRewrite(LogicalRewriteType::SargableFilterReorder, + &LogicalRewriter::bindAboveBelow<SargableNode, FilterNode, SubstituteReorder>); + registerRewrite( + LogicalRewriteType::SargableEvaluationReorder, + &LogicalRewriter::bindAboveBelow<SargableNode, EvaluationNode, SubstituteReorder>); + + registerRewrite(LogicalRewriteType::LimitSkipSubstitute, + &LogicalRewriter::bindSingleNode<LimitSkipNode, SubstituteConvert>); + + registerRewrite(LogicalRewriteType::SargableMerge, + &LogicalRewriter::bindAboveBelow<SargableNode, SargableNode, SubstituteMerge>); + registerRewrite(LogicalRewriteType::FilterSubstitute, + &LogicalRewriter::bindSingleNode<FilterNode, SubstituteConvert>); + registerRewrite(LogicalRewriteType::EvaluationSubstitute, + &LogicalRewriter::bindSingleNode<EvaluationNode, SubstituteConvert>); + + registerRewrite(LogicalRewriteType::FilterValueScanPropagate, + &LogicalRewriter::bindAboveBelow<FilterNode, ValueScanNode, SubstituteReorder>); + registerRewrite( + LogicalRewriteType::EvaluationValueScanPropagate, + &LogicalRewriter::bindAboveBelow<EvaluationNode, ValueScanNode, SubstituteReorder>); + registerRewrite( + LogicalRewriteType::SargableValueScanPropagate, + &LogicalRewriter::bindAboveBelow<SargableNode, ValueScanNode, SubstituteReorder>); + registerRewrite( + LogicalRewriteType::CollationValueScanPropagate, + &LogicalRewriter::bindAboveBelow<CollationNode, ValueScanNode, SubstituteReorder>); + registerRewrite( + LogicalRewriteType::LimitSkipValueScanPropagate, + &LogicalRewriter::bindAboveBelow<LimitSkipNode, ValueScanNode, SubstituteReorder>); + registerRewrite( + LogicalRewriteType::ExchangeValueScanPropagate, + &LogicalRewriter::bindAboveBelow<ExchangeNode, ValueScanNode, SubstituteReorder>); + + registerRewrite(LogicalRewriteType::GroupByExplore, + &LogicalRewriter::bindSingleNode<GroupByNode, ExploreConvert>); + registerRewrite(LogicalRewriteType::SargableSplit, + &LogicalRewriter::bindSingleNode<SargableNode, ExploreConvert>); + + registerRewrite(LogicalRewriteType::FilterRIDIntersectReorder, + &LogicalRewriter::bindAboveBelow<FilterNode, RIDIntersectNode, ExploreReorder>); + registerRewrite( + LogicalRewriteType::EvaluationRIDIntersectReorder, + &LogicalRewriter::bindAboveBelow<EvaluationNode, RIDIntersectNode, ExploreReorder>); +} + +bool LogicalRewriter::rewriteToFixPoint() { + int iterationCount = 0; + + while (!_groupsPending.empty()) { + iterationCount++; + if (_memo.getDebugInfo().exceedsIterationLimit(iterationCount)) { + // Iteration limit exceeded. + return false; + } + + const GroupIdType groupId = *_groupsPending.begin(); + rewriteGroup(groupId); + _groupsPending.erase(groupId); + } + + return true; +} + +void LogicalRewriter::rewriteGroup(const GroupIdType groupId) { + auto& queue = _memo.getGroup(groupId)._logicalRewriteQueue; + while (!queue.empty()) { + LogicalRewriteEntry rewriteEntry = std::move(*queue.top()); + // TODO: check if rewriteEntry is different than previous (remove duplicates). + queue.pop(); + + _rewriteMap.at(rewriteEntry._type)(this, rewriteEntry._nodeId); + } +} + +template <class AboveType, class BelowType, template <class, class> class R> +void LogicalRewriter::bindAboveBelow(const MemoLogicalNodeId nodeMemoId) { + // Get a reference to the node instead of the node itself. + // Rewrites insert into the memo and can move it. + ABT::reference_type node = _memo.getNode(nodeMemoId); + const GroupIdType currentGroupId = nodeMemoId._groupId; + + if (node.is<AboveType>()) { + // Try to bind as parent. + const GroupIdType targetGroupId = node.cast<AboveType>() + ->getChild() + .template cast<MemoLogicalDelegatorNode>() + ->getGroupId(); + + for (size_t i = 0; i < _memo.getGroup(targetGroupId)._logicalNodes.size(); i++) { + const MemoLogicalNodeId targetNodeId{targetGroupId, i}; + auto targetNode = _memo.getNode(targetNodeId); + if (targetNode.is<BelowType>()) { + RewriteContext ctx(*this, nodeMemoId, targetNodeId); + R<AboveType, BelowType>()(node, targetNode, ctx); + if (ctx.hasSubstituted()) { + return; + } + } + } + } + + if (node.is<BelowType>()) { + // Try to bind as child. + NodeIdSet usageNodeIdSet; + { + const auto& inputGroupsToNodeId = _memo.getInputGroupsToNodeIdMap(); + auto it = inputGroupsToNodeId.find({currentGroupId}); + if (it != inputGroupsToNodeId.cend()) { + usageNodeIdSet = it->second; + } + } + + for (const MemoLogicalNodeId& parentNodeId : usageNodeIdSet) { + auto targetNode = _memo.getNode(parentNodeId); + if (targetNode.is<AboveType>()) { + uassert(6624047, + "Parent child groupId mismatch (usage map index incorrect?)", + targetNode.cast<AboveType>() + ->getChild() + .template cast<MemoLogicalDelegatorNode>() + ->getGroupId() == currentGroupId); + + RewriteContext ctx(*this, parentNodeId, nodeMemoId); + R<AboveType, BelowType>()(targetNode, node, ctx); + if (ctx.hasSubstituted()) { + return; + } + } + } + } +} + +template <class Type, template <class> class R> +void LogicalRewriter::bindSingleNode(const MemoLogicalNodeId nodeMemoId) { + // Get a reference to the node instead of the node itself. + // Rewrites insert into the memo and can move it. + ABT::reference_type node = _memo.getNode(nodeMemoId); + if (node.is<Type>()) { + RewriteContext ctx(*this, nodeMemoId); + R<Type>()(node, ctx); + } +} + +const LogicalRewriter::RewriteSet& LogicalRewriter::getExplorationSet() { + return _explorationSet; +} + +const LogicalRewriter::RewriteSet& LogicalRewriter::getSubstitutionSet() { + return _substitutionSet; +} + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/logical_rewriter.h b/src/mongo/db/query/optimizer/cascades/logical_rewriter.h new file mode 100644 index 00000000000..642fdae1842 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/logical_rewriter.h @@ -0,0 +1,132 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <queue> + +#include "mongo/db/query/optimizer/cascades/logical_rewriter_rules.h" +#include "mongo/db/query/optimizer/cascades/memo.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer::cascades { + +class LogicalRewriter { + friend class RewriteContext; + +public: + /** + * Maximum size of PartialSchemaRequirements for a SargableNode. + * This limits the number of splits for index intersection. + */ + static constexpr size_t kMaxPartialSchemaReqCount = 10; + + /* + * How many times are we allowed to split a sargable node to facilitate index intersection. + * Results in at most 2^N index intersections. + */ + static constexpr size_t kMaxSargableNodeSplitCount = 2; + + /** + * Map of rewrite type to rewrite priority + */ + using RewriteSet = opt::unordered_map<LogicalRewriteType, double>; + + LogicalRewriter(Memo& memo, PrefixId& prefixId, RewriteSet rewriteSet); + + LogicalRewriter() = delete; + LogicalRewriter(const LogicalRewriter& other) = delete; + LogicalRewriter(LogicalRewriter&& other) = default; + + GroupIdType addRootNode(const ABT& node); + std::pair<GroupIdType, NodeIdSet> addNode(const ABT& node, + GroupIdType targetGroupId, + bool addExistingNodeWithNewChild); + void clearGroup(GroupIdType groupId); + + /** + * Performs logical rewrites across all groups until a fix point is reached. + * Use this method to perform "standalone" rewrites. + */ + bool rewriteToFixPoint(); + + /** + * Performs rewrites only for a particular group. Use this method to perform rewrites driven by + * top-down optimization. + */ + void rewriteGroup(GroupIdType groupId); + + static const RewriteSet& getExplorationSet(); + static const RewriteSet& getSubstitutionSet(); + +private: + using RewriteFn = + std::function<void(LogicalRewriter* rewriter, const MemoLogicalNodeId nodeId)>; + using RewriteFnMap = opt::unordered_map<LogicalRewriteType, RewriteFn>; + + /** + * Attempts to perform a reordering rewrite specified by the R template argument. + */ + template <class AboveType, class BelowType, template <class, class> class R> + void bindAboveBelow(MemoLogicalNodeId nodeMemoId); + + /** + * Attempts to perform a simple rewrite specified by the R template argument. + */ + template <class Type, template <class> class R> + void bindSingleNode(MemoLogicalNodeId nodeMemoId); + + void registerRewrite(LogicalRewriteType rewriteType, RewriteFn fn); + void initializeRewrites(); + + static RewriteSet _explorationSet; + static RewriteSet _substitutionSet; + + const RewriteSet _activeRewriteSet; + + // For standalone logical rewrite phase, keeps track of which groups still have rewrites + // pending. + std::set<int> _groupsPending; + + // We don't own those: + Memo& _memo; + PrefixId& _prefixId; + + RewriteFnMap _rewriteMap; + + // Contains the set of top-level index fields for a given scanDef. For example "a.b" is encoded + // as "a". This is used to constrain the possible splits of a sargable node. + opt::unordered_map<std::string, opt::unordered_set<FieldNameType>> _indexFieldPrefixMap; + + // Track number of times a SargableNode at a given position in the memo has been split. + opt::unordered_map<MemoLogicalNodeId, size_t, NodeIdHash> _sargableSplitCountMap; +}; + + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/logical_rewriter_rules.h b/src/mongo/db/query/optimizer/cascades/logical_rewriter_rules.h new file mode 100644 index 00000000000..ab562c685c8 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/logical_rewriter_rules.h @@ -0,0 +1,89 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/defs.h" + +namespace mongo::optimizer::cascades { + +#define LOGICALREWRITER_NAMES(F) \ + /* "Linear" reordering rewrites. */ \ + F(FilterEvaluationReorder) \ + F(FilterCollationReorder) \ + F(EvaluationCollationReorder) \ + F(EvaluationLimitSkipReorder) \ + \ + F(FilterGroupByReorder) \ + F(GroupCollationReorder) \ + \ + F(FilterUnwindReorder) \ + F(EvaluationUnwindReorder) \ + F(UnwindCollationReorder) \ + \ + F(FilterExchangeReorder) \ + F(ExchangeEvaluationReorder) \ + \ + F(FilterUnionReorder) \ + \ + /* Merging rewrites. */ \ + F(CollationMerge) \ + F(LimitSkipMerge) \ + F(SargableMerge) \ + \ + /* Local-global optimization for GroupBy */ \ + F(GroupByExplore) \ + \ + F(SargableFilterReorder) \ + F(SargableEvaluationReorder) \ + \ + /* Propagate ValueScan nodes*/ \ + F(FilterValueScanPropagate) \ + F(EvaluationValueScanPropagate) \ + F(SargableValueScanPropagate) \ + F(CollationValueScanPropagate) \ + F(LimitSkipValueScanPropagate) \ + F(ExchangeValueScanPropagate) \ + \ + F(LimitSkipSubstitute) \ + \ + /* Convert filter and evaluation nodes into sargable nodes */ \ + F(FilterSubstitute) \ + F(EvaluationSubstitute) \ + F(SargableSplit) \ + F(FilterRIDIntersectReorder) \ + F(EvaluationRIDIntersectReorder) + +MAKE_PRINTABLE_ENUM(LogicalRewriteType, LOGICALREWRITER_NAMES); +MAKE_PRINTABLE_ENUM_STRING_ARRAY(LogicalRewriterTypeEnum, + LogicalRewriterType, + LOGICALREWRITER_NAMES); +#undef LOGICALREWRITER_NAMES + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/memo.cpp b/src/mongo/db/query/optimizer/cascades/memo.cpp new file mode 100644 index 00000000000..d39dc7ccbc5 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/memo.cpp @@ -0,0 +1,794 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <set> + +#include "mongo/db/query/optimizer/cascades/memo.h" +#include "mongo/db/query/optimizer/explain.h" +#include "mongo/db/query/optimizer/utils/abt_hash.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer::cascades { + +size_t MemoNodeRefHash::operator()(const ABT::reference_type& nodeRef) const { + // Compare delegator as well. + return ABTHashGenerator::generate(nodeRef); +} + +bool MemoNodeRefCompare::operator()(const ABT::reference_type& left, + const ABT::reference_type& right) const { + // Deep comparison. + return left.follow() == right.follow(); +} + +ABT::reference_type OrderPreservingABTSet::at(const size_t index) const { + return _vector.at(index).ref(); +} + +std::pair<size_t, bool> OrderPreservingABTSet::emplace_back(ABT node) { + auto [index, found] = find(node.ref()); + if (found) { + return {index, false}; + } + + const size_t id = _vector.size(); + _vector.emplace_back(std::move(node)); + _map.emplace(_vector.back().ref(), id); + return {id, true}; +} + +std::pair<size_t, bool> OrderPreservingABTSet::find(ABT::reference_type node) const { + auto it = _map.find(node); + if (it == _map.end()) { + return {0, false}; + } + + return {it->second, true}; +} + +void OrderPreservingABTSet::clear() { + _map.clear(); + _vector.clear(); +} + +size_t OrderPreservingABTSet::size() const { + return _vector.size(); +} + +const ABTVector& OrderPreservingABTSet::getVector() const { + return _vector; +} + +PhysRewriteEntry::PhysRewriteEntry(const double priority, + ABT node, + std::vector<std::pair<ABT*, properties::PhysProps>> childProps, + NodeCEMap nodeCEMap) + : _priority(priority), + _node(std::move(node)), + _childProps(std::move(childProps)), + _nodeCEMap(std::move(nodeCEMap)) {} + +PhysOptimizationResult::PhysOptimizationResult() + : PhysOptimizationResult(0, {}, CostType::kInfinity) {} + +PhysOptimizationResult::PhysOptimizationResult(size_t index, + properties::PhysProps physProps, + CostType costLimit) + : _index(index), + _physProps(std::move(physProps)), + _costLimit(std::move(costLimit)), + _nodeInfo(), + _lastImplementedNodePos(0), + _queue() {} + +bool PhysOptimizationResult::isOptimized() const { + return _queue.empty(); +} + +void PhysOptimizationResult::raiseCostLimit(CostType costLimit) { + _costLimit = costLimit; + // Allow for re-optimization under the higher cost limit. + _lastImplementedNodePos = 0; +} + +bool PhysRewriteEntryComparator::operator()(const std::unique_ptr<PhysRewriteEntry>& x, + const std::unique_ptr<PhysRewriteEntry>& y) const { + // Lower numerical priority is considered last (and thus de-queued first). + return x->_priority > y->_priority; +} + +static ABT createBinderMap(const properties::LogicalProps& logicalProperties) { + const properties::ProjectionAvailability& projSet = + properties::getPropertyConst<properties::ProjectionAvailability>(logicalProperties); + + ProjectionNameVector projectionVector; + ABTVector expressions; + + ProjectionNameOrderedSet ordered = convertToOrderedSet(projSet.getProjections()); + for (const ProjectionName& projection : ordered) { + projectionVector.push_back(projection); + expressions.emplace_back(make<Source>()); + } + + return make<ExpressionBinder>(std::move(projectionVector), std::move(expressions)); +} + +Group::Group(ProjectionNameSet projections) + : _logicalNodes(), + _logicalProperties( + properties::makeLogicalProps(properties::ProjectionAvailability(std::move(projections)))), + _binder(createBinderMap(_logicalProperties)), + _logicalRewriteQueue(), + _physicalNodes() {} + +const ExpressionBinder& Group::binder() const { + auto pointer = _binder.cast<ExpressionBinder>(); + uassert(6624048, "Invalid binder type", pointer); + + return *pointer; +} + +PhysOptimizationResult& PhysNodes::addOptimizationResult(properties::PhysProps properties, + CostType costLimit) { + const size_t index = _physicalNodes.size(); + _physPropsToPhysNodeMap.emplace(properties, index); + return *_physicalNodes.emplace_back(std::make_unique<PhysOptimizationResult>( + index, std::move(properties), std::move(costLimit))); +} + +const PhysOptimizationResult& PhysNodes::at(const size_t index) const { + return *_physicalNodes.at(index); +} + +PhysOptimizationResult& PhysNodes::at(const size_t index) { + return *_physicalNodes.at(index); +} + +std::pair<size_t, bool> PhysNodes::find(const properties::PhysProps& props) const { + auto it = _physPropsToPhysNodeMap.find(props); + if (it == _physPropsToPhysNodeMap.cend()) { + return {0, false}; + } + return {it->second, true}; +} + +const PhysNodes::PhysNodeVector& PhysNodes::getNodes() const { + return _physicalNodes; +} + +size_t PhysNodes::PhysPropsHasher::operator()(const properties::PhysProps& physProps) const { + return ABTHashGenerator::generateForPhysProps(physProps); +} + +class MemoIntegrator { +public: + explicit MemoIntegrator(Memo& memo, + Memo::NodeTargetGroupMap targetGroupMap, + NodeIdSet& insertedNodeIds, + const bool addExistingNodeWithNewChild) + : _memo(memo), + _insertedNodeIds(insertedNodeIds), + _targetGroupMap(std::move(targetGroupMap)), + _addExistingNodeWithNewChild(addExistingNodeWithNewChild) {} + + /** + * Nodes + */ + void prepare(const ABT& n, const ScanNode& node, const VariableEnvironment& /*env*/) { + // noop + } + + GroupIdType transport(const ABT& n, + const ScanNode& node, + const VariableEnvironment& env, + GroupIdType /*binder*/) { + return addNodes(n, node, n, env, {}); + } + + void prepare(const ABT& n, const ValueScanNode& node, const VariableEnvironment& /*env*/) { + // noop + } + + GroupIdType transport(const ABT& n, + const ValueScanNode& node, + const VariableEnvironment& env, + GroupIdType /*binder*/) { + return addNodes(n, node, n, env, {}); + } + + void prepare(const ABT& n, + const MemoLogicalDelegatorNode& node, + const VariableEnvironment& /*env*/) { + // noop + } + + GroupIdType transport(const ABT& /*n*/, + const MemoLogicalDelegatorNode& node, + const VariableEnvironment& /*env*/) { + return node.getGroupId(); + } + + void prepare(const ABT& n, const FilterNode& node, const VariableEnvironment& /*env*/) { + updateTargetGroupMapUnary(n, node); + } + + GroupIdType transport(const ABT& n, + const FilterNode& node, + const VariableEnvironment& env, + GroupIdType child, + GroupIdType /*binder*/) { + return addNode(n, node, env, child); + } + + void prepare(const ABT& n, const EvaluationNode& node, const VariableEnvironment& /*env*/) { + updateTargetGroupMapUnary(n, node); + } + + GroupIdType transport(const ABT& n, + const EvaluationNode& node, + const VariableEnvironment& env, + GroupIdType child, + GroupIdType /*binder*/) { + return addNode(n, node, env, child); + } + + void prepare(const ABT& n, const SargableNode& node, const VariableEnvironment& /*env*/) { + updateTargetGroupMapUnary(n, node); + } + + GroupIdType transport(const ABT& n, + const SargableNode& node, + const VariableEnvironment& env, + GroupIdType child, + GroupIdType /*binder*/, + GroupIdType /*references*/) { + return addNode(n, node, env, child); + } + + void prepare(const ABT& n, const RIDIntersectNode& node, const VariableEnvironment& /*env*/) { + // noop. + } + + GroupIdType transport(const ABT& n, + const RIDIntersectNode& node, + const VariableEnvironment& env, + GroupIdType leftChild, + GroupIdType rightChild) { + return addNodes(n, node, env, leftChild, rightChild); + } + + void prepare(const ABT& n, const BinaryJoinNode& node, const VariableEnvironment& /*env*/) { + updateTargetGroupMapBinary(n, node); + } + + GroupIdType transport(const ABT& n, + const BinaryJoinNode& node, + const VariableEnvironment& env, + GroupIdType leftChild, + GroupIdType rightChild, + GroupIdType /*filter*/) { + return addNodes(n, node, env, leftChild, rightChild); + } + + void prepare(const ABT& n, const UnionNode& node, const VariableEnvironment& /*env*/) { + updateTargetGroupMapNary(n, node); + } + + GroupIdType transport(const ABT& n, + const UnionNode& node, + const VariableEnvironment& env, + Memo::GroupIdVector children, + GroupIdType /*binder*/, + GroupIdType /*refs*/) { + return addNodes(n, node, env, std::move(children)); + } + + void prepare(const ABT& n, const GroupByNode& node, const VariableEnvironment& /*env*/) { + updateTargetGroupMapUnary(n, node); + } + + GroupIdType transport(const ABT& n, + const GroupByNode& node, + const VariableEnvironment& env, + GroupIdType child, + GroupIdType /*binderAgg*/, + GroupIdType /*refsAgg*/, + GroupIdType /*binderGb*/, + GroupIdType /*refsGb*/) { + return addNode(n, node, env, child); + } + + void prepare(const ABT& n, const UnwindNode& node, const VariableEnvironment& /*env*/) { + updateTargetGroupMapUnary(n, node); + } + + GroupIdType transport(const ABT& n, + const UnwindNode& node, + const VariableEnvironment& env, + GroupIdType child, + GroupIdType /*binder*/, + GroupIdType /*refs*/) { + return addNode(n, node, env, child); + } + + void prepare(const ABT& n, const CollationNode& node, const VariableEnvironment& /*env*/) { + updateTargetGroupMapUnary(n, node); + } + + GroupIdType transport(const ABT& n, + const CollationNode& node, + const VariableEnvironment& env, + GroupIdType child, + GroupIdType /*refs*/) { + return addNode(n, node, env, child); + } + + void prepare(const ABT& n, const LimitSkipNode& node, const VariableEnvironment& /*env*/) { + updateTargetGroupMapUnary(n, node); + } + + GroupIdType transport(const ABT& n, + const LimitSkipNode& node, + const VariableEnvironment& env, + GroupIdType child) { + return addNode(n, node, env, child); + } + + void prepare(const ABT& n, const ExchangeNode& node, const VariableEnvironment& /*env*/) { + updateTargetGroupMapUnary(n, node); + } + + GroupIdType transport(const ABT& n, + const ExchangeNode& node, + const VariableEnvironment& env, + GroupIdType child, + GroupIdType /*refs*/) { + return addNode(n, node, env, child); + } + + void prepare(const ABT& n, const RootNode& node, const VariableEnvironment& /*env*/) { + updateTargetGroupMapUnary(n, node); + } + + GroupIdType transport(const ABT& n, + const RootNode& node, + const VariableEnvironment& env, + GroupIdType child, + GroupIdType /*refs*/) { + return addNode(n, node, env, child); + } + + /** + * Other ABT types. + */ + template <typename T, typename... Ts> + GroupIdType transport(const ABT& /*n*/, + const T& /*node*/, + const VariableEnvironment& /*env*/, + Ts&&...) { + static_assert(!canBeLogicalNode<T>(), "Logical node must implement its transport."); + return -1; + } + + template <typename T, typename... Ts> + void prepare(const ABT& n, const T& /*node*/, const VariableEnvironment& /*env*/) { + static_assert(!canBeLogicalNode<T>(), "Logical node must implement its prepare."); + } + + GroupIdType integrate(const ABT& n) { + return algebra::transport<true>(n, *this, VariableEnvironment::build(n, &_memo)); + } + +private: + GroupIdType addNodes(const ABT& n, + const Node& node, + ABT forMemo, + const VariableEnvironment& env, + Memo::GroupIdVector childGroupIds) { + auto it = _targetGroupMap.find(n.ref()); + const GroupIdType targetGroupId = (it == _targetGroupMap.cend()) ? -1 : it->second; + const auto result = _memo.addNode(std::move(childGroupIds), + env.getProjections(&node), + targetGroupId, + _insertedNodeIds, + std::move(forMemo)); + return result._groupId; + } + + template <class T, typename... Args> + GroupIdType addNodes(const ABT& n, + const T& node, + const VariableEnvironment& env, + Memo::GroupIdVector childGroupIds) { + ABT forMemo = n; + auto& childNodes = forMemo.template cast<T>()->nodes(); + for (size_t i = 0; i < childNodes.size(); i++) { + const GroupIdType childGroupId = childGroupIds.at(i); + uassert(6624121, "Invalid child group", childGroupId >= 0); + childNodes.at(i) = make<MemoLogicalDelegatorNode>(childGroupId); + } + + return addNodes(n, node, std::move(forMemo), env, std::move(childGroupIds)); + } + + template <class T> + GroupIdType addNode(const ABT& n, + const T& node, + const VariableEnvironment& env, + GroupIdType childGroupId) { + ABT forMemo = n; + uassert(6624122, "Invalid child group", childGroupId >= 0); + forMemo.cast<T>()->getChild() = make<MemoLogicalDelegatorNode>(childGroupId); + return addNodes(n, node, std::move(forMemo), env, {childGroupId}); + } + + template <class T> + GroupIdType addNodes(const ABT& n, + const T& node, + const VariableEnvironment& env, + GroupIdType leftGroupId, + GroupIdType rightGroupId) { + ABT forMemo = n; + uassert(6624123, "Invalid left child group", leftGroupId >= 0); + uassert(6624124, "Invalid right child group", rightGroupId >= 0); + + forMemo.cast<T>()->getLeftChild() = make<MemoLogicalDelegatorNode>(leftGroupId); + forMemo.cast<T>()->getRightChild() = make<MemoLogicalDelegatorNode>(rightGroupId); + return addNodes(n, node, std::move(forMemo), env, {leftGroupId, rightGroupId}); + } + + template <class T> + ABT::reference_type findExistingNodeFromTargetGroupMap(const ABT& n, const T& node) { + auto it = _targetGroupMap.find(n.ref()); + if (it == _targetGroupMap.cend()) { + return nullptr; + } + const auto [index, found] = _memo.findNodeInGroup(it->second, n.ref()); + if (!found) { + return nullptr; + } + + ABT::reference_type result = _memo.getNode({it->second, index}); + uassert(6624049, "Node type in memo does not match target type", result.is<T>()); + return result; + } + + void updateTargetGroupRefs( + const std::vector<std::pair<ABT::reference_type, GroupIdType>>& childGroups) { + for (auto [childRef, targetGroupId] : childGroups) { + auto it = _targetGroupMap.find(childRef); + if (it == _targetGroupMap.cend()) { + _targetGroupMap.emplace(childRef, targetGroupId); + } else if (it->second != targetGroupId) { + uasserted(6624050, "Incompatible target groups for parent and child"); + } + } + } + + template <class T> + void updateTargetGroupMapUnary(const ABT& n, const T& node) { + if (_addExistingNodeWithNewChild) { + return; + } + + ABT::reference_type existing = findExistingNodeFromTargetGroupMap(n, node); + if (!existing.empty()) { + const GroupIdType targetGroupId = existing.cast<T>() + ->getChild() + .template cast<MemoLogicalDelegatorNode>() + ->getGroupId(); + updateTargetGroupRefs({{node.getChild().ref(), targetGroupId}}); + } + } + + template <class T> + void updateTargetGroupMapNary(const ABT& n, const T& node) { + ABT::reference_type existing = findExistingNodeFromTargetGroupMap(n, node); + if (!existing.empty()) { + const ABTVector& existingChildren = existing.cast<T>()->nodes(); + const ABTVector& targetChildren = node.nodes(); + uassert(6624051, + "Different number of children between existing and target node", + existingChildren.size() == targetChildren.size()); + + std::vector<std::pair<ABT::reference_type, GroupIdType>> childGroups; + for (size_t i = 0; i < existingChildren.size(); i++) { + const ABT& existingChild = existingChildren.at(i); + const ABT& targetChild = targetChildren.at(i); + childGroups.emplace_back( + targetChild.ref(), + existingChild.cast<MemoLogicalDelegatorNode>()->getGroupId()); + } + updateTargetGroupRefs(childGroups); + } + } + + template <class T> + void updateTargetGroupMapBinary(const ABT& n, const T& node) { + ABT::reference_type existing = findExistingNodeFromTargetGroupMap(n, node); + if (existing.empty()) { + return; + } + + const T& existingNode = *existing.cast<T>(); + const GroupIdType leftGroupId = + existingNode.getLeftChild().template cast<MemoLogicalDelegatorNode>()->getGroupId(); + const GroupIdType rightGroupId = + existingNode.getRightChild().template cast<MemoLogicalDelegatorNode>()->getGroupId(); + updateTargetGroupRefs( + {{node.getLeftChild().ref(), leftGroupId}, {node.getRightChild().ref(), rightGroupId}}); + } + + /** + * We do not own any of these. + */ + Memo& _memo; + NodeIdSet& _insertedNodeIds; + + /** + * We own this. + */ + Memo::NodeTargetGroupMap _targetGroupMap; + + // If set we enable modification of target group based on existing nodes. In practical terms, we + // would not assume that if F(x) = F(y) then x = y. This is currently used in conjunction with + // $elemMatch rewrite (PathTraverse over PathCompose). + bool _addExistingNodeWithNewChild; +}; + +size_t Memo::GroupIdVectorHash::operator()(const Memo::GroupIdVector& v) const { + size_t result = 17; + for (const GroupIdType id : v) { + updateHash(result, std::hash<GroupIdType>()(id)); + } + return result; +} + +size_t Memo::NodeTargetGroupHash::operator()(const ABT::reference_type& nodeRef) const { + return std::hash<const Node*>()(nodeRef.cast<Node>()); +} + +Memo::Memo(DebugInfo debugInfo, + const Metadata& metadata, + std::unique_ptr<LogicalPropsInterface> logicalPropsDerivation, + std::unique_ptr<CEInterface> ceDerivation) + : _groups(), + _inputGroupsToNodeIdMap(), + _nodeIdToInputGroupsMap(), + _metadata(metadata), + _logicalPropsDerivation(std::move(logicalPropsDerivation)), + _ceDerivation(std::move(ceDerivation)), + _debugInfo(std::move(debugInfo)), + _stats() { + uassert(6624125, "Empty logical properties derivation", _logicalPropsDerivation.get()); + uassert(6624126, "Empty CE derivation", _ceDerivation.get()); +} + +const Group& Memo::getGroup(const GroupIdType groupId) const { + return *_groups.at(groupId); +} + +Group& Memo::getGroup(const GroupIdType groupId) { + return *_groups.at(groupId); +} + +std::pair<size_t, bool> Memo::findNodeInGroup(GroupIdType groupId, ABT::reference_type node) const { + return getGroup(groupId)._logicalNodes.find(node); +} + +GroupIdType Memo::addGroup(ProjectionNameSet projections) { + _groups.emplace_back(std::make_unique<Group>(std::move(projections))); + return _groups.size() - 1; +} + +std::pair<MemoLogicalNodeId, bool> Memo::addNode(GroupIdType groupId, ABT n) { + uassert(6624052, "Attempting to insert a physical node", !n.is<PhysicalNode>()); + uassert(6624053, + "Attempting to insert a logical delegator node", + !n.is<MemoLogicalDelegatorNode>()); + + OrderPreservingABTSet& nodes = _groups.at(groupId)->_logicalNodes; + auto [index, inserted] = nodes.emplace_back(std::move(n)); + return {{groupId, index}, inserted}; +} + +ABT::reference_type Memo::getNode(const MemoLogicalNodeId nodeMemoId) const { + return getGroup(nodeMemoId._groupId)._logicalNodes.at(nodeMemoId._index); +} + +std::pair<MemoLogicalNodeId, bool> Memo::findNode(const GroupIdVector& groups, const ABT& node) { + const auto it = _inputGroupsToNodeIdMap.find(groups); + if (it != _inputGroupsToNodeIdMap.cend()) { + for (const MemoLogicalNodeId& nodeMemoId : it->second) { + if (getNode(nodeMemoId) == node) { + return {nodeMemoId, true}; + } + } + } + return {{0, 0}, false}; +} + +void Memo::estimateCE(const GroupIdType groupId) { + // If inserted into a new group, derive logical properties, and cardinality estimation + // for the new group. + Group& group = getGroup(groupId); + properties::LogicalProps& props = group._logicalProperties; + + const ABT::reference_type nodeRef = group._logicalNodes.at(0); + properties::LogicalProps logicalProps = + _logicalPropsDerivation->deriveProps(_metadata, nodeRef, nullptr, this, groupId); + props.merge(logicalProps); + + const CEType estimate = _ceDerivation->deriveCE(*this, props, nodeRef); + auto ceProp = properties::CardinalityEstimate(estimate); + + if (auto sargablePtr = nodeRef.cast<SargableNode>(); sargablePtr != nullptr) { + auto& partialSchemaKeyCEMap = ceProp.getPartialSchemaKeyCEMap(); + for (const auto& [key, req] : sargablePtr->getReqMap()) { + ABT singularReq = make<SargableNode>(PartialSchemaRequirements{{key, req}}, + CandidateIndexMap{}, + sargablePtr->getTarget(), + sargablePtr->getChild()); + const CEType singularEst = _ceDerivation->deriveCE(*this, props, singularReq.ref()); + partialSchemaKeyCEMap.emplace(key, singularEst); + } + } + + properties::setPropertyOverwrite(props, std::move(ceProp)); + if (_debugInfo.hasDebugLevel(2)) { + std::cout << "Group " << groupId << ": " + << ExplainGenerator::explainLogicalProps("Logical properties", props); + } +} + +MemoLogicalNodeId Memo::addNode(GroupIdVector groupVector, + ProjectionNameSet projections, + const GroupIdType targetGroupId, + NodeIdSet& insertedNodeIds, + ABT n) { + for (const GroupIdType groupId : groupVector) { + // Invalid tree: node is its own child. + uassert(6624127, "Target group appears inside group vector", groupId != targetGroupId); + } + + auto [existingId, foundNode] = findNode(groupVector, n); + + if (foundNode) { + uassert(6624054, + "Found node outside target group", + targetGroupId < 0 || targetGroupId == existingId._groupId); + return existingId; + } + + const bool noTargetGroup = targetGroupId < 0; + // Only for debugging. + ProjectionNameSet projectionsCopy; + if (!noTargetGroup && _debugInfo.isDebugMode()) { + projectionsCopy = projections; + } + + // Current node is not in the memo. Insert unchanged. + const GroupIdType groupId = noTargetGroup ? addGroup(std::move(projections)) : targetGroupId; + auto [newId, inserted] = addNode(groupId, std::move(n)); + if (inserted || noTargetGroup) { + insertedNodeIds.insert(newId); + _inputGroupsToNodeIdMap[groupVector].insert(newId); + _nodeIdToInputGroupsMap[newId] = groupVector; + + if (noTargetGroup) { + estimateCE(groupId); + } else if (_debugInfo.isDebugMode()) { + const Group& group = getGroup(groupId); + // If inserted into an existing group, verify we deliver all expected projections. + for (const ProjectionName& groupProjection : group.binder().names()) { + uassert(6624055, + "Node does not project all specified group projections", + projectionsCopy.find(groupProjection) != projectionsCopy.cend()); + } + + // TODO: possibly verify cardinality estimation + } + } + + return newId; +} + +GroupIdType Memo::integrate(const ABT& node, + NodeTargetGroupMap targetGroupMap, + NodeIdSet& insertedNodeIds, + const bool addExistingNodeWithNewChild) { + _stats._numIntegrations++; + MemoIntegrator integrator( + *this, std::move(targetGroupMap), insertedNodeIds, addExistingNodeWithNewChild); + return integrator.integrate(node); +} + +size_t Memo::getGroupCount() const { + return _groups.size(); +} + +void Memo::clearLogicalNodes(const GroupIdType groupId) { + auto& group = getGroup(groupId); + auto& logicalNodes = group._logicalNodes; + + for (size_t index = 0; index < logicalNodes.size(); index++) { + const MemoLogicalNodeId nodeId{groupId, index}; + const auto& groupVector = _nodeIdToInputGroupsMap.at(nodeId); + _inputGroupsToNodeIdMap.at(groupVector).erase(nodeId); + _nodeIdToInputGroupsMap.erase(nodeId); + } + + logicalNodes.clear(); + group._logicalRewriteQueue = {}; +} + +const Memo::InputGroupsToNodeIdMap& Memo::getInputGroupsToNodeIdMap() const { + return _inputGroupsToNodeIdMap; +} + +const DebugInfo& Memo::getDebugInfo() const { + return _debugInfo; +} + +void Memo::clear() { + _stats = {}; + _groups.clear(); + _inputGroupsToNodeIdMap.clear(); + _nodeIdToInputGroupsMap.clear(); +} + +const Memo::Stats& Memo::getStats() const { + return _stats; +} + +size_t Memo::getLogicalNodeCount() const { + size_t result = 0; + for (const auto& group : _groups) { + result += group->_logicalNodes.size(); + } + return result; +} + +size_t Memo::getPhysicalNodeCount() const { + size_t result = 0; + for (const auto& group : _groups) { + result += group->_physicalNodes.getNodes().size(); + } + return result; +} + +const Metadata& Memo::getMetadata() const { + return _metadata; +} + +const CEInterface& Memo::getCEDerivation() const { + return *_ceDerivation; +} + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/memo.h b/src/mongo/db/query/optimizer/cascades/memo.h new file mode 100644 index 00000000000..4d0a943ac56 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/memo.h @@ -0,0 +1,248 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <map> +#include <set> +#include <unordered_map> +#include <vector> + +#include "mongo/db/query/optimizer/cascades/interfaces.h" +#include "mongo/db/query/optimizer/cascades/rewrite_queues.h" +#include "mongo/db/query/optimizer/reference_tracker.h" + +namespace mongo::optimizer::cascades { + +struct MemoNodeRefHash { + size_t operator()(const ABT::reference_type& nodeRef) const; +}; + +struct MemoNodeRefCompare { + bool operator()(const ABT::reference_type& left, const ABT::reference_type& right) const; +}; + +class OrderPreservingABTSet { +public: + OrderPreservingABTSet() = default; + OrderPreservingABTSet(const OrderPreservingABTSet&) = delete; + OrderPreservingABTSet(OrderPreservingABTSet&&) = default; + + ABT::reference_type at(size_t index) const; + std::pair<size_t, bool> emplace_back(ABT node); + std::pair<size_t, bool> find(ABT::reference_type node) const; + + void clear(); + + size_t size() const; + const ABTVector& getVector() const; + +private: + opt::unordered_map<ABT::reference_type, size_t, MemoNodeRefHash, MemoNodeRefCompare> _map; + ABTVector _vector; +}; + +struct PhysNodeInfo { + ABT _node; + + // Total cost for the entire subtree. + CostType _cost; + + // Operator cost (without including the subtree). + CostType _localCost; + + // For display purposes, adjusted cardinality based on physical properties (e.g. Repetition and + // Limit-Skip). + CEType _adjustedCE; +}; + +struct PhysOptimizationResult { + PhysOptimizationResult(); + PhysOptimizationResult(size_t index, properties::PhysProps physProps, CostType costLimit); + + bool isOptimized() const; + void raiseCostLimit(CostType costLimit); + + const size_t _index; + const properties::PhysProps _physProps; + + CostType _costLimit; + // If set, we have successfully optimized. + boost::optional<PhysNodeInfo> _nodeInfo; + // Rejected physical plans. + std::vector<PhysNodeInfo> _rejectedNodeInfo; + + // Index of last logical node in our group we implemented. + size_t _lastImplementedNodePos; + + PhysRewriteQueue _queue; +}; + +struct PhysNodes { + using PhysNodeVector = std::vector<std::unique_ptr<PhysOptimizationResult>>; + + PhysNodes() = default; + + PhysOptimizationResult& addOptimizationResult(properties::PhysProps properties, + CostType costLimit); + + const PhysOptimizationResult& at(size_t index) const; + PhysOptimizationResult& at(size_t index); + + std::pair<size_t, bool> find(const properties::PhysProps& props) const; + + const PhysNodeVector& getNodes() const; + +private: + PhysNodeVector _physicalNodes; + + struct PhysPropsHasher { + size_t operator()(const properties::PhysProps& physProps) const; + }; + + // Used to speed up lookups into the winner's circle using physical properties. + opt::unordered_map<properties::PhysProps, size_t, PhysPropsHasher> _physPropsToPhysNodeMap; +}; + +struct Group { + explicit Group(ProjectionNameSet projections); + + Group(const Group&) = delete; + Group(Group&&) = default; + + const ExpressionBinder& binder() const; + + // Associated logical nodes. + OrderPreservingABTSet _logicalNodes; + // Group logical properties. + properties::LogicalProps _logicalProperties; + ABT _binder; + + LogicalRewriteQueue _logicalRewriteQueue; + + // Best physical plan for given physical properties: aka "Winner's circle". + PhysNodes _physicalNodes; +}; + +class Memo { + friend class PhysicalRewriter; + +public: + using GroupIdVector = std::vector<GroupIdType>; + + struct Stats { + // Number of calls to integrate() + size_t _numIntegrations = 0; + // Number of recursive physical optimization calls. + size_t _physPlanExplorationCount = 0; + // Number of checks to winner's circle. + size_t _physMemoCheckCount = 0; + }; + + struct GroupIdVectorHash { + size_t operator()(const GroupIdVector& v) const; + }; + using InputGroupsToNodeIdMap = opt::unordered_map<GroupIdVector, NodeIdSet, GroupIdVectorHash>; + + /** + * Inverse map. + */ + using NodeIdToInputGroupsMap = opt::unordered_map<MemoLogicalNodeId, GroupIdVector, NodeIdHash>; + + struct NodeTargetGroupHash { + size_t operator()(const ABT::reference_type& nodeRef) const; + }; + using NodeTargetGroupMap = + opt::unordered_map<ABT::reference_type, GroupIdType, NodeTargetGroupHash>; + + Memo(DebugInfo debugInfo, + const Metadata& metadata, + std::unique_ptr<LogicalPropsInterface> logicalPropsDerivation, + std::unique_ptr<CEInterface> ceDerivation); + + const Group& getGroup(GroupIdType groupId) const; + Group& getGroup(GroupIdType groupId); + size_t getGroupCount() const; + + std::pair<size_t, bool> findNodeInGroup(GroupIdType groupId, ABT::reference_type node) const; + + ABT::reference_type getNode(MemoLogicalNodeId nodeMemoId) const; + + void estimateCE(GroupIdType groupId); + + MemoLogicalNodeId addNode(GroupIdVector groupVector, + ProjectionNameSet projections, + GroupIdType targetGroupId, + NodeIdSet& insertedNodeIds, + ABT n); + + GroupIdType integrate(const ABT& node, + NodeTargetGroupMap targetGroupMap, + NodeIdSet& insertedNodeIds, + bool addExistingNodeWithNewChild = false); + + void clearLogicalNodes(GroupIdType groupId); + + const InputGroupsToNodeIdMap& getInputGroupsToNodeIdMap() const; + + const DebugInfo& getDebugInfo() const; + + void clear(); + + const Stats& getStats() const; + size_t getLogicalNodeCount() const; + size_t getPhysicalNodeCount() const; + + const Metadata& getMetadata() const; + const CEInterface& getCEDerivation() const; + +private: + GroupIdType addGroup(ProjectionNameSet projections); + + std::pair<MemoLogicalNodeId, bool> addNode(GroupIdType groupId, ABT n); + + std::pair<MemoLogicalNodeId, bool> findNode(const GroupIdVector& groups, const ABT& node); + + std::vector<std::unique_ptr<Group>> _groups; + + // Used to find nodes using particular groups as inputs. + InputGroupsToNodeIdMap _inputGroupsToNodeIdMap; + + NodeIdToInputGroupsMap _nodeIdToInputGroupsMap; + + const Metadata& _metadata; + std::unique_ptr<LogicalPropsInterface> _logicalPropsDerivation; + std::unique_ptr<CEInterface> _ceDerivation; + + const DebugInfo _debugInfo; + + Stats _stats; +}; + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp b/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp new file mode 100644 index 00000000000..fb3d9185220 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp @@ -0,0 +1,407 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/cascades/physical_rewriter.h" + +#include "mongo/db/query/optimizer/cascades/enforcers.h" +#include "mongo/db/query/optimizer/cascades/implementers.h" +#include "mongo/db/query/optimizer/explain.h" +#include "mongo/db/query/optimizer/utils/memo_utils.h" + +namespace mongo::optimizer::cascades { + +using namespace properties; + +/** + * Helper class used to check if two physical property sets are compatible by testing each + * constituent property for compatibility. This is used to check if a winner's circle entry can be + * reused. + */ +class PropCompatibleVisitor { +public: + bool operator()(const PhysProperty&, const CollationRequirement& requiredProp) { + return collationsCompatible( + getPropertyConst<CollationRequirement>(_availableProps).getCollationSpec(), + requiredProp.getCollationSpec()); + } + + bool operator()(const PhysProperty&, const LimitSkipRequirement& requiredProp) { + const auto& available = getPropertyConst<LimitSkipRequirement>(_availableProps); + return available.getSkip() >= requiredProp.getSkip() && + available.getAbsoluteLimit() <= requiredProp.getAbsoluteLimit(); + } + + bool operator()(const PhysProperty&, const ProjectionRequirement& requiredProp) { + const auto& availableProjections = + getPropertyConst<ProjectionRequirement>(_availableProps).getProjections(); + // Do we have a projection superset (not necessarily strict superset)? + for (const ProjectionName& projectionName : requiredProp.getProjections().getVector()) { + if (!availableProjections.find(projectionName).second) { + return false; + } + } + return true; + } + + bool operator()(const PhysProperty&, const IndexingRequirement& requiredProp) { + const auto& available = getPropertyConst<IndexingRequirement>(_availableProps); + return available.getIndexReqTarget() == requiredProp.getIndexReqTarget() && + available.getNeedsRID() == requiredProp.getNeedsRID() && + (available.getDedupRID() || !requiredProp.getDedupRID()) && + available.getSatisfiedPartialIndexesGroupId() == + requiredProp.getSatisfiedPartialIndexesGroupId(); + } + + template <class T> + bool operator()(const PhysProperty&, const T& requiredProp) { + return getPropertyConst<T>(_availableProps) == requiredProp; + } + + static bool propertiesCompatible(const PhysProps& requiredProps, + const PhysProps& availableProps) { + if (requiredProps.size() != availableProps.size()) { + return false; + } + + PropCompatibleVisitor visitor(availableProps); + for (const auto& [key, prop] : requiredProps) { + if (availableProps.find(key) == availableProps.cend() || !prop.visit(visitor)) { + return false; + } + } + return true; + } + +private: + PropCompatibleVisitor(const PhysProps& availableProp) : _availableProps(availableProp) {} + + // We don't own this. + const PhysProps& _availableProps; +}; + +PhysicalRewriter::PhysicalRewriter( + Memo& memo, + const QueryHints& hints, + const opt::unordered_map<std::string, ProjectionName>& ridProjections, + const CostingInterface& costDerivation, + std::unique_ptr<LogicalRewriter>& logicalRewriter) + : _memo(memo), + _costDerivation(costDerivation), + _hints(hints), + _ridProjections(ridProjections), + _logicalRewriter(logicalRewriter) {} + +static void printCandidateInfo(const ABT& node, + const GroupIdType groupId, + const CostType nodeCost, + const ChildPropsType& childProps, + const PhysOptimizationResult& bestResult) { + std::cout + << "group: " << groupId << ", id: " << bestResult._index + << ", nodeCost: " << nodeCost.toString() << ", best cost: " + << (bestResult._nodeInfo ? bestResult._nodeInfo->_cost : CostType::kInfinity).toString() + << "\n"; + std::cout << ExplainGenerator::explainPhysProps("Physical properties", bestResult._physProps) + << "\n"; + std::cout << "Node: \n" << ExplainGenerator::explainV2(node) << "\n"; + + for (const auto& childProp : childProps) { + std::cout << ExplainGenerator::explainPhysProps("Child properties", childProp.second); + } +} + +void PhysicalRewriter::costAndRetainBestNode(ABT node, + ChildPropsType childProps, + NodeCEMap nodeCEMap, + const GroupIdType groupId, + const PrefixId& prefixId, + PhysOptimizationResult& bestResult) { + const CostAndCE nodeCostAndCE = + _costDerivation.deriveCost(_memo, bestResult._physProps, node.ref(), childProps, nodeCEMap); + const CostType nodeCost = nodeCostAndCE._cost; + uassert(6624056, "Must get non-infinity cost for physical node.", !nodeCost.isInfinite()); + + if (_memo.getDebugInfo().hasDebugLevel(3)) { + std::cout << "Requesting optimization\n"; + printCandidateInfo(node, groupId, nodeCost, childProps, bestResult); + } + + const CostType childCostLimit = + bestResult._nodeInfo ? bestResult._nodeInfo->_cost : bestResult._costLimit; + const auto [success, cost] = optimizeChildren(nodeCost, childProps, prefixId, childCostLimit); + const bool improvement = + success && (!bestResult._nodeInfo || cost < bestResult._nodeInfo->_cost); + + if (_memo.getDebugInfo().hasDebugLevel(3)) { + std::cout << (success ? (improvement ? "Improved" : "Did not improve") + : "Failed optimizing") + << "\n"; + printCandidateInfo(node, groupId, nodeCost, childProps, bestResult); + } + + PhysNodeInfo candidateNodeInfo{ + unwrapConstFilter(std::move(node)), cost, nodeCost, nodeCostAndCE._ce}; + const bool keepRejectedPlans = _hints._keepRejectedPlans; + if (improvement) { + if (keepRejectedPlans && bestResult._nodeInfo) { + bestResult._rejectedNodeInfo.push_back(std::move(*bestResult._nodeInfo)); + } + bestResult._nodeInfo = std::move(candidateNodeInfo); + } else if (keepRejectedPlans) { + bestResult._rejectedNodeInfo.push_back(std::move(candidateNodeInfo)); + } +} + +/** + * Convert nodes from logical to physical memo delegators. + * Performs branch-and-bound search. + */ +std::pair<bool, CostType> PhysicalRewriter::optimizeChildren(const CostType nodeCost, + ChildPropsType childProps, + const PrefixId& prefixId, + const CostType costLimit) { + const bool disableBranchAndBound = _hints._disableBranchAndBound; + + CostType totalCost = nodeCost; + if (costLimit < totalCost && !disableBranchAndBound) { + return {false, CostType::kInfinity}; + } + + for (auto& [node, props] : childProps) { + const GroupIdType groupId = node->cast<MemoLogicalDelegatorNode>()->getGroupId(); + + const CostType childCostLimit = + disableBranchAndBound ? CostType::kInfinity : (costLimit - totalCost); + auto optGroupResult = optimizeGroup(groupId, std::move(props), prefixId, childCostLimit); + if (!optGroupResult._success) { + return {false, CostType::kInfinity}; + } + + totalCost += optGroupResult._cost; + if (costLimit < totalCost && !disableBranchAndBound) { + return {false, CostType::kInfinity}; + } + + ABT optimizedChild = + make<MemoPhysicalDelegatorNode>(MemoPhysicalNodeId{groupId, optGroupResult._index}); + std::swap(*node, optimizedChild); + } + + return {true, totalCost}; +} + +PhysicalRewriter::OptimizeGroupResult::OptimizeGroupResult() + : _success(false), _index(0), _cost(CostType::kInfinity) {} + +PhysicalRewriter::OptimizeGroupResult::OptimizeGroupResult(const size_t index, const CostType cost) + : _success(true), _index(index), _cost(std::move(cost)) { + uassert(6624347, + "Cannot have successful optimization with infinite cost", + _cost < CostType::kInfinity); +} + +PhysicalRewriter::OptimizeGroupResult PhysicalRewriter::optimizeGroup(const GroupIdType groupId, + PhysProps physProps, + PrefixId prefixId, + CostType costLimit) { + const size_t localPlanExplorationCount = ++_memo._stats._physPlanExplorationCount; + if (_memo.getDebugInfo().hasDebugLevel(2)) { + std::cout << "#" << localPlanExplorationCount << " Optimizing group " << groupId + << ", cost limit: " << costLimit.toString() << "\n"; + std::cout << ExplainGenerator::explainPhysProps("Physical properties", physProps) << "\n"; + } + + Group& group = _memo.getGroup(groupId); + const LogicalProps& logicalProps = group._logicalProperties; + if (hasProperty<IndexingAvailability>(logicalProps) && + !hasProperty<IndexingRequirement>(physProps)) { + // Re-optimize under complete scan indexing requirements. + setPropertyOverwrite( + physProps, + IndexingRequirement{ + IndexReqTarget::Complete, false /*needRID*/, true /*dedupRID*/, groupId}); + } + + auto& physicalNodes = group._physicalNodes; + // Establish if we have found exact match of the physical properties in the winner's circle. + const auto [exactPropsIndex, hasExactProps] = physicalNodes.find(physProps); + // If true, we have found compatible (but not equal) props with cost under our cost limit. + bool hasCompatibleProps = false; + + if (hasExactProps) { + PhysOptimizationResult& physNode = physicalNodes.at(exactPropsIndex); + if (!physNode.isOptimized()) { + // Currently optimizing under the same properties higher up the stack (recursive loop). + return {}; + } + // At this point we have an optimized entry. + + if (!physNode._nodeInfo) { + if (physNode._costLimit < costLimit) { + physNode.raiseCostLimit(costLimit); + // Fall through and continue optimizing. + } else { + // Previously failed to optimize under less strict cost limit. + return {}; + } + } else if (costLimit < physNode._nodeInfo->_cost) { + // We have a stricter limit than our previous optimization's cost. + return {}; + } else { + // Reuse result under identical properties. + if (_memo.getDebugInfo().hasDebugLevel(3)) { + std::cout << "Reusing winner's circle entry: group: " << groupId + << ", id: " << physNode._index + << ", cost: " << physNode._nodeInfo->_cost.toString() + << ", limit: " << costLimit.toString() << "\n"; + std::cout << "Existing props: " + << ExplainGenerator::explainPhysProps("existing", physNode._physProps) + << "\n"; + std::cout << "New props: " << ExplainGenerator::explainPhysProps("new", physProps) + << "\n"; + std::cout << "Reused plan: " + << ExplainGenerator::explainV2(physNode._nodeInfo->_node) << "\n"; + } + return {physNode._index, physNode._nodeInfo->_cost}; + } + } else { + // Check winner's circle for compatible properties. + for (const auto& physNode : physicalNodes.getNodes()) { + _memo._stats._physMemoCheckCount++; + + if (!physNode->_nodeInfo) { + continue; + } + // At this point we have an optimized entry. + + if (costLimit < physNode->_nodeInfo->_cost) { + // Properties are not identical. Continue exploring even if limit was stricter. + continue; + } + + if (!PropCompatibleVisitor::propertiesCompatible(physProps, physNode->_physProps)) { + // We are stricter that what is available. + continue; + } + + if (physNode->_nodeInfo->_cost < costLimit) { + if (_memo.getDebugInfo().hasDebugLevel(3)) { + std::cout << "Reducing cost limit: group: " << groupId + << ", id: " << physNode->_index + << ", cost: " << physNode->_nodeInfo->_cost.toString() + << ", limit: " << costLimit.toString() << "\n"; + std::cout << ExplainGenerator::explainPhysProps("Existing props", + physNode->_physProps) + << "\n"; + std::cout << ExplainGenerator::explainPhysProps("New props", physProps) << "\n"; + } + + // Reduce cost limit result under compatible properties. + hasCompatibleProps = true; + costLimit = physNode->_nodeInfo->_cost; + } + } + } + + // If found an exact match for properties, re-use entry and continue optimizing under higher + // cost limit. Otherwise create with a new entry for the current properties. + PhysOptimizationResult& bestResult = hasExactProps + ? physicalNodes.at(exactPropsIndex) + : physicalNodes.addOptimizationResult(physProps, costLimit); + + // Enforcement rewrites run just once, and are independent of the logical nodes. + if (hasProperty<ProjectionRequirement>(bestResult._physProps)) { + // Verify properties can be enforced and add enforcers if necessary. + addEnforcers(groupId, + _memo.getMetadata(), + prefixId, + bestResult._queue, + bestResult._physProps, + logicalProps); + } + + // Iterate until we perform all logical for the group and physical rewrites for our best plan. + const OrderPreservingABTSet& logicalNodes = group._logicalNodes; + while (bestResult._lastImplementedNodePos < logicalNodes.size() || !bestResult._queue.empty()) { + if (_logicalRewriter) { + // Attempt to perform logical rewrites. + _logicalRewriter->rewriteGroup(groupId); + } + + // Add rewrites to convert logical into physical nodes. Only add rewrites for newly added + // logical nodes. + addImplementers( + _memo, _hints, _ridProjections, prefixId, bestResult, logicalProps, logicalNodes); + + // Perform physical rewrites, use branch-and-bound. + while (!bestResult._queue.empty()) { + PhysRewriteEntry rewrite = std::move(*bestResult._queue.top()); + bestResult._queue.pop(); + + NodeCEMap nodeCEMap = std::move(rewrite._nodeCEMap); + if (nodeCEMap.empty()) { + nodeCEMap.emplace( + rewrite._node.cast<Node>(), + getPropertyConst<CardinalityEstimate>(logicalProps).getEstimate()); + } + + costAndRetainBestNode(std::move(rewrite._node), + std::move(rewrite._childProps), + std::move(nodeCEMap), + groupId, + prefixId, + bestResult); + } + } + + uassert(6624128, "Result is not optimized!", bestResult.isOptimized()); + if (!bestResult._nodeInfo) { + uassert(6624348, + "Must optimize successfully if found compatible properties!", + !hasCompatibleProps); + return {}; + } + + // We have a successful rewrite. + if (_memo.getDebugInfo().hasDebugLevel(2)) { + std::cout << "#" << localPlanExplorationCount << " Optimized group: " << groupId + << ", id: " << bestResult._index + << ", cost: " << bestResult._nodeInfo->_cost.toString() << "\n"; + std::cout << ExplainGenerator::explainPhysProps("Physical properties", + bestResult._physProps) + << "\n"; + std::cout << "Node: \n" + << ExplainGenerator::explainV2( + bestResult._nodeInfo->_node, false /*displayProperties*/, &_memo); + } + + return {bestResult._index, bestResult._nodeInfo->_cost}; +} + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/physical_rewriter.h b/src/mongo/db/query/optimizer/cascades/physical_rewriter.h new file mode 100644 index 00000000000..1463c4a7858 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/physical_rewriter.h @@ -0,0 +1,93 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/cascades/logical_rewriter.h" +#include "mongo/db/query/optimizer/cascades/memo.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer::cascades { + +class PhysicalRewriter { + friend class PropEnforcerVisitor; + friend class ImplementationVisitor; + +public: + struct OptimizeGroupResult { + OptimizeGroupResult(); + OptimizeGroupResult(size_t index, CostType cost); + + OptimizeGroupResult(const OptimizeGroupResult& other) = default; + OptimizeGroupResult(OptimizeGroupResult&& other) = default; + + bool _success; + size_t _index; + CostType _cost; + }; + + PhysicalRewriter(Memo& memo, + const QueryHints& hints, + const opt::unordered_map<std::string, ProjectionName>& ridProjections, + const CostingInterface& costDerivation, + std::unique_ptr<LogicalRewriter>& logicalRewriter); + + /** + * Main entry point for physical optimization. + * Optimize a logical plan rooted at a RootNode, and return an index into the winner's circle if + * successful. + */ + OptimizeGroupResult optimizeGroup(GroupIdType groupId, + properties::PhysProps physProps, + PrefixId prefixId, + CostType costLimit); + +private: + void costAndRetainBestNode(ABT node, + ChildPropsType childProps, + NodeCEMap nodeCEMap, + GroupIdType groupId, + const PrefixId& prefixId, + PhysOptimizationResult& bestResult); + + std::pair<bool, CostType> optimizeChildren(CostType nodeCost, + ChildPropsType childProps, + const PrefixId& prefixId, + CostType costLimit); + + // We don't own any of this. + Memo& _memo; + const CostingInterface& _costDerivation; + const QueryHints& _hints; + const opt::unordered_map<std::string, ProjectionName>& _ridProjections; + // If set, we'll perform logical rewrites as part of OptimizeGroup(). + std::unique_ptr<LogicalRewriter>& _logicalRewriter; +}; + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/rewrite_queues.cpp b/src/mongo/db/query/optimizer/cascades/rewrite_queues.cpp new file mode 100644 index 00000000000..8d0ef6809b7 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/rewrite_queues.cpp @@ -0,0 +1,76 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/cascades/rewrite_queues.h" +#include "mongo/db/query/optimizer/utils/memo_utils.h" +#include <mongo/db/query/optimizer/defs.h> + +namespace mongo::optimizer::cascades { + +LogicalRewriteEntry::LogicalRewriteEntry(const double priority, + LogicalRewriteType type, + MemoLogicalNodeId nodeId) + : _priority(priority), _type(type), _nodeId(nodeId) {} + +bool LogicalRewriteEntryComparator::operator()( + const std::unique_ptr<LogicalRewriteEntry>& x, + const std::unique_ptr<LogicalRewriteEntry>& y) const { + // Lower numerical priority is considered last (and thus de-queued first). + if (x->_priority > y->_priority) { + return true; + } else if (x->_priority < y->_priority) { + return false; + } + + // Make sure entries in the queue are consistently ordered. + if (x->_nodeId._groupId < y->_nodeId._groupId) { + return true; + } else if (x->_nodeId._groupId > y->_nodeId._groupId) { + return false; + } + return x->_nodeId._index < y->_nodeId._index; +} + +void optimizeChildrenNoAssert(PhysRewriteQueue& queue, + const double priority, + ABT node, + ChildPropsType childProps, + NodeCEMap nodeCEMap) { + queue.emplace(std::make_unique<PhysRewriteEntry>( + priority, std::move(node), std::move(childProps), std::move(nodeCEMap))); +} + +void optimizeUnderNewProperties(cascades::PhysRewriteQueue& queue, + const double priority, + ABT child, + properties::PhysProps props) { + optimizeChild<FilterNode>(queue, priority, wrapConstFilter(std::move(child)), std::move(props)); +} + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/rewrite_queues.h b/src/mongo/db/query/optimizer/cascades/rewrite_queues.h new file mode 100644 index 00000000000..a64b7d000a1 --- /dev/null +++ b/src/mongo/db/query/optimizer/cascades/rewrite_queues.h @@ -0,0 +1,149 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <queue> + +#include "mongo/db/query/optimizer/cascades/logical_rewriter_rules.h" +#include "mongo/db/query/optimizer/node_defs.h" + +namespace mongo::optimizer::cascades { + +/** + * Keeps track of candidate physical rewrites. + */ +struct LogicalRewriteEntry { + LogicalRewriteEntry(double priority, LogicalRewriteType type, MemoLogicalNodeId nodeId); + + LogicalRewriteEntry() = delete; + LogicalRewriteEntry(const LogicalRewriteEntry& other) = delete; + LogicalRewriteEntry(LogicalRewriteEntry&& other) = default; + + // Numerically lower priority gets applied first. + double _priority; + + LogicalRewriteType _type; + MemoLogicalNodeId _nodeId; +}; + +struct LogicalRewriteEntryComparator { + bool operator()(const std::unique_ptr<LogicalRewriteEntry>& x, + const std::unique_ptr<LogicalRewriteEntry>& y) const; +}; + +using LogicalRewriteQueue = std::priority_queue<std::unique_ptr<LogicalRewriteEntry>, + std::vector<std::unique_ptr<LogicalRewriteEntry>>, + LogicalRewriteEntryComparator>; + +/** + * For now all physical rules use the same priority. + * TODO: use specific priorities (may depend on node parameters). + */ +static constexpr double kDefaultPriority = 10.0; + +/** + * Keeps track of candidate physical rewrites. + */ +struct PhysRewriteEntry { + PhysRewriteEntry(double priority, ABT node, ChildPropsType childProps, NodeCEMap nodeCEMap); + + PhysRewriteEntry() = delete; + PhysRewriteEntry(const PhysRewriteEntry& other) = delete; + PhysRewriteEntry(PhysRewriteEntry&& other) = default; + + // Numerically lower priority gets applied first. + double _priority; + + ABT _node; + ChildPropsType _childProps; + + NodeCEMap _nodeCEMap; +}; + +struct PhysRewriteEntryComparator { + bool operator()(const std::unique_ptr<PhysRewriteEntry>& x, + const std::unique_ptr<PhysRewriteEntry>& y) const; +}; + +using PhysRewriteQueue = std::priority_queue<std::unique_ptr<PhysRewriteEntry>, + std::vector<std::unique_ptr<PhysRewriteEntry>>, + PhysRewriteEntryComparator>; + +void optimizeChildrenNoAssert(PhysRewriteQueue& queue, + double priority, + ABT node, + ChildPropsType childProps, + NodeCEMap nodeCEMap = {}); + +template <class T> +static void optimizeChildren(PhysRewriteQueue& queue, + double priority, + ABT node, + ChildPropsType childProps) { + static_assert(canBePhysicalNode<T>(), "Can only optimize a physical node."); + optimizeChildrenNoAssert(queue, priority, std::move(node), std::move(childProps)); +} + +template <class T> +static void optimizeChild(PhysRewriteQueue& queue, + double priority, + ABT node, + properties::PhysProps childProps) { + ABT& childRef = node.cast<T>()->getChild(); + optimizeChildren<T>( + queue, priority, std::move(node), ChildPropsType{{&childRef, std::move(childProps)}}); +} + +template <class T> +static void optimizeChild(PhysRewriteQueue& queue, const double priority, ABT node) { + optimizeChildren<T>(queue, priority, std::move(node), {}); +} + +void optimizeUnderNewProperties(PhysRewriteQueue& queue, + double priority, + ABT child, + properties::PhysProps props); + +template <class T> +static void optimizeChildren(PhysRewriteQueue& queue, + double priority, + ABT node, + properties::PhysProps leftProps, + properties::PhysProps rightProps) { + ABT& leftChildRef = node.cast<T>()->getLeftChild(); + ABT& rightChildRef = node.cast<T>()->getRightChild(); + optimizeChildren<T>( + queue, + priority, + std::move(node), + {{&leftChildRef, std::move(leftProps)}, {&rightChildRef, std::move(rightProps)}}); +} + +} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/containers.h b/src/mongo/db/query/optimizer/containers.h new file mode 100644 index 00000000000..5a9c11e9c12 --- /dev/null +++ b/src/mongo/db/query/optimizer/containers.h @@ -0,0 +1,81 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <unordered_map> +#include <unordered_set> + +#include "mongo/stdx/trusted_hasher.h" +#include "mongo/stdx/unordered_map.h" +#include "mongo/stdx/unordered_set.h" + +namespace mongo::optimizer::opt { + +namespace { +enum class ContainerImpl { STD, STDX }; + +// For debugging, switch between STD and STDX containers. +static constexpr ContainerImpl kContainerImpl = ContainerImpl::STDX; + +template <ContainerImpl> +struct OptContainers {}; + +template <> +struct OptContainers<ContainerImpl::STDX> { + template <class K> + using Hasher = mongo::DefaultHasher<K>; + + template <class T, class H, typename... Args> + using unordered_set = stdx::unordered_set<T, H, Args...>; + template <class K, class V, class H, typename... Args> + using unordered_map = stdx::unordered_map<K, V, H, Args...>; +}; + +template <> +struct OptContainers<ContainerImpl::STD> { + template <class K> + using Hasher = std::hash<K>; + + template <class T, class H, typename... Args> + using unordered_set = std::unordered_set<T, H, Args...>; // NOLINT + template <class K, class V, class H, typename... Args> + using unordered_map = std::unordered_map<K, V, H, Args...>; // NOLINT +}; + +using ActiveContainers = OptContainers<kContainerImpl>; +} // namespace + +template <class T, class H = ActiveContainers::Hasher<T>, typename... Args> +using unordered_set = ActiveContainers::unordered_set<T, H, Args...>; + +template <class K, class V, class H = ActiveContainers::Hasher<K>, typename... Args> +using unordered_map = ActiveContainers::unordered_map<K, V, H, Args...>; + +} // namespace mongo::optimizer::opt diff --git a/src/mongo/db/query/optimizer/defs.cpp b/src/mongo/db/query/optimizer/defs.cpp new file mode 100644 index 00000000000..9949bb07ab6 --- /dev/null +++ b/src/mongo/db/query/optimizer/defs.cpp @@ -0,0 +1,257 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/defs.h" +#include "mongo/db/query/optimizer/utils/utils.h" +#include "mongo/util/assert_util.h" + +namespace mongo::optimizer { + +static opt::unordered_map<ProjectionName, size_t> createMapFromVector( + const ProjectionNameVector& v) { + opt::unordered_map<ProjectionName, size_t> result; + for (size_t i = 0; i < v.size(); i++) { + result.emplace(v.at(i), i); + } + return result; +} + +ProjectionNameOrderPreservingSet::ProjectionNameOrderPreservingSet(ProjectionNameVector v) + : _map(createMapFromVector(v)), _vector(std::move(v)) {} + +ProjectionNameOrderPreservingSet::ProjectionNameOrderPreservingSet( + const ProjectionNameOrderPreservingSet& other) + : _map(other._map), _vector(other._vector) {} + +ProjectionNameOrderPreservingSet::ProjectionNameOrderPreservingSet( + ProjectionNameOrderPreservingSet&& other) noexcept + : _map(std::move(other._map)), _vector(std::move(other._vector)) {} + +bool ProjectionNameOrderPreservingSet::operator==( + const ProjectionNameOrderPreservingSet& other) const { + return _vector == other._vector; +} + +std::pair<size_t, bool> ProjectionNameOrderPreservingSet::emplace_back( + ProjectionName projectionName) { + auto [index, found] = find(projectionName); + if (found) { + return {index, false}; + } + + const size_t id = _vector.size(); + _vector.emplace_back(std::move(projectionName)); + _map.emplace(_vector.back(), id); + return {id, true}; +} + +std::pair<size_t, bool> ProjectionNameOrderPreservingSet::find( + const ProjectionName& projectionName) const { + auto it = _map.find(projectionName); + if (it == _map.end()) { + return {0, false}; + } + + return {it->second, true}; +} + +bool ProjectionNameOrderPreservingSet::erase(const ProjectionName& projectionName) { + auto [index, found] = find(projectionName); + if (!found) { + return false; + } + + if (index < _vector.size() - 1) { + // Repoint map. + _map.at(_vector.back()) = index; + // Fill gap with last element. + _vector.at(index) = std::move(_vector.back()); + } + + _map.erase(projectionName); + _vector.resize(_vector.size() - 1); + + return true; +} + +bool ProjectionNameOrderPreservingSet::isEqualIgnoreOrder( + const ProjectionNameOrderPreservingSet& other) const { + size_t numMatches = 0; + for (const auto& projectionName : _vector) { + if (other.find(projectionName).second) { + numMatches++; + } else { + return false; + } + } + + return numMatches == other._vector.size(); +} + +const ProjectionNameVector& ProjectionNameOrderPreservingSet::getVector() const { + return _vector; +} + +bool FieldProjectionMap::operator==(const FieldProjectionMap& other) const { + return _ridProjection == other._ridProjection && _rootProjection == other._rootProjection && + _fieldProjections == other._fieldProjections; +} + +bool MemoLogicalNodeId::operator==(const MemoLogicalNodeId& other) const { + return _groupId == other._groupId && _index == other._index; +} + +size_t NodeIdHash::operator()(const MemoLogicalNodeId& id) const { + size_t result = 17; + updateHash(result, std::hash<GroupIdType>()(id._groupId)); + updateHash(result, std::hash<size_t>()(id._index)); + return result; +} + +bool MemoPhysicalNodeId::operator==(const MemoPhysicalNodeId& other) const { + return _groupId == other._groupId && _index == other._index; +} + +DebugInfo DebugInfo::kDefaultForTests = + DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, DebugInfo::kIterationLimitForTests); +DebugInfo DebugInfo::kDefaultForProd = DebugInfo(false, 0, -1); + +DebugInfo::DebugInfo(const bool debugMode, const int debugLevel, const int iterationLimit) + : _debugMode(debugMode), _debugLevel(debugLevel), _iterationLimit(iterationLimit) {} + +bool DebugInfo::isDebugMode() const { + return _debugMode; +} + +bool DebugInfo::hasDebugLevel(const int debugLevel) const { + return _debugLevel >= debugLevel; +} + +bool DebugInfo::exceedsIterationLimit(const int iterations) const { + return _iterationLimit >= 0 && iterations > _iterationLimit; +} + +CostType CostType::kInfinity = CostType(true /*isInfinite*/, 0.0); +CostType CostType::kZero = CostType(false /*isInfinite*/, 0.0); + +CostType::CostType(const bool isInfinite, const double cost) + : _isInfinite(isInfinite), _cost(cost) { + uassert(6624346, "Cost is negative", _cost >= 0.0); +} + +bool CostType::operator==(const CostType& other) const { + return _isInfinite == other._isInfinite && (_isInfinite || _cost == other._cost); +} + +bool CostType::operator!=(const CostType& other) const { + return !(*this == other); +} + +bool CostType::operator<(const CostType& other) const { + return !_isInfinite && (other._isInfinite || _cost < other._cost); +} + +CostType CostType::operator+(const CostType& other) const { + return (_isInfinite || other._isInfinite) ? kInfinity : fromDouble(_cost + other._cost); +} + +CostType CostType::operator-(const CostType& other) const { + uassert(6624001, "Cannot subtract an infinite cost", other != kInfinity); + return _isInfinite ? kInfinity : fromDouble(_cost - other._cost); +} + +CostType& CostType::operator+=(const CostType& other) { + *this = (*this + other); + return *this; +} + +CostType CostType::fromDouble(const double cost) { + uassert(8423327, "Invalid cost.", !std::isnan(cost) && cost >= 0.0); + return CostType(false /*isInfinite*/, cost); +} + +std::string CostType::toString() const { + std::ostringstream os; + if (_isInfinite) { + os << "{Infinite cost}"; + } else { + os << _cost; + } + return os.str(); +} + +double CostType::getCost() const { + uassert(6624002, "Attempted to coerce infinite cost to a double", !_isInfinite); + return _cost; +} + +bool CostType::isInfinite() const { + return _isInfinite; +} + +CollationOp reverseCollationOp(const CollationOp op) { + switch (op) { + case CollationOp::Ascending: + return CollationOp::Descending; + case CollationOp::Descending: + return CollationOp::Ascending; + case CollationOp::Clustered: + return CollationOp::Clustered; + + default: + MONGO_UNREACHABLE; + } +} + +bool collationOpsCompatible(const CollationOp availableOp, const CollationOp requiredOp) { + return requiredOp == CollationOp::Clustered || requiredOp == availableOp; +} + +bool collationsCompatible(const ProjectionCollationSpec& available, + const ProjectionCollationSpec& required) { + // Check if required is more restrictive than available. If yes, reject. + if (available.size() < required.size()) { + return false; + } + + for (size_t i = 0; i < required.size(); i++) { + const auto& requiredEntry = required.at(i); + const auto& availableEntry = available.at(i); + + if (requiredEntry.first != availableEntry.first || + !collationOpsCompatible(availableEntry.second, requiredEntry.second)) { + return false; + } + } + + // Available is at least as restrictive as required. + return true; +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/defs.h b/src/mongo/db/query/optimizer/defs.h new file mode 100644 index 00000000000..12c153ef4e9 --- /dev/null +++ b/src/mongo/db/query/optimizer/defs.h @@ -0,0 +1,254 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <set> +#include <sstream> +#include <string> +#include <vector> + +#include "mongo/db/query/optimizer/containers.h" +#include "mongo/db/query/optimizer/utils/printable_enum.h" + + +namespace mongo::optimizer { + +using FieldNameType = std::string; +using FieldPathType = std::vector<FieldNameType>; + +using CollectionNameType = std::string; + +using ProjectionName = std::string; +using ProjectionNameSet = opt::unordered_set<ProjectionName>; +using ProjectionNameOrderedSet = std::set<ProjectionName>; +using ProjectionNameVector = std::vector<ProjectionName>; +using ProjectionRenames = opt::unordered_map<ProjectionName, ProjectionName>; + +class ProjectionNameOrderPreservingSet { +public: + ProjectionNameOrderPreservingSet() = default; + ProjectionNameOrderPreservingSet(ProjectionNameVector v); + + ProjectionNameOrderPreservingSet(const ProjectionNameOrderPreservingSet& other); + ProjectionNameOrderPreservingSet(ProjectionNameOrderPreservingSet&& other) noexcept; + + bool operator==(const ProjectionNameOrderPreservingSet& other) const; + + std::pair<size_t, bool> emplace_back(ProjectionName projectionName); + std::pair<size_t, bool> find(const ProjectionName& projectionName) const; + bool erase(const ProjectionName& projectionName); + + bool isEqualIgnoreOrder(const ProjectionNameOrderPreservingSet& other) const; + + const ProjectionNameVector& getVector() const; + +private: + opt::unordered_map<ProjectionName, size_t> _map; + ProjectionNameVector _vector; +}; + +#define INDEXREQTARGET_NAMES(F) \ + F(Index) \ + F(Seek) \ + F(Complete) + +MAKE_PRINTABLE_ENUM(IndexReqTarget, INDEXREQTARGET_NAMES); +MAKE_PRINTABLE_ENUM_STRING_ARRAY(IndexReqTargetEnum, IndexReqTarget, INDEXREQTARGET_NAMES); +#undef INDEXREQTARGET_NAMES + +#define DISTRIBUTIONTYPE_NAMES(F) \ + F(Centralized) \ + F(Replicated) \ + F(RoundRobin) \ + F(HashPartitioning) \ + F(RangePartitioning) \ + F(UnknownPartitioning) + +MAKE_PRINTABLE_ENUM(DistributionType, DISTRIBUTIONTYPE_NAMES); +MAKE_PRINTABLE_ENUM_STRING_ARRAY(DistributionTypeEnum, DistributionType, DISTRIBUTIONTYPE_NAMES); +#undef DISTRIBUTIONTYPE_NAMES + +// In case of covering scan, index, or fetch, specify names of bound projections for each field. +// Also optionally specify if applicable the rid and record (root) projections. +struct FieldProjectionMap { + ProjectionName _ridProjection; + ProjectionName _rootProjection; + opt::unordered_map<FieldNameType, ProjectionName> _fieldProjections; + + bool operator==(const FieldProjectionMap& other) const; +}; + +// Used to generate field names encoding index keys for covered indexes. +static constexpr const char* kIndexKeyPrefix = "<indexKey>"; + +/** + * Memo-related types. + */ +using GroupIdType = int64_t; + +// Logical node id. +struct MemoLogicalNodeId { + GroupIdType _groupId; + size_t _index; + + bool operator==(const MemoLogicalNodeId& other) const; +}; + +struct NodeIdHash { + size_t operator()(const MemoLogicalNodeId& id) const; +}; +using NodeIdSet = opt::unordered_set<MemoLogicalNodeId, NodeIdHash>; + +// Physical node id. +struct MemoPhysicalNodeId { + GroupIdType _groupId; + size_t _index; + + bool operator==(const MemoPhysicalNodeId& other) const; +}; + +class DebugInfo { +public: + static const int kIterationLimitForTests = 10000; + static const int kDefaultDebugLevelForTests = 1; + + static DebugInfo kDefaultForTests; + static DebugInfo kDefaultForProd; + + DebugInfo(bool debugMode, int debugLevel, int iterationLimit); + + bool isDebugMode() const; + + bool hasDebugLevel(int debugLevel) const; + + bool exceedsIterationLimit(int iterations) const; + +private: + // Are we in debug mode? Can we do additional logging, etc? + const bool _debugMode; + + const int _debugLevel; + + // Maximum number of rewrite iterations. + const int _iterationLimit; +}; + +using CEType = double; +using SelectivityType = double; + +class CostType { +public: + static CostType kInfinity; + static CostType kZero; + + static CostType fromDouble(double cost); + + CostType(const CostType& other) = default; + CostType(CostType&& other) = default; + CostType& operator=(const CostType& other) = default; + + bool operator==(const CostType& other) const; + bool operator!=(const CostType& other) const; + bool operator<(const CostType& other) const; + + CostType operator+(const CostType& other) const; + CostType operator-(const CostType& other) const; + CostType& operator+=(const CostType& other); + + std::string toString() const; + + /** + * Returns the cost as a double, or asserts if infinite. + */ + double getCost() const; + + bool isInfinite() const; + +private: + CostType(bool isInfinite, double cost); + + bool _isInfinite; + double _cost; +}; + +struct CostAndCE { + CostType _cost; + CEType _ce; +}; + +#define COLLATIONOP_OPNAMES(F) \ + F(Ascending) \ + F(Descending) \ + F(Clustered) + +MAKE_PRINTABLE_ENUM(CollationOp, COLLATIONOP_OPNAMES); +MAKE_PRINTABLE_ENUM_STRING_ARRAY(CollationOpEnum, CollationOp, COLLATIONOP_OPNAMES); +#undef PATHSYNTAX_OPNAMES + +using ProjectionCollationEntry = std::pair<ProjectionName, CollationOp>; +using ProjectionCollationSpec = std::vector<ProjectionCollationEntry>; + +CollationOp reverseCollationOp(CollationOp op); + +bool collationOpsCompatible(CollationOp availableOp, CollationOp requiredOp); +bool collationsCompatible(const ProjectionCollationSpec& available, + const ProjectionCollationSpec& required); + +enum class DisableIndexOptions { + Enabled, // All types of indexes are enabled. + DisableAll, // Disable all indexes. + DisablePartialOnly // Only disable partial indexes. +}; + +struct QueryHints { + // Disable full collection scans. + bool _disableScan = false; + + // Disable index scans. + DisableIndexOptions _disableIndexes = DisableIndexOptions::Enabled; + + // Disable placing a hash-join during RIDIntersect implementation. + bool _disableHashJoinRIDIntersect = false; + + // Disable placing a merge-join during RIDIntersect implementation. + bool _disableMergeJoinRIDIntersect = false; + + // Disable placing a group-by and union based RIDIntersect implementation. + bool _disableGroupByAndUnionRIDIntersect = false; + + // If set keep track of rejected plans in the memo. + bool _keepRejectedPlans = false; + + // Disable Cascades branch-and-bound strategy, and fully evaluate all plans. Used in conjunction + // with keeping rejected plans. + bool _disableBranchAndBound = false; +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/explain.cpp b/src/mongo/db/query/optimizer/explain.cpp new file mode 100644 index 00000000000..064ffc1c0aa --- /dev/null +++ b/src/mongo/db/query/optimizer/explain.cpp @@ -0,0 +1,2411 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/explain.h" + +#include "mongo/db/exec/sbe/values/bson.h" +#include "mongo/db/query/optimizer/cascades/memo.h" +#include "mongo/db/query/optimizer/node.h" +#include "mongo/util/assert_util.h" + +namespace mongo::optimizer { + +BSONObj ABTPrinter::explainBSON() const { + auto [tag, val] = optimizer::ExplainGenerator::explainBSON( + _abtTree, true /*displayProperties*/, nullptr /*Memo*/, _nodeToPropsMap); + uassert(6624070, "Expected an object", tag == sbe::value::TypeTags::Object); + sbe::value::ValueGuard vg(tag, val); + + BSONObjBuilder builder; + sbe::bson::convertToBsonObj(builder, sbe::value::getObjectView(val)); + return builder.done().getOwned(); +} + +enum class ExplainVersion { V1, V2, V3, Vmax }; + +bool constexpr operator<(const ExplainVersion v1, const ExplainVersion v2) { + return static_cast<int>(v1) < static_cast<int>(v2); +} +bool constexpr operator<=(const ExplainVersion v1, const ExplainVersion v2) { + return static_cast<int>(v1) <= static_cast<int>(v2); +} +bool constexpr operator>(const ExplainVersion v1, const ExplainVersion v2) { + return static_cast<int>(v1) > static_cast<int>(v2); +} +bool constexpr operator>=(const ExplainVersion v1, const ExplainVersion v2) { + return static_cast<int>(v1) >= static_cast<int>(v2); +} + +static constexpr ExplainVersion kDefaultExplainVersion = ExplainVersion::V1; + +enum class CommandType { Indent, Unindent, AddLine }; + +struct CommandStruct { + CommandStruct() = default; + CommandStruct(const CommandType type, std::string str) : _type(type), _str(std::move(str)) {} + + CommandType _type; + std::string _str; +}; + +using CommandVector = std::vector<CommandStruct>; + +template <const ExplainVersion version = kDefaultExplainVersion> +class ExplainPrinterImpl { +public: + ExplainPrinterImpl() + : _cmd(), + _os(), + _osDirty(false), + _indentCount(0), + _childrenRemaining(0), + _cmdInsertPos(-1) {} + + ~ExplainPrinterImpl() { + uassert(6624003, "Unmatched indentations", _indentCount == 0); + uassert(6624004, "Incorrect child count mark", _childrenRemaining == 0); + } + + ExplainPrinterImpl(const ExplainPrinterImpl& other) = delete; + ExplainPrinterImpl& operator=(const ExplainPrinterImpl& other) = delete; + + explicit ExplainPrinterImpl(const std::string& initialStr) : ExplainPrinterImpl() { + print(initialStr); + } + + ExplainPrinterImpl(ExplainPrinterImpl&& other) noexcept + : _cmd(std::move(other._cmd)), + _os(std::move(other._os)), + _osDirty(other._osDirty), + _indentCount(other._indentCount), + _childrenRemaining(other._childrenRemaining), + _cmdInsertPos(other._cmdInsertPos) {} + + template <class T> + ExplainPrinterImpl& print(const T& t) { + _os << t; + _osDirty = true; + return *this; + } + + /** + * Here and below: "other" printer(s) may be siphoned out. + */ + ExplainPrinterImpl& print(ExplainPrinterImpl& other) { + return print(other, false /*singleLevel*/); + } + + template <class P> + ExplainPrinterImpl& printSingleLevel(P& other, const std::string& singleLevelSpacer = " ") { + return print(other, true /*singleLevel*/, singleLevelSpacer); + } + + ExplainPrinterImpl& printAppend(ExplainPrinterImpl& other) { + // Ignore append + return print(other); + } + + ExplainPrinterImpl& print(std::vector<ExplainPrinterImpl>& other) { + for (auto&& element : other) { + print(element); + } + return *this; + } + + ExplainPrinterImpl& printAppend(std::vector<ExplainPrinterImpl>& other) { + // Ignore append. + return print(other); + } + + ExplainPrinterImpl& setChildCount(const size_t childCount) { + if (version > ExplainVersion::V1) { + _childrenRemaining = childCount; + indent(""); + for (int i = 0; i < _childrenRemaining - 1; i++) { + indent("|"); + } + } + return *this; + } + + ExplainPrinterImpl& maybeReverse() { + if (version > ExplainVersion::V1) { + _cmdInsertPos = _cmd.size(); + } + return *this; + } + + ExplainPrinterImpl& fieldName(const std::string& name, + const ExplainVersion minVersion = ExplainVersion::V1, + const ExplainVersion maxVersion = ExplainVersion::Vmax) { + if (minVersion <= version && maxVersion >= version) { + print(name); + print(": "); + } + return *this; + } + + ExplainPrinterImpl& separator(const std::string& separator) { + return print(separator); + } + + std::string str() { + newLine(); + + std::ostringstream os; + std::vector<std::string> linePrefix; + + bool firstAddLine = true; + for (const auto& cmd : _cmd) { + switch (cmd._type) { + case CommandType::Indent: + linePrefix.push_back(cmd._str); + break; + + case CommandType::Unindent: { + linePrefix.pop_back(); + break; + } + + case CommandType::AddLine: { + for (const std::string& element : linePrefix) { + if (!element.empty()) { + os << element << ((version == ExplainVersion::V1) ? " " : " "); + } + } + os << cmd._str << "\n"; + + firstAddLine = false; + break; + } + + default: { MONGO_UNREACHABLE; } + } + } + + return os.str(); + } + + void newLine() { + if (!_osDirty) { + return; + } + const std::string& str = _os.str(); + _cmd.emplace_back(CommandType::AddLine, str); + _os.str(""); + _os.clear(); + _osDirty = false; + } + + const CommandVector& getCommands() const { + return _cmd; + } + +private: + template <class P> + ExplainPrinterImpl& print(P& other, + const bool singleLevel, + const std::string& singleLevelSpacer = " ") { + CommandVector toAppend; + if (_cmdInsertPos >= 0) { + toAppend = CommandVector(_cmd.cbegin() + _cmdInsertPos, _cmd.cend()); + _cmd.resize(static_cast<size_t>(_cmdInsertPos)); + } + + const bool hadChildrenRemaining = _childrenRemaining > 0; + if (hadChildrenRemaining) { + _childrenRemaining--; + } + other.newLine(); + + if (singleLevel) { + uassert(6624071, "Unexpected dirty status", _osDirty); + + bool first = true; + for (const auto& element : other.getCommands()) { + if (element._type == CommandType::AddLine) { + if (first) { + first = false; + } else { + _os << singleLevelSpacer; + } + _os << element._str; + } + } + } else { + newLine(); + if (!hadChildrenRemaining) { + indent(); + } + for (const auto& element : other.getCommands()) { + _cmd.push_back(element); + } + unIndent(); + } + + if (_cmdInsertPos >= 0) { + std::copy(toAppend.cbegin(), toAppend.cend(), std::back_inserter(_cmd)); + } + + return *this; + } + + void indent(std::string s = " ") { + newLine(); + _indentCount++; + _cmd.emplace_back(CommandType::Indent, std::move(s)); + } + + void unIndent() { + newLine(); + _indentCount--; + _cmd.emplace_back(CommandType::Unindent, ""); + } + + CommandVector _cmd; + std::ostringstream _os; + bool _osDirty; + int _indentCount; + int _childrenRemaining; + int _cmdInsertPos; +}; + +template <> +class ExplainPrinterImpl<ExplainVersion::V3> { + static constexpr ExplainVersion version = ExplainVersion::V3; + +public: + ExplainPrinterImpl() { + reset(); + } + + ~ExplainPrinterImpl() { + if (_initialized) { + releaseValue(_tag, _val); + } + } + + ExplainPrinterImpl(const ExplainPrinterImpl& other) = delete; + ExplainPrinterImpl& operator=(const ExplainPrinterImpl& other) = delete; + + ExplainPrinterImpl(ExplainPrinterImpl&& other) noexcept { + _nextFieldName = std::move(other._nextFieldName); + _initialized = other._initialized; + _tag = other._tag; + _val = other._val; + _fieldNameSet = std::move(other._fieldNameSet); + + other.reset(); + } + + explicit ExplainPrinterImpl(const std::string& nodeName) : ExplainPrinterImpl() { + fieldName("nodeType").print(nodeName); + } + + auto moveValue() { + auto result = std::pair<sbe::value::TypeTags, sbe::value::Value>(_tag, _val); + reset(); + return result; + } + + ExplainPrinterImpl& print(const bool v) { + addValue(sbe::value::TypeTags::Boolean, v); + return *this; + } + + ExplainPrinterImpl& print(const int64_t v) { + addValue(sbe::value::TypeTags::NumberInt64, sbe::value::bitcastFrom<int64_t>(v)); + return *this; + } + + ExplainPrinterImpl& print(const int32_t v) { + addValue(sbe::value::TypeTags::NumberInt32, sbe::value::bitcastFrom<int32_t>(v)); + return *this; + } + + ExplainPrinterImpl& print(const size_t v) { + addValue(sbe::value::TypeTags::NumberInt64, sbe::value::bitcastFrom<size_t>(v)); + return *this; + } + + ExplainPrinterImpl& print(const double v) { + addValue(sbe::value::TypeTags::NumberDouble, sbe::value::bitcastFrom<double>(v)); + return *this; + } + + ExplainPrinterImpl& print(const std::pair<sbe::value::TypeTags, sbe::value::Value> v) { + auto [tag, val] = sbe::value::copyValue(v.first, v.second); + addValue(tag, val); + return *this; + } + + ExplainPrinterImpl& print(const std::string& s) { + auto [tag, val] = sbe::value::makeNewString(s); + addValue(tag, val); + return *this; + } + + ExplainPrinterImpl& print(const char* s) { + return print(static_cast<std::string>(s)); + } + + /** + * Here and below: "other" printer(s) may be siphoned out. + */ + ExplainPrinterImpl& print(ExplainPrinterImpl& other) { + return print(other, false /*append*/); + } + + ExplainPrinterImpl& printSingleLevel(ExplainPrinterImpl& other, + const std::string& /*singleLevelSpacer*/ = " ") { + // Ignore single level. + return print(other); + } + + ExplainPrinterImpl& printAppend(ExplainPrinterImpl& other) { + return print(other, true /*append*/); + } + + ExplainPrinterImpl& print(std::vector<ExplainPrinterImpl>& other) { + return print(other, false /*append*/); + } + + ExplainPrinterImpl& printAppend(std::vector<ExplainPrinterImpl>& other) { + return print(other, true /*append*/); + } + + ExplainPrinterImpl& setChildCount(const size_t /*childCount*/) { + // Ignored. + return *this; + } + + ExplainPrinterImpl& maybeReverse() { + // Ignored. + return *this; + } + + ExplainPrinterImpl& fieldName(const std::string& name, + const ExplainVersion minVersion = ExplainVersion::V1, + const ExplainVersion maxVersion = ExplainVersion::Vmax) { + if (minVersion <= version && maxVersion >= version) { + _nextFieldName = name; + } + return *this; + } + + ExplainPrinterImpl& separator(const std::string& /*separator*/) { + // Ignored. + return *this; + } + +private: + ExplainPrinterImpl& print(ExplainPrinterImpl& other, const bool append) { + auto [tag, val] = other.moveValue(); + addValue(tag, val, append); + if (append) { + sbe::value::releaseValue(tag, val); + } + return *this; + } + + ExplainPrinterImpl& print(std::vector<ExplainPrinterImpl>& other, const bool append) { + auto [tag, val] = sbe::value::makeNewArray(); + sbe::value::Array* arr = sbe::value::getArrayView(val); + for (auto&& element : other) { + auto [tag1, val1] = element.moveValue(); + arr->push_back(tag1, val1); + } + addValue(tag, val, append); + return *this; + } + + void addValue(sbe::value::TypeTags tag, sbe::value::Value val, const bool append = false) { + if (!_initialized) { + _initialized = true; + _canAppend = !_nextFieldName.empty(); + if (_canAppend) { + std::tie(_tag, _val) = sbe::value::makeNewObject(); + } else { + _tag = tag; + _val = val; + return; + } + } + + if (!_canAppend) { + uasserted(6624072, "Cannot append to scalar"); + return; + } + + if (append) { + uassert(6624073, "Field name is not empty", _nextFieldName.empty()); + uassert(6624349, + "Other printer does not contain Object", + tag == sbe::value::TypeTags::Object); + sbe::value::Object* obj = sbe::value::getObjectView(val); + for (size_t i = 0; i < obj->size(); i++) { + const auto field = obj->getAt(i); + auto [fieldTag, fieldVal] = sbe::value::copyValue(field.first, field.second); + addField(obj->field(i), fieldTag, fieldVal); + } + } else { + addField(_nextFieldName, tag, val); + _nextFieldName.clear(); + } + } + + void addField(const std::string& fieldName, sbe::value::TypeTags tag, sbe::value::Value val) { + uassert(6624074, "Field name is empty", !fieldName.empty()); + uassert(6624075, "Duplicate field name", _fieldNameSet.insert(fieldName).second); + sbe::value::getObjectView(_val)->push_back(fieldName, tag, val); + } + + void reset() { + _nextFieldName.clear(); + _initialized = false; + _canAppend = false; + _tag = sbe::value::TypeTags::Nothing; + _val = 0; + _fieldNameSet.clear(); + } + + std::string _nextFieldName; + bool _initialized; + bool _canAppend; + sbe::value::TypeTags _tag; + sbe::value::Value _val; + // For debugging. + opt::unordered_set<std::string> _fieldNameSet; +}; + +template <const ExplainVersion version = kDefaultExplainVersion> +class ExplainGeneratorTransporter { +public: + using ExplainPrinter = ExplainPrinterImpl<version>; + + ExplainGeneratorTransporter(bool displayProperties = false, + const cascades::Memo* memo = nullptr, + const NodeToGroupPropsMap& nodeMap = {}) + : _displayProperties(displayProperties), _memo(memo), _nodeMap(nodeMap) { + uassert(6624005, + "Memo must be provided in order to display properties.", + !_displayProperties || (_memo != nullptr || version == ExplainVersion::V3)); + } + + /** + * Helper function that appends the logical and physical properties of 'node' nested under a new + * field named 'properties'. Only applicable for BSON explain, for other versions this is a + * no-op. + */ + void maybePrintProps(ExplainPrinter& nodePrinter, const Node& node) { + if (!_displayProperties || version != ExplainVersion::V3 || _nodeMap.empty()) { + return; + } + auto it = _nodeMap.find(&node); + uassert(6624006, "Failed to find node properties", it != _nodeMap.end()); + + const NodeProps& props = it->second; + + ExplainPrinter logPropPrinter = printLogicalProps("logical", props._logicalProps); + ExplainPrinter physPropPrinter = printPhysProps("physical", props._physicalProps); + + ExplainPrinter propsPrinter; + propsPrinter.fieldName("cost") + .print(props._cost.getCost()) + .fieldName("localCost") + .print(props._localCost.getCost()) + .fieldName("adjustedCE") + .print(props._adjustedCE) + .fieldName("planNodeID") + .print(props._planNodeId) + .fieldName("logicalProperties") + .print(logPropPrinter) + .fieldName("physicalProperties") + .print(physPropPrinter); + ExplainPrinter res; + res.fieldName("properties").print(propsPrinter); + nodePrinter.printAppend(res); + } + + static void printBooleanFlag(ExplainPrinter& printer, + const std::string& name, + const bool flag, + const bool addComma = true) { + if constexpr (version < ExplainVersion::V3) { + if (flag) { + if (addComma) { + printer.print(", "); + } + printer.print(name); + } + } else if constexpr (version == ExplainVersion::V3) { + printer.fieldName(name).print(flag); + } else { + static_assert("Unknown version"); + } + } + + static void printDirectToParentHelper(const bool directToParent, + ExplainPrinter& parent, + std::function<void(ExplainPrinter& printer)> fn) { + if (directToParent) { + fn(parent); + } else { + ExplainPrinter printer; + fn(printer); + parent.printAppend(printer); + } + } + + /** + * Nodes + */ + ExplainPrinter transport(const References& references, std::vector<ExplainPrinter> inResults) { + ExplainPrinter printer; + printer.separator("RefBlock: ").printAppend(inResults); + return printer; + } + + ExplainPrinter transport(const ExpressionBinder& binders, + std::vector<ExplainPrinter> inResults) { + std::map<std::string, ExplainPrinter> ordered; + for (size_t idx = 0; idx < inResults.size(); ++idx) { + ordered.emplace(binders.names()[idx], std::move(inResults[idx])); + } + + ExplainPrinter printer; + printer.separator("BindBlock:"); + + for (auto& [name, child] : ordered) { + if constexpr (version < ExplainVersion::V3) { + ExplainPrinter local; + local.print("[").print(name).print("]").print(child); + printer.print(local); + } else if constexpr (version == ExplainVersion::V3) { + printer.separator(" ").fieldName(name).print(child); + } else { + static_assert("Unknown version"); + } + } + + return printer; + } + + static void printFieldProjectionMap(ExplainPrinter& printer, const FieldProjectionMap& map) { + std::map<FieldNameType, ProjectionName> ordered; + if (!map._ridProjection.empty()) { + ordered["<rid>"] = map._ridProjection; + } + if (!map._rootProjection.empty()) { + ordered["<root>"] = map._rootProjection; + } + for (const auto& entry : map._fieldProjections) { + ordered.insert(entry); + } + + if constexpr (version < ExplainVersion::V3) { + bool first = true; + for (const auto& [fieldName, projectionName] : ordered) { + if (first) { + first = false; + } else { + printer.print(", "); + } + printer.print("'").print(fieldName).print("': ").print(projectionName); + } + } else if constexpr (version == ExplainVersion::V3) { + ExplainPrinter local; + for (const auto& [fieldName, projectionName] : ordered) { + local.fieldName(fieldName).print(projectionName); + } + printer.fieldName("fieldProjectionMap").print(local); + } else { + static_assert("Unknown version"); + } + } + + ExplainPrinter transport(const ScanNode& node, ExplainPrinter bindResult) { + ExplainPrinter printer("Scan"); + maybePrintProps(printer, node); + + printer.separator(" [") + .fieldName("scanDefName", ExplainVersion::V3) + .print(node.getScanDefName()) + .separator("]") + .fieldName("bindings", ExplainVersion::V3) + .print(bindResult); + return printer; + } + + ExplainPrinter transport(const PhysicalScanNode& node, ExplainPrinter bindResult) { + ExplainPrinter printer("PhysicalScan"); + maybePrintProps(printer, node); + + printer.separator(" [{"); + printFieldProjectionMap(printer, node.getFieldProjectionMap()); + printer.separator("}, ") + .fieldName("scanDefName", ExplainVersion::V3) + .print(node.getScanDefName()); + + printBooleanFlag(printer, "parallel", node.useParallelScan()); + + printer.separator("]").fieldName("bindings", ExplainVersion::V3).print(bindResult); + + return printer; + } + + ExplainPrinter transport(const ValueScanNode& node, ExplainPrinter bindResult) { + ExplainPrinter valuePrinter = generate(node.getValueArray()); + + ExplainPrinter printer("ValueScan"); + maybePrintProps(printer, node); + printer.separator(" [") + .fieldName("arraySize") + .print(node.getArraySize()) + .separator("]") + .fieldName("values", ExplainVersion::V3) + .print(valuePrinter) + .fieldName("bindings", ExplainVersion::V3) + .print(bindResult); + return printer; + } + + ExplainPrinter transport(const CoScanNode& node) { + ExplainPrinter printer("CoScan"); + maybePrintProps(printer, node); + printer.separator(" []"); + return printer; + } + + void printInterval(ExplainPrinter& printer, const IntervalRequirement& interval) { + const BoundRequirement& lowBound = interval.getLowBound(); + const BoundRequirement& highBound = interval.getHighBound(); + + if constexpr (version < ExplainVersion::V3) { + const auto printBoundFn = [](ExplainPrinter& printer, const ABT& bound) { + // Since we are printing on a single level, use V1 printer in order to avoid + // children being reversed. + ExplainGeneratorTransporter<ExplainVersion::V1> gen; + auto boundPrinter = gen.generate(bound); + printer.printSingleLevel(boundPrinter); + }; + + printer.print(lowBound.isInclusive() ? "[" : "("); + if (lowBound.isInfinite()) { + printer.print("-inf"); + } else { + printBoundFn(printer, lowBound.getBound()); + } + + printer.print(", "); + if (highBound.isInfinite()) { + printer.print("+inf"); + } else { + printBoundFn(printer, highBound.getBound()); + } + + printer.print(highBound.isInclusive() ? "]" : ")"); + } else if constexpr (version == ExplainVersion::V3) { + ExplainPrinter lowBoundPrinter; + lowBoundPrinter.fieldName("inclusive") + .print(lowBound.isInclusive()) + .fieldName("infinite") + .print(lowBound.isInfinite()); + if (!lowBound.isInfinite()) { + ExplainPrinter boundPrinter = generate(lowBound.getBound()); + lowBoundPrinter.fieldName("bound").print(boundPrinter); + } + + ExplainPrinter highBoundPrinter; + highBoundPrinter.fieldName("inclusive") + .print(highBound.isInclusive()) + .fieldName("infinite") + .print(highBound.isInfinite()); + if (!highBound.isInfinite()) { + ExplainPrinter boundPrinter = generate(highBound.getBound()); + highBoundPrinter.fieldName("bound").print(boundPrinter); + } + + printer.fieldName("lowBound") + .print(lowBoundPrinter) + .fieldName("highBound") + .print(highBoundPrinter); + } else { + static_assert("Version not implemented"); + } + } + + std::string printInterval(const IntervalRequirement& interval) { + ExplainPrinter printer; + printInterval(printer, interval); + return printer.str(); + } + + ExplainPrinter printIntervalExpr(const IntervalReqExpr::Node& intervalExpr) { + IntervalPrinter<IntervalReqExpr> intervalPrinter(*this); + return intervalPrinter.print(intervalExpr); + } + + void printInterval(ExplainPrinter& printer, const MultiKeyIntervalRequirement& interval) { + if constexpr (version < ExplainVersion::V3) { + bool first = true; + for (const auto& entry : interval) { + if (first) { + first = false; + } else { + printer.print(", "); + } + printInterval(printer, entry); + } + } else if constexpr (version == ExplainVersion::V3) { + std::vector<ExplainPrinter> printers; + for (const auto& entry : interval) { + ExplainPrinter local; + printInterval(local, entry); + printers.push_back(std::move(local)); + } + printer.print(printers); + } else { + static_assert("Version not implemented"); + } + } + + template <class T> + class IntervalPrinter { + public: + IntervalPrinter(ExplainGeneratorTransporter& instance) : _instance(instance) {} + + ExplainPrinter transport(const typename T::Atom& node) { + ExplainPrinter printer; + printer.separator("{"); + _instance.printInterval(printer, node.getExpr()); + printer.separator("}"); + return printer; + } + + template <bool isConjunction> + ExplainPrinter print(std::vector<ExplainPrinter> childResults) { + if constexpr (version < ExplainVersion::V3) { + ExplainPrinter printer; + printer.separator("{"); + + bool first = true; + for (ExplainPrinter& child : childResults) { + if (first) { + first = false; + } else if constexpr (isConjunction) { + printer.print(" ^ "); + } else { + printer.print(" U "); + } + printer.print(child); + } + printer.separator("}"); + + return printer; + } else if constexpr (version == ExplainVersion::V3) { + ExplainPrinter printer; + if constexpr (isConjunction) { + printer.fieldName("conjunction"); + } else { + printer.fieldName("disjunction"); + } + printer.print(childResults); + return printer; + } else { + static_assert("Version not implemented"); + } + } + + ExplainPrinter transport(const typename T::Conjunction& node, + std::vector<ExplainPrinter> childResults) { + return print<true /*isConjunction*/>(std::move(childResults)); + } + + ExplainPrinter transport(const typename T::Disjunction& node, + std::vector<ExplainPrinter> childResults) { + return print<false /*isConjunction*/>(std::move(childResults)); + } + + ExplainPrinter print(const typename T::Node& intervals) { + return algebra::transport<false>(intervals, *this); + } + + private: + ExplainGeneratorTransporter& _instance; + }; + + ExplainPrinter transport(const IndexScanNode& node, ExplainPrinter bindResult) { + ExplainPrinter printer("IndexScan"); + maybePrintProps(printer, node); + + printer.separator(" [{"); + printFieldProjectionMap(printer, node.getFieldProjectionMap()); + printer.separator("}, "); + + const auto& spec = node.getIndexSpecification(); + printer.fieldName("scanDefName") + .print(spec.getScanDefName()) + .separator(", ") + .fieldName("indexDefName") + .print(spec.getIndexDefName()) + .separator(", "); + + printer.fieldName("interval").separator("{"); + printInterval(printer, spec.getInterval()); + printer.separator("}"); + + printBooleanFlag(printer, "reversed", spec.isReverseOrder()); + + printer.separator("]").fieldName("bindings", ExplainVersion::V3).print(bindResult); + return printer; + } + + ExplainPrinter transport(const SeekNode& node, + ExplainPrinter bindResult, + ExplainPrinter refsResult) { + ExplainPrinter printer("Seek"); + maybePrintProps(printer, node); + + printer.separator(" [") + .fieldName("ridProjection") + .print(node.getRIDProjectionName()) + .separator(", {"); + printFieldProjectionMap(printer, node.getFieldProjectionMap()); + printer.separator("}, ") + .fieldName("scanDefName", ExplainVersion::V3) + .print(node.getScanDefName()) + .separator("]") + .setChildCount(2) + .fieldName("bindings", ExplainVersion::V3) + .print(bindResult) + .fieldName("references", ExplainVersion::V3) + .print(refsResult); + + return printer; + } + + ExplainPrinter transport(const MemoLogicalDelegatorNode& node) { + ExplainPrinter printer("MemoLogicalDelegator"); + maybePrintProps(printer, node); + printer.separator(" [").fieldName("groupId").print(node.getGroupId()).separator("]"); + return printer; + } + + ExplainPrinter transport(const MemoPhysicalDelegatorNode& node) { + const auto id = node.getNodeId(); + + if (_displayProperties) { + const auto& group = _memo->getGroup(id._groupId); + const auto& result = group._physicalNodes.at(id._index); + uassert(6624076, + "Physical delegator must be pointing to an optimized result.", + result._nodeInfo.has_value()); + + const auto& nodeInfo = *result._nodeInfo; + const ABT& n = nodeInfo._node; + + ExplainPrinter nodePrinter = generate(n); + if (n.template is<MemoPhysicalDelegatorNode>()) { + // Handle delegation. + return nodePrinter; + } + + ExplainPrinter logPropPrinter = printLogicalProps("Logical", group._logicalProperties); + ExplainPrinter physPropPrinter = printPhysProps("Physical", result._physProps); + + ExplainPrinter printer("Properties"); + printer.separator(" [") + .fieldName("cost") + .print(nodeInfo._cost.getCost()) + .separator(", ") + .fieldName("localCost") + .print(nodeInfo._localCost.getCost()) + .separator(", ") + .fieldName("adjustedCE") + .print(nodeInfo._adjustedCE) + .separator("]") + .setChildCount(3) + .fieldName("logicalProperties", ExplainVersion::V3) + .print(logPropPrinter) + .fieldName("physicalProperties", ExplainVersion::V3) + .print(physPropPrinter) + .fieldName("node", ExplainVersion::V3) + .print(nodePrinter); + return printer; + } + + ExplainPrinter printer("MemoPhysicalDelegator"); + printer.separator(" [") + .fieldName("groupId") + .print(id._groupId) + .separator(", ") + .fieldName("index") + .print(id._index) + .separator("]"); + return printer; + } + + ExplainPrinter transport(const FilterNode& node, + ExplainPrinter childResult, + ExplainPrinter filterResult) { + ExplainPrinter printer("Filter"); + maybePrintProps(printer, node); + printer.separator(" []") + .setChildCount(2) + .fieldName("filter", ExplainVersion::V3) + .print(filterResult) + .fieldName("child", ExplainVersion::V3) + .print(childResult); + return printer; + } + + ExplainPrinter transport(const EvaluationNode& node, + ExplainPrinter childResult, + ExplainPrinter projectionResult) { + ExplainPrinter printer("Evaluation"); + maybePrintProps(printer, node); + printer.separator(" []") + .setChildCount(2) + .fieldName("projection", ExplainVersion::V3) + .print(projectionResult) + .fieldName("child", ExplainVersion::V3) + .print(childResult); + return printer; + } + + void printPartialSchemaReqMap(ExplainPrinter& parent, const PartialSchemaRequirements& reqMap) { + std::vector<ExplainPrinter> printers; + for (const auto& [key, req] : reqMap) { + ExplainPrinter local; + + local.fieldName("refProjection").print(key._projectionName).separator(", "); + ExplainPrinter pathPrinter = generate(key._path); + local.fieldName("path").separator("'").printSingleLevel(pathPrinter).separator("', "); + + if (req.hasBoundProjectionName()) { + local.fieldName("boundProjection") + .print(req.getBoundProjectionName()) + .separator(", "); + } + + local.fieldName("intervals"); + { + ExplainPrinter intervals = printIntervalExpr(req.getIntervals()); + local.printSingleLevel(intervals, "" /*singleLevelSpacer*/); + } + + printers.push_back(std::move(local)); + } + + parent.fieldName("requirementsMap").print(printers); + } + + ExplainPrinter transport(const SargableNode& node, + ExplainPrinter childResult, + ExplainPrinter bindResult, + ExplainPrinter refsResult) { + ExplainPrinter printer("Sargable"); + maybePrintProps(printer, node); + printer.separator(" [") + .fieldName("target", ExplainVersion::V3) + .print(IndexReqTargetEnum::toString[static_cast<int>(node.getTarget())]) + .separator("]") + .setChildCount(5); + + if constexpr (version < ExplainVersion::V3) { + ExplainPrinter local; + printPartialSchemaReqMap(local, node.getReqMap()); + printer.print(local); + } else if constexpr (version == ExplainVersion::V3) { + printPartialSchemaReqMap(printer, node.getReqMap()); + } else { + static_assert("Unknown version"); + } + + std::set<std::string> orderedIndexDefName; + for (const auto& entry : node.getCandidateIndexMap()) { + orderedIndexDefName.insert(entry.first); + } + + std::vector<ExplainPrinter> candidateIndexesPrinters; + size_t candidateIndex = 0; + for (const auto& indexDefName : orderedIndexDefName) { + candidateIndex++; + ExplainPrinter local; + local.fieldName("candidateId") + .print(candidateIndex) + .separator(", ") + .fieldName("indexDefName", ExplainVersion::V3) + .print(indexDefName) + .separator(", "); + + const auto& candidateIndexEntry = node.getCandidateIndexMap().at(indexDefName); + local.separator("{"); + printFieldProjectionMap(local, candidateIndexEntry._fieldProjectionMap); + local.separator("}, {"); + + { + std::set<size_t> orderedFields; + for (const size_t fieldId : candidateIndexEntry._fieldsToCollate) { + orderedFields.insert(fieldId); + } + + if constexpr (version < ExplainVersion::V3) { + bool first = true; + for (const size_t fieldId : orderedFields) { + if (first) { + first = false; + } else { + local.print(", "); + } + local.print(fieldId); + } + } else if constexpr (version == ExplainVersion::V3) { + std::vector<ExplainPrinter> printers; + for (const size_t fieldId : orderedFields) { + ExplainPrinter local1; + local1.print(fieldId); + printers.push_back(std::move(local1)); + } + local.fieldName("fieldsToCollate").print(printers); + } else { + static_assert("Unknown version"); + } + } + + local.separator("}, ").fieldName("intervals", ExplainVersion::V3); + { + IntervalPrinter<MultiKeyIntervalReqExpr> intervalPrinter(*this); + ExplainPrinter intervals = intervalPrinter.print(candidateIndexEntry._intervals); + local.printSingleLevel(intervals, "" /*singleLevelSpacer*/); + } + + if (!candidateIndexEntry._residualRequirements.empty()) { + if constexpr (version < ExplainVersion::V3) { + ExplainPrinter residualReqMapPrinter; + printPartialSchemaReqMap(residualReqMapPrinter, + candidateIndexEntry._residualRequirements); + local.print(residualReqMapPrinter); + } else if (version == ExplainVersion::V3) { + printPartialSchemaReqMap(local, candidateIndexEntry._residualRequirements); + } else { + static_assert("Unknown version"); + } + } + + if (!candidateIndexEntry._residualKeyMap.empty()) { + std::vector<ExplainPrinter> residualKeyMapPrinters; + for (const auto& [queryKey, residualKey] : candidateIndexEntry._residualKeyMap) { + ExplainPrinter local1; + + ExplainPrinter pathPrinter = generate(queryKey._path); + local1.fieldName("queryRefProjection") + .print(queryKey._projectionName) + .separator(", ") + .fieldName("queryPath") + .separator("'") + .printSingleLevel(pathPrinter) + .separator("', ") + .fieldName("residualRefProjection") + .print(residualKey._projectionName) + .separator(", "); + + ExplainPrinter pathPrinter1 = generate(residualKey._path); + local1.fieldName("residualPath") + .separator("'") + .printSingleLevel(pathPrinter1) + .separator("'"); + + residualKeyMapPrinters.push_back(std::move(local1)); + } + + local.fieldName("residualKeyMap").print(residualKeyMapPrinters); + + std::vector<ExplainPrinter> projNamePrinters; + for (const ProjectionName& projName : + candidateIndexEntry._residualRequirementsTempProjections) { + ExplainPrinter local1; + local1.print(projName); + projNamePrinters.push_back(std::move(local1)); + } + local.fieldName("tempProjections").print(projNamePrinters); + } + + candidateIndexesPrinters.push_back(std::move(local)); + } + ExplainPrinter candidateIndexesPrinter; + candidateIndexesPrinter.fieldName("candidateIndexes").print(candidateIndexesPrinters); + + printer.printAppend(candidateIndexesPrinter) + .fieldName("bindings", ExplainVersion::V3) + .print(bindResult) + .fieldName("references", ExplainVersion::V3) + .print(refsResult) + .fieldName("child", ExplainVersion::V3) + .print(childResult); + return printer; + } + + ExplainPrinter transport(const RIDIntersectNode& node, + ExplainPrinter leftChildResult, + ExplainPrinter rightChildResult) { + ExplainPrinter printer("RIDIntersect"); + maybePrintProps(printer, node); + + printer.separator(" [") + .fieldName("scanProjectionName", ExplainVersion::V3) + .print(node.getScanProjectionName()); + printBooleanFlag(printer, "hasLeftIntervals", node.hasLeftIntervals()); + printBooleanFlag(printer, "hasRightIntervals", node.hasRightIntervals()); + + printer.separator("]") + .setChildCount(2) + .maybeReverse() + .fieldName("leftChild", ExplainVersion::V3) + .print(leftChildResult) + .fieldName("rightChild", ExplainVersion::V3) + .print(rightChildResult); + return printer; + } + + ExplainPrinter transport(const BinaryJoinNode& node, + ExplainPrinter leftChildResult, + ExplainPrinter rightChildResult, + ExplainPrinter filterResult) { + ExplainPrinter printer("BinaryJoin"); + maybePrintProps(printer, node); + + printer.separator(" [") + .fieldName("joinType") + .print(JoinTypeEnum::toString[static_cast<int>(node.getJoinType())]); + + if constexpr (version < ExplainVersion::V3) { + if (!node.getCorrelatedProjectionNames().empty()) { + printer.print(", {"); + bool first = true; + for (const ProjectionName& projectionName : node.getCorrelatedProjectionNames()) { + if (first) { + first = false; + } else { + printer.print(", "); + } + printer.print(projectionName); + } + printer.print("}"); + } + } else if constexpr (version == ExplainVersion::V3) { + std::vector<ExplainPrinter> printers; + for (const ProjectionName& projectionName : node.getCorrelatedProjectionNames()) { + ExplainPrinter local; + local.print(projectionName); + printers.push_back(std::move(local)); + } + printer.fieldName("correlatedProjections").print(printers); + } else { + static_assert("Version not implemented"); + } + + printer.separator("]") + .setChildCount(3) + .fieldName("expression", ExplainVersion::V3) + .print(filterResult) + .maybeReverse() + .fieldName("leftChild", ExplainVersion::V3) + .print(leftChildResult) + .fieldName("rightChild", ExplainVersion::V3) + .print(rightChildResult); + return printer; + } + + void printEqualityJoinCondition(ExplainPrinter& printer, + const ProjectionNameVector& leftKeys, + const ProjectionNameVector& rightKeys) { + if constexpr (version < ExplainVersion::V3) { + printer.print("Condition"); + for (size_t i = 0; i < leftKeys.size(); i++) { + ExplainPrinter local; + local.print(leftKeys.at(i)).print(" = ").print(rightKeys.at(i)); + printer.print(local); + } + } else if constexpr (version == ExplainVersion::V3) { + std::vector<ExplainPrinter> printers; + for (size_t i = 0; i < leftKeys.size(); i++) { + ExplainPrinter local; + local.fieldName("leftKey") + .print(leftKeys.at(i)) + .fieldName("rightKey") + .print(rightKeys.at(i)); + printers.push_back(std::move(local)); + } + printer.print(printers); + } else { + static_assert("Version not implemented"); + } + } + + ExplainPrinter transport(const HashJoinNode& node, + ExplainPrinter leftChildResult, + ExplainPrinter rightChildResult, + ExplainPrinter /*refsResult*/) { + ExplainPrinter printer("HashJoin"); + maybePrintProps(printer, node); + + printer.separator(" [") + .fieldName("joinType") + .print(JoinTypeEnum::toString[static_cast<int>(node.getJoinType())]) + .separator("]"); + + ExplainPrinter joinConditionPrinter; + printEqualityJoinCondition(joinConditionPrinter, node.getLeftKeys(), node.getRightKeys()); + + printer.setChildCount(3) + .fieldName("joinCondition", ExplainVersion::V3) + .print(joinConditionPrinter) + .maybeReverse() + .fieldName("leftChild", ExplainVersion::V3) + .print(leftChildResult) + .fieldName("rightChild", ExplainVersion::V3) + .print(rightChildResult); + return printer; + } + + ExplainPrinter transport(const MergeJoinNode& node, + ExplainPrinter leftChildResult, + ExplainPrinter rightChildResult, + ExplainPrinter /*refsResult*/) { + ExplainPrinter printer("MergeJoin"); + maybePrintProps(printer, node); + printer.separator(" []"); + + ExplainPrinter joinConditionPrinter; + printEqualityJoinCondition(joinConditionPrinter, node.getLeftKeys(), node.getRightKeys()); + + ExplainPrinter collationPrinter; + if constexpr (version < ExplainVersion::V3) { + collationPrinter.print("Collation"); + for (const CollationOp op : node.getCollation()) { + ExplainPrinter local; + local.print(CollationOpEnum::toString[static_cast<int>(op)]); + collationPrinter.print(local); + } + } else if constexpr (version == ExplainVersion::V3) { + std::vector<ExplainPrinter> printers; + for (const CollationOp op : node.getCollation()) { + ExplainPrinter local; + local.print(CollationOpEnum::toString[static_cast<int>(op)]); + printers.push_back(std::move(local)); + } + collationPrinter.print(printers); + } else { + static_assert("Version not implemented"); + } + + printer.setChildCount(4) + .fieldName("joinCondition", ExplainVersion::V3) + .print(joinConditionPrinter) + .fieldName("collation", ExplainVersion::V3) + .print(collationPrinter) + .maybeReverse() + .fieldName("leftChild", ExplainVersion::V3) + .print(leftChildResult) + .fieldName("rightChild", ExplainVersion::V3) + .print(rightChildResult); + return printer; + } + + ExplainPrinter transport(const UnionNode& node, + std::vector<ExplainPrinter> childResults, + ExplainPrinter bindResult, + ExplainPrinter /*refsResult*/) { + ExplainPrinter printer("Union"); + maybePrintProps(printer, node); + printer.separator(" []") + .setChildCount(childResults.size() + 1) + .fieldName("bindings", ExplainVersion::V3) + .print(bindResult) + .maybeReverse() + .fieldName("children", ExplainVersion::V3) + .print(childResults); + return printer; + } + + ExplainPrinter transport(const GroupByNode& node, + ExplainPrinter childResult, + ExplainPrinter bindAggResult, + ExplainPrinter refsAggResult, + ExplainPrinter bindGbResult, + ExplainPrinter refsGbResult) { + std::map<ProjectionName, size_t> ordered; + const ProjectionNameVector& aggProjectionNames = node.getAggregationProjectionNames(); + for (size_t i = 0; i < aggProjectionNames.size(); i++) { + ordered.emplace(aggProjectionNames.at(i), i); + } + + ExplainPrinter printer("GroupBy"); + maybePrintProps(printer, node); + printer.separator(" ["); + if (version >= ExplainVersion::V3 || node.getType() != GroupNodeType::Complete) { + printer.fieldName("type", ExplainVersion::V3) + .print(GroupNodeTypeEnum::toString[static_cast<int>(node.getType())]); + } + printer.separator("]"); + + std::vector<ExplainPrinter> aggPrinters; + for (const auto& [projectionName, index] : ordered) { + ExplainPrinter local; + local.separator("[") + .fieldName("projectionName", ExplainVersion::V3) + .print(projectionName) + .separator("]"); + ExplainPrinter aggExpr = generate(node.getAggregationExpressions().at(index)); + local.fieldName("aggregation", ExplainVersion::V3).print(aggExpr); + aggPrinters.push_back(std::move(local)); + } + + ExplainPrinter gbPrinter; + gbPrinter.fieldName("groupings").print(refsGbResult); + + ExplainPrinter aggPrinter; + aggPrinter.fieldName("aggregations").print(aggPrinters); + + printer.setChildCount(3) + .printAppend(gbPrinter) + .printAppend(aggPrinter) + .fieldName("child", ExplainVersion::V3) + .print(childResult); + return printer; + } + + ExplainPrinter transport(const UnwindNode& node, + ExplainPrinter childResult, + ExplainPrinter bindResult, + ExplainPrinter refsResult) { + ExplainPrinter printer("Unwind"); + maybePrintProps(printer, node); + + printer.separator(" ["); + printBooleanFlag(printer, "retainNonArrays", node.getRetainNonArrays(), false /*addComma*/); + printer.separator("]"); + + printer.setChildCount(2) + .fieldName("bind", ExplainVersion::V3) + .print(bindResult) + .fieldName("child", ExplainVersion::V3) + .print(childResult); + return printer; + } + + static void printCollationProperty(ExplainPrinter& parent, + const properties::CollationRequirement& property, + const bool directToParent) { + std::vector<ExplainPrinter> propPrinters; + for (const auto& entry : property.getCollationSpec()) { + ExplainPrinter local; + local.fieldName("projectionName", ExplainVersion::V3) + .print(entry.first) + .separator(": ") + .fieldName("collationOp", ExplainVersion::V3) + .print(CollationOpEnum::toString[static_cast<int>(entry.second)]); + propPrinters.push_back(std::move(local)); + } + + printDirectToParentHelper(directToParent, parent, [&](ExplainPrinter& printer) { + printer.fieldName("collation").print(propPrinters); + }); + } + + ExplainPrinter transport(const UniqueNode& node, + ExplainPrinter childResult, + ExplainPrinter /*refsResult*/) { + ExplainPrinter printer("Unique"); + maybePrintProps(printer, node); + + printer.separator(" []").setChildCount(2); + printPropertyProjections(printer, node.getProjections(), false /*directToParent*/); + printer.fieldName("child", ExplainVersion::V3).print(childResult); + + return printer; + } + + ExplainPrinter transport(const CollationNode& node, + ExplainPrinter childResult, + ExplainPrinter refsResult) { + ExplainPrinter printer("Collation"); + maybePrintProps(printer, node); + + printer.separator(" []").setChildCount(3); + printCollationProperty(printer, node.getProperty(), false /*directToParent*/); + printer.fieldName("references", ExplainVersion::V3) + .print(refsResult) + .fieldName("child", ExplainVersion::V3) + .print(childResult); + + return printer; + } + + static void printLimitSkipProperty(ExplainPrinter& propPrinter, + ExplainPrinter& limitPrinter, + ExplainPrinter& skipPrinter, + const properties::LimitSkipRequirement& property) { + propPrinter.fieldName("propType", ExplainVersion::V3) + .print("limitSkip") + .separator(":") + .printAppend(limitPrinter) + .printAppend(skipPrinter); + } + + static void printLimitSkipProperty(ExplainPrinter& parent, + const properties::LimitSkipRequirement& property, + const bool directToParent) { + ExplainPrinter limitPrinter; + limitPrinter.fieldName("limit"); + if (property.hasLimit()) { + limitPrinter.print(property.getLimit()); + } else { + limitPrinter.print("(none)"); + } + + ExplainPrinter skipPrinter; + skipPrinter.fieldName("skip").print(property.getSkip()); + + printDirectToParentHelper(directToParent, parent, [&](ExplainPrinter& printer) { + printLimitSkipProperty(printer, limitPrinter, skipPrinter, property); + }); + } + + ExplainPrinter transport(const LimitSkipNode& node, ExplainPrinter childResult) { + ExplainPrinter printer("LimitSkip"); + maybePrintProps(printer, node); + + printer.separator(" []").setChildCount(2); + printLimitSkipProperty(printer, node.getProperty(), false /*directToParent*/); + printer.fieldName("child", ExplainVersion::V3).print(childResult); + + return printer; + } + + static void printPropertyProjections(ExplainPrinter& parent, + const ProjectionNameVector& projections, + const bool directToParent) { + std::vector<ExplainPrinter> printers; + for (const ProjectionName& projection : projections) { + ExplainPrinter local; + local.print(projection); + printers.push_back(std::move(local)); + } + + printDirectToParentHelper(directToParent, parent, [&](ExplainPrinter& printer) { + printer.fieldName("projections"); + if (printers.empty()) { + ExplainPrinter dummy; + printer.print(dummy); + } else { + printer.print(printers); + } + }); + } + + static void printDistributionProperty(ExplainPrinter& parent, + const properties::DistributionRequirement& property, + const bool directToParent) { + const auto& distribAndProjections = property.getDistributionAndProjections(); + + ExplainPrinter typePrinter; + typePrinter.fieldName("type").print( + DistributionTypeEnum::toString[static_cast<int>(distribAndProjections._type)]); + + printBooleanFlag(typePrinter, "disableExchanges", property.getDisableExchanges()); + + const bool hasProjections = !distribAndProjections._projectionNames.empty(); + ExplainPrinter projectionPrinter; + if (hasProjections) { + printPropertyProjections( + projectionPrinter, distribAndProjections._projectionNames, true /*directToParent*/); + typePrinter.printAppend(projectionPrinter); + } + + printDirectToParentHelper(directToParent, parent, [&](ExplainPrinter& printer) { + printer.fieldName("distribution").print(typePrinter); + }); + } + + static void printProjectionRequirementProperty( + ExplainPrinter& parent, + const properties::ProjectionRequirement& property, + const bool directToParent) { + printPropertyProjections(parent, property.getProjections().getVector(), directToParent); + } + + ExplainPrinter transport(const ExchangeNode& node, + ExplainPrinter childResult, + ExplainPrinter refsResult) { + ExplainPrinter printer("Exchange"); + maybePrintProps(printer, node); + + printer.separator(" []").setChildCount(3); + printDistributionProperty(printer, node.getProperty(), false /*directToParent*/); + printer.fieldName("references", ExplainVersion::V3) + .print(refsResult) + .fieldName("child", ExplainVersion::V3) + .print(childResult); + + return printer; + } + + struct LogicalPropPrintVisitor { + LogicalPropPrintVisitor(ExplainPrinter& parent) : _parent(parent){}; + + void operator()(const properties::LogicalProperty&, + const properties::ProjectionAvailability& prop) { + ProjectionNameOrderedSet ordered; + for (const ProjectionName& projection : prop.getProjections()) { + ordered.insert(projection); + } + + std::vector<ExplainPrinter> printers; + for (const ProjectionName& projection : ordered) { + ExplainPrinter local; + local.print(projection); + printers.push_back(std::move(local)); + } + _parent.fieldName("projections").print(printers); + } + + void operator()(const properties::LogicalProperty&, + const properties::CardinalityEstimate& prop) { + std::vector<ExplainPrinter> fieldPrinters; + + ExplainPrinter cePrinter; + cePrinter.fieldName("ce").print(prop.getEstimate()); + fieldPrinters.push_back(std::move(cePrinter)); + + if (!prop.getPartialSchemaKeyCEMap().empty()) { + std::vector<ExplainPrinter> reqPrinters; + for (const auto& [key, ce] : prop.getPartialSchemaKeyCEMap()) { + ExplainGeneratorTransporter<version> gen; + ExplainPrinter pathPrinter = gen.generate(key._path); + + ExplainPrinter local; + local.fieldName("refProjection") + .print(key._projectionName) + .separator(", ") + .fieldName("path") + .separator("'") + .printSingleLevel(pathPrinter) + .separator("', ") + .fieldName("ce") + .print(ce); + reqPrinters.push_back(std::move(local)); + } + ExplainPrinter requirementsPrinter; + requirementsPrinter.fieldName("requirementCEs").print(reqPrinters); + fieldPrinters.push_back(std::move(requirementsPrinter)); + } + + _parent.fieldName("cardinalityEstimate").print(fieldPrinters); + } + + void operator()(const properties::LogicalProperty&, + const properties::IndexingAvailability& prop) { + ExplainPrinter printer; + printer.separator("[") + .fieldName("groupId") + .print(prop.getScanGroupId()) + .separator(", ") + .fieldName("scanProjection") + .print(prop.getScanProjection()) + .separator(", ") + .fieldName("scanDefName") + .print(prop.getScanDefName()); + printBooleanFlag(printer, "possiblyEqPredsOnly", prop.getPossiblyEqPredsOnly()); + printer.separator("]"); + + if (!prop.getSatisfiedPartialIndexes().empty()) { + std::set<std::string> ordered; + for (const auto& indexName : prop.getSatisfiedPartialIndexes()) { + ordered.insert(indexName); + } + + std::vector<ExplainPrinter> printers; + for (const auto& indexName : ordered) { + ExplainPrinter local; + local.print(indexName); + printers.push_back(std::move(local)); + } + printer.fieldName("satisfiedPartialIndexes").print(printers); + } + + _parent.fieldName("indexingAvailability").print(printer); + } + + void operator()(const properties::LogicalProperty&, + const properties::CollectionAvailability& prop) { + std::set<std::string> orderedSet; + for (const std::string& scanDef : prop.getScanDefSet()) { + orderedSet.insert(scanDef); + } + + std::vector<ExplainPrinter> printers; + for (const std::string& scanDef : orderedSet) { + ExplainPrinter local; + local.print(scanDef); + printers.push_back(std::move(local)); + } + if (printers.empty()) { + ExplainPrinter dummy; + printers.push_back(std::move(dummy)); + } + + _parent.fieldName("collectionAvailability").print(printers); + } + + void operator()(const properties::LogicalProperty&, + const properties::DistributionAvailability& prop) { + struct Comparator { + bool operator()(const properties::DistributionRequirement& d1, + const properties::DistributionRequirement& d2) const { + const auto& distr1 = d1.getDistributionAndProjections(); + const auto& distr2 = d2.getDistributionAndProjections(); + + if (distr1._type < distr2._type) { + return true; + } + if (distr1._type > distr2._type) { + return false; + } + return distr1._projectionNames < distr2._projectionNames; + } + }; + + std::set<properties::DistributionRequirement, Comparator> ordered; + for (const auto& distributionProp : prop.getDistributionSet()) { + ordered.insert(distributionProp); + } + + std::vector<ExplainPrinter> printers; + for (const auto& distributionProp : ordered) { + ExplainPrinter local; + printDistributionProperty(local, distributionProp, true /*directToParent*/); + printers.push_back(std::move(local)); + } + _parent.fieldName("distributionAvailability").print(printers); + } + + private: + // We don't own this. + ExplainPrinter& _parent; + }; + + struct PhysPropPrintVisitor { + PhysPropPrintVisitor(ExplainPrinter& parent) : _parent(parent){}; + + void operator()(const properties::PhysProperty&, + const properties::CollationRequirement& prop) { + printCollationProperty(_parent, prop, true /*directToParent*/); + } + + void operator()(const properties::PhysProperty&, + const properties::LimitSkipRequirement& prop) { + printLimitSkipProperty(_parent, prop, true /*directToParent*/); + } + + void operator()(const properties::PhysProperty&, + const properties::ProjectionRequirement& prop) { + printProjectionRequirementProperty(_parent, prop, true /*directToParent*/); + } + + void operator()(const properties::PhysProperty&, + const properties::DistributionRequirement& prop) { + printDistributionProperty(_parent, prop, true /*directToParent*/); + } + + void operator()(const properties::PhysProperty&, + const properties::IndexingRequirement& prop) { + ExplainPrinter printer; + + printer.fieldName("target", ExplainVersion::V3) + .print(IndexReqTargetEnum::toString[static_cast<int>(prop.getIndexReqTarget())]); + printBooleanFlag(printer, "needsRID", prop.getNeedsRID()); + printBooleanFlag(printer, "dedupRID", prop.getDedupRID()); + + // TODO: consider printing satisfied partial indexes. + _parent.fieldName("indexingRequirement").print(printer); + } + + void operator()(const properties::PhysProperty&, + const properties::RepetitionEstimate& prop) { + _parent.fieldName("repetitionEstimate").print(prop.getEstimate()); + } + + void operator()(const properties::PhysProperty&, const properties::LimitEstimate& prop) { + _parent.fieldName("limitEstimate").print(prop.getEstimate()); + } + + private: + // We don't own this. + ExplainPrinter& _parent; + }; + + template <class P, class V, class C> + static ExplainPrinter printProps(const std::string& description, const C& props) { + ExplainPrinter printer; + if (version < ExplainVersion::V3) { + printer.print(description).print(":"); + } + + std::map<typename P::key_type, P> ordered; + for (const auto& entry : props) { + ordered.insert(entry); + } + + ExplainPrinter local; + V visitor(local); + for (const auto& entry : ordered) { + entry.second.visit(visitor); + } + printer.print(local); + + return printer; + } + + static ExplainPrinter printLogicalProps(const std::string& description, + const properties::LogicalProps& props) { + return printProps<properties::LogicalProperty, LogicalPropPrintVisitor>(description, props); + } + + static ExplainPrinter printPhysProps(const std::string& description, + const properties::PhysProps& props) { + return printProps<properties::PhysProperty, PhysPropPrintVisitor>(description, props); + } + + ExplainPrinter transport(const RootNode& node, + ExplainPrinter childResult, + ExplainPrinter refsResult) { + ExplainPrinter printer("Root"); + maybePrintProps(printer, node); + + printer.separator(" []").setChildCount(3); + printProjectionRequirementProperty(printer, node.getProperty(), false /*directToParent*/); + printer.fieldName("references", ExplainVersion::V3) + .print(refsResult) + .fieldName("child", ExplainVersion::V3) + .print(childResult); + + return printer; + } + + /** + * Expressions + */ + ExplainPrinter transport(const Blackhole& expr) { + ExplainPrinter printer("Blackhole"); + printer.separator(" []"); + return printer; + } + + ExplainPrinter transport(const Constant& expr) { + ExplainPrinter printer("Const"); + printer.separator(" [") + .fieldName("value", ExplainVersion::V3) + .print(expr.get()) + .separator("]"); + return printer; + } + + ExplainPrinter transport(const Variable& expr) { + ExplainPrinter printer("Variable"); + printer.separator(" [") + .fieldName("name", ExplainVersion::V3) + .print(expr.name()) + .separator("]"); + return printer; + } + + ExplainPrinter transport(const UnaryOp& expr, ExplainPrinter inResult) { + ExplainPrinter printer("UnaryOp"); + printer.separator(" [") + .fieldName("op", ExplainVersion::V3) + .print(OperationsEnum::toString[static_cast<int>(expr.op())]) + .separator("]") + .setChildCount(1) + .fieldName("input", ExplainVersion::V3) + .print(inResult); + return printer; + } + + ExplainPrinter transport(const BinaryOp& expr, + ExplainPrinter leftResult, + ExplainPrinter rightResult) { + ExplainPrinter printer("BinaryOp"); + printer.separator(" [") + .fieldName("op", ExplainVersion::V3) + .print(OperationsEnum::toString[static_cast<int>(expr.op())]) + .separator("]") + .setChildCount(2) + .maybeReverse() + .fieldName("left", ExplainVersion::V3) + .print(leftResult) + .fieldName("right", ExplainVersion::V3) + .print(rightResult); + return printer; + } + + + ExplainPrinter transport(const If& expr, + ExplainPrinter condResult, + ExplainPrinter thenResult, + ExplainPrinter elseResult) { + ExplainPrinter printer("If"); + printer.separator(" []") + .setChildCount(3) + .maybeReverse() + .fieldName("condition", ExplainVersion::V3) + .print(condResult) + .fieldName("then", ExplainVersion::V3) + .print(thenResult) + .fieldName("else", ExplainVersion::V3) + .print(elseResult); + return printer; + } + + ExplainPrinter transport(const Let& expr, + ExplainPrinter bindResult, + ExplainPrinter exprResult) { + ExplainPrinter printer("Let"); + printer.separator(" [") + .fieldName("variable", ExplainVersion::V3) + .print(expr.varName()) + .separator("]") + .setChildCount(2) + .maybeReverse() + .fieldName("bind", ExplainVersion::V3) + .print(bindResult) + .fieldName("expression", ExplainVersion::V3) + .print(exprResult); + return printer; + } + + ExplainPrinter transport(const LambdaAbstraction& expr, ExplainPrinter inResult) { + ExplainPrinter printer("LambdaAbstraction"); + printer.separator(" [") + .fieldName("variable", ExplainVersion::V3) + .print(expr.varName()) + .separator("]") + .setChildCount(1) + .fieldName("input", ExplainVersion::V3) + .print(inResult); + return printer; + } + + ExplainPrinter transport(const LambdaApplication& expr, + ExplainPrinter lambdaResult, + ExplainPrinter argumentResult) { + ExplainPrinter printer("LambdaApplication"); + printer.separator(" []") + .setChildCount(2) + .maybeReverse() + .fieldName("lambda", ExplainVersion::V3) + .print(lambdaResult) + .fieldName("argument", ExplainVersion::V3) + .print(argumentResult); + return printer; + } + + ExplainPrinter transport(const FunctionCall& expr, std::vector<ExplainPrinter> argResults) { + ExplainPrinter printer("FunctionCall"); + printer.separator(" [") + .fieldName("name", ExplainVersion::V3) + .print(expr.name()) + .separator("]"); + if (!argResults.empty()) { + printer.setChildCount(argResults.size()) + .maybeReverse() + .fieldName("arguments", ExplainVersion::V3) + .print(argResults); + } + return printer; + } + + ExplainPrinter transport(const EvalPath& expr, + ExplainPrinter pathResult, + ExplainPrinter inputResult) { + ExplainPrinter printer("EvalPath"); + printer.separator(" []") + .setChildCount(2) + .maybeReverse() + .fieldName("path", ExplainVersion::V3) + .print(pathResult) + .fieldName("input", ExplainVersion::V3) + .print(inputResult); + return printer; + } + + ExplainPrinter transport(const EvalFilter& expr, + ExplainPrinter pathResult, + ExplainPrinter inputResult) { + ExplainPrinter printer("EvalFilter"); + printer.separator(" []") + .setChildCount(2) + .maybeReverse() + .fieldName("path", ExplainVersion::V3) + .print(pathResult) + .fieldName("input", ExplainVersion::V3) + .print(inputResult); + return printer; + } + + /** + * Paths + */ + ExplainPrinter transport(const PathConstant& path, ExplainPrinter inResult) { + ExplainPrinter printer("PathConstant"); + printer.separator(" []") + .setChildCount(1) + .fieldName("input", ExplainVersion::V3) + .print(inResult); + return printer; + } + + ExplainPrinter transport(const PathLambda& path, ExplainPrinter inResult) { + ExplainPrinter printer("PathLambda"); + printer.separator(" []") + .setChildCount(1) + .fieldName("input", ExplainVersion::V3) + .print(inResult); + return printer; + } + + ExplainPrinter transport(const PathIdentity& path) { + ExplainPrinter printer("PathIdentity"); + printer.separator(" []"); + return printer; + } + + ExplainPrinter transport(const PathDefault& path, ExplainPrinter inResult) { + ExplainPrinter printer("PathDefault"); + printer.separator(" []") + .setChildCount(1) + .fieldName("input", ExplainVersion::V3) + .print(inResult); + return printer; + } + + ExplainPrinter transport(const PathCompare& path, ExplainPrinter valueResult) { + ExplainPrinter printer("PathCompare"); + printer.separator(" [") + .fieldName("op", ExplainVersion::V3) + .print(OperationsEnum::toString[static_cast<int>(path.op())]) + .separator("]") + .setChildCount(1) + .fieldName("value", ExplainVersion::V3) + .print(valueResult); + return printer; + } + + static void printPathProjections(ExplainPrinter& printer, + const opt::unordered_set<std::string>& names) { + std::set<std::string> ordered; + for (const std::string& s : names) { + ordered.insert(s); + } + + if constexpr (version < ExplainVersion::V3) { + bool first = true; + for (const std::string& s : ordered) { + if (first) { + first = false; + } else { + printer.print(", "); + } + printer.print(s); + } + } else if constexpr (version == ExplainVersion::V3) { + std::vector<ExplainPrinter> printers; + for (const std::string& s : ordered) { + ExplainPrinter local; + local.print(s); + printers.push_back(std::move(local)); + } + printer.fieldName("projections").print(printers); + } else { + static_assert("Unknown version"); + } + } + + ExplainPrinter transport(const PathDrop& path) { + ExplainPrinter printer("PathDrop"); + printer.separator(" ["); + printPathProjections(printer, path.getNames()); + printer.separator("]"); + return printer; + } + + ExplainPrinter transport(const PathKeep& path) { + ExplainPrinter printer("PathKeep"); + printer.separator(" ["); + printPathProjections(printer, path.getNames()); + printer.separator("]"); + return printer; + } + + ExplainPrinter transport(const PathObj& path) { + ExplainPrinter printer("PathObj"); + printer.separator(" []"); + return printer; + } + + ExplainPrinter transport(const PathArr& path) { + ExplainPrinter printer("PathArr"); + printer.separator(" []"); + return printer; + } + + ExplainPrinter transport(const PathTraverse& path, ExplainPrinter inResult) { + ExplainPrinter printer("PathTraverse"); + printer.separator(" []") + .setChildCount(1) + .fieldName("input", ExplainVersion::V3) + .print(inResult); + return printer; + } + + ExplainPrinter transport(const PathField& path, ExplainPrinter inResult) { + ExplainPrinter printer("PathField"); + printer.separator(" [") + .fieldName("path", ExplainVersion::V3) + .print(path.name()) + .separator("]") + .setChildCount(1) + .fieldName("input", ExplainVersion::V3) + .print(inResult); + return printer; + } + + ExplainPrinter transport(const PathGet& path, ExplainPrinter inResult) { + ExplainPrinter printer("PathGet"); + printer.separator(" [") + .fieldName("path", ExplainVersion::V3) + .print(path.name()) + .separator("]") + .setChildCount(1) + .fieldName("input", ExplainVersion::V3) + .print(inResult); + return printer; + } + + ExplainPrinter transport(const PathComposeM& path, + ExplainPrinter leftResult, + ExplainPrinter rightResult) { + ExplainPrinter printer("PathComposeM"); + printer.separator(" []") + .setChildCount(2) + .maybeReverse() + .fieldName("leftInput", ExplainVersion::V3) + .print(leftResult) + .fieldName("rightInput", ExplainVersion::V3) + .print(rightResult); + return printer; + } + + ExplainPrinter transport(const PathComposeA& path, + ExplainPrinter leftResult, + ExplainPrinter rightResult) { + ExplainPrinter printer("PathComposeA"); + printer.separator(" []") + .setChildCount(2) + .maybeReverse() + .fieldName("leftInput", ExplainVersion::V3) + .print(leftResult) + .fieldName("rightInput", ExplainVersion::V3) + .print(rightResult); + return printer; + } + + ExplainPrinter transport(const Source& expr) { + ExplainPrinter printer("Source"); + printer.separator(" []"); + return printer; + } + + ExplainPrinter generate(const ABT& node) { + return algebra::transport<false>(node, *this); + } + + void printPhysNodeInfo(ExplainPrinter& printer, const cascades::PhysNodeInfo& nodeInfo) { + printer.fieldName("cost"); + if (nodeInfo._cost.isInfinite()) { + printer.print(nodeInfo._cost.toString()); + } else { + printer.print(nodeInfo._cost.getCost()); + } + printer.separator(", ") + .fieldName("localCost") + .print(nodeInfo._localCost.getCost()) + .separator(", ") + .fieldName("adjustedCE") + .print(nodeInfo._adjustedCE); + + ExplainPrinter nodePrinter = generate(nodeInfo._node); + printer.separator(", ").fieldName("node").print(nodePrinter); + } + + ExplainPrinter printMemo() { + std::vector<ExplainPrinter> groupPrinters; + for (size_t groupId = 0; groupId < _memo->getGroupCount(); groupId++) { + const cascades::Group& group = _memo->getGroup(groupId); + + ExplainPrinter groupPrinter; + groupPrinter.fieldName("groupId").print(groupId).setChildCount(3); + { + ExplainPrinter logicalPropPrinter = + printLogicalProps("Logical properties", group._logicalProperties); + groupPrinter.fieldName("logicalProperties", ExplainVersion::V3) + .print(logicalPropPrinter); + } + + { + std::vector<ExplainPrinter> logicalNodePrinters; + const ABTVector& logicalNodes = group._logicalNodes.getVector(); + for (size_t i = 0; i < logicalNodes.size(); i++) { + ExplainPrinter local; + local.fieldName("logicalNodeId").print(i); + ExplainPrinter nodePrinter = generate(logicalNodes.at(i)); + local.fieldName("node", ExplainVersion::V3).print(nodePrinter); + + logicalNodePrinters.push_back(std::move(local)); + } + ExplainPrinter logicalNodePrinter; + logicalNodePrinter.print(logicalNodePrinters); + + groupPrinter.fieldName("logicalNodes").print(logicalNodePrinter); + } + + { + std::vector<ExplainPrinter> physicalNodePrinters; + for (const auto& physOptResult : group._physicalNodes.getNodes()) { + ExplainPrinter local; + local.fieldName("physicalNodeId") + .print(physOptResult->_index) + .separator(", ") + .fieldName("costLimit"); + + if (physOptResult->_costLimit.isInfinite()) { + local.print(physOptResult->_costLimit.toString()); + } else { + local.print(physOptResult->_costLimit.getCost()); + } + + ExplainPrinter propPrinter = + printPhysProps("Physical properties", physOptResult->_physProps); + local.fieldName("physicalProperties", ExplainVersion::V3).print(propPrinter); + + if (physOptResult->_nodeInfo) { + ExplainPrinter local1; + printPhysNodeInfo(local1, *physOptResult->_nodeInfo); + + if (!physOptResult->_rejectedNodeInfo.empty()) { + std::vector<ExplainPrinter> rejectedPrinters; + for (const auto& rejectedPlan : physOptResult->_rejectedNodeInfo) { + ExplainPrinter local2; + printPhysNodeInfo(local2, rejectedPlan); + rejectedPrinters.emplace_back(std::move(local2)); + } + local1.fieldName("rejectedPlans").print(rejectedPrinters); + } + + local.fieldName("nodeInfo", ExplainVersion::V3).print(local1); + } else { + local.separator(" (failed to optimize)"); + } + + physicalNodePrinters.push_back(std::move(local)); + } + ExplainPrinter physNodePrinter; + physNodePrinter.print(physicalNodePrinters); + + groupPrinter.fieldName("physicalNodes").print(physNodePrinter); + } + + groupPrinters.push_back(std::move(groupPrinter)); + } + + ExplainPrinter printer; + printer.fieldName("Memo").print(groupPrinters); + return printer; + } + +private: + const bool _displayProperties; + + // We don't own this. + const cascades::Memo* _memo; + const NodeToGroupPropsMap& _nodeMap; +}; + +std::string ExplainGenerator::explain(const ABT& node, + const bool displayProperties, + const cascades::Memo* memo, + const NodeToGroupPropsMap& nodeMap) { + ExplainGeneratorTransporter gen(displayProperties, memo, nodeMap); + return gen.generate(node).str(); +} + +std::string ExplainGenerator::explainV2(const ABT& node, + const bool displayProperties, + const cascades::Memo* memo, + const NodeToGroupPropsMap& nodeMap) { + ExplainGeneratorTransporter<ExplainVersion::V2> gen(displayProperties, memo, nodeMap); + return gen.generate(node).str(); +} + +std::pair<sbe::value::TypeTags, sbe::value::Value> ExplainGenerator::explainBSON( + const ABT& node, + const bool displayProperties, + const cascades::Memo* memo, + const NodeToGroupPropsMap& nodeMap) { + ExplainGeneratorTransporter<ExplainVersion::V3> gen(displayProperties, memo, nodeMap); + return gen.generate(node).moveValue(); +} + +template <class PrinterType> +static void printBSONstr(PrinterType& printer, + const sbe::value::TypeTags tag, + const sbe::value::Value val) { + switch (tag) { + case sbe::value::TypeTags::Array: { + const auto* array = sbe::value::getArrayView(val); + + PrinterType local; + for (size_t index = 0; index < array->size(); index++) { + if (index > 0) { + local.print(", "); + local.newLine(); + } + const auto [tag1, val1] = array->getAt(index); + printBSONstr(local, tag1, val1); + } + printer.print("[").print(local).print("]"); + + break; + } + + case sbe::value::TypeTags::Object: { + const auto* obj = sbe::value::getObjectView(val); + + PrinterType local; + for (size_t index = 0; index < obj->size(); index++) { + if (index > 0) { + local.print(", "); + local.newLine(); + } + local.fieldName(obj->field(index)); + const auto [tag1, val1] = obj->getAt(index); + printBSONstr(local, tag1, val1); + } + printer.print("{").print(local).print("}"); + + break; + } + + default: { + std::ostringstream os; + os << std::make_pair(tag, val); + printer.print(os.str()); + } + } +} + +std::string ExplainGenerator::printBSON(const sbe::value::TypeTags tag, + const sbe::value::Value val) { + ExplainPrinterImpl<ExplainVersion::V2> printer; + printBSONstr(printer, tag, val); + return printer.str(); +} + +std::string ExplainGenerator::explainLogicalProps(const std::string& description, + const properties::LogicalProps& props) { + return ExplainGeneratorTransporter<ExplainVersion::V2>::printLogicalProps(description, props) + .str(); +} + +std::string ExplainGenerator::explainPhysProps(const std::string& description, + const properties::PhysProps& props) { + return ExplainGeneratorTransporter<ExplainVersion::V2>::printPhysProps(description, props) + .str(); +} + +std::string ExplainGenerator::explainMemo(const cascades::Memo& memo) { + ExplainGeneratorTransporter<ExplainVersion::V2> gen(false /*displayProperties*/, &memo); + return gen.printMemo().str(); +} + +std::pair<sbe::value::TypeTags, sbe::value::Value> ExplainGenerator::explainMemoBSON( + const cascades::Memo& memo) { + ExplainGeneratorTransporter<ExplainVersion::V3> gen(false /*displayProperties*/, &memo); + return gen.printMemo().moveValue(); +} + +std::string ExplainGenerator::explainPartialSchemaReqMap(const PartialSchemaRequirements& reqMap) { + ExplainGeneratorTransporter<ExplainVersion::V2> gen; + ExplainGeneratorTransporter<ExplainVersion::V2>::ExplainPrinter result; + gen.printPartialSchemaReqMap(result, reqMap); + return result.str(); +} + +std::string ExplainGenerator::explainInterval(const IntervalRequirement& interval) { + ExplainGeneratorTransporter<ExplainVersion::V2> gen; + return gen.printInterval(interval); +} + +std::string ExplainGenerator::explainIntervalExpr(const IntervalReqExpr::Node& intervalExpr) { + ExplainGeneratorTransporter<ExplainVersion::V2> gen; + return gen.printIntervalExpr(intervalExpr).str(); +} + +std::string _printNode(const ABT& node) { + if (node.empty()) { + return "Empty\n"; + } + return ExplainGenerator::explainV2(node); +} + +std::string _printInterval(const IntervalRequirement& interval) { + return ExplainGenerator::explainInterval(interval); +} + +std::string _printLogicalProps(const properties::LogicalProps& props) { + return ExplainGenerator::explainLogicalProps("Logical Properties", props); +} + +std::string _printPhysProps(const properties::PhysProps& props) { + return ExplainGenerator::explainPhysProps("Physical Properties", props); +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/explain.h b/src/mongo/db/query/optimizer/explain.h new file mode 100644 index 00000000000..970d912cdbc --- /dev/null +++ b/src/mongo/db/query/optimizer/explain.h @@ -0,0 +1,113 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/exec/sbe/values/value.h" +#include "mongo/db/query/optimizer/explain_interface.h" +#include "mongo/db/query/optimizer/metadata.h" +#include "mongo/db/query/optimizer/node_defs.h" +#include "mongo/db/query/optimizer/props.h" +#include "mongo/db/query/optimizer/syntax/syntax.h" + +namespace mongo::optimizer { + +namespace cascades { +class Memo; +} + +/** + * This structure holds any data that is required by the BSON version of explain. It is + * self-sufficient and separate because it must outlive the other optimizer state as it is used by + * the runtime plan executor. + */ +class ABTPrinter : public AbstractABTPrinter { +public: + ABTPrinter(ABT abtTree, NodeToGroupPropsMap nodeToPropsMap) + : _abtTree(std::move(abtTree)), _nodeToPropsMap(std::move(nodeToPropsMap)) {} + + BSONObj explainBSON() const override final; + +private: + ABT _abtTree; + NodeToGroupPropsMap _nodeToPropsMap; +}; + +class ExplainGenerator { +public: + // Optionally display logical and physical properties using the memo. + // whenever memo delegators are printed. + static std::string explain(const ABT& node, + bool displayProperties = false, + const cascades::Memo* memo = nullptr, + const NodeToGroupPropsMap& nodeMap = {}); + + // Optionally display logical and physical properties using the memo. + // whenever memo delegators are printed. + static std::string explainV2(const ABT& node, + bool displayProperties = false, + const cascades::Memo* memo = nullptr, + const NodeToGroupPropsMap& nodeMap = {}); + + static std::pair<sbe::value::TypeTags, sbe::value::Value> explainBSON( + const ABT& node, + bool displayProperties = false, + const cascades::Memo* memo = nullptr, + const NodeToGroupPropsMap& nodeMap = {}); + + static std::string printBSON(sbe::value::TypeTags tag, sbe::value::Value val); + + static std::string explainLogicalProps(const std::string& description, + const properties::LogicalProps& props); + static std::string explainPhysProps(const std::string& description, + const properties::PhysProps& props); + + static std::string explainMemo(const cascades::Memo& memo); + + static std::pair<sbe::value::TypeTags, sbe::value::Value> explainMemoBSON( + const cascades::Memo& memo); + + static std::string explainPartialSchemaReqMap(const PartialSchemaRequirements& reqMap); + + static std::string explainInterval(const IntervalRequirement& interval); + + static std::string explainIntervalExpr(const IntervalReqExpr::Node& intervalExpr); +}; + +// Functions used by GDB for pretty-printing. For whatever reason GDB cannot find the static +// methods of ExplainGenerator, while it can find the functions below. + +std::string _printNode(const ABT& node); + +std::string _printInterval(const IntervalRequirement& interval); + +std::string _printLogicalProps(const properties::LogicalProps& props); +std::string _printPhysProps(const properties::PhysProps& props); + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/explain_interface.h b/src/mongo/db/query/optimizer/explain_interface.h new file mode 100644 index 00000000000..a007d1c2b49 --- /dev/null +++ b/src/mongo/db/query/optimizer/explain_interface.h @@ -0,0 +1,48 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/bson/bsonobj.h" + + +namespace mongo::optimizer { + +/** + * Used to abstract away the explaining of the optimizer. + * Note: we should not depend on anything here (certainly not the rest of the optimizer). + */ +class AbstractABTPrinter { +public: + virtual ~AbstractABTPrinter() = default; + + virtual BSONObj explainBSON() const = 0; +}; + +}; // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/index_bounds.cpp b/src/mongo/db/query/optimizer/index_bounds.cpp new file mode 100644 index 00000000000..5c2ff9fd66b --- /dev/null +++ b/src/mongo/db/query/optimizer/index_bounds.cpp @@ -0,0 +1,246 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/index_bounds.h" + +#include "mongo/db/query/optimizer/node.h" + + +namespace mongo::optimizer { + +BoundRequirement::BoundRequirement() : _inclusive(false), _bound() {} + +BoundRequirement::BoundRequirement(bool inclusive, boost::optional<ABT> bound) + : _inclusive(inclusive), _bound(std::move(bound)) { + uassert(6624077, "Infinite bound cannot be inclusive", !inclusive || !isInfinite()); +} + +bool BoundRequirement::operator==(const BoundRequirement& other) const { + return _inclusive == other._inclusive && _bound == other._bound; +} + +bool BoundRequirement::isInclusive() const { + return _inclusive; +} + +void BoundRequirement::setInclusive(bool value) { + _inclusive = value; +} + +bool BoundRequirement::isInfinite() const { + return !_bound.has_value(); +} + +const ABT& BoundRequirement::getBound() const { + uassert(6624078, "Cannot retrieve infinite bound", !isInfinite()); + return _bound.get(); +} + +IntervalRequirement::IntervalRequirement(BoundRequirement lowBound, BoundRequirement highBound) + : _lowBound(std::move(lowBound)), _highBound(std::move(highBound)) {} + +bool IntervalRequirement::operator==(const IntervalRequirement& other) const { + return _lowBound == other._lowBound && _highBound == other._highBound; +} + +bool IntervalRequirement::isFullyOpen() const { + return _lowBound.isInfinite() && _highBound.isInfinite(); +} + +bool IntervalRequirement::isEquality() const { + return _lowBound.isInclusive() && _highBound.isInclusive() && _lowBound == _highBound; +} + +const BoundRequirement& IntervalRequirement::getLowBound() const { + return _lowBound; +} + +BoundRequirement& IntervalRequirement::getLowBound() { + return _lowBound; +} + +const BoundRequirement& IntervalRequirement::getHighBound() const { + return _highBound; +} + +BoundRequirement& IntervalRequirement::getHighBound() { + return _highBound; +} + +PartialSchemaKey::PartialSchemaKey() : PartialSchemaKey({}, make<PathIdentity>()) {} + +PartialSchemaKey::PartialSchemaKey(ProjectionName projectionName) + : PartialSchemaKey(std::move(projectionName), make<PathIdentity>()) {} + +PartialSchemaKey::PartialSchemaKey(ProjectionName projectionName, ABT path) + : _projectionName(std::move(projectionName)), _path(std::move(path)) { + assertPathSort(_path); +} + +bool PartialSchemaKey::operator==(const PartialSchemaKey& other) const { + return _projectionName == other._projectionName && _path == other._path; +} + +bool PartialSchemaKey::emptyPath() const { + return _path.is<PathIdentity>(); +} + +bool isIntervalReqFullyOpenDNF(const IntervalReqExpr::Node& n) { + if (auto singular = IntervalReqExpr::getSingularDNF(n); singular && singular->isFullyOpen()) { + return true; + } + return false; +} + +PartialSchemaRequirement::PartialSchemaRequirement() + : _intervals(IntervalReqExpr::makeSingularDNF()) {} + +PartialSchemaRequirement::PartialSchemaRequirement(ProjectionName boundProjectionName, + IntervalReqExpr::Node intervals) + : _boundProjectionName(std::move(boundProjectionName)), _intervals(std::move(intervals)) {} + +bool PartialSchemaRequirement::operator==(const PartialSchemaRequirement& other) const { + return _boundProjectionName == other._boundProjectionName && _intervals == other._intervals; +} + +bool PartialSchemaRequirement::hasBoundProjectionName() const { + return !_boundProjectionName.empty(); +} + +const ProjectionName& PartialSchemaRequirement::getBoundProjectionName() const { + return _boundProjectionName; +} + +void PartialSchemaRequirement::setBoundProjectionName(ProjectionName boundProjectionName) { + _boundProjectionName = std::move(boundProjectionName); +} + +const IntervalReqExpr::Node& PartialSchemaRequirement::getIntervals() const { + return _intervals; +} + +IntervalReqExpr::Node& PartialSchemaRequirement::getIntervals() { + return _intervals; +} + +/** + * Helper class used to compare PartialSchemaKey objects. + */ +class Path3WCompare { +public: + Path3WCompare() {} + + int compareTags(const ABT& n, const ABT& other) { + const auto t1 = n.tagOf(); + const auto t2 = other.tagOf(); + return (t1 == t2) ? 0 : ((t1 < t2) ? -1 : 1); + } + + int operator()(const ABT& n, const PathGet& node, const ABT& other) { + if (auto otherGet = other.cast<PathGet>(); otherGet != nullptr) { + const int varCmp = node.name().compare(otherGet->name()); + return (varCmp == 0) ? node.getPath().visit(*this, otherGet->getPath()) : varCmp; + } + return compareTags(n, other); + } + + int operator()(const ABT& n, const PathTraverse& node, const ABT& other) { + if (auto otherTraverse = other.cast<PathTraverse>(); otherTraverse != nullptr) { + return node.getPath().visit(*this, otherTraverse->getPath()); + } + return compareTags(n, other); + } + + int operator()(const ABT& n, const PathIdentity& node, const ABT& other) { + return compareTags(n, other); + } + + template <typename T, typename... Ts> + int operator()(const ABT& /*n*/, const T& /*node*/, Ts&&...) { + uasserted(6624079, "Unexpected node type"); + return 0; + } + + static int compare(const ABT& node, const ABT& other) { + Path3WCompare instance; + return node.visit(instance, other); + } +}; + +bool PartialSchemaKeyLessComparator::operator()(const PartialSchemaKey& k1, + const PartialSchemaKey& k2) const { + const int projCmp = k1._projectionName.compare(k2._projectionName); + if (projCmp != 0) { + return projCmp < 0; + } + return Path3WCompare::compare(k1._path, k2._path) < 0; +} + +ResidualRequirement::ResidualRequirement(PartialSchemaKey key, + PartialSchemaRequirement req, + CEType ce) + : _key(std::move(key)), _req(std::move(req)), _ce(ce) {} + +bool CandidateIndexEntry::operator==(const CandidateIndexEntry& other) const { + return _fieldProjectionMap == other._fieldProjectionMap && _intervals == other._intervals && + _residualRequirements == other._residualRequirements && + _fieldsToCollate == other._fieldsToCollate; +} + +IndexSpecification::IndexSpecification(std::string scanDefName, + std::string indexDefName, + MultiKeyIntervalRequirement interval, + bool reverseOrder) + : _scanDefName(std::move(scanDefName)), + _indexDefName(std::move(indexDefName)), + _interval(std::move(interval)), + _reverseOrder(reverseOrder) {} + +bool IndexSpecification::operator==(const IndexSpecification& other) const { + return _scanDefName == other._scanDefName && _indexDefName == other._indexDefName && + _interval == other._interval && _reverseOrder == other._reverseOrder; +} + +const std::string& IndexSpecification::getScanDefName() const { + return _scanDefName; +} + +const std::string& IndexSpecification::getIndexDefName() const { + return _indexDefName; +} + +const MultiKeyIntervalRequirement& IndexSpecification::getInterval() const { + return _interval; +} + +bool IndexSpecification::isReverseOrder() const { + return _reverseOrder; +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/index_bounds.h b/src/mongo/db/query/optimizer/index_bounds.h new file mode 100644 index 00000000000..f172c361292 --- /dev/null +++ b/src/mongo/db/query/optimizer/index_bounds.h @@ -0,0 +1,206 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/bool_expression.h" +#include "mongo/db/query/optimizer/defs.h" +#include "mongo/db/query/optimizer/syntax/syntax.h" + + +namespace mongo::optimizer { + +class BoundRequirement { +public: + static BoundRequirement makeInfinite() { + return BoundRequirement(false, boost::none); + } + + BoundRequirement(); + BoundRequirement(bool inclusive, boost::optional<ABT> bound); + + bool operator==(const BoundRequirement& other) const; + + bool isInclusive() const; + void setInclusive(bool value); + + bool isInfinite() const; + const ABT& getBound() const; + +private: + bool _inclusive; + + // If we do not have a bound ABT, the bound is considered infinite. + boost::optional<ABT> _bound; +}; + +class IntervalRequirement { +public: + IntervalRequirement() = default; + IntervalRequirement(BoundRequirement lowBound, BoundRequirement highBound); + + bool operator==(const IntervalRequirement& other) const; + + bool isFullyOpen() const; + bool isEquality() const; + + const BoundRequirement& getLowBound() const; + BoundRequirement& getLowBound(); + const BoundRequirement& getHighBound() const; + BoundRequirement& getHighBound(); + +private: + BoundRequirement _lowBound; + BoundRequirement _highBound; +}; + +struct PartialSchemaKey { + PartialSchemaKey(); + PartialSchemaKey(ProjectionName projectionName); + PartialSchemaKey(ProjectionName projectionName, ABT path); + + bool operator==(const PartialSchemaKey& other) const; + + bool emptyPath() const; + + // Referred, or input projection name. + ProjectionName _projectionName; + + // (Partially determined) path. + ABT _path; +}; + +using IntervalReqExpr = BoolExpr<IntervalRequirement>; +bool isIntervalReqFullyOpenDNF(const IntervalReqExpr::Node& n); + +class PartialSchemaRequirement { +public: + PartialSchemaRequirement(); + PartialSchemaRequirement(ProjectionName boundProjectionName, IntervalReqExpr::Node intervals); + + bool operator==(const PartialSchemaRequirement& other) const; + + bool hasBoundProjectionName() const; + const ProjectionName& getBoundProjectionName() const; + void setBoundProjectionName(ProjectionName boundProjectionName); + + const IntervalReqExpr::Node& getIntervals() const; + IntervalReqExpr::Node& getIntervals(); + +private: + // Bound, or output projection name. + ProjectionName _boundProjectionName; + + IntervalReqExpr::Node _intervals; +}; + +struct PartialSchemaKeyLessComparator { + bool operator()(const PartialSchemaKey& k1, const PartialSchemaKey& k2) const; +}; + +// Map from referred (or input) projection name to list of requirements for that projection. +using PartialSchemaRequirements = + std::map<PartialSchemaKey, PartialSchemaRequirement, PartialSchemaKeyLessComparator>; + +using PartialSchemaKeyCE = std::map<PartialSchemaKey, CEType, PartialSchemaKeyLessComparator>; +using ResidualKeyMap = std::map<PartialSchemaKey, PartialSchemaKey, PartialSchemaKeyLessComparator>; + +using PartialSchemaKeySet = std::set<PartialSchemaKey, PartialSchemaKeyLessComparator>; + +// Requirements which are not satisfied directly by an IndexScan, PhysicalScan or Seek (e.g. using +// an index field, or scan field). They are intended to be sorted in their containing vector from +// most to least selective. +struct ResidualRequirement { + ResidualRequirement(PartialSchemaKey key, PartialSchemaRequirement req, CEType ce); + + PartialSchemaKey _key; + PartialSchemaRequirement _req; + CEType _ce; +}; +using ResidualRequirements = std::vector<ResidualRequirement>; + +// A sequence of intervals corresponding, one for each index key. +using MultiKeyIntervalRequirement = std::vector<IntervalRequirement>; + +// Multi-key intervals represent unions and conjunctions of individual multi-key intervals. +using MultiKeyIntervalReqExpr = BoolExpr<MultiKeyIntervalRequirement>; + +// Used to pre-compute candidate indexes for SargableNodes. +struct CandidateIndexEntry { + bool operator==(const CandidateIndexEntry& other) const; + + FieldProjectionMap _fieldProjectionMap; + MultiKeyIntervalReqExpr::Node _intervals; + + PartialSchemaRequirements _residualRequirements; + // Projections needed to evaluate residual requirements. + ProjectionNameSet _residualRequirementsTempProjections; + + // Used for CE. Mapping for residual requirement key to query key. + ResidualKeyMap _residualKeyMap; + + // We have equalities on those index fields, and thus do not consider for collation + // requirements. + // TODO: consider a bitset. + opt::unordered_set<size_t> _fieldsToCollate; +}; + +using CandidateIndexMap = opt::unordered_map<std::string /*index name*/, CandidateIndexEntry>; + +class IndexSpecification { +public: + IndexSpecification(std::string scanDefName, + std::string indexDefName, + MultiKeyIntervalRequirement interval, + bool reverseOrder); + + bool operator==(const IndexSpecification& other) const; + + const std::string& getScanDefName() const; + const std::string& getIndexDefName() const; + + const MultiKeyIntervalRequirement& getInterval() const; + + bool isReverseOrder() const; + +private: + // Name of the collection. + const std::string _scanDefName; + + // The name of the index. + const std::string _indexDefName; + + // The index interval. + MultiKeyIntervalRequirement _interval; + + // Do we reverse the index order. + const bool _reverseOrder; +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/interval_intersection_test.cpp b/src/mongo/db/query/optimizer/interval_intersection_test.cpp new file mode 100644 index 00000000000..7593e199846 --- /dev/null +++ b/src/mongo/db/query/optimizer/interval_intersection_test.cpp @@ -0,0 +1,667 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <string> +#include <vector> + +#include "mongo/db/query/optimizer/explain.h" +#include "mongo/db/query/optimizer/metadata.h" +#include "mongo/db/query/optimizer/opt_phase_manager.h" +#include "mongo/db/query/optimizer/utils/interval_utils.h" +#include "mongo/db/query/optimizer/utils/unit_test_utils.h" +#include "mongo/platform/atomic_word.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/processinfo.h" + +namespace mongo::optimizer { +namespace { + +struct QueryTest { + std::string query; + std::string expectedPlan; +}; + +std::string optimizedQueryPlan(const std::string& query, + const opt::unordered_map<std::string, IndexDefinition>& indexes) { + PrefixId prefixId; + std::string scanDefName = "coll"; + Metadata metadata = {{{scanDefName, ScanDefinition{{}, indexes}}}}; + ABT translated = + translatePipeline(metadata, "[{$match: " + query + "}]", scanDefName, prefixId); + + OptPhaseManager phaseManager( + OptPhaseManager::getAllRewritesSet(), prefixId, metadata, DebugInfo::kDefaultForTests); + + ABT optimized = translated; + ASSERT_TRUE(phaseManager.optimize(optimized)); + return ExplainGenerator::explainV2(optimized); +} + +TEST(IntervalIntersection, SingleFieldIntersection) { + opt::unordered_map<std::string, IndexDefinition> testIndex = { + {"index1", makeIndexDefinition("a0", CollationOp::Ascending, /*Not multikey*/ false)}}; + + std::vector<QueryTest> queryTests = { + {"{a0: {$gt:14, $lt:21}}", + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': scan_0}, coll]\n" + "| | BindBlock:\n" + "| | [scan_0]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: coll, indexDefName: index1, interval: " + "{(Const [14], Const [21])}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n"}, + {"{$and: [{a0: {$gt:14}}, {a0: {$lt: 21}}]}", + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': scan_0}, coll]\n" + "| | BindBlock:\n" + "| | [scan_0]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: coll, indexDefName: index1, interval: " + "{(Const [14], Const [21])}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n"}, + {"{$or: [{$and: [{a0: {$gt:9}}, {a0: {$lt: 12}}]}, {$and: [{a0: {$gt:40}}, {a0: {$lt: " + "44}}]}]}", + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': scan_0}, coll]\n" + "| | BindBlock:\n" + "| | [scan_0]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [rid_0]\n" + "| aggregations: \n" + "Union []\n" + "| | BindBlock:\n" + "| | [rid_0]\n" + "| | Source []\n" + "| IndexScan [{'<rid>': rid_0}, scanDefName: coll, indexDefName: index1, " + "interval: {(Const [40], Const [44])}]\n" + "| BindBlock:\n" + "| [rid_0]\n" + "| Source []\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: coll, indexDefName: index1, interval: " + "{(Const [9], Const [12])}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n"}, + // Contradictions + // Empty interval + {"{$and: [{a0: {$gt:20}}, {a0: {$lt: 20}}]}", + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [scan_0]\n" + "| Const [Nothing]\n" + "LimitSkip []\n" + "| limitSkip:\n" + "| limit: 0\n" + "| skip: 0\n" + "CoScan []\n"}, + // One conjunct non-empty, one conjunct empty + {"{$or: [{$and: [{a0: {$gt:9}}, {a0: {$lt: 12}}]}, {$and: [{a0: {$gt:44}}, {a0: {$lt: " + "40}}]}]}", + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': scan_0}, coll]\n" + "| | BindBlock:\n" + "| | [scan_0]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: coll, indexDefName: index1, interval: " + "{(Const [9], Const [12])}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n"}, + // Both conjuncts empty, whole disjunct empty + {"{$or: [{$and: [{a0: {$gt:15}}, {a0: {$lt: 10}}]}, {$and: [{a0: {$gt:44}}, {a0: {$lt: " + "40}}]}]}", + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [scan_0]\n" + "| Const [Nothing]\n" + "LimitSkip []\n" + "| limitSkip:\n" + "| limit: 0\n" + "| skip: 0\n" + "CoScan []\n"}, + {"{$or: [{$and: [{a0: {$gt:12}}, {a0: {$lt: 12}}]}, {$and: [{a0: {$gte:42}}, {a0: {$lt: " + "42}}]}]}", + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [scan_0]\n" + "| Const [Nothing]\n" + "LimitSkip []\n" + "| limitSkip:\n" + "| limit: 0\n" + "| skip: 0\n" + "CoScan []\n"}, + }; + + /* + std::cout << "\nExpected query plans:\n\n"; + for (auto& qt : queryTests) { + std::cout << optimizedQueryPlan(qt.query, testIndex) << std::endl << std::endl; + } + */ + ASSERT_EQ(queryTests[0].expectedPlan, optimizedQueryPlan(queryTests[0].query, testIndex)); + ASSERT_EQ(queryTests[1].expectedPlan, optimizedQueryPlan(queryTests[1].query, testIndex)); + ASSERT_EQ(queryTests[2].expectedPlan, optimizedQueryPlan(queryTests[2].query, testIndex)); + ASSERT_EQ(queryTests[3].expectedPlan, optimizedQueryPlan(queryTests[3].query, testIndex)); + ASSERT_EQ(queryTests[4].expectedPlan, optimizedQueryPlan(queryTests[4].query, testIndex)); + ASSERT_EQ(queryTests[5].expectedPlan, optimizedQueryPlan(queryTests[5].query, testIndex)); + ASSERT_EQ(queryTests[6].expectedPlan, optimizedQueryPlan(queryTests[6].query, testIndex)); +} + +TEST(IntervalIntersection, MultiFieldIntersection) { + std::vector<TestIndexField> indexFields{{"a0", CollationOp::Ascending, false}, + {"b0", CollationOp::Ascending, false}}; + + opt::unordered_map<std::string, IndexDefinition> testIndex = { + {"index1", makeCompositeIndexDefinition(indexFields, false /*isMultiKey*/)}}; + + std::vector<QueryTest> queryTests = { + {"{$and: [{a0: {$gt:11}}, {a0: {$lt:14}}, {b0: {$gt: 21}}, {b0: {$lt: 12}}]}", + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [scan_0]\n" + "| Const [Nothing]\n" + "LimitSkip []\n" + "| limitSkip:\n" + "| limit: 0\n" + "| skip: 0\n" + "CoScan []\n"}, + {"{$and: [{a0: {$gt:14}}, {a0: {$lt:11}}, {b0: {$gt: 12}}, {b0: {$lt: 21}}]}", + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Filter []\n" + "| FunctionCall [traverseF]\n" + "| | | Const [false]\n" + "| | LambdaAbstraction [valCmp4]\n" + "| | BinaryOp [Lt]\n" + "| | | Const [0]\n" + "| | BinaryOp [Cmp3w]\n" + "| | | Const [21]\n" + "| | Variable [valCmp4]\n" + "| FunctionCall [getField]\n" + "| | Const [\"b0\"]\n" + "| Const [Nothing]\n" + "Filter []\n" + "| FunctionCall [traverseF]\n" + "| | | Const [false]\n" + "| | LambdaAbstraction [valCmp1]\n" + "| | BinaryOp [Gt]\n" + "| | | Const [0]\n" + "| | BinaryOp [Cmp3w]\n" + "| | | Const [12]\n" + "| | Variable [valCmp1]\n" + "| FunctionCall [getField]\n" + "| | Const [\"b0\"]\n" + "| Const [Nothing]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [scan_0]\n" + "| Const [Nothing]\n" + "LimitSkip []\n" + "| limitSkip:\n" + "| limit: 0\n" + "| skip: 0\n" + "CoScan []\n"}, + {"{$and: [{a0: {$gt:14}}, {a0: {$lt:11}}, {b0: {$gt: 21}}, {b0: {$lt: 12}}]}", + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Filter []\n" + "| FunctionCall [traverseF]\n" + "| | | Const [false]\n" + "| | LambdaAbstraction [valCmp10]\n" + "| | BinaryOp [Lt]\n" + "| | | Const [0]\n" + "| | BinaryOp [Cmp3w]\n" + "| | | Const [12]\n" + "| | Variable [valCmp10]\n" + "| FunctionCall [getField]\n" + "| | Const [\"b0\"]\n" + "| Const [Nothing]\n" + "Filter []\n" + "| FunctionCall [traverseF]\n" + "| | | Const [false]\n" + "| | LambdaAbstraction [valCmp7]\n" + "| | BinaryOp [Gt]\n" + "| | | Const [0]\n" + "| | BinaryOp [Cmp3w]\n" + "| | | Const [21]\n" + "| | Variable [valCmp7]\n" + "| FunctionCall [getField]\n" + "| | Const [\"b0\"]\n" + "| Const [Nothing]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [scan_0]\n" + "| Const [Nothing]\n" + "LimitSkip []\n" + "| limitSkip:\n" + "| limit: 0\n" + "| skip: 0\n" + "CoScan []\n"}, + {"{$and: [{a0: 42}, {b0: {$gt: 21}}, {b0: {$lt: 12}}]}", + "Root []\n" + "| | projections: \n" + "| | scan_0\n" + "| RefBlock: \n" + "| Variable [scan_0]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [scan_0]\n" + "| Const [Nothing]\n" + "LimitSkip []\n" + "| limitSkip:\n" + "| limit: 0\n" + "| skip: 0\n" + "CoScan []\n"}, + }; + + /*std::cout << "\nExpected query plans:\n\n"; + for (auto& qt : queryTests) { + std::cout << optimizedQueryPlan(qt.query, testIndex) << std::endl << std::endl; + }*/ + + ASSERT_EQ(queryTests[0].expectedPlan, optimizedQueryPlan(queryTests[0].query, testIndex)); + // TODO: these tests contain an escaped string literal in the explain output, which causes + // failure in other tests by not escaping string literals as in the expected explain. + // Most likely there is some shared state in the SBE printer that gets messed up. + // ASSERT_EQ(queryTests[1].expectedPlan, optimizedQueryPlan(queryTests[1].query, testIndex)); + // ASSERT_EQ(queryTests[2].expectedPlan, optimizedQueryPlan(queryTests[2].query, testIndex)); + // ASSERT_EQ(queryTests[3].expectedPlan, optimizedQueryPlan(queryTests[3].query, testIndex)); +} + +TEST(IntervalIntersection, VariableIntervals) { + { + auto interval = + IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{ + IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{ + IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ + BoundRequirement(true /*inclusive*/, make<Variable>("v1")), + BoundRequirement::makeInfinite()}), + IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ + BoundRequirement(false /*inclusive*/, make<Variable>("v2")), + BoundRequirement::makeInfinite()})})}); + + auto result = intersectDNFIntervals(interval); + ASSERT_TRUE(result); + + // (max(v1, v2), +inf) U [v2 >= v1 ? MaxKey : v1, max(v1, v2)] + ASSERT_EQ( + "{\n" + " {\n" + " {[If [] BinaryOp [Gte] Variable [v2] Variable [v1] Const [maxKey] Variable " + "[v1], If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable [v2]]}\n" + " }\n" + " U \n" + " {\n" + " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable " + "[v2], +inf)}\n" + " }\n" + "}\n", + ExplainGenerator::explainIntervalExpr(*result)); + + // Make sure repeated intersection does not change the result. + auto result1 = intersectDNFIntervals(*result); + ASSERT_TRUE(result1); + ASSERT_TRUE(*result == *result1); + } + + { + auto interval = + IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{ + IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{ + IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ + BoundRequirement(true /*inclusive*/, make<Variable>("v1")), + BoundRequirement(true /*inclusive*/, make<Variable>("v3"))}), + IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ + BoundRequirement(true /*inclusive*/, make<Variable>("v2")), + BoundRequirement(true /*inclusive*/, make<Variable>("v4"))})})}); + + auto result = intersectDNFIntervals(interval); + ASSERT_TRUE(result); + + // [v1, v3] ^ [v2, v4] -> [max(v1, v2), min(v3, v4)] + ASSERT_EQ( + "{\n" + " {\n" + " {[If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable " + "[v2], If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n" + " }\n" + "}\n", + ExplainGenerator::explainIntervalExpr(*result)); + + // Make sure repeated intersection does not change the result. + auto result1 = intersectDNFIntervals(*result); + ASSERT_TRUE(result1); + ASSERT_TRUE(*result == *result1); + } + + { + auto interval = + IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{ + IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{ + IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ + BoundRequirement(false /*inclusive*/, make<Variable>("v1")), + BoundRequirement(true /*inclusive*/, make<Variable>("v3"))}), + IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ + BoundRequirement(true /*inclusive*/, make<Variable>("v2")), + BoundRequirement(true /*inclusive*/, make<Variable>("v4"))})})}); + + auto result = intersectDNFIntervals(interval); + ASSERT_TRUE(result); + + ASSERT_EQ( + "{\n" + " {\n" + " {[If [] BinaryOp [Gte] Variable [v1] Variable [v2] Const [maxKey] Variable " + "[v2], If [] BinaryOp [Lte] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable " + "[v1] Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] " + "Variable [v4] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable " + "[v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n" + " }\n" + " U \n" + " {\n" + " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable " + "[v2], If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n" + " }\n" + "}\n", + ExplainGenerator::explainIntervalExpr(*result)); + + // Make sure repeated intersection does not change the result. + auto result1 = intersectDNFIntervals(*result); + ASSERT_TRUE(result1); + ASSERT_TRUE(*result == *result1); + } + + { + auto interval = + IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{ + IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{ + IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ + BoundRequirement(false /*inclusive*/, make<Variable>("v1")), + BoundRequirement(true /*inclusive*/, make<Variable>("v3"))}), + IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ + BoundRequirement(true /*inclusive*/, make<Variable>("v2")), + BoundRequirement(false /*inclusive*/, make<Variable>("v4"))})})}); + + auto result = intersectDNFIntervals(interval); + ASSERT_TRUE(result); + + ASSERT_EQ( + "{\n" + " {\n" + " {[If [] BinaryOp [Gte] Variable [v1] Variable [v2] Const [maxKey] Variable " + "[v2], If [] BinaryOp [Lte] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable " + "[v1] Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] " + "Variable [v4] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable " + "[v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n" + " }\n" + " U \n" + " {\n" + " {[If [] BinaryOp [Gte] If [] BinaryOp [Gte] Variable [v1] Variable [v2] " + "Variable [v1] Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable " + "[v3] Variable [v4] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] " + "Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable " + "[v4], If [] BinaryOp [Lte] Variable [v4] Variable [v3] Const [minKey] Variable " + "[v3]]}\n" + " }\n" + " U \n" + " {\n" + " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable " + "[v2], If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4])}\n" + " }\n" + "}\n", + ExplainGenerator::explainIntervalExpr(*result)); + + // Make sure repeated intersection does not change the result. + auto result1 = intersectDNFIntervals(*result); + ASSERT_TRUE(result1); + ASSERT_TRUE(*result == *result1); + } +} + +template <int N> +void updateResults(const bool lowInc, + const int low, + const bool highInc, + const int high, + std::bitset<N>& inclusion) { + for (int v = 0; v < low + (lowInc ? 0 : 1); v++) { + inclusion.set(v, false); + } + for (int v = high + (highInc ? 1 : 0); v < N; v++) { + inclusion.set(v, false); + } +} + +template <int N> +class IntervalInclusionTransport { +public: + using ResultType = std::bitset<N>; + + ResultType transport(const IntervalReqExpr::Atom& node) { + const auto& expr = node.getExpr(); + const auto& lb = expr.getLowBound(); + const auto& hb = expr.getHighBound(); + + std::bitset<N> result; + result.flip(); + updateResults<N>(lb.isInclusive(), + lb.getBound().cast<Constant>()->getValueInt32(), + hb.isInclusive(), + hb.getBound().cast<Constant>()->getValueInt32(), + result); + return result; + } + + ResultType transport(const IntervalReqExpr::Conjunction& node, + std::vector<ResultType> childResults) { + for (size_t index = 1; index < childResults.size(); index++) { + childResults.front() &= childResults.at(index); + } + return childResults.front(); + } + + ResultType transport(const IntervalReqExpr::Disjunction& node, + std::vector<ResultType> childResults) { + for (size_t index = 1; index < childResults.size(); index++) { + childResults.front() |= childResults.at(index); + } + return childResults.front(); + } + + ResultType computeInclusion(const IntervalReqExpr::Node& intervals) { + return algebra::transport<false>(intervals, *this); + } +}; + +template <int V> +int decode(int& permutation) { + const int result = permutation % V; + permutation /= V; + return result; +} + +template <int N> +void testInterval(int permutation) { + const bool low1Inc = decode<2>(permutation); + const int low1 = decode<N>(permutation); + const bool high1Inc = decode<2>(permutation); + const int high1 = decode<N>(permutation); + const bool low2Inc = decode<2>(permutation); + const int low2 = decode<N>(permutation); + const bool high2Inc = decode<2>(permutation); + const int high2 = decode<N>(permutation); + + auto interval = IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{ + IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{ + IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ + {low1Inc, Constant::int32(low1)}, {high1Inc, Constant::int32(high1)}}), + IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ + {low2Inc, Constant::int32(low2)}, {high2Inc, Constant::int32(high2)}})})}); + + auto result = intersectDNFIntervals(interval); + std::bitset<N> inclusionActual; + if (result) { + // Since we are testing with constants, we should have at most one interval. + ASSERT_TRUE(IntervalReqExpr::getSingularDNF(*result)); + + IntervalInclusionTransport<N> transport; + // Compute actual inclusion bitset based on the interval intersection code. + inclusionActual = transport.computeInclusion(*result); + } + + std::bitset<N> inclusionExpected; + inclusionExpected.flip(); + + // Compute ground truth. + updateResults<N>(low1Inc, low1, high1Inc, high1, inclusionExpected); + updateResults<N>(low2Inc, low2, high2Inc, high2, inclusionExpected); + + ASSERT_EQ(inclusionExpected, inclusionActual); +} + +TEST(IntervalIntersection, IntervalPermutations) { + static constexpr int N = 10; + static constexpr int numPermutations = N * N * N * N * 2 * 2 * 2 * 2; + + /** + * Test for interval intersection. Generate intervals with constants in the + * range of [0, N), with random inclusion/exclusion of the endpoints. Intersect the intervals + * and verify against ground truth. + */ + const size_t numThreads = ProcessInfo::getNumCores(); + std::cout << "Testing " << numPermutations << " interval permutations using " << numThreads + << " cores...\n"; + auto timeBegin = Date_t::now(); + + AtomicWord<int> permutation(0); + std::vector<stdx::thread> threads; + for (size_t i = 0; i < numThreads; i++) { + threads.emplace_back([&permutation]() { + for (;;) { + const int nextP = permutation.fetchAndAdd(1); + if (nextP >= numPermutations) { + break; + } + testInterval<N>(nextP); + } + }); + } + for (auto& thread : threads) { + thread.join(); + } + + const auto elapsed = + (Date_t::now().toMillisSinceEpoch() - timeBegin.toMillisSinceEpoch()) / 1000.0; + std::cout << "...done. Took: " << elapsed << " s.\n"; +} + +} // namespace +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp b/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp new file mode 100644 index 00000000000..4c912046d3a --- /dev/null +++ b/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp @@ -0,0 +1,1495 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/cascades/ce_heuristic.h" +#include "mongo/db/query/optimizer/cascades/logical_props_derivation.h" +#include "mongo/db/query/optimizer/explain.h" +#include "mongo/db/query/optimizer/node.h" +#include "mongo/db/query/optimizer/opt_phase_manager.h" +#include "mongo/db/query/optimizer/utils/unit_test_utils.h" +#include "mongo/unittest/unittest.h" + +namespace mongo::optimizer { +namespace { + +TEST(LogicalRewriter, RootNodeMerge) { + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("a", "test"); + ABT limitSkipNode1 = + make<LimitSkipNode>(properties::LimitSkipRequirement(-1, 10), std::move(scanNode)); + ABT limitSkipNode2 = + make<LimitSkipNode>(properties::LimitSkipRequirement(5, 0), std::move(limitSkipNode1)); + + ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a"}}, + std::move(limitSkipNode2)); + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " a\n" + " RefBlock: \n" + " Variable [a]\n" + " LimitSkip []\n" + " limitSkip:\n" + " limit: 5\n" + " skip: 0\n" + " LimitSkip []\n" + " limitSkip:\n" + " limit: (none)\n" + " skip: 10\n" + " Scan [test]\n" + " BindBlock:\n" + " [a]\n" + " Source []\n", + rootNode); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", {{}, {}}}}}, + DebugInfo::kDefaultForTests); + ABT rewritten = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(rewritten)); + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " a\n" + " RefBlock: \n" + " Variable [a]\n" + " LimitSkip []\n" + " limitSkip:\n" + " limit: 5\n" + " skip: 10\n" + " Scan [test]\n" + " BindBlock:\n" + " [a]\n" + " Source []\n", + rewritten); +} + +TEST(LogicalRewriter, Memo) { + using namespace cascades; + using namespace properties; + + Metadata metadata{{{"test", {}}}}; + Memo memo(DebugInfo::kDefaultForTests, + metadata, + std::make_unique<DefaultLogicalPropsDerivation>(), + std::make_unique<HeuristicCE>()); + + ABT scanNode = make<ScanNode>("ptest", "test"); + ABT filterNode = make<FilterNode>( + make<EvalFilter>(make<PathConstant>(make<UnaryOp>(Operations::Neg, Constant::int64(1))), + make<Variable>("ptest")), + std::move(scanNode)); + ABT evalNode = make<EvaluationNode>( + "P1", + make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("ptest")), + std::move(filterNode)); + + NodeIdSet insertedNodeIds; + const GroupIdType rootGroupId = memo.integrate(evalNode, {}, insertedNodeIds); + ASSERT_EQ(2, rootGroupId); + ASSERT_EQ(3, memo.getGroupCount()); + + NodeIdSet expectedInsertedNodeIds = {{0, 0}, {1, 0}, {2, 0}}; + ASSERT_TRUE(insertedNodeIds == expectedInsertedNodeIds); + + ASSERT_EXPLAIN_MEMO( + "Memo: \n" + " groupId: 0\n" + " | | Logical properties:\n" + " | | cardinalityEstimate: \n" + " | | ce: 1000\n" + " | | projections: \n" + " | | ptest\n" + " | | indexingAvailability: \n" + " | | [groupId: 0, scanProjection: ptest, scanDefName: test, " + "possiblyEqPredsOnly]\n" + " | | collectionAvailability: \n" + " | | test\n" + " | | distributionAvailability: \n" + " | | distribution: \n" + " | | type: Centralized\n" + " | logicalNodes: \n" + " | logicalNodeId: 0\n" + " | Scan [test]\n" + " | BindBlock:\n" + " | [ptest]\n" + " | Source []\n" + " physicalNodes: \n" + " groupId: 1\n" + " | | Logical properties:\n" + " | | cardinalityEstimate: \n" + " | | ce: 100\n" + " | | projections: \n" + " | | ptest\n" + " | | indexingAvailability: \n" + " | | [groupId: 0, scanProjection: ptest, scanDefName: test]\n" + " | | collectionAvailability: \n" + " | | test\n" + " | | distributionAvailability: \n" + " | | distribution: \n" + " | | type: Centralized\n" + " | logicalNodes: \n" + " | logicalNodeId: 0\n" + " | Filter []\n" + " | | EvalFilter []\n" + " | | | Variable [ptest]\n" + " | | PathConstant []\n" + " | | UnaryOp [Neg]\n" + " | | Const [1]\n" + " | MemoLogicalDelegator [groupId: 0]\n" + " physicalNodes: \n" + " groupId: 2\n" + " | | Logical properties:\n" + " | | cardinalityEstimate: \n" + " | | ce: 100\n" + " | | projections: \n" + " | | P1\n" + " | | ptest\n" + " | | indexingAvailability: \n" + " | | [groupId: 0, scanProjection: ptest, scanDefName: test]\n" + " | | collectionAvailability: \n" + " | | test\n" + " | | distributionAvailability: \n" + " | | distribution: \n" + " | | type: Centralized\n" + " | logicalNodes: \n" + " | logicalNodeId: 0\n" + " | Evaluation []\n" + " | | BindBlock:\n" + " | | [P1]\n" + " | | EvalPath []\n" + " | | | Variable [ptest]\n" + " | | PathConstant []\n" + " | | Const [2]\n" + " | MemoLogicalDelegator [groupId: 1]\n" + " physicalNodes: \n", + memo); + + { + // Try to insert into the memo again. + NodeIdSet insertedNodeIds; + const GroupIdType group = memo.integrate(evalNode, {}, insertedNodeIds); + ASSERT_EQ(2, group); + ASSERT_EQ(3, memo.getGroupCount()); + + // Nothing was inserted. + ASSERT_EQ(1, memo.getGroup(0)._logicalNodes.size()); + ASSERT_EQ(1, memo.getGroup(1)._logicalNodes.size()); + ASSERT_EQ(1, memo.getGroup(2)._logicalNodes.size()); + } + + // Insert a different tree, this time only scan and project. + ABT scanNode1 = make<ScanNode>("ptest", "test"); + ABT evalNode1 = make<EvaluationNode>( + "P1", + make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("ptest")), + std::move(scanNode1)); + + { + NodeIdSet insertedNodeIds1; + const GroupIdType rootGroupId1 = memo.integrate(evalNode1, {}, insertedNodeIds1); + ASSERT_EQ(3, rootGroupId1); + ASSERT_EQ(4, memo.getGroupCount()); + + // Nothing was inserted in first 3 groups. + ASSERT_EQ(1, memo.getGroup(0)._logicalNodes.size()); + ASSERT_EQ(1, memo.getGroup(1)._logicalNodes.size()); + ASSERT_EQ(1, memo.getGroup(2)._logicalNodes.size()); + } + + { + const Group& group = memo.getGroup(3); + ASSERT_EQ(1, group._logicalNodes.size()); + + ASSERT_EXPLAIN( + "Evaluation []\n" + " BindBlock:\n" + " [P1]\n" + " EvalPath []\n" + " PathConstant []\n" + " Const [2]\n" + " Variable [ptest]\n" + " MemoLogicalDelegator [groupId: 0]\n", + group._logicalNodes.at(0)); + } +} + +TEST(LogicalRewriter, FilterProjectRewrite) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("ptest", "test"); + ABT collationNode = make<CollationNode>( + CollationRequirement({{"ptest", CollationOp::Ascending}}), std::move(scanNode)); + ABT evalNode = + make<EvaluationNode>("P1", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), + std::move(collationNode)); + ABT filterNode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("P1")), + std::move(evalNode)); + + ABT rootNode = make<RootNode>(properties::ProjectionRequirement{{}}, std::move(filterNode)); + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " RefBlock: \n" + " Filter []\n" + " EvalFilter []\n" + " PathIdentity []\n" + " Variable [P1]\n" + " Evaluation []\n" + " BindBlock:\n" + " [P1]\n" + " EvalPath []\n" + " PathIdentity []\n" + " Variable [ptest]\n" + " Collation []\n" + " collation: \n" + " ptest: Ascending\n" + " RefBlock: \n" + " Variable [ptest]\n" + " Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + rootNode); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", {{}, {}}}}}, + DebugInfo::kDefaultForTests); + ABT latest = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(latest)); + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " RefBlock: \n" + " Collation []\n" + " collation: \n" + " ptest: Ascending\n" + " RefBlock: \n" + " Variable [ptest]\n" + " Filter []\n" + " EvalFilter []\n" + " PathIdentity []\n" + " Variable [P1]\n" + " Evaluation []\n" + " BindBlock:\n" + " [P1]\n" + " EvalPath []\n" + " PathIdentity []\n" + " Variable [ptest]\n" + " Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + latest); +} + +TEST(LogicalRewriter, FilterProjectComplexRewrite) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("ptest", "test"); + + ABT projection2Node = make<EvaluationNode>( + "p2", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode)); + + ABT projection3Node = + make<EvaluationNode>("p3", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), + std::move(projection2Node)); + + ABT collationNode = make<CollationNode>( + CollationRequirement({{"ptest", CollationOp::Ascending}}), std::move(projection3Node)); + + ABT projection1Node = + make<EvaluationNode>("p1", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), + std::move(collationNode)); + + ABT filter1Node = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("p1")), + std::move(projection1Node)); + + ABT filterScanNode = make<FilterNode>( + make<EvalFilter>(make<PathIdentity>(), make<Variable>("ptest")), std::move(filter1Node)); + + ABT filter2Node = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("p2")), + std::move(filterScanNode)); + + ABT rootNode = make<RootNode>(properties::ProjectionRequirement{{}}, std::move(filter2Node)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| RefBlock: \n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [p2]\n" + "| PathIdentity []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [p1]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [p1]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Collation []\n" + "| | collation: \n" + "| | ptest: Ascending\n" + "| RefBlock: \n" + "| Variable [ptest]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [p3]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [p2]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + rootNode); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", {{}, {}}}}}, + DebugInfo::kDefaultForTests); + ABT latest = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(latest)); + + // Note: this assert depends on the order on which we consider rewrites. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| RefBlock: \n" + "Collation []\n" + "| | collation: \n" + "| | ptest: Ascending\n" + "| RefBlock: \n" + "| Variable [ptest]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [p2]\n" + "| PathIdentity []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [p1]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [p1]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [p3]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [p2]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + latest); +} + +TEST(LogicalRewriter, FilterProjectGroupRewrite) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("ptest", "test"); + + ABT projectionANode = make<EvaluationNode>( + "a", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode)); + ABT projectionBNode = + make<EvaluationNode>("b", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), + std::move(projectionANode)); + + ABT groupByNode = make<GroupByNode>(ProjectionNameVector{"a"}, + ProjectionNameVector{"c"}, + makeSeq(make<Variable>("b")), + std::move(projectionBNode)); + + ABT filterANode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("a")), + std::move(groupByNode)); + + ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"c"}}, + std::move(filterANode)); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", {{}, {}}}}}, + DebugInfo::kDefaultForTests); + ABT latest = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(latest)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | c\n" + "| RefBlock: \n" + "| Variable [c]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [a]\n" + "| aggregations: \n" + "| [c]\n" + "| Variable [b]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [b]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [a]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [a]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + latest); +} + +TEST(LogicalRewriter, FilterProjectUnwindRewrite) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("ptest", "test"); + + ABT projectionANode = make<EvaluationNode>( + "a", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode)); + ABT projectionBNode = + make<EvaluationNode>("b", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), + std::move(projectionANode)); + + ABT unwindNode = + make<UnwindNode>("a", "a_pid", false /*retainNonArrays*/, std::move(projectionBNode)); + + // This filter should stay above the unwind. + ABT filterANode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("a")), + std::move(unwindNode)); + + // This filter should be pushed down below the unwind. + ABT filterBNode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("b")), + std::move(filterANode)); + + ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a", "b"}}, + std::move(filterBNode)); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", {{}, {}}}}}, + DebugInfo::kDefaultForTests); + ABT latest = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(latest)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | a\n" + "| | b\n" + "| RefBlock: \n" + "| Variable [a]\n" + "| Variable [b]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [b]\n" + "| PathIdentity []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [a]\n" + "| PathIdentity []\n" + "Unwind []\n" + "| BindBlock:\n" + "| [a]\n" + "| Source []\n" + "| [a_pid]\n" + "| Source []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [b]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [a]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + latest); +} + +TEST(LogicalRewriter, FilterProjectExchangeRewrite) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("ptest", "test"); + + ABT projectionANode = make<EvaluationNode>( + "a", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode)); + ABT projectionBNode = + make<EvaluationNode>("b", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), + std::move(projectionANode)); + + ABT exchangeNode = make<ExchangeNode>( + properties::DistributionRequirement({DistributionType::HashPartitioning, {"a"}}), + std::move(projectionBNode)); + + ABT filterANode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("a")), + std::move(exchangeNode)); + + ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a", "b"}}, + std::move(filterANode)); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", {{}, {}}}}}, + DebugInfo::kDefaultForTests); + ABT latest = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(latest)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | a\n" + "| | b\n" + "| RefBlock: \n" + "| Variable [a]\n" + "| Variable [b]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [b]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: HashPartitioning\n" + "| | projections: \n" + "| | a\n" + "| RefBlock: \n" + "| Variable [a]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [a]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [a]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + latest); +} + +TEST(LogicalRewriter, UnwindCollationRewrite) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("ptest", "test"); + + ABT projectionANode = make<EvaluationNode>( + "a", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode)); + ABT projectionBNode = + make<EvaluationNode>("b", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), + std::move(projectionANode)); + + // This collation node should stay below the unwind. + ABT collationANode = make<CollationNode>(CollationRequirement({{"a", CollationOp::Ascending}}), + std::move(projectionBNode)); + + // This collation node should go above the unwind. + ABT collationBNode = make<CollationNode>(CollationRequirement({{"b", CollationOp::Ascending}}), + std::move(collationANode)); + + ABT unwindNode = + make<UnwindNode>("a", "a_pid", false /*retainNonArrays*/, std::move(collationBNode)); + + ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a", "b"}}, + std::move(unwindNode)); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", {{}, {}}}}}, + DebugInfo::kDefaultForTests); + ABT latest = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(latest)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | a\n" + "| | b\n" + "| RefBlock: \n" + "| Variable [a]\n" + "| Variable [b]\n" + "Collation []\n" + "| | collation: \n" + "| | b: Ascending\n" + "| RefBlock: \n" + "| Variable [b]\n" + "Unwind []\n" + "| BindBlock:\n" + "| [a]\n" + "| Source []\n" + "| [a_pid]\n" + "| Source []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [b]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [a]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + latest); +} + +TEST(LogicalRewriter, FilterUnionReorderSingleProjection) { + PrefixId prefixId; + ABT scanNode1 = make<ScanNode>("ptest1", "test1"); + ABT scanNode2 = make<ScanNode>("ptest2", "test2"); + // Create two eval nodes such that the two branches of the union share a projection. + ABT evalNode1 = + make<EvaluationNode>("pUnion", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest1")), + std::move(scanNode1)); + ABT evalNode2 = + make<EvaluationNode>("pUnion", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest2")), + std::move(scanNode2)); + + ABT unionNode = make<UnionNode>(ProjectionNameVector{"pUnion"}, makeSeq(evalNode1, evalNode2)); + + ABT filter = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))), + make<Variable>("pUnion")), + std::move(unionNode)); + ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"pUnion"}}, + std::move(filter)); + + ABT latest = std::move(rootNode); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pUnion\n" + "| RefBlock: \n" + "| Variable [pUnion]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [pUnion]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [pUnion]\n" + "| | Source []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [pUnion]\n" + "| | EvalPath []\n" + "| | | Variable [ptest2]\n" + "| | PathIdentity []\n" + "| Scan [test2]\n" + "| BindBlock:\n" + "| [ptest2]\n" + "| Source []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [pUnion]\n" + "| EvalPath []\n" + "| | Variable [ptest1]\n" + "| PathIdentity []\n" + "Scan [test1]\n" + " BindBlock:\n" + " [ptest1]\n" + " Source []\n", + latest); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase}, + prefixId, + {{{"test1", {{}, {}}}, {"test2", {{}, {}}}}}, + DebugInfo::kDefaultForTests); + ASSERT_TRUE(phaseManager.optimize(latest)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pUnion\n" + "| RefBlock: \n" + "| Variable [pUnion]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [pUnion]\n" + "| | Source []\n" + "| Filter []\n" + "| | EvalFilter []\n" + "| | | Variable [pUnion]\n" + "| | PathGet [a]\n" + "| | PathTraverse []\n" + "| | PathCompare [Eq]\n" + "| | Const [1]\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [pUnion]\n" + "| | EvalPath []\n" + "| | | Variable [ptest2]\n" + "| | PathIdentity []\n" + "| Scan [test2]\n" + "| BindBlock:\n" + "| [ptest2]\n" + "| Source []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [pUnion]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [pUnion]\n" + "| EvalPath []\n" + "| | Variable [ptest1]\n" + "| PathIdentity []\n" + "Scan [test1]\n" + " BindBlock:\n" + " [ptest1]\n" + " Source []\n", + latest); +} + +TEST(LogicalRewriter, MultipleFilterUnionReorder) { + PrefixId prefixId; + ABT scanNode1 = make<ScanNode>("ptest1", "test1"); + ABT scanNode2 = make<ScanNode>("ptest2", "test2"); + + // Create multiple shared projections for each child. + ABT pUnion11 = + make<EvaluationNode>("pUnion1", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest1")), + std::move(scanNode1)); + ABT pUnion12 = + make<EvaluationNode>("pUnion2", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest1")), + std::move(pUnion11)); + + ABT pUnion21 = + make<EvaluationNode>("pUnion1", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest2")), + std::move(scanNode2)); + ABT pUnion22 = + make<EvaluationNode>("pUnion2", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest2")), + std::move(pUnion21)); + + ABT unionNode = + make<UnionNode>(ProjectionNameVector{"pUnion1", "pUnion2"}, makeSeq(pUnion12, pUnion22)); + + // Create two filters, one for each of the two common projections. + ABT filterUnion1 = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))), + make<Variable>("pUnion1")), + std::move(unionNode)); + ABT filterUnion2 = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))), + make<Variable>("pUnion2")), + std::move(filterUnion1)); + ABT rootNode = make<RootNode>( + properties::ProjectionRequirement{ProjectionNameVector{"pUnion1", "pUnion2"}}, + std::move(filterUnion2)); + + ABT latest = std::move(rootNode); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pUnion1\n" + "| | pUnion2\n" + "| RefBlock: \n" + "| Variable [pUnion1]\n" + "| Variable [pUnion2]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [pUnion2]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [pUnion1]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [pUnion1]\n" + "| | Source []\n" + "| | [pUnion2]\n" + "| | Source []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [pUnion2]\n" + "| | EvalPath []\n" + "| | | Variable [ptest2]\n" + "| | PathIdentity []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [pUnion1]\n" + "| | EvalPath []\n" + "| | | Variable [ptest2]\n" + "| | PathIdentity []\n" + "| Scan [test2]\n" + "| BindBlock:\n" + "| [ptest2]\n" + "| Source []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [pUnion2]\n" + "| EvalPath []\n" + "| | Variable [ptest1]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [pUnion1]\n" + "| EvalPath []\n" + "| | Variable [ptest1]\n" + "| PathIdentity []\n" + "Scan [test1]\n" + " BindBlock:\n" + " [ptest1]\n" + " Source []\n", + latest); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase}, + prefixId, + {{{"test1", {{}, {}}}, {"test2", {{}, {}}}}}, + DebugInfo::kDefaultForTests); + ASSERT_TRUE(phaseManager.optimize(latest)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pUnion1\n" + "| | pUnion2\n" + "| RefBlock: \n" + "| Variable [pUnion1]\n" + "| Variable [pUnion2]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [pUnion1]\n" + "| | Source []\n" + "| | [pUnion2]\n" + "| | Source []\n" + "| Filter []\n" + "| | EvalFilter []\n" + "| | | Variable [pUnion2]\n" + "| | PathGet [a]\n" + "| | PathTraverse []\n" + "| | PathCompare [Eq]\n" + "| | Const [1]\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [pUnion2]\n" + "| | EvalPath []\n" + "| | | Variable [ptest2]\n" + "| | PathIdentity []\n" + "| Filter []\n" + "| | EvalFilter []\n" + "| | | Variable [pUnion1]\n" + "| | PathGet [a]\n" + "| | PathTraverse []\n" + "| | PathCompare [Eq]\n" + "| | Const [1]\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [pUnion1]\n" + "| | EvalPath []\n" + "| | | Variable [ptest2]\n" + "| | PathIdentity []\n" + "| Scan [test2]\n" + "| BindBlock:\n" + "| [ptest2]\n" + "| Source []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [pUnion2]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [pUnion2]\n" + "| EvalPath []\n" + "| | Variable [ptest1]\n" + "| PathIdentity []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [pUnion1]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [pUnion1]\n" + "| EvalPath []\n" + "| | Variable [ptest1]\n" + "| PathIdentity []\n" + "Scan [test1]\n" + " BindBlock:\n" + " [ptest1]\n" + " Source []\n", + latest); +} + +TEST(LogicalRewriter, FilterUnionUnionPushdown) { + PrefixId prefixId; + ABT scanNode1 = make<ScanNode>("ptest", "test1"); + ABT scanNode2 = make<ScanNode>("ptest", "test2"); + ABT unionNode = make<UnionNode>(ProjectionNameVector{"ptest"}, makeSeq(scanNode1, scanNode2)); + + ABT scanNode3 = make<ScanNode>("ptest", "test3"); + ABT parentUnionNode = + make<UnionNode>(ProjectionNameVector{"ptest"}, makeSeq(unionNode, scanNode3)); + + ABT filter = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))), + make<Variable>("ptest")), + std::move(parentUnionNode)); + ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"ptest"}}, + std::move(filter)); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test1", {{}, {}}}, {"test2", {{}, {}}}, {"test3", {{}, {}}}}}, + DebugInfo::kDefaultForTests); + ABT latest = std::move(rootNode); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | ptest\n" + "| RefBlock: \n" + "| Variable [ptest]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [ptest]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [ptest]\n" + "| | Source []\n" + "| Scan [test3]\n" + "| BindBlock:\n" + "| [ptest]\n" + "| Source []\n" + "Union []\n" + "| | BindBlock:\n" + "| | [ptest]\n" + "| | Source []\n" + "| Scan [test2]\n" + "| BindBlock:\n" + "| [ptest]\n" + "| Source []\n" + "Scan [test1]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + latest); + + ASSERT_TRUE(phaseManager.optimize(latest)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | ptest\n" + "| RefBlock: \n" + "| Variable [ptest]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [ptest]\n" + "| | Source []\n" + "| Sargable [Complete]\n" + "| | | | | requirementsMap: \n" + "| | | | | refProjection: ptest, path: 'PathGet [a] PathTraverse [] " + "PathIdentity []', intervals: {{{[Const [1], Const [1]]}}}\n" + "| | | | candidateIndexes: \n" + "| | | BindBlock:\n" + "| | RefBlock: \n" + "| | Variable [ptest]\n" + "| Scan [test3]\n" + "| BindBlock:\n" + "| [ptest]\n" + "| Source []\n" + "Union []\n" + "| | BindBlock:\n" + "| | [ptest]\n" + "| | Source []\n" + "| Sargable [Complete]\n" + "| | | | | requirementsMap: \n" + "| | | | | refProjection: ptest, path: 'PathGet [a] PathTraverse [] " + "PathIdentity []', intervals: {{{[Const [1], Const [1]]}}}\n" + "| | | | candidateIndexes: \n" + "| | | BindBlock:\n" + "| | RefBlock: \n" + "| | Variable [ptest]\n" + "| Scan [test2]\n" + "| BindBlock:\n" + "| [ptest]\n" + "| Source []\n" + "Sargable [Complete]\n" + "| | | | requirementsMap: \n" + "| | | | refProjection: ptest, path: 'PathGet [a] PathTraverse [] PathIdentity " + "[]', intervals: {{{[Const [1], Const [1]]}}}\n" + "| | | candidateIndexes: \n" + "| | BindBlock:\n" + "| RefBlock: \n" + "| Variable [ptest]\n" + "Scan [test1]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + latest); +} + +TEST(LogicalRewriter, UnionPreservesCommonLogicalProps) { + ABT scanNode1 = make<ScanNode>("ptest1", "test1"); + ABT scanNode2 = make<ScanNode>("ptest2", "test2"); + ABT evalNode1 = make<EvaluationNode>( + "a", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("ptest1")), + std::move(scanNode1)); + + ABT evalNode2 = make<EvaluationNode>( + "a", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("ptest2")), + std::move(scanNode2)); + ABT unionNode = make<UnionNode>(ProjectionNameVector{"a"}, makeSeq(evalNode1, evalNode2)); + + ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a"}}, + std::move(unionNode)); + + Metadata metadata{{{"test1", + ScanDefinition{{}, + {}, + {DistributionType::HashPartitioning, + makeSeq(make<PathGet>("a", make<PathIdentity>()))}}}, + {"test2", + ScanDefinition{{}, + {}, + {DistributionType::HashPartitioning, + makeSeq(make<PathGet>("a", make<PathIdentity>()))}}}}, + 2}; + + // Run the reordering rewrite such that the scan produces a hash partition. + PrefixId prefixId; + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_MEMO( + "Memo: \n" + " groupId: 0\n" + " | | Logical properties:\n" + " | | cardinalityEstimate: \n" + " | | ce: 1000\n" + " | | projections: \n" + " | | ptest1\n" + " | | indexingAvailability: \n" + " | | [groupId: 0, scanProjection: ptest1, scanDefName: test1, " + "possiblyEqPredsOnly]\n" + " | | collectionAvailability: \n" + " | | test1\n" + " | | distributionAvailability: \n" + " | | distribution: \n" + " | | type: UnknownPartitioning\n" + " | logicalNodes: \n" + " | logicalNodeId: 0\n" + " | Scan [test1]\n" + " | BindBlock:\n" + " | [ptest1]\n" + " | Source []\n" + " physicalNodes: \n" + " groupId: 1\n" + " | | Logical properties:\n" + " | | cardinalityEstimate: \n" + " | | ce: 1000\n" + " | | requirementCEs: \n" + " | | refProjection: ptest1, path: 'PathGet [a] PathIdentity []', ce: " + "1000\n" + " | | projections: \n" + " | | a\n" + " | | ptest1\n" + " | | indexingAvailability: \n" + " | | [groupId: 0, scanProjection: ptest1, scanDefName: test1]\n" + " | | collectionAvailability: \n" + " | | test1\n" + " | | distributionAvailability: \n" + " | | distribution: \n" + " | | type: Centralized\n" + " | | distribution: \n" + " | | type: RoundRobin\n" + " | | distribution: \n" + " | | type: HashPartitioning\n" + " | | projections: \n" + " | | a\n" + " | | distribution: \n" + " | | type: UnknownPartitioning\n" + " | logicalNodes: \n" + " | logicalNodeId: 0\n" + " | Sargable [Complete]\n" + " | | | | | requirementsMap: \n" + " | | | | | refProjection: ptest1, path: 'PathGet [a] " + "PathIdentity []', boundProjection: a, intervals: {{{(-inf, +inf)}}}\n" + " | | | | candidateIndexes: \n" + " | | | BindBlock:\n" + " | | | [a]\n" + " | | | Source []\n" + " | | RefBlock: \n" + " | | Variable [ptest1]\n" + " | MemoLogicalDelegator [groupId: 0]\n" + " physicalNodes: \n" + " groupId: 2\n" + " | | Logical properties:\n" + " | | cardinalityEstimate: \n" + " | | ce: 1000\n" + " | | projections: \n" + " | | ptest2\n" + " | | indexingAvailability: \n" + " | | [groupId: 2, scanProjection: ptest2, scanDefName: test2, " + "possiblyEqPredsOnly]\n" + " | | collectionAvailability: \n" + " | | test2\n" + " | | distributionAvailability: \n" + " | | distribution: \n" + " | | type: UnknownPartitioning\n" + " | logicalNodes: \n" + " | logicalNodeId: 0\n" + " | Scan [test2]\n" + " | BindBlock:\n" + " | [ptest2]\n" + " | Source []\n" + " physicalNodes: \n" + " groupId: 3\n" + " | | Logical properties:\n" + " | | cardinalityEstimate: \n" + " | | ce: 1000\n" + " | | requirementCEs: \n" + " | | refProjection: ptest2, path: 'PathGet [a] PathIdentity []', ce: " + "1000\n" + " | | projections: \n" + " | | a\n" + " | | ptest2\n" + " | | indexingAvailability: \n" + " | | [groupId: 2, scanProjection: ptest2, scanDefName: test2]\n" + " | | collectionAvailability: \n" + " | | test2\n" + " | | distributionAvailability: \n" + " | | distribution: \n" + " | | type: Centralized\n" + " | | distribution: \n" + " | | type: RoundRobin\n" + " | | distribution: \n" + " | | type: HashPartitioning\n" + " | | projections: \n" + " | | a\n" + " | | distribution: \n" + " | | type: UnknownPartitioning\n" + " | logicalNodes: \n" + " | logicalNodeId: 0\n" + " | Sargable [Complete]\n" + " | | | | | requirementsMap: \n" + " | | | | | refProjection: ptest2, path: 'PathGet [a] " + "PathIdentity []', boundProjection: a, intervals: {{{(-inf, +inf)}}}\n" + " | | | | candidateIndexes: \n" + " | | | BindBlock:\n" + " | | | [a]\n" + " | | | Source []\n" + " | | RefBlock: \n" + " | | Variable [ptest2]\n" + " | MemoLogicalDelegator [groupId: 2]\n" + " physicalNodes: \n" + " groupId: 4\n" + " | | Logical properties:\n" + " | | cardinalityEstimate: \n" + " | | ce: 2000\n" + " | | projections: \n" + " | | a\n" + " | | collectionAvailability: \n" + " | | test1\n" + " | | test2\n" + " | | distributionAvailability: \n" + " | | distribution: \n" + " | | type: Centralized\n" + " | | distribution: \n" + " | | type: RoundRobin\n" + " | | distribution: \n" + " | | type: HashPartitioning\n" + " | | projections: \n" + " | | a\n" + " | | distribution: \n" + " | | type: UnknownPartitioning\n" + " | logicalNodes: \n" + " | logicalNodeId: 0\n" + " | Union []\n" + " | | | BindBlock:\n" + " | | | [a]\n" + " | | | Source []\n" + " | | MemoLogicalDelegator [groupId: 3]\n" + " | MemoLogicalDelegator [groupId: 1]\n" + " physicalNodes: \n" + " groupId: 5\n" + " | | Logical properties:\n" + " | | cardinalityEstimate: \n" + " | | ce: 2000\n" + " | | projections: \n" + " | | a\n" + " | | collectionAvailability: \n" + " | | test1\n" + " | | test2\n" + " | | distributionAvailability: \n" + " | | distribution: \n" + " | | type: Centralized\n" + " | | distribution: \n" + " | | type: RoundRobin\n" + " | | distribution: \n" + " | | type: HashPartitioning\n" + " | | projections: \n" + " | | a\n" + " | | distribution: \n" + " | | type: UnknownPartitioning\n" + " | logicalNodes: \n" + " | logicalNodeId: 0\n" + " | Root []\n" + " | | | projections: \n" + " | | | a\n" + " | | RefBlock: \n" + " | | Variable [a]\n" + " | MemoLogicalDelegator [groupId: 4]\n" + " physicalNodes: \n", + phaseManager.getMemo()); +} + +TEST(LogicalRewriter, SargableCE) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("ptest", "test"); + + ABT filterANode = make<FilterNode>( + make<EvalFilter>(make<PathGet>("a", make<PathCompare>(Operations::Eq, Constant::int64(1))), + make<Variable>("ptest")), + std::move(scanNode)); + ABT filterBNode = make<FilterNode>( + make<EvalFilter>(make<PathGet>("b", make<PathCompare>(Operations::Eq, Constant::int64(2))), + make<Variable>("ptest")), + std::move(filterANode)); + + ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"ptest"}}, + std::move(filterBNode)); + + OptPhaseManager phaseManager({OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase}, + prefixId, + {{{"test", {{}, {}}}}}, + DebugInfo::kDefaultForTests); + ABT latest = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(latest)); + + // Displays SargableNode-specific per-key estimates. + ASSERT_EXPLAIN_MEMO( + "Memo: \n" + " groupId: 0\n" + " | | Logical properties:\n" + " | | cardinalityEstimate: \n" + " | | ce: 1000\n" + " | | projections: \n" + " | | ptest\n" + " | | indexingAvailability: \n" + " | | [groupId: 0, scanProjection: ptest, scanDefName: test, " + "possiblyEqPredsOnly]\n" + " | | collectionAvailability: \n" + " | | test\n" + " | | distributionAvailability: \n" + " | | distribution: \n" + " | | type: Centralized\n" + " | logicalNodes: \n" + " | logicalNodeId: 0\n" + " | Scan [test]\n" + " | BindBlock:\n" + " | [ptest]\n" + " | Source []\n" + " physicalNodes: \n" + " groupId: 1\n" + " | | Logical properties:\n" + " | | cardinalityEstimate: \n" + " | | ce: 10\n" + " | | requirementCEs: \n" + " | | refProjection: ptest, path: 'PathGet [a] PathIdentity []', ce: " + "100\n" + " | | refProjection: ptest, path: 'PathGet [b] PathIdentity []', ce: " + "100\n" + " | | projections: \n" + " | | ptest\n" + " | | indexingAvailability: \n" + " | | [groupId: 0, scanProjection: ptest, scanDefName: test, " + "possiblyEqPredsOnly]\n" + " | | collectionAvailability: \n" + " | | test\n" + " | | distributionAvailability: \n" + " | | distribution: \n" + " | | type: Centralized\n" + " | logicalNodes: \n" + " | logicalNodeId: 0\n" + " | Sargable [Complete]\n" + " | | | | | requirementsMap: \n" + " | | | | | refProjection: ptest, path: 'PathGet [a] PathIdentity " + "[]', intervals: {{{[Const [1], Const [1]]}}}\n" + " | | | | | refProjection: ptest, path: 'PathGet [b] PathIdentity " + "[]', intervals: {{{[Const [2], Const [2]]}}}\n" + " | | | | candidateIndexes: \n" + " | | | BindBlock:\n" + " | | RefBlock: \n" + " | | Variable [ptest]\n" + " | MemoLogicalDelegator [groupId: 0]\n" + " physicalNodes: \n" + " groupId: 2\n" + " | | Logical properties:\n" + " | | cardinalityEstimate: \n" + " | | ce: 10\n" + " | | projections: \n" + " | | ptest\n" + " | | indexingAvailability: \n" + " | | [groupId: 0, scanProjection: ptest, scanDefName: test, " + "possiblyEqPredsOnly]\n" + " | | collectionAvailability: \n" + " | | test\n" + " | | distributionAvailability: \n" + " | | distribution: \n" + " | | type: Centralized\n" + " | logicalNodes: \n" + " | logicalNodeId: 0\n" + " | Root []\n" + " | | | projections: \n" + " | | | ptest\n" + " | | RefBlock: \n" + " | | Variable [ptest]\n" + " | MemoLogicalDelegator [groupId: 1]\n" + " physicalNodes: \n", + phaseManager.getMemo()); +} + +} // namespace +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/metadata.cpp b/src/mongo/db/query/optimizer/metadata.cpp new file mode 100644 index 00000000000..d81cec8b64c --- /dev/null +++ b/src/mongo/db/query/optimizer/metadata.cpp @@ -0,0 +1,161 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/metadata.h" + +#include "mongo/db/query/optimizer/node.h" + + +namespace mongo::optimizer { + +DistributionAndPaths::DistributionAndPaths(DistributionType type) + : DistributionAndPaths(type, {}) {} + +DistributionAndPaths::DistributionAndPaths(DistributionType type, ABTVector paths) + : _type(type), _paths(std::move(paths)) { + uassert(6624080, + "Invalid distribution type", + _paths.empty() || _type == DistributionType::HashPartitioning || + _type == DistributionType::RangePartitioning); +} + +bool IndexCollationEntry::operator==(const IndexCollationEntry& other) const { + return _path == other._path && _op == other._op; +} + +IndexCollationEntry::IndexCollationEntry(ABT path, CollationOp op) + : _path(std::move(path)), _op(op) {} + +IndexDefinition::IndexDefinition(IndexCollationSpec collationSpec, bool isMultiKey) + : IndexDefinition(std::move(collationSpec), isMultiKey, {DistributionType::Centralized}, {}) {} + +IndexDefinition::IndexDefinition(IndexCollationSpec collationSpec, + bool isMultiKey, + DistributionAndPaths distributionAndPaths, + PartialSchemaRequirements partialReqMap) + : IndexDefinition(std::move(collationSpec), + 2 /*version*/, + 0 /*orderingBits*/, + isMultiKey, + std::move(distributionAndPaths), + std::move(partialReqMap)) {} + +IndexDefinition::IndexDefinition(IndexCollationSpec collationSpec, + int64_t version, + uint32_t orderingBits, + bool isMultiKey, + DistributionAndPaths distributionAndPaths, + PartialSchemaRequirements partialReqMap) + : _collationSpec(std::move(collationSpec)), + _version(version), + _orderingBits(orderingBits), + _isMultiKey(isMultiKey), + _distributionAndPaths(distributionAndPaths), + _partialReqMap(std::move(partialReqMap)) {} + +const IndexCollationSpec& IndexDefinition::getCollationSpec() const { + return _collationSpec; +} + +int64_t IndexDefinition::getVersion() const { + return _version; +} + +uint32_t IndexDefinition::getOrdering() const { + return _orderingBits; +} + +bool IndexDefinition::isMultiKey() const { + return _isMultiKey; +} + +const DistributionAndPaths& IndexDefinition::getDistributionAndPaths() const { + return _distributionAndPaths; +} + +const PartialSchemaRequirements& IndexDefinition::getPartialReqMap() const { + return _partialReqMap; +} + +ScanDefinition::ScanDefinition() : ScanDefinition(OptionsMapType{}, {}) {} + +ScanDefinition::ScanDefinition(OptionsMapType options, + opt::unordered_map<std::string, IndexDefinition> indexDefs) + : ScanDefinition(std::move(options), + std::move(indexDefs), + {DistributionType::Centralized}, + true /*exists*/) {} + +ScanDefinition::ScanDefinition(OptionsMapType options, + opt::unordered_map<std::string, IndexDefinition> indexDefs, + DistributionAndPaths distributionAndPaths, + const bool exists, + const CEType ce) + : _options(std::move(options)), + _distributionAndPaths(std::move(distributionAndPaths)), + _indexDefs(std::move(indexDefs)), + _exists(exists), + _ce(ce) {} + +const ScanDefinition::OptionsMapType& ScanDefinition::getOptionsMap() const { + return _options; +} + +const DistributionAndPaths& ScanDefinition::getDistributionAndPaths() const { + return _distributionAndPaths; +} + +const opt::unordered_map<std::string, IndexDefinition>& ScanDefinition::getIndexDefs() const { + return _indexDefs; +} + +opt::unordered_map<std::string, IndexDefinition>& ScanDefinition::getIndexDefs() { + return _indexDefs; +} + +bool ScanDefinition::exists() const { + return _exists; +} + +CEType ScanDefinition::getCE() const { + return _ce; +} + +Metadata::Metadata(opt::unordered_map<std::string, ScanDefinition> scanDefs) + : Metadata(std::move(scanDefs), 1 /*numberOfPartitions*/) {} + +Metadata::Metadata(opt::unordered_map<std::string, ScanDefinition> scanDefs, + size_t numberOfPartitions) + : _scanDefs(std::move(scanDefs)), _numberOfPartitions(numberOfPartitions) {} + +bool Metadata::isParallelExecution() const { + return _numberOfPartitions > 1; +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/metadata.h b/src/mongo/db/query/optimizer/metadata.h new file mode 100644 index 00000000000..62f18327d40 --- /dev/null +++ b/src/mongo/db/query/optimizer/metadata.h @@ -0,0 +1,160 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <map> + +#include "mongo/db/query/optimizer/index_bounds.h" + + +namespace mongo::optimizer { + +struct DistributionAndPaths { + DistributionAndPaths(DistributionType type); + DistributionAndPaths(DistributionType type, ABTVector paths); + + DistributionType _type; + ABTVector _paths; +}; + + +struct IndexCollationEntry { + IndexCollationEntry(ABT path, CollationOp op); + + bool operator==(const IndexCollationEntry& other) const; + + ABT _path; + CollationOp _op; +}; + +using IndexCollationSpec = std::vector<IndexCollationEntry>; + +/** + * Defines an available system index. + */ +class IndexDefinition { +public: + // For testing. + IndexDefinition(IndexCollationSpec collationSpec, bool isMultiKey); + + IndexDefinition(IndexCollationSpec collationSpec, + bool isMultiKey, + DistributionAndPaths distributionAndPaths, + PartialSchemaRequirements partialReqMap); + + IndexDefinition(IndexCollationSpec collationSpec, + int64_t version, + uint32_t orderingBits, + bool isMultiKey, + DistributionAndPaths distributionAndPaths, + PartialSchemaRequirements partialReqMap); + + const IndexCollationSpec& getCollationSpec() const; + + int64_t getVersion() const; + uint32_t getOrdering() const; + bool isMultiKey() const; + + const DistributionAndPaths& getDistributionAndPaths() const; + + const PartialSchemaRequirements& getPartialReqMap() const; + +private: + const IndexCollationSpec _collationSpec; + + const int64_t _version; + const uint32_t _orderingBits; + const bool _isMultiKey; + + const DistributionAndPaths _distributionAndPaths; + + // Requirements map for partial filter expression. + const PartialSchemaRequirements _partialReqMap; +}; + +// Used to specify parameters to scan node, such as collection name, or file where collection is +// read from. +class ScanDefinition { +public: + using OptionsMapType = opt::unordered_map<std::string, std::string>; + + ScanDefinition(); + ScanDefinition(OptionsMapType options, + opt::unordered_map<std::string, IndexDefinition> indexDefs); + ScanDefinition(OptionsMapType options, + opt::unordered_map<std::string, IndexDefinition> indexDefs, + DistributionAndPaths distributionAndPaths, + bool exists = true, + CEType ce = -1.0); + + const OptionsMapType& getOptionsMap() const; + + const DistributionAndPaths& getDistributionAndPaths() const; + + const opt::unordered_map<std::string, IndexDefinition>& getIndexDefs() const; + opt::unordered_map<std::string, IndexDefinition>& getIndexDefs(); + + bool exists() const; + + CEType getCE() const; + +private: + OptionsMapType _options; + DistributionAndPaths _distributionAndPaths; + + /** + * Indexes associated with this collection. + */ + opt::unordered_map<std::string, IndexDefinition> _indexDefs; + + /** + * True if the collection exists. + */ + bool _exists; + + // If positive, estimated number of docs in the collection. + CEType _ce; +}; + +struct Metadata { + Metadata(opt::unordered_map<std::string, ScanDefinition> scanDefs); + Metadata(opt::unordered_map<std::string, ScanDefinition> scanDefs, size_t numberOfPartitions); + + opt::unordered_map<std::string, ScanDefinition> _scanDefs; + + // Degree of parallelism. + size_t _numberOfPartitions; + + bool isParallelExecution() const; + + // TODO: generalize cluster spec. +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/node.cpp b/src/mongo/db/query/optimizer/node.cpp new file mode 100644 index 00000000000..b3d656930c3 --- /dev/null +++ b/src/mongo/db/query/optimizer/node.cpp @@ -0,0 +1,775 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <functional> +#include <stack> + +#include "mongo/db/query/optimizer/node.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer { + +/** + * A simple helper that creates a vector of Sources and binds names. + */ +static ABT buildSimpleBinder(const ProjectionNameVector& names) { + ABTVector sources; + for (size_t idx = 0; idx < names.size(); ++idx) { + sources.emplace_back(make<Source>()); + } + + return make<ExpressionBinder>(names, std::move(sources)); +} + +static ABT buildReferences(const ProjectionNameSet& projections) { + ABTVector variables; + ProjectionNameOrderedSet ordered = convertToOrderedSet(projections); + for (const ProjectionName& projection : ordered) { + variables.emplace_back(make<Variable>(projection)); + } + return make<References>(std::move(variables)); +} + +ScanNode::ScanNode(ProjectionName projectionName, std::string scanDefName) + : Base(buildSimpleBinder({std::move(projectionName)})), _scanDefName(std::move(scanDefName)) {} + +const ProjectionName& ScanNode::getProjectionName() const { + return binder().names()[0]; +} + +const ProjectionType& ScanNode::getProjection() const { + return binder().exprs()[0]; +} + +const std::string& ScanNode::getScanDefName() const { + return _scanDefName; +} + +bool ScanNode::operator==(const ScanNode& other) const { + return getProjectionName() == other.getProjectionName() && _scanDefName == other._scanDefName; +} + +static ProjectionNameVector extractProjectionNamesForScan( + const FieldProjectionMap& fieldProjectionMap) { + ProjectionNameVector result; + + if (!fieldProjectionMap._ridProjection.empty()) { + result.push_back(fieldProjectionMap._ridProjection); + } + if (!fieldProjectionMap._rootProjection.empty()) { + result.push_back(fieldProjectionMap._rootProjection); + } + for (const auto& entry : fieldProjectionMap._fieldProjections) { + result.push_back(entry.second); + } + + return result; +} + +PhysicalScanNode::PhysicalScanNode(FieldProjectionMap fieldProjectionMap, + std::string scanDefName, + bool useParallelScan) + : Base(buildSimpleBinder(extractProjectionNamesForScan(fieldProjectionMap))), + _fieldProjectionMap(std::move(fieldProjectionMap)), + _scanDefName(std::move(scanDefName)), + _useParallelScan(useParallelScan) {} + +bool PhysicalScanNode::operator==(const PhysicalScanNode& other) const { + return _fieldProjectionMap == other._fieldProjectionMap && _scanDefName == other._scanDefName && + _useParallelScan == other._useParallelScan; +} + +const FieldProjectionMap& PhysicalScanNode::getFieldProjectionMap() const { + return _fieldProjectionMap; +} + +const std::string& PhysicalScanNode::getScanDefName() const { + return _scanDefName; +} + +bool PhysicalScanNode::useParallelScan() const { + return _useParallelScan; +} + +ValueScanNode::ValueScanNode(ProjectionNameVector projections) + : ValueScanNode(std::move(projections), Constant::emptyArray()) {} + +ValueScanNode::ValueScanNode(ProjectionNameVector projections, ABT valueArray) + : Base(buildSimpleBinder(std::move(projections))), _valueArray(std::move(valueArray)) { + const auto constPtr = _valueArray.cast<Constant>(); + uassert(6624081, "Expected a constant", constPtr != nullptr); + + const auto [tag, val] = constPtr->get(); + uassert(6624082, "Expected an array constant.", tag == sbe::value::TypeTags::Array); + + const auto arr = sbe::value::getArrayView(val); + _arraySize = arr->size(); + const size_t projectionCount = binder().names().size(); + for (size_t i = 0; i < _arraySize; i++) { + const auto [tag1, val1] = arr->getAt(i); + uassert(6624083, "Expected an array element.", tag1 == sbe::value::TypeTags::Array); + const size_t innerSize = sbe::value::getArrayView(val1)->size(); + uassert(6624084, "Invalid array size.", innerSize == projectionCount); + } +} + +bool ValueScanNode::operator==(const ValueScanNode& other) const { + return binder() == other.binder() && _arraySize == other._arraySize && + _valueArray == other._valueArray; +} + +const ABT& ValueScanNode::getValueArray() const { + return _valueArray; +} + +size_t ValueScanNode::getArraySize() const { + return _arraySize; +} + +CoScanNode::CoScanNode() : Base() {} + +bool CoScanNode::operator==(const CoScanNode& other) const { + return true; +} + +IndexScanNode::IndexScanNode(FieldProjectionMap fieldProjectionMap, IndexSpecification indexSpec) + : Base(buildSimpleBinder(extractProjectionNamesForScan(fieldProjectionMap))), + _fieldProjectionMap(std::move(fieldProjectionMap)), + _indexSpec(std::move(indexSpec)) {} + +bool IndexScanNode::operator==(const IndexScanNode& other) const { + // Scan spec does not participate, the indexSpec by itself should determine equality. + return _fieldProjectionMap == other._fieldProjectionMap && _indexSpec == other._indexSpec; +} + +const FieldProjectionMap& IndexScanNode::getFieldProjectionMap() const { + return _fieldProjectionMap; +} + +const IndexSpecification& IndexScanNode::getIndexSpecification() const { + return _indexSpec; +} + +SeekNode::SeekNode(ProjectionName ridProjectionName, + FieldProjectionMap fieldProjectionMap, + std::string scanDefName) + : Base(buildSimpleBinder(extractProjectionNamesForScan(fieldProjectionMap)), + make<References>(ProjectionNameVector{ridProjectionName})), + _ridProjectionName(std::move(ridProjectionName)), + _fieldProjectionMap(std::move(fieldProjectionMap)), + _scanDefName(std::move(scanDefName)) {} + +bool SeekNode::operator==(const SeekNode& other) const { + return _ridProjectionName == other._ridProjectionName && + _fieldProjectionMap == other._fieldProjectionMap && _scanDefName == other._scanDefName; +} + +const FieldProjectionMap& SeekNode::getFieldProjectionMap() const { + return _fieldProjectionMap; +} + +const std::string& SeekNode::getScanDefName() const { + return _scanDefName; +} + +const ProjectionName& SeekNode::getRIDProjectionName() const { + return _ridProjectionName; +} + +MemoLogicalDelegatorNode::MemoLogicalDelegatorNode(const GroupIdType groupId) + : Base(), _groupId(groupId) {} + +GroupIdType MemoLogicalDelegatorNode::getGroupId() const { + return _groupId; +} + +bool MemoLogicalDelegatorNode::operator==(const MemoLogicalDelegatorNode& other) const { + return _groupId == other._groupId; +} + +MemoPhysicalDelegatorNode::MemoPhysicalDelegatorNode(const MemoPhysicalNodeId nodeId) + : Base(), _nodeId(nodeId) {} + +bool MemoPhysicalDelegatorNode::operator==(const MemoPhysicalDelegatorNode& other) const { + return _nodeId == other._nodeId; +} + +MemoPhysicalNodeId MemoPhysicalDelegatorNode::getNodeId() const { + return _nodeId; +} + +FilterNode::FilterNode(FilterType filter, ABT child) : Base(std::move(child), std::move(filter)) { + assertExprSort(getFilter()); + assertNodeSort(getChild()); +} + +bool FilterNode::operator==(const FilterNode& other) const { + return getFilter() == other.getFilter() && getChild() == other.getChild(); +} + +const FilterType& FilterNode::getFilter() const { + return get<1>(); +} + +FilterType& FilterNode::getFilter() { + return get<1>(); +} + +const ABT& FilterNode::getChild() const { + return get<0>(); +} + +ABT& FilterNode::getChild() { + return get<0>(); +} + +EvaluationNode::EvaluationNode(ProjectionName projectionName, ProjectionType projection, ABT child) + : Base(std::move(child), + make<ExpressionBinder>(std::move(projectionName), std::move(projection))) { + assertNodeSort(getChild()); +} + +bool EvaluationNode::operator==(const EvaluationNode& other) const { + return binder() == other.binder() && getProjection() == other.getProjection() && + getChild() == other.getChild(); +} + +RIDIntersectNode::RIDIntersectNode(ProjectionName scanProjectionName, + const bool hasLeftIntervals, + const bool hasRightIntervals, + ABT leftChild, + ABT rightChild) + : Base(std::move(leftChild), std::move(rightChild)), + _scanProjectionName(std::move(scanProjectionName)), + _hasLeftIntervals(hasLeftIntervals), + _hasRightIntervals(hasRightIntervals) { + assertNodeSort(getLeftChild()); + assertNodeSort(getRightChild()); +} + +const ABT& RIDIntersectNode::getLeftChild() const { + return get<0>(); +} + +ABT& RIDIntersectNode::getLeftChild() { + return get<0>(); +} + +const ABT& RIDIntersectNode::getRightChild() const { + return get<1>(); +} + +ABT& RIDIntersectNode::getRightChild() { + return get<1>(); +} + +bool RIDIntersectNode::operator==(const RIDIntersectNode& other) const { + return _scanProjectionName == other._scanProjectionName && + _hasLeftIntervals == other._hasLeftIntervals && + _hasRightIntervals == other._hasRightIntervals && getLeftChild() == other.getLeftChild() && + getRightChild() == other.getRightChild(); +} + +const ProjectionName& RIDIntersectNode::getScanProjectionName() const { + return _scanProjectionName; +} + +bool RIDIntersectNode::hasLeftIntervals() const { + return _hasLeftIntervals; +} + +bool RIDIntersectNode::hasRightIntervals() const { + return _hasRightIntervals; +} + +static ProjectionNameVector createSargableBindings(const PartialSchemaRequirements& reqMap) { + ProjectionNameVector result; + for (const auto& entry : reqMap) { + if (entry.second.hasBoundProjectionName()) { + result.push_back(entry.second.getBoundProjectionName()); + } + } + return result; +} + +static ProjectionNameVector createSargableReferences(const PartialSchemaRequirements& reqMap) { + ProjectionNameOrderPreservingSet result; + for (const auto& entry : reqMap) { + result.emplace_back(entry.first._projectionName); + } + return result.getVector(); +} + +SargableNode::SargableNode(PartialSchemaRequirements reqMap, + CandidateIndexMap candidateIndexMap, + const IndexReqTarget target, + ABT child) + : Base(std::move(child), + buildSimpleBinder(createSargableBindings(reqMap)), + make<References>(createSargableReferences(reqMap))), + _reqMap(std::move(reqMap)), + _candidateIndexMap(std::move(candidateIndexMap)), + _target(target) { + assertNodeSort(getChild()); + uassert(6624085, "Empty requirements map", !_reqMap.empty()); + // We currently use a 64-bit mask when splitting into left and right requirements. + uassert(6624086, "Requirements map too large", _reqMap.size() < 64); + + // Assert merged map does not contain duplicate bound projections. + ProjectionNameSet boundsProjectionNameSet; + for (const auto& entry : _reqMap) { + if (entry.second.hasBoundProjectionName() && + !boundsProjectionNameSet.insert(entry.second.getBoundProjectionName()).second) { + uasserted(6624087, "Duplicate bound projection"); + } + } + + // Assert there are no references to internally bound projections. + for (const auto& entry : _reqMap) { + if (boundsProjectionNameSet.find(entry.first._projectionName) != + boundsProjectionNameSet.cend()) { + uasserted(6624088, "We are binding to an internal projection"); + } + } +} + +bool SargableNode::operator==(const SargableNode& other) const { + return _reqMap == other._reqMap && _candidateIndexMap == other._candidateIndexMap && + _target == other._target && getChild() == other.getChild(); +} + +const PartialSchemaRequirements& SargableNode::getReqMap() const { + return _reqMap; +} + +const CandidateIndexMap& SargableNode::getCandidateIndexMap() const { + return _candidateIndexMap; +} + +IndexReqTarget SargableNode::getTarget() const { + return _target; +} + +BinaryJoinNode::BinaryJoinNode(JoinType joinType, + ProjectionNameSet correlatedProjectionNames, + FilterType filter, + ABT leftChild, + ABT rightChild) + : Base(std::move(leftChild), std::move(rightChild), std::move(filter)), + _joinType(joinType), + _correlatedProjectionNames(std::move(correlatedProjectionNames)) { + assertExprSort(getFilter()); + assertNodeSort(getLeftChild()); + assertNodeSort(getRightChild()); +} + +JoinType BinaryJoinNode::getJoinType() const { + return _joinType; +} + +const ProjectionNameSet& BinaryJoinNode::getCorrelatedProjectionNames() const { + return _correlatedProjectionNames; +} + +bool BinaryJoinNode::operator==(const BinaryJoinNode& other) const { + return _joinType == other._joinType && + _correlatedProjectionNames == other._correlatedProjectionNames && + getLeftChild() == other.getLeftChild() && getRightChild() == other.getRightChild(); +} + +const ABT& BinaryJoinNode::getLeftChild() const { + return get<0>(); +} + +ABT& BinaryJoinNode::getLeftChild() { + return get<0>(); +} + +const ABT& BinaryJoinNode::getRightChild() const { + return get<1>(); +} + +ABT& BinaryJoinNode::getRightChild() { + return get<1>(); +} + +const ABT& BinaryJoinNode::getFilter() const { + return get<2>(); +} + +static ABT buildHashJoinReferences(const ProjectionNameVector& leftKeys, + const ProjectionNameVector& rightKeys) { + ABTVector variables; + for (const ProjectionName& projection : leftKeys) { + variables.emplace_back(make<Variable>(projection)); + } + for (const ProjectionName& projection : rightKeys) { + variables.emplace_back(make<Variable>(projection)); + } + + return make<References>(std::move(variables)); +} + +HashJoinNode::HashJoinNode(JoinType joinType, + ProjectionNameVector leftKeys, + ProjectionNameVector rightKeys, + ABT leftChild, + ABT rightChild) + : Base(std::move(leftChild), + std::move(rightChild), + buildHashJoinReferences(leftKeys, rightKeys)), + _joinType(joinType), + _leftKeys(std::move(leftKeys)), + _rightKeys(std::move(rightKeys)) { + uassert( + 6624089, "Invalid key sizes", !_leftKeys.empty() && _leftKeys.size() == _rightKeys.size()); + assertNodeSort(getLeftChild()); + assertNodeSort(getRightChild()); +} + +bool HashJoinNode::operator==(const HashJoinNode& other) const { + return _joinType == other._joinType && _leftKeys == other._leftKeys && + _rightKeys == other._rightKeys && getLeftChild() == other.getLeftChild() && + getRightChild() == other.getRightChild(); +} + +JoinType HashJoinNode::getJoinType() const { + return _joinType; +} + +const ProjectionNameVector& HashJoinNode::getLeftKeys() const { + return _leftKeys; +} + +const ProjectionNameVector& HashJoinNode::getRightKeys() const { + return _rightKeys; +} + +const ABT& HashJoinNode::getLeftChild() const { + return get<0>(); +} + +ABT& HashJoinNode::getLeftChild() { + return get<0>(); +} + +const ABT& HashJoinNode::getRightChild() const { + return get<1>(); +} + +ABT& HashJoinNode::getRightChild() { + return get<1>(); +} + +MergeJoinNode::MergeJoinNode(ProjectionNameVector leftKeys, + ProjectionNameVector rightKeys, + std::vector<CollationOp> collation, + ABT leftChild, + ABT rightChild) + : Base(std::move(leftChild), + std::move(rightChild), + buildHashJoinReferences(leftKeys, rightKeys)), + _collation(std::move(collation)), + _leftKeys(std::move(leftKeys)), + _rightKeys(std::move(rightKeys)) { + uassert( + 6624090, "Invalid key sizes", !_leftKeys.empty() && _leftKeys.size() == _rightKeys.size()); + uassert(6624091, "Invalid collation size", _collation.size() == _leftKeys.size()); + assertNodeSort(getLeftChild()); + assertNodeSort(getRightChild()); +} + +bool MergeJoinNode::operator==(const MergeJoinNode& other) const { + return _leftKeys == other._leftKeys && _rightKeys == other._rightKeys && + _collation == other._collation && getLeftChild() == other.getLeftChild() && + getRightChild() == other.getRightChild(); +} + +const ProjectionNameVector& MergeJoinNode::getLeftKeys() const { + return _leftKeys; +} + +const ProjectionNameVector& MergeJoinNode::getRightKeys() const { + return _rightKeys; +} + +const std::vector<CollationOp>& MergeJoinNode::getCollation() const { + return _collation; +} + +const ABT& MergeJoinNode::getLeftChild() const { + return get<0>(); +} + +ABT& MergeJoinNode::getLeftChild() { + return get<0>(); +} + +const ABT& MergeJoinNode::getRightChild() const { + return get<1>(); +} + +ABT& MergeJoinNode::getRightChild() { + return get<1>(); +} + +/** + * A helper that builds References object of UnionNode for reference tracking purposes. + * + * Example: union outputs 3 projections: A,B,C and it has 4 children. Then the References object is + * a vector of variables A,B,C,A,B,C,A,B,C,A,B,C. One group of variables per child. + */ +static ABT buildUnionReferences(const ProjectionNameVector& names, const size_t numOfChildren) { + ABTVector variables; + for (size_t outerIdx = 0; outerIdx < numOfChildren; ++outerIdx) { + for (size_t idx = 0; idx < names.size(); ++idx) { + variables.emplace_back(make<Variable>(names[idx])); + } + } + + return make<References>(std::move(variables)); +} + +UnionNode::UnionNode(ProjectionNameVector unionProjectionNames, ABTVector children) + : Base(std::move(children), + buildSimpleBinder(unionProjectionNames), + buildUnionReferences(unionProjectionNames, children.size())) { + uassert(6624007, "Empty union", !unionProjectionNames.empty()); + + for (auto& n : nodes()) { + assertNodeSort(n); + } +} + +bool UnionNode::operator==(const UnionNode& other) const { + return binder() == other.binder() && nodes() == other.nodes(); +} + +GroupByNode::GroupByNode(ProjectionNameVector groupByProjectionNames, + ProjectionNameVector aggregationProjectionNames, + ABTVector aggregationExpressions, + ABT child) + : GroupByNode(std::move(groupByProjectionNames), + std::move(aggregationProjectionNames), + std::move(aggregationExpressions), + GroupNodeType::Complete, + std::move(child)) {} + +GroupByNode::GroupByNode(ProjectionNameVector groupByProjectionNames, + ProjectionNameVector aggregationProjectionNames, + ABTVector aggregationExpressions, + GroupNodeType type, + ABT child) + : Base(std::move(child), + buildSimpleBinder(aggregationProjectionNames), + make<References>(std::move(aggregationExpressions)), + buildSimpleBinder(groupByProjectionNames), + make<References>(groupByProjectionNames)), + _type(type) { + assertNodeSort(getChild()); + uassert(6624300, + "Mismatched number of agg expressions and names", + getAggregationExpressions().size() == getAggregationProjectionNames().size()); +} + +bool GroupByNode::operator==(const GroupByNode& other) const { + return getAggregationProjectionNames() == other.getAggregationProjectionNames() && + getAggregationProjections() == other.getAggregationProjections() && + getGroupByProjectionNames() == other.getGroupByProjectionNames() && _type == other._type && + getChild() == other.getChild(); +} + +const ABTVector& GroupByNode::getAggregationExpressions() const { + return get<2>().cast<References>()->nodes(); +} + +const ABT& GroupByNode::getChild() const { + return get<0>(); +} + +ABT& GroupByNode::getChild() { + return get<0>(); +} + +GroupNodeType GroupByNode::getType() const { + return _type; +} + +UnwindNode::UnwindNode(ProjectionName projectionName, + ProjectionName pidProjectionName, + const bool retainNonArrays, + ABT child) + : Base(std::move(child), + buildSimpleBinder(ProjectionNameVector{projectionName, std::move(pidProjectionName)}), + make<References>(ProjectionNameVector{projectionName})), + _retainNonArrays(retainNonArrays) { + assertNodeSort(getChild()); +} + +bool UnwindNode::getRetainNonArrays() const { + return _retainNonArrays; +} + +const ABT& UnwindNode::getChild() const { + return get<0>(); +} + +ABT& UnwindNode::getChild() { + return get<0>(); +} + +bool UnwindNode::operator==(const UnwindNode& other) const { + return binder() == other.binder() && _retainNonArrays == other._retainNonArrays && + getChild() == other.getChild(); +} + +UniqueNode::UniqueNode(ProjectionNameVector projections, ABT child) + : Base(std::move(child), make<References>(ProjectionNameVector{projections})), + _projections(std::move(projections)) { + assertNodeSort(getChild()); + uassert(6624092, "Empty projections", !_projections.empty()); +} + +bool UniqueNode::operator==(const UniqueNode& other) const { + return _projections == other._projections; +} + +const ProjectionNameVector& UniqueNode::getProjections() const { + return _projections; +} + +const ABT& UniqueNode::getChild() const { + return get<0>(); +} + +CollationNode::CollationNode(properties::CollationRequirement property, ABT child) + : Base(std::move(child), + buildReferences(extractReferencedColumns(properties::makePhysProps(property)))), + _property(std::move(property)) { + assertNodeSort(getChild()); +} + +const properties::CollationRequirement& CollationNode::getProperty() const { + return _property; +} + +properties::CollationRequirement& CollationNode::getProperty() { + return _property; +} + +bool CollationNode::operator==(const CollationNode& other) const { + return _property == other._property && getChild() == other.getChild(); +} + +const ABT& CollationNode::getChild() const { + return get<0>(); +} + +ABT& CollationNode::getChild() { + return get<0>(); +} + +LimitSkipNode::LimitSkipNode(properties::LimitSkipRequirement property, ABT child) + : Base(std::move(child)), _property(std::move(property)) { + assertNodeSort(getChild()); +} + +const properties::LimitSkipRequirement& LimitSkipNode::getProperty() const { + return _property; +} + +properties::LimitSkipRequirement& LimitSkipNode::getProperty() { + return _property; +} + +bool LimitSkipNode::operator==(const LimitSkipNode& other) const { + return _property == other._property && getChild() == other.getChild(); +} + +const ABT& LimitSkipNode::getChild() const { + return get<0>(); +} + +ABT& LimitSkipNode::getChild() { + return get<0>(); +} + +ExchangeNode::ExchangeNode(const properties::DistributionRequirement distribution, ABT child) + : Base(std::move(child), buildReferences(distribution.getAffectedProjectionNames())), + _distribution(std::move(distribution)) { + assertNodeSort(getChild()); + uassert(6624008, + "Cannot exchange towards an unknown distribution", + distribution.getDistributionAndProjections()._type != + DistributionType::UnknownPartitioning); +} + +bool ExchangeNode::operator==(const ExchangeNode& other) const { + return _distribution == other._distribution && getChild() == other.getChild(); +} + +const ABT& ExchangeNode::getChild() const { + return get<0>(); +} + +ABT& ExchangeNode::getChild() { + return get<0>(); +} + +const properties::DistributionRequirement& ExchangeNode::getProperty() const { + return _distribution; +} + +properties::DistributionRequirement& ExchangeNode::getProperty() { + return _distribution; +} + +RootNode::RootNode(properties::ProjectionRequirement property, ABT child) + : Base(std::move(child), buildReferences(property.getAffectedProjectionNames())), + _property(std::move(property)) { + assertNodeSort(getChild()); +} + +bool RootNode::operator==(const RootNode& other) const { + return getChild() == other.getChild() && _property == other._property; +} + +const properties::ProjectionRequirement& RootNode::getProperty() const { + return _property; +} + +const ABT& RootNode::getChild() const { + return get<0>(); +} + +ABT& RootNode::getChild() { + return get<0>(); +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/node.h b/src/mongo/db/query/optimizer/node.h new file mode 100644 index 00000000000..1a202140f6e --- /dev/null +++ b/src/mongo/db/query/optimizer/node.h @@ -0,0 +1,862 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <memory> +#include <sstream> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "mongo/db/query/optimizer/algebra/operator.h" +#include "mongo/db/query/optimizer/defs.h" +#include "mongo/db/query/optimizer/metadata.h" +#include "mongo/db/query/optimizer/props.h" +#include "mongo/db/query/optimizer/syntax/expr.h" +#include "mongo/db/query/optimizer/syntax/path.h" + + +namespace mongo::optimizer { + +using FilterType = ABT; +using ProjectionType = ABT; + +/** + * Marker for node class (both logical and physical sub-classes). + * A node not marked with either LogicalNode or PhysicalNode is considered to be both a logical and + * a physical node (e.g. a filter node). It is invalid to mark a node with both tags in the same + * time. + */ +class Node {}; + +/** + * Marker for exclusively logical nodes. + */ +class LogicalNode {}; + +/** + * Marker for exclusively physical nodes. + */ +class PhysicalNode {}; + +inline void assertNodeSort(const ABT& e) { + if (!e.is<Node>()) { + uasserted(6624009, "Node syntax sort expected"); + } +} + +template <class T> +inline constexpr bool canBeLogicalNode() { + // Node which is not exclusively physical. + return std::is_base_of_v<Node, T> && !std::is_base_of_v<PhysicalNode, T>; +} + +template <class T> +inline constexpr bool canBePhysicalNode() { + // Node which is not exclusively logical. + return std::is_base_of_v<Node, T> && !std::is_base_of_v<LogicalNode, T>; +} + +/** + * Logical Scan node. + * It defines scanning a collection with an optional projection name that contains the documents. + * The collection is specified via the scanDefName entry in the metadata. + */ +class ScanNode final : public Operator<ScanNode, 1>, public Node, public LogicalNode { + using Base = Operator<ScanNode, 1>; + +public: + static constexpr const char* kDefaultCollectionNameSpec = "collectionName"; + + ScanNode(ProjectionName projectionName, std::string scanDefName); + + bool operator==(const ScanNode& other) const; + + const ExpressionBinder& binder() const { + const ABT& result = get<0>(); + uassert(6624010, "Invalid binder type", result.is<ExpressionBinder>()); + return *result.cast<ExpressionBinder>(); + } + + const ProjectionName& getProjectionName() const; + const ProjectionType& getProjection() const; + + const std::string& getScanDefName() const; + +private: + const std::string _scanDefName; +}; + +/** + * Physical Scan node. + * It defines scanning a collection with an optional projection name that contains the documents. + * The collection is specified via the scanDefName entry in the metadata. + * + * Optionally set of fields is specified to retrieve from the underlying collection, and expose as + * projections. + */ +class PhysicalScanNode final : public Operator<PhysicalScanNode, 1>, + public Node, + public PhysicalNode { + using Base = Operator<PhysicalScanNode, 1>; + +public: + PhysicalScanNode(FieldProjectionMap fieldProjectionMap, + std::string scanDefName, + bool useParallelScan); + + bool operator==(const PhysicalScanNode& other) const; + + const ExpressionBinder& binder() const { + const ABT& result = get<0>(); + uassert(6624011, "Invalid binder type", result.is<ExpressionBinder>()); + return *result.cast<ExpressionBinder>(); + } + + const FieldProjectionMap& getFieldProjectionMap() const; + + const std::string& getScanDefName() const; + + bool useParallelScan() const; + +private: + const FieldProjectionMap _fieldProjectionMap; + const std::string _scanDefName; + const bool _useParallelScan; +}; + +/** + * Logical ValueScanNode. + * + * It originates a set of projections each with a fixed + * sequence of values, which is encoded as an array. + */ +class ValueScanNode final : public Operator<ValueScanNode, 1>, public Node, public LogicalNode { + using Base = Operator<ValueScanNode, 1>; + +public: + ValueScanNode(ProjectionNameVector projections); + ValueScanNode(ProjectionNameVector projections, ABT valueArray); + + bool operator==(const ValueScanNode& other) const; + + const ExpressionBinder& binder() const { + const ABT& result = get<0>(); + uassert(6624012, "Invalid binder type", result.is<ExpressionBinder>()); + return *result.cast<ExpressionBinder>(); + } + + const ABT& getValueArray() const; + size_t getArraySize() const; + +private: + const ABT _valueArray; + size_t _arraySize; +}; + +/** + * Physical CoScanNode. + * + * Conceptually it originates an infinite stream of Nothing. + * A typical use case is to limit it to one document, and attach projections with a following + * EvaluationNode(s). + */ +class CoScanNode final : public Operator<CoScanNode, 0>, public Node, public PhysicalNode { + using Base = Operator<CoScanNode, 0>; + +public: + CoScanNode(); + + bool operator==(const CoScanNode& other) const; +}; + +/** + * Index scan node. + * Retrieve data using an index. Return recordIds or values (if the index is covering). + * This is a physical node. + * + * The collection is specified by scanDef, and the index by the indexDef. + */ +class IndexScanNode final : public Operator<IndexScanNode, 1>, public Node, public PhysicalNode { + using Base = Operator<IndexScanNode, 1>; + +public: + IndexScanNode(FieldProjectionMap fieldProjectionMap, IndexSpecification indexSpec); + + bool operator==(const IndexScanNode& other) const; + + const ExpressionBinder& binder() const { + const ABT& result = get<0>(); + uassert(6624013, "Invalid binder type", result.is<ExpressionBinder>()); + return *result.cast<ExpressionBinder>(); + } + + const FieldProjectionMap& getFieldProjectionMap() const; + const IndexSpecification& getIndexSpecification() const; + +private: + const FieldProjectionMap _fieldProjectionMap; + const IndexSpecification _indexSpec; +}; + +/** + * SeekNode. + * Retrieve values using rowIds (typically previously retrieved using an index scan). + * This is a physical node. + * + * 'ridProjectionName' parameter designates the incoming rid which is the starting point of the + * seek. 'fieldProjectionMap' may choose to include an outgoing rid which will contain the + * successive (if we do not have a following limit) document ids. + * + * TODO: Can we let it advance with a limit based on upper rid limit in case of primary index? + */ +class SeekNode final : public Operator<SeekNode, 2>, public Node, public PhysicalNode { + using Base = Operator<SeekNode, 2>; + +public: + SeekNode(ProjectionName ridProjectionName, + FieldProjectionMap fieldProjectionMap, + std::string scanDefName); + + bool operator==(const SeekNode& other) const; + + const ExpressionBinder& binder() const { + const ABT& result = get<0>(); + uassert(6624014, "Invalid binder type", result.is<ExpressionBinder>()); + return *result.cast<ExpressionBinder>(); + } + + const ProjectionName& getRIDProjectionName() const; + + const FieldProjectionMap& getFieldProjectionMap() const; + + const std::string& getScanDefName() const; + +private: + const ProjectionName _ridProjectionName; + const FieldProjectionMap _fieldProjectionMap; + const std::string _scanDefName; +}; + + +/** + * Logical group delegator node: scan from a given group. + * Used in conjunction with memo. + */ +class MemoLogicalDelegatorNode final : public Operator<MemoLogicalDelegatorNode, 0>, + public Node, + public LogicalNode { + using Base = Operator<MemoLogicalDelegatorNode, 0>; + +public: + MemoLogicalDelegatorNode(GroupIdType groupId); + + bool operator==(const MemoLogicalDelegatorNode& other) const; + + GroupIdType getGroupId() const; + +private: + const GroupIdType _groupId; +}; + +/** + * Physical group delegator node: refer to a physical node in a memo group. + * Used in conjunction with memo. + */ +class MemoPhysicalDelegatorNode final : public Operator<MemoPhysicalDelegatorNode, 0>, + public Node, + public PhysicalNode { + using Base = Operator<MemoPhysicalDelegatorNode, 0>; + +public: + MemoPhysicalDelegatorNode(MemoPhysicalNodeId nodeId); + + bool operator==(const MemoPhysicalDelegatorNode& other) const; + + MemoPhysicalNodeId getNodeId() const; + +private: + const MemoPhysicalNodeId _nodeId; +}; + +/** + * Filter node. + * It applies a filter over its input. + * + * This node is both logical and physical. + */ +class FilterNode final : public Operator<FilterNode, 2>, public Node { + using Base = Operator<FilterNode, 2>; + +public: + FilterNode(FilterType filter, ABT child); + + bool operator==(const FilterNode& other) const; + + const FilterType& getFilter() const; + FilterType& getFilter(); + + const ABT& getChild() const; + ABT& getChild(); +}; + +/** + * Evaluation node. + * Adds a new projection to its input. + * + * This node is both logical and physical. + */ +class EvaluationNode final : public Operator<EvaluationNode, 2>, public Node { + using Base = Operator<EvaluationNode, 2>; + +public: + EvaluationNode(ProjectionName projectionName, ProjectionType projection, ABT child); + + bool operator==(const EvaluationNode& other) const; + + const ExpressionBinder& binder() const { + const ABT& result = get<1>(); + uassert(6624015, "Invalid binder type", result.is<ExpressionBinder>()); + return *result.cast<ExpressionBinder>(); + } + + const ProjectionName& getProjectionName() const { + return binder().names()[0]; + } + + const ProjectionType& getProjection() const { + return binder().exprs()[0]; + } + + const ABT& getChild() const { + return get<0>(); + } + + ABT& getChild() { + return get<0>(); + } +}; + +/** + * RID intersection node. + * This is a logical node representing either index-index intersection or index-collection scan + * (seek) fetch. + * + * It is equivalent to a join node with the difference that RID projections do not exist on logical + * level, and thus projection names are not determined until physical optimization. We want to also + * restrict the type of operations on RIDs (in this case only set intersection) as opposed to say + * filter on rid = 5. + */ +class RIDIntersectNode final : public Operator<RIDIntersectNode, 2>, + public Node, + public LogicalNode { + using Base = Operator<RIDIntersectNode, 2>; + +public: + RIDIntersectNode(ProjectionName scanProjectionName, + bool hasLeftIntervals, + bool hasRightIntervals, + ABT leftChild, + ABT rightChild); + + bool operator==(const RIDIntersectNode& other) const; + + const ABT& getLeftChild() const; + ABT& getLeftChild(); + + const ABT& getRightChild() const; + ABT& getRightChild(); + + const ProjectionName& getScanProjectionName() const; + + bool hasLeftIntervals() const; + bool hasRightIntervals() const; + +private: + const ProjectionName _scanProjectionName; + + // If true left and right children have at least one proper interval (not fully open). + const bool _hasLeftIntervals; + const bool _hasRightIntervals; +}; + +/** + * Sargable node. + * This is a logical node which represents special kinds of (simple) evaluations and filters which + * are amenable to being used in indexing or covered scans. + * + * It collects a conjunction of predicates in the following form: + * <path, inputProjection> -> <interval, outputProjection> + + * For example to encode a conjunction which encodes filtering with array traversal on "a" + ($match(a: {$gt, 1}} combined with a retrieval of the field "b" (without restrictions on its + value). + * PathGet "a" Traverse Id | scan_0 -> [1, +inf], <none> + * PathGet "b" Id | scan_0 -> (-inf, +inf), "pb" + */ +class SargableNode final : public Operator<SargableNode, 3>, public Node, public LogicalNode { + using Base = Operator<SargableNode, 3>; + +public: + SargableNode(PartialSchemaRequirements reqMap, + CandidateIndexMap candidateIndexMap, + IndexReqTarget target, + ABT child); + + bool operator==(const SargableNode& other) const; + + const ExpressionBinder& binder() const { + const ABT& result = get<1>(); + uassert(6624016, "Invalid binder type", result.is<ExpressionBinder>()); + return *result.cast<ExpressionBinder>(); + } + + const ABT& getChild() const { + return get<0>(); + } + ABT& getChild() { + return get<0>(); + } + + const PartialSchemaRequirements& getReqMap() const; + const CandidateIndexMap& getCandidateIndexMap() const; + + IndexReqTarget getTarget() const; + +private: + const PartialSchemaRequirements _reqMap; + + CandidateIndexMap _candidateIndexMap; + + // Performance optimization to limit number of groups. + // Under what indexing requirements can this node be implemented. + const IndexReqTarget _target; +}; + +#define JOIN_TYPE(F) \ + F(Inner) \ + F(Left) \ + F(Right) \ + F(Full) + +MAKE_PRINTABLE_ENUM(JoinType, JOIN_TYPE); +MAKE_PRINTABLE_ENUM_STRING_ARRAY(JoinTypeEnum, JoinType, JOIN_TYPE); +#undef JOIN_TYPE + +/** + * Logical binary join. + * Join of two logical nodes. Can express inner and outer joins, with an associated join predicate. + * + * This node is logical, with a default physical implementation corresponding to a Nested Loops Join + * (NLJ). + * Variables used in the inner (right) side are automatically bound with variables from the left + * (outer) side. + */ +class BinaryJoinNode final : public Operator<BinaryJoinNode, 3>, public Node { + using Base = Operator<BinaryJoinNode, 3>; + +public: + BinaryJoinNode(JoinType joinType, + ProjectionNameSet correlatedProjectionNames, + FilterType filter, + ABT leftChild, + ABT rightChild); + + bool operator==(const BinaryJoinNode& other) const; + + JoinType getJoinType() const; + + const ProjectionNameSet& getCorrelatedProjectionNames() const; + + const ABT& getLeftChild() const; + ABT& getLeftChild(); + + const ABT& getRightChild() const; + ABT& getRightChild(); + + const ABT& getFilter() const; + +private: + const JoinType _joinType; + + // Those projections must exist on the outer side and are used to bind free variables on the + // inner side. + const ProjectionNameSet _correlatedProjectionNames; +}; + +/** + * Physical hash join node. + * Join condition is a conjunction of pairwise equalities between corresponding left and right keys. + * + * TODO: support all join types (not just Inner). + */ +class HashJoinNode final : public Operator<HashJoinNode, 3>, public Node, public PhysicalNode { + using Base = Operator<HashJoinNode, 3>; + +public: + HashJoinNode(JoinType joinType, + ProjectionNameVector leftKeys, + ProjectionNameVector rightKeys, + ABT leftChild, + ABT rightChild); + + bool operator==(const HashJoinNode& other) const; + + JoinType getJoinType() const; + const ProjectionNameVector& getLeftKeys() const; + const ProjectionNameVector& getRightKeys() const; + + const ABT& getLeftChild() const; + ABT& getLeftChild(); + + const ABT& getRightChild() const; + ABT& getRightChild(); + +private: + const JoinType _joinType; + + // Join condition is a conjunction of _leftKeys.at(i) == _rightKeys.at(i). + const ProjectionNameVector _leftKeys; + const ProjectionNameVector _rightKeys; +}; + +/** + * Merge Join node. + * This is a physical node representing joining of two sorted inputs. + */ +class MergeJoinNode final : public Operator<MergeJoinNode, 3>, public Node, public PhysicalNode { + using Base = Operator<MergeJoinNode, 3>; + +public: + MergeJoinNode(ProjectionNameVector leftKeys, + ProjectionNameVector rightKeys, + std::vector<CollationOp> collation, + ABT leftChild, + ABT rightChild); + + bool operator==(const MergeJoinNode& other) const; + + const ProjectionNameVector& getLeftKeys() const; + const ProjectionNameVector& getRightKeys() const; + + const std::vector<CollationOp>& getCollation() const; + + const ABT& getLeftChild() const; + ABT& getLeftChild(); + + const ABT& getRightChild() const; + ABT& getRightChild(); + +private: + // Describes how to merge the sorted streams. + std::vector<CollationOp> _collation; + + // Join condition is a conjunction of _leftKeys.at(i) == _rightKeys.at(i). + const ProjectionNameVector _leftKeys; + const ProjectionNameVector _rightKeys; +}; + +/** + * Union of several logical nodes. Projections in common to all nodes are logically union-ed in the + * output. It can be used with a single child just to restrict projections. + * + * This node is both logical and physical. + */ +class UnionNode final : public OperatorDynamic<UnionNode, 2>, public Node { + using Base = OperatorDynamic<UnionNode, 2>; + +public: + UnionNode(ProjectionNameVector unionProjectionNames, ABTVector children); + + bool operator==(const UnionNode& other) const; + + const ExpressionBinder& binder() const { + const ABT& result = get<0>(); + uassert(6624017, "Invalid binder type", result.is<ExpressionBinder>()); + return *result.cast<ExpressionBinder>(); + } +}; + +#define GROUPNODETYPE_OPNAMES(F) \ + F(Complete) \ + F(Local) \ + F(Global) + +MAKE_PRINTABLE_ENUM(GroupNodeType, GROUPNODETYPE_OPNAMES); +MAKE_PRINTABLE_ENUM_STRING_ARRAY(GroupNodeTypeEnum, GroupNodeType, GROUPNODETYPE_OPNAMES); +#undef PATHSYNTAX_OPNAMES + +/** + * Group-by node. + * This node is logical with a default physical implementation corresponding to a hash group-by. + * Projects the group-by column from its child, and adds aggregation expressions. + * + * TODO: other physical implementations: stream group-by. + */ +class GroupByNode : public Operator<GroupByNode, 5>, public Node { + using Base = Operator<GroupByNode, 5>; + +public: + GroupByNode(ProjectionNameVector groupByProjectionNames, + ProjectionNameVector aggregationProjectionNames, + ABTVector aggregationExpressions, + ABT child); + + GroupByNode(ProjectionNameVector groupByProjectionNames, + ProjectionNameVector aggregationProjectionNames, + ABTVector aggregationExpressions, + GroupNodeType type, + ABT child); + + bool operator==(const GroupByNode& other) const; + + const ExpressionBinder& binderAgg() const { + const ABT& result = get<1>(); + uassert(6624018, "Invalid binder type", result.is<ExpressionBinder>()); + return *result.cast<ExpressionBinder>(); + } + + const ExpressionBinder& binderGb() const { + const ABT& result = get<3>(); + uassert(6624019, "Invalid binder type", result.is<ExpressionBinder>()); + return *result.cast<ExpressionBinder>(); + } + + const ProjectionNameVector& getGroupByProjectionNames() const { + return binderGb().names(); + } + + const ProjectionNameVector& getAggregationProjectionNames() const { + return binderAgg().names(); + } + + const auto& getAggregationProjections() const { + return binderAgg().exprs(); + } + + const auto& getGroupByProjections() const { + return binderGb().exprs(); + } + + const ABTVector& getAggregationExpressions() const; + + const ABT& getChild() const; + ABT& getChild(); + + GroupNodeType getType() const; + +private: + // Used for local-global rewrite. + GroupNodeType _type; +}; + +/** + * Unwind node. + * Unwinds an embedded relation inside an array. Generates unwinding positions in the CID + * projection. + * + * This node is both logical and physical. + */ +class UnwindNode final : public Operator<UnwindNode, 3>, public Node { + using Base = Operator<UnwindNode, 3>; + +public: + UnwindNode(ProjectionName projectionName, + ProjectionName pidProjectionName, + bool retainNonArrays, + ABT child); + + bool operator==(const UnwindNode& other) const; + + const ExpressionBinder& binder() const { + const ABT& result = get<1>(); + uassert(6624020, "Invalid binder type", result.is<ExpressionBinder>()); + return *result.cast<ExpressionBinder>(); + } + + const ProjectionName& getProjectionName() const { + return binder().names()[0]; + } + + const ProjectionName& getPIDProjectionName() const { + return binder().names()[1]; + } + + const ProjectionType& getProjection() const { + return binder().exprs()[0]; + } + + const ProjectionType& getPIDProjection() const { + return binder().exprs()[1]; + } + + const ABT& getChild() const; + + ABT& getChild(); + + bool getRetainNonArrays() const; + +private: + const bool _retainNonArrays; +}; + +/** + * Unique node. + * + * This is a physical node. It encodes an operation which will duplicate the child input using a + * sequence of given projection names. It is similar to GroupBy using the given projections as a + * compound grouping key. + */ +class UniqueNode final : public Operator<UniqueNode, 2>, public Node, public PhysicalNode { + using Base = Operator<UniqueNode, 2>; + +public: + UniqueNode(ProjectionNameVector projections, ABT child); + + bool operator==(const UniqueNode& other) const; + + const ProjectionNameVector& getProjections() const; + + const ABT& getChild() const; + +private: + ProjectionNameVector _projections; +}; + +/** + * Collation node. + * This node is both logical and physical. + * + * It represents an operator to collate (sort, or cluster) the input. + */ +class CollationNode final : public Operator<CollationNode, 2>, public Node { + using Base = Operator<CollationNode, 2>; + +public: + CollationNode(properties::CollationRequirement property, ABT child); + + bool operator==(const CollationNode& other) const; + + const properties::CollationRequirement& getProperty() const; + properties::CollationRequirement& getProperty(); + + const ABT& getChild() const; + + ABT& getChild(); + +private: + properties::CollationRequirement _property; +}; + +/** + * Limit and skip node. + * This node is both logical and physical. + * + * It limits the size of the input by a fixed amount. + */ +class LimitSkipNode final : public Operator<LimitSkipNode, 1>, public Node { + using Base = Operator<LimitSkipNode, 1>; + +public: + LimitSkipNode(properties::LimitSkipRequirement property, ABT child); + + bool operator==(const LimitSkipNode& other) const; + + const properties::LimitSkipRequirement& getProperty() const; + properties::LimitSkipRequirement& getProperty(); + + const ABT& getChild() const; + + ABT& getChild(); + +private: + properties::LimitSkipRequirement _property; +}; + +/** + * Exchange node. + * It specifies how the relation is spread across machines in the execution environment. + * Currently only single-node, and hash-based partitioning are supported. + * TODO: range-based partitioning, replication, and round-robin. + * + * This node is both logical and physical. + */ +class ExchangeNode final : public Operator<ExchangeNode, 2>, public Node { + using Base = Operator<ExchangeNode, 2>; + +public: + ExchangeNode(properties::DistributionRequirement distribution, ABT child); + + bool operator==(const ExchangeNode& other) const; + + const properties::DistributionRequirement& getProperty() const; + properties::DistributionRequirement& getProperty(); + + const ABT& getChild() const; + + ABT& getChild(); + +private: + properties::DistributionRequirement _distribution; + + /** + * Defined for hash and range-based partitioning. + * TODO: other exchange-specific params (e.g. chunk boundaries?) + */ + const ProjectionName _projectionName; +}; + +/** + * Root of the tree that holds references to the output of the query. In the mql case the query + * outputs a single "column" (aka document) but in a general case (SQL) we can output arbitrary many + * "columns". We need the internal references for the output projections in order to keep them live, + * otherwise they would be dropped from the tree by DCE. + * + * This node is only logical. + */ +class RootNode final : public Operator<RootNode, 2>, public Node { + using Base = Operator<RootNode, 2>; + +public: + RootNode(properties::ProjectionRequirement property, ABT child); + + bool operator==(const RootNode& other) const; + + const properties::ProjectionRequirement& getProperty() const; + + const ABT& getChild() const; + ABT& getChild(); + +private: + const properties::ProjectionRequirement _property; +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/node_defs.h b/src/mongo/db/query/optimizer/node_defs.h new file mode 100644 index 00000000000..9a097f24591 --- /dev/null +++ b/src/mongo/db/query/optimizer/node_defs.h @@ -0,0 +1,66 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/node.h" +#include "mongo/db/query/optimizer/props.h" + +namespace mongo::optimizer { + +// Used for physical rewrites. For each child we optimize, we specify physical properties. +using ChildPropsType = std::vector<std::pair<ABT*, properties::PhysProps>>; + +// Use for physical rewrites. For each physical node we implement, we set CE to use when costing. +using NodeCEMap = opt::unordered_map<const Node*, CEType>; + +struct NodeProps { + // Used to tie to a corresponding SBE stage. + int32_t _planNodeId; + + // Which is the corresponding memo group, and its properties. + MemoPhysicalNodeId _groupId; + properties::LogicalProps _logicalProps; + properties::PhysProps _physicalProps; + + // Total cost of the best plan (includes the subtree). + CostType _cost; + // Local cost (excludes subtree). + CostType _localCost; + + // For display purposes, adjusted cardinality based on physical properties (e.g. Repetition and + // Limit-Skip). + CEType _adjustedCE; +}; + +// Map from node to various properties, including logical and physical. Used to determine for +// example which of the available projections are used for exchanges. +using NodeToGroupPropsMap = opt::unordered_map<const Node*, NodeProps>; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/opt_phase_manager.cpp b/src/mongo/db/query/optimizer/opt_phase_manager.cpp new file mode 100644 index 00000000000..7c282ea2fc5 --- /dev/null +++ b/src/mongo/db/query/optimizer/opt_phase_manager.cpp @@ -0,0 +1,332 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/opt_phase_manager.h" + +#include "mongo/db/query/optimizer/cascades/ce_heuristic.h" +#include "mongo/db/query/optimizer/cascades/cost_derivation.h" +#include "mongo/db/query/optimizer/cascades/logical_props_derivation.h" +#include "mongo/db/query/optimizer/rewrites/const_eval.h" +#include "mongo/db/query/optimizer/rewrites/path.h" +#include "mongo/db/query/optimizer/rewrites/path_lower.h" +#include "mongo/db/query/optimizer/utils/memo_utils.h" + +namespace mongo::optimizer { + +OptPhaseManager::PhaseSet OptPhaseManager::_allRewrites = {OptPhase::ConstEvalPre, + OptPhase::PathFuse, + OptPhase::MemoSubstitutionPhase, + OptPhase::MemoExplorationPhase, + OptPhase::MemoImplementationPhase, + OptPhase::PathLower, + OptPhase::ConstEvalPost}; + +OptPhaseManager::OptPhaseManager(OptPhaseManager::PhaseSet phaseSet, + PrefixId& prefixId, + Metadata metadata, + DebugInfo debugInfo) + : OptPhaseManager(std::move(phaseSet), + prefixId, + false /*requireRID*/, + std::move(metadata), + std::make_unique<HeuristicCE>(), + std::make_unique<DefaultCosting>(), + std::move(debugInfo)) {} + +OptPhaseManager::OptPhaseManager(OptPhaseManager::PhaseSet phaseSet, + PrefixId& prefixId, + const bool requireRID, + Metadata metadata, + std::unique_ptr<CEInterface> ceDerivation, + std::unique_ptr<CostingInterface> costDerivation, + DebugInfo debugInfo) + : _phaseSet(std::move(phaseSet)), + _debugInfo(std::move(debugInfo)), + _hints(), + _metadata(std::move(metadata)), + _memo(_debugInfo, + _metadata, + std::make_unique<DefaultLogicalPropsDerivation>(), + std::move(ceDerivation)), + _costDerivation(std::move(costDerivation)), + _physicalNodeId(), + _requireRID(requireRID), + _ridProjections(), + _prefixId(prefixId) { + uassert(6624093, "Empty Cost derivation", _costDerivation.get()); + + for (const auto& entry : _metadata._scanDefs) { + _ridProjections.emplace(entry.first, _prefixId.getNextId("rid")); + } +} + +template <OptPhaseManager::OptPhase phase, class C> +bool OptPhaseManager::runStructuralPhase(C instance, VariableEnvironment& env, ABT& input) { + if (!hasPhase(phase)) { + return true; + } + + for (int iterationCount = 0; instance.optimize(input); iterationCount++) { + if (_debugInfo.exceedsIterationLimit(iterationCount)) { + // Iteration limit exceeded. + return false; + } + } + + return !env.hasFreeVariables(); +} + +template <OptPhaseManager::OptPhase phase1, OptPhaseManager::OptPhase phase2, class C1, class C2> +bool OptPhaseManager::runStructuralPhases(C1 instance1, + C2 instance2, + VariableEnvironment& env, + ABT& input) { + const bool hasPhase1 = hasPhase(phase1); + const bool hasPhase2 = hasPhase(phase2); + if (!hasPhase1 && !hasPhase2) { + return true; + } + + bool changed = true; + for (int iterationCount = 0; changed; iterationCount++) { + if (_debugInfo.exceedsIterationLimit(iterationCount)) { + // Iteration limit exceeded. + return false; + } + + changed = false; + if (hasPhase1) { + changed |= instance1.optimize(input); + } + if (hasPhase2) { + changed |= instance2.optimize(input); + } + } + + return !env.hasFreeVariables(); +} + +bool OptPhaseManager::runMemoLogicalRewrite(const OptPhase phase, + VariableEnvironment& env, + const LogicalRewriter::RewriteSet& rewriteSet, + GroupIdType& rootGroupId, + const bool runStandalone, + std::unique_ptr<LogicalRewriter>& logicalRewriter, + ABT& input) { + if (!hasPhase(phase)) { + return true; + } + + _memo.clear(); + logicalRewriter = std::make_unique<LogicalRewriter>(_memo, _prefixId, rewriteSet); + rootGroupId = logicalRewriter->addRootNode(input); + + if (runStandalone) { + if (!logicalRewriter->rewriteToFixPoint()) { + return false; + } + + input = extractLatestPlan(_memo, rootGroupId); + env.rebuild(input); + } + + return !env.hasFreeVariables(); +} + +bool OptPhaseManager::runMemoPhysicalRewrite(const OptPhase phase, + VariableEnvironment& env, + const GroupIdType rootGroupId, + std::unique_ptr<LogicalRewriter>& logicalRewriter, + ABT& input) { + using namespace properties; + + if (!hasPhase(phase)) { + return true; + } + if (rootGroupId < 0) { + // Nothing inserted in the memo. Logical rewrites did not run? + return false; + } + + // By default we require centralized result. + // Also by default we do not require projections: the Root node will add those. + PhysProps physProps = makePhysProps(DistributionRequirement(DistributionType::Centralized)); + if (_requireRID) { + const auto& rootLogicalProps = _memo.getGroup(rootGroupId)._logicalProperties; + if (!hasProperty<IndexingAvailability>(rootLogicalProps)) { + // We cannot obtain rid for this query. + return false; + } + + setProperty( + physProps, + IndexingRequirement( + IndexReqTarget::Complete, true /*needRID*/, true /*dedupRID*/, rootGroupId)); + } + + PhysicalRewriter rewriter(_memo, _hints, _ridProjections, *_costDerivation, logicalRewriter); + + auto optGroupResult = + rewriter.optimizeGroup(rootGroupId, std::move(physProps), _prefixId, CostType::kInfinity); + if (!optGroupResult._success) { + return false; + } + + _physicalNodeId = {rootGroupId, optGroupResult._index}; + std::tie(input, _nodeToGroupPropsMap) = extractPhysicalPlan(_physicalNodeId, _metadata, _memo); + + env.rebuild(input); + return !env.hasFreeVariables(); +} + +bool OptPhaseManager::runMemoRewritePhases(VariableEnvironment& env, ABT& input) { + GroupIdType rootGroupId = -1; + std::unique_ptr<LogicalRewriter> logicalRewriter; + + if (!runMemoLogicalRewrite(OptPhase::MemoSubstitutionPhase, + env, + LogicalRewriter::getSubstitutionSet(), + rootGroupId, + true /*runStandalone*/, + logicalRewriter, + input)) { + return false; + } + + if (!runMemoLogicalRewrite(OptPhase::MemoExplorationPhase, + env, + LogicalRewriter::getExplorationSet(), + rootGroupId, + !hasPhase(OptPhase::MemoImplementationPhase), + logicalRewriter, + input)) { + return false; + } + + if (!runMemoPhysicalRewrite( + OptPhase::MemoImplementationPhase, env, rootGroupId, logicalRewriter, input)) { + return false; + } + + return true; +} + +bool OptPhaseManager::optimize(ABT& input) { + VariableEnvironment env = VariableEnvironment::build(input); + if (env.hasFreeVariables()) { + return false; + } + + if (!runStructuralPhases<OptPhase::ConstEvalPre, OptPhase::PathFuse, ConstEval, PathFusion>( + ConstEval{env, true /*disableSargableInlining*/}, PathFusion{env}, env, input)) { + return false; + } + + if (!runMemoRewritePhases(env, input)) { + return false; + } + + if (!runStructuralPhase<OptPhase::PathLower, PathLowering>( + PathLowering{_prefixId, env}, env, input)) { + return false; + } + + ProjectionNameSet erasedProjNames; + if (!runStructuralPhase<OptPhase::ConstEvalPost, ConstEval>( + ConstEval{env, false /*disableSargableInlining*/, &erasedProjNames}, env, input)) { + return false; + } + if (!erasedProjNames.empty()) { + // If we have erased some eval nodes, make sure to delete the corresponding projection names + // from the node property map. + for (auto& [nodePtr, props] : _nodeToGroupPropsMap) { + if (properties::hasProperty<properties::ProjectionRequirement>(props._physicalProps)) { + auto& requiredProjNames = + properties::getProperty<properties::ProjectionRequirement>(props._physicalProps) + .getProjections(); + for (const ProjectionName& projName : erasedProjNames) { + requiredProjNames.erase(projName); + } + } + } + } + + env.rebuild(input); + if (env.hasFreeVariables()) { + return false; + } + + return true; +} + +bool OptPhaseManager::hasPhase(const OptPhase phase) const { + return _phaseSet.find(phase) != _phaseSet.cend(); +} + +const OptPhaseManager::PhaseSet& OptPhaseManager::getAllRewritesSet() { + return _allRewrites; +} + +MemoPhysicalNodeId OptPhaseManager::getPhysicalNodeId() const { + return _physicalNodeId; +} + +const QueryHints& OptPhaseManager::getHints() const { + return _hints; +} + +QueryHints& OptPhaseManager::getHints() { + return _hints; +} + +const Memo& OptPhaseManager::getMemo() const { + return _memo; +} + +const Metadata& OptPhaseManager::getMetadata() const { + return _metadata; +} + +PrefixId& OptPhaseManager::getPrefixId() const { + return _prefixId; +} + +const NodeToGroupPropsMap& OptPhaseManager::getNodeToGroupPropsMap() const { + return _nodeToGroupPropsMap; +} + +NodeToGroupPropsMap& OptPhaseManager::getNodeToGroupPropsMap() { + return _nodeToGroupPropsMap; +} + +const opt::unordered_map<std::string, ProjectionName>& OptPhaseManager::getRIDProjections() const { + return _ridProjections; +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/opt_phase_manager.h b/src/mongo/db/query/optimizer/opt_phase_manager.h new file mode 100644 index 00000000000..c079c4f08aa --- /dev/null +++ b/src/mongo/db/query/optimizer/opt_phase_manager.h @@ -0,0 +1,181 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <unordered_set> + +#include "mongo/db/query/optimizer/cascades/interfaces.h" +#include "mongo/db/query/optimizer/cascades/logical_rewriter.h" +#include "mongo/db/query/optimizer/cascades/physical_rewriter.h" + +namespace mongo::optimizer { + +using namespace cascades; + +/** + * This class wraps together different optimization phases. + * First the transport rewrites are applied such as constant folding and redundant expression + * elimination. Second the logical and physical reordering rewrites are applied using the memo. + * Third the final transport rewritesd are applied. + */ +class OptPhaseManager { +public: + enum class OptPhase { + // ConstEval performs the following rewrites: constant folding, inlining, and dead code + // elimination. + ConstEvalPre, + PathFuse, + + // Memo phases below perform Cascades-style optimization. + // Reorder and transform nodes. Convert Filter and Eval nodes to SargableNodes, and possibly + // merge them. + MemoSubstitutionPhase, + // Performs Local-global and rewrites to enable index intersection. + // If there is an implementation phase, it runs integrated with the top-down optimization. + // If there is no implementation phase, it runs standalone. + MemoExplorationPhase, + // Implementation and enforcement rules. + MemoImplementationPhase, + + PathLower, + ConstEvalPost + }; + + using PhaseSet = opt::unordered_set<OptPhase>; + + OptPhaseManager(PhaseSet phaseSet, PrefixId& prefixId, Metadata metadata, DebugInfo debugInfo); + OptPhaseManager(PhaseSet phaseSet, + PrefixId& prefixId, + bool requireRID, + Metadata metadata, + std::unique_ptr<CEInterface> ceDerivation, + std::unique_ptr<CostingInterface> costDerivation, + DebugInfo debugInfo); + + /** + * Optimization modifies the input argument. + * Return result is true for successful optimization and false for failure. + */ + bool optimize(ABT& input); + + static const PhaseSet& getAllRewritesSet(); + + MemoPhysicalNodeId getPhysicalNodeId() const; + + const QueryHints& getHints() const; + QueryHints& getHints(); + + const Memo& getMemo() const; + + const Metadata& getMetadata() const; + + PrefixId& getPrefixId() const; + + const NodeToGroupPropsMap& getNodeToGroupPropsMap() const; + NodeToGroupPropsMap& getNodeToGroupPropsMap(); + + const opt::unordered_map<std::string, ProjectionName>& getRIDProjections() const; + +private: + bool hasPhase(OptPhase phase) const; + + template <OptPhase phase, class C> + bool runStructuralPhase(C instance, VariableEnvironment& env, ABT& input); + + /** + * Run two structural phases until mutual fixpoint. + * We assume we can construct from the types by initializing with env. + */ + template <const OptPhase phase1, const OptPhase phase2, class C1, class C2> + bool runStructuralPhases(C1 instance1, C2 instance2, VariableEnvironment& env, ABT& input); + + bool runMemoLogicalRewrite(OptPhase phase, + VariableEnvironment& env, + const LogicalRewriter::RewriteSet& rewriteSet, + GroupIdType& rootGroupId, + bool runStandalone, + std::unique_ptr<LogicalRewriter>& logicalRewriter, + ABT& input); + + bool runMemoPhysicalRewrite(OptPhase phase, + VariableEnvironment& env, + GroupIdType rootGroupId, + std::unique_ptr<LogicalRewriter>& logicalRewriter, + ABT& input); + + bool runMemoRewritePhases(VariableEnvironment& env, ABT& input); + + + static PhaseSet _allRewrites; + + const PhaseSet _phaseSet; + + const DebugInfo _debugInfo; + + QueryHints _hints; + + Metadata _metadata; + + /** + * Final state of the memo after physical rewrites are complete. + */ + Memo _memo; + + /** + * Cost derivation function. + */ + std::unique_ptr<CostingInterface> _costDerivation; + + /** + * Root physical node if we have performed physical rewrites. + */ + MemoPhysicalNodeId _physicalNodeId; + + /** + * Map from node to logical and physical properties. + */ + NodeToGroupPropsMap _nodeToGroupPropsMap; + + /** + * Used to optimize update and delete statements. If set will include indexing requirement with + * seed physical properties. + */ + const bool _requireRID; + + /** + * RID projection names we have generated for each scanDef. Used for physical rewriting. + */ + opt::unordered_map<std::string, ProjectionName> _ridProjections; + + // We don't own this. + PrefixId& _prefixId; +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/optimizer_test.cpp b/src/mongo/db/query/optimizer/optimizer_test.cpp new file mode 100644 index 00000000000..856ad5e3d07 --- /dev/null +++ b/src/mongo/db/query/optimizer/optimizer_test.cpp @@ -0,0 +1,639 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/explain.h" +#include "mongo/db/query/optimizer/node.h" +#include "mongo/db/query/optimizer/reference_tracker.h" +#include "mongo/db/query/optimizer/rewrites/const_eval.h" +#include "mongo/db/query/optimizer/utils/unit_test_utils.h" +#include "mongo/db/query/optimizer/utils/utils.h" +#include "mongo/unittest/unittest.h" + +namespace mongo::optimizer { +namespace { + +TEST(Optimizer, ConstEval) { + // 1 + 2 + auto tree = make<BinaryOp>(Operations::Add, Constant::int64(1), Constant::int64(2)); + + // Run the evaluator. + auto env = VariableEnvironment::build(tree); + ConstEval evaluator{env}; + evaluator.optimize(tree); + + // The result must be Constant. + auto result = tree.cast<Constant>(); + ASSERT(result != nullptr); + + // And the value must be 3 (i.e. 1+2). + ASSERT_EQ(result->getValueInt64(), 3); + + ASSERT_NE(ABT::tagOf<Constant>(), ABT::tagOf<BinaryOp>()); + ASSERT_EQ(tree.tagOf(), ABT::tagOf<Constant>()); +} + +TEST(Optimizer, ConstEvalCompose) { + // (1 + 2) + 3 + auto tree = + make<BinaryOp>(Operations::Add, + make<BinaryOp>(Operations::Add, Constant::int64(1), Constant::int64(2)), + Constant::int64(3)); + + // Run the evaluator. + auto env = VariableEnvironment::build(tree); + ConstEval evaluator{env}; + evaluator.optimize(tree); + + // The result must be Constant. + auto result = tree.cast<Constant>(); + ASSERT(result != nullptr); + + // And the value must be 6 (i.e. 1+2+3). + ASSERT_EQ(result->getValueInt64(), 6); +} + +TEST(Optimizer, Tracker1) { + ABT scanNode = make<ScanNode>("ptest", "test"); + ABT filterNode = make<FilterNode>( + make<EvalFilter>(make<PathConstant>(make<UnaryOp>(Operations::Neg, Constant::int64(1))), + make<Variable>("ptest")), + std::move(scanNode)); + ABT evalNode = make<EvaluationNode>( + "P1", + make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("ptest")), + std::move(filterNode)); + + ABT rootNode = + make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"P1", "ptest"}}, + std::move(evalNode)); + + auto env = VariableEnvironment::build(rootNode); + ASSERT(!env.hasFreeVariables()); +} + +TEST(Optimizer, Tracker2) { + ABT expr = make<Let>("x", + Constant::int64(4), + make<BinaryOp>(Operations::Add, make<Variable>("x"), make<Variable>("x"))); + + auto env = VariableEnvironment::build(expr); + ConstEval evaluator{env}; + evaluator.optimize(expr); + + // The result must be Constant. + auto result = expr.cast<Constant>(); + ASSERT(result != nullptr); + + // And the value must be 8 (i.e. x+x = 4+4 = 8). + ASSERT_EQ(result->getValueInt64(), 8); +} + +TEST(Optimizer, Tracker3) { + ABT scanNode = make<ScanNode>("ptest", "test"); + ABT filterNode = make<FilterNode>(make<Variable>("free"), std::move(scanNode)); + ABT evalNode1 = make<EvaluationNode>("free", Constant::int64(5), std::move(filterNode)); + + auto env = VariableEnvironment::build(evalNode1); + // "free" must still be a free variable. + ASSERT(env.hasFreeVariables()); + ASSERT_EQ(env.freeOccurences("free"), 1); + + // Projecting "unrelated" must not resolve "free". + ABT evalNode2 = make<EvaluationNode>("unrelated", Constant::int64(5), std::move(evalNode1)); + + env.rebuild(evalNode2); + ASSERT(env.hasFreeVariables()); + ASSERT_EQ(env.freeOccurences("free"), 1); + + // Another expression referencing "free" will resolve. But the original "free" reference is + // unaffected (i.e. it is still a free variable). + ABT filterNode2 = make<FilterNode>(make<Variable>("free"), std::move(evalNode2)); + + env.rebuild(filterNode2); + ASSERT(env.hasFreeVariables()); + ASSERT_EQ(env.freeOccurences("free"), 1); +} + +TEST(Optimizer, Tracker4) { + ABT scanNode = make<ScanNode>("ptest", "test"); + auto scanNodeRef = scanNode.ref(); + ABT evalNode = make<EvaluationNode>("unrelated", Constant::int64(5), std::move(scanNode)); + ABT filterNode = make<FilterNode>(make<Variable>("ptest"), std::move(evalNode)); + + auto env = VariableEnvironment::build(filterNode); + ASSERT(!env.hasFreeVariables()); + + // Get all variables from the expression + auto vars = VariableEnvironment::getVariables(filterNode.cast<FilterNode>()->getFilter()); + ASSERT(vars._variables.size() == 1); + // Get all definitions from the scan and below (even though there is nothing below the scan). + auto defs = env.getDefinitions(scanNodeRef); + // Make sure that variables are defined by the scan (and not by Eval). + for (auto v : vars._variables) { + auto it = defs.find(v->name()); + ASSERT(it != defs.end()); + ASSERT(it->second.definedBy == env.getDefinition(v).definedBy); + } +} + +TEST(Optimizer, RefExplain) { + ABT scanNode = make<ScanNode>("ptest", "test"); + ASSERT_EXPLAIN( + "Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + scanNode); + + // Now repeat for the reference type. + auto ref = scanNode.ref(); + ASSERT_EXPLAIN( + "Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + ref); + + ASSERT_EQ(scanNode.tagOf(), ref.tagOf()); +} + +TEST(Optimizer, CoScan) { + ABT coScanNode = make<CoScanNode>(); + ABT limitNode = + make<LimitSkipNode>(properties::LimitSkipRequirement(1, 0), std::move(coScanNode)); + + VariableEnvironment venv = VariableEnvironment::build(limitNode); + ASSERT_TRUE(!venv.hasFreeVariables()); + + ASSERT_EXPLAIN( + "LimitSkip []\n" + " limitSkip:\n" + " limit: 1\n" + " skip: 0\n" + " CoScan []\n", + limitNode); +} + +TEST(Optimizer, Basic) { + ABT scanNode = make<ScanNode>("ptest", "test"); + ASSERT_EXPLAIN( + "Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + scanNode); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>(make<PathConstant>(make<UnaryOp>(Operations::Neg, Constant::int64(1))), + make<Variable>("ptest")), + std::move(scanNode)); + ABT evalNode = make<EvaluationNode>( + "P1", + make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("ptest")), + std::move(filterNode)); + + ABT rootNode = + make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"P1", "ptest"}}, + std::move(evalNode)); + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " P1\n" + " ptest\n" + " RefBlock: \n" + " Variable [P1]\n" + " Variable [ptest]\n" + " Evaluation []\n" + " BindBlock:\n" + " [P1]\n" + " EvalPath []\n" + " PathConstant []\n" + " Const [2]\n" + " Variable [ptest]\n" + " Filter []\n" + " EvalFilter []\n" + " PathConstant []\n" + " UnaryOp [Neg]\n" + " Const [1]\n" + " Variable [ptest]\n" + " Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + rootNode); + + + ABT clonedNode = rootNode; + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " P1\n" + " ptest\n" + " RefBlock: \n" + " Variable [P1]\n" + " Variable [ptest]\n" + " Evaluation []\n" + " BindBlock:\n" + " [P1]\n" + " EvalPath []\n" + " PathConstant []\n" + " Const [2]\n" + " Variable [ptest]\n" + " Filter []\n" + " EvalFilter []\n" + " PathConstant []\n" + " UnaryOp [Neg]\n" + " Const [1]\n" + " Variable [ptest]\n" + " Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + clonedNode); + + auto env = VariableEnvironment::build(rootNode); + ProjectionNameSet set = env.topLevelProjections(); + ProjectionNameSet expSet = {"P1", "ptest"}; + ASSERT(expSet == set); + ASSERT(!env.hasFreeVariables()); +} + +TEST(Optimizer, GroupBy) { + ABT scanNode = make<ScanNode>("ptest", "test"); + ABT evalNode1 = make<EvaluationNode>("p1", Constant::int64(1), std::move(scanNode)); + ABT evalNode2 = make<EvaluationNode>("p2", Constant::int64(2), std::move(evalNode1)); + ABT evalNode3 = make<EvaluationNode>("p3", Constant::int64(3), std::move(evalNode2)); + + { + auto env = VariableEnvironment::build(evalNode3); + ProjectionNameSet projSet = env.topLevelProjections(); + ProjectionNameSet expSet = {"p1", "p2", "p3", "ptest"}; + ASSERT(expSet == projSet); + ASSERT(!env.hasFreeVariables()); + } + + ABT agg1 = Constant::int64(10); + ABT agg2 = Constant::int64(11); + ABT groupByNode = make<GroupByNode>(ProjectionNameVector{"p1", "p2"}, + ProjectionNameVector{"a1", "a2"}, + makeSeq(std::move(agg1), std::move(agg2)), + std::move(evalNode3)); + + ABT rootNode = make<RootNode>( + properties::ProjectionRequirement{ProjectionNameVector{"p1", "p2", "a1", "a2"}}, + std::move(groupByNode)); + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " p1\n" + " p2\n" + " a1\n" + " a2\n" + " RefBlock: \n" + " Variable [a1]\n" + " Variable [a2]\n" + " Variable [p1]\n" + " Variable [p2]\n" + " GroupBy []\n" + " groupings: \n" + " RefBlock: \n" + " Variable [p1]\n" + " Variable [p2]\n" + " aggregations: \n" + " [a1]\n" + " Const [10]\n" + " [a2]\n" + " Const [11]\n" + " Evaluation []\n" + " BindBlock:\n" + " [p3]\n" + " Const [3]\n" + " Evaluation []\n" + " BindBlock:\n" + " [p2]\n" + " Const [2]\n" + " Evaluation []\n" + " BindBlock:\n" + " [p1]\n" + " Const [1]\n" + " Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + rootNode); + + { + auto env = VariableEnvironment::build(rootNode); + ProjectionNameSet projSet = env.topLevelProjections(); + ProjectionNameSet expSet = {"p1", "p2", "a1", "a2"}; + ASSERT(expSet == projSet); + ASSERT(!env.hasFreeVariables()); + } +} + +TEST(Optimizer, Union) { + ABT scanNode1 = make<ScanNode>("ptest", "test"); + ABT projNode1 = make<EvaluationNode>("B", Constant::int64(3), std::move(scanNode1)); + ABT scanNode2 = make<ScanNode>("ptest", "test"); + ABT projNode2 = make<EvaluationNode>("B", Constant::int64(4), std::move(scanNode2)); + ABT scanNode3 = make<ScanNode>("ptest1", "test"); + ABT evalNode = make<EvaluationNode>( + "ptest", + make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("ptest1")), + std::move(scanNode3)); + ABT projNode3 = make<EvaluationNode>("B", Constant::int64(5), std::move(evalNode)); + + + ABT unionNode = make<UnionNode>(ProjectionNameVector{"ptest", "B"}, + makeSeq(projNode1, projNode2, projNode3)); + + ABT rootNode = + make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"ptest", "B"}}, + std::move(unionNode)); + + { + auto env = VariableEnvironment::build(rootNode); + ProjectionNameSet projSet = env.topLevelProjections(); + ProjectionNameSet expSet = {"ptest", "B"}; + ASSERT(expSet == projSet); + ASSERT(!env.hasFreeVariables()); + } + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " ptest\n" + " B\n" + " RefBlock: \n" + " Variable [B]\n" + " Variable [ptest]\n" + " Union []\n" + " BindBlock:\n" + " [B]\n" + " Source []\n" + " [ptest]\n" + " Source []\n" + " Evaluation []\n" + " BindBlock:\n" + " [B]\n" + " Const [3]\n" + " Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n" + " Evaluation []\n" + " BindBlock:\n" + " [B]\n" + " Const [4]\n" + " Scan [test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n" + " Evaluation []\n" + " BindBlock:\n" + " [B]\n" + " Const [5]\n" + " Evaluation []\n" + " BindBlock:\n" + " [ptest]\n" + " EvalPath []\n" + " PathConstant []\n" + " Const [2]\n" + " Variable [ptest1]\n" + " Scan [test]\n" + " BindBlock:\n" + " [ptest1]\n" + " Source []\n", + rootNode); +} + +TEST(Optimizer, Unwind) { + ABT scanNode = make<ScanNode>("p1", "test"); + ABT evalNode = make<EvaluationNode>( + "p2", + make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("p1")), + std::move(scanNode)); + ABT unwindNode = make<UnwindNode>("p2", "p2pid", true /*retainNonArrays*/, std::move(evalNode)); + + // Make a copy of unwindNode as it will be used later again in the wind test. + ABT rootNode = make<RootNode>( + properties::ProjectionRequirement{ProjectionNameVector{"p1", "p2", "p2pid"}}, unwindNode); + + { + auto env = VariableEnvironment::build(rootNode); + ProjectionNameSet projSet = env.topLevelProjections(); + ProjectionNameSet expSet = {"p1", "p2", "p2pid"}; + ASSERT(expSet == projSet); + ASSERT(!env.hasFreeVariables()); + } + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " p1\n" + " p2\n" + " p2pid\n" + " RefBlock: \n" + " Variable [p1]\n" + " Variable [p2]\n" + " Variable [p2pid]\n" + " Unwind [retainNonArrays]\n" + " BindBlock:\n" + " [p2]\n" + " Source []\n" + " [p2pid]\n" + " Source []\n" + " Evaluation []\n" + " BindBlock:\n" + " [p2]\n" + " EvalPath []\n" + " PathConstant []\n" + " Const [2]\n" + " Variable [p1]\n" + " Scan [test]\n" + " BindBlock:\n" + " [p1]\n" + " Source []\n", + rootNode); +} + +TEST(Optimizer, Collation) { + ABT scanNode = make<ScanNode>("a", "test"); + ABT evalNode = make<EvaluationNode>( + "b", + make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("a")), + std::move(scanNode)); + + ABT collationNode = + make<CollationNode>(properties::CollationRequirement( + {{"a", CollationOp::Ascending}, {"b", CollationOp::Clustered}}), + std::move(evalNode)); + { + auto env = VariableEnvironment::build(collationNode); + ProjectionNameSet projSet = env.topLevelProjections(); + ProjectionNameSet expSet = {"a", "b"}; + ASSERT(expSet == projSet); + ASSERT(!env.hasFreeVariables()); + } + + ASSERT_EXPLAIN( + "Collation []\n" + " collation: \n" + " a: Ascending\n" + " b: Clustered\n" + " RefBlock: \n" + " Variable [a]\n" + " Variable [b]\n" + " Evaluation []\n" + " BindBlock:\n" + " [b]\n" + " EvalPath []\n" + " PathConstant []\n" + " Const [2]\n" + " Variable [a]\n" + " Scan [test]\n" + " BindBlock:\n" + " [a]\n" + " Source []\n", + collationNode); +} + +TEST(Optimizer, LimitSkip) { + ABT scanNode = make<ScanNode>("a", "test"); + ABT evalNode = make<EvaluationNode>( + "b", + make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("a")), + std::move(scanNode)); + + ABT limitSkipNode = + make<LimitSkipNode>(properties::LimitSkipRequirement(10, 20), std::move(evalNode)); + { + auto env = VariableEnvironment::build(limitSkipNode); + ProjectionNameSet projSet = env.topLevelProjections(); + ProjectionNameSet expSet = {"a", "b"}; + ASSERT(expSet == projSet); + ASSERT(!env.hasFreeVariables()); + } + + ASSERT_EXPLAIN( + "LimitSkip []\n" + " limitSkip:\n" + " limit: 10\n" + " skip: 20\n" + " Evaluation []\n" + " BindBlock:\n" + " [b]\n" + " EvalPath []\n" + " PathConstant []\n" + " Const [2]\n" + " Variable [a]\n" + " Scan [test]\n" + " BindBlock:\n" + " [a]\n" + " Source []\n", + limitSkipNode); +} + +TEST(Optimizer, Distribution) { + ABT scanNode = make<ScanNode>("a", "test"); + ABT evalNode = make<EvaluationNode>( + "b", + make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("a")), + std::move(scanNode)); + + ABT exchangeNode = make<ExchangeNode>( + properties::DistributionRequirement({DistributionType::HashPartitioning, {"b"}}), + std::move(evalNode)); + + ASSERT_EXPLAIN( + "Exchange []\n" + " distribution: \n" + " type: HashPartitioning\n" + " projections: \n" + " b\n" + " RefBlock: \n" + " Variable [b]\n" + " Evaluation []\n" + " BindBlock:\n" + " [b]\n" + " EvalPath []\n" + " PathConstant []\n" + " Const [2]\n" + " Variable [a]\n" + " Scan [test]\n" + " BindBlock:\n" + " [a]\n" + " Source []\n", + exchangeNode); +} + +TEST(Properties, Basic) { + using namespace properties; + + CollationRequirement collation1( + {{"p1", CollationOp::Ascending}, {"p2", CollationOp::Descending}}); + CollationRequirement collation2( + {{"p1", CollationOp::Ascending}, {"p2", CollationOp::Clustered}}); + ASSERT_TRUE(collationsCompatible(collation1.getCollationSpec(), collation2.getCollationSpec())); + ASSERT_FALSE( + collationsCompatible(collation2.getCollationSpec(), collation1.getCollationSpec())); + + PhysProps props; + ASSERT_FALSE(hasProperty<CollationRequirement>(props)); + ASSERT_TRUE(setProperty(props, collation1)); + ASSERT_TRUE(hasProperty<CollationRequirement>(props)); + ASSERT_FALSE(setProperty(props, collation2)); + ASSERT_TRUE(collation1 == getProperty<CollationRequirement>(props)); + + LimitSkipRequirement ls(10, 20); + ASSERT_FALSE(hasProperty<LimitSkipRequirement>(props)); + ASSERT_TRUE(setProperty(props, ls)); + ASSERT_TRUE(hasProperty<LimitSkipRequirement>(props)); + ASSERT_TRUE(ls == getProperty<LimitSkipRequirement>(props)); + + LimitSkipRequirement ls1(-1, 10); + LimitSkipRequirement ls2(5, 0); + { + LimitSkipRequirement ls3 = ls2; + combineLimitSkipProperties(ls3, ls1); + ASSERT_EQ(5, ls3.getLimit()); + ASSERT_EQ(10, ls3.getSkip()); + } + { + LimitSkipRequirement ls3 = ls1; + combineLimitSkipProperties(ls3, ls2); + ASSERT_EQ(0, ls3.getLimit()); + ASSERT_EQ(0, ls3.getSkip()); + } +} + +} // namespace +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp b/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp new file mode 100644 index 00000000000..80ee0f07ebb --- /dev/null +++ b/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp @@ -0,0 +1,4659 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/cascades/ce_heuristic.h" +#include "mongo/db/query/optimizer/cascades/ce_hinted.h" +#include "mongo/db/query/optimizer/cascades/cost_derivation.h" +#include "mongo/db/query/optimizer/explain.h" +#include "mongo/db/query/optimizer/metadata.h" +#include "mongo/db/query/optimizer/node.h" +#include "mongo/db/query/optimizer/opt_phase_manager.h" +#include "mongo/db/query/optimizer/utils/unit_test_utils.h" +#include "mongo/unittest/unittest.h" + +namespace mongo::optimizer { +namespace { + +TEST(PhysRewriter, PhysicalRewriterBasic) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("p1", "test"); + + ABT projectionNode1 = make<EvaluationNode>( + "p2", make<EvalPath>(make<PathIdentity>(), make<Variable>("p1")), std::move(scanNode)); + + ABT filter1Node = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("p1")), + std::move(projectionNode1)); + + ABT filter2Node = make<FilterNode>( + make<EvalFilter>(make<PathGet>("a", make<PathCompare>(Operations::Eq, Constant::int64(1))), + make<Variable>("p2")), + std::move(filter1Node)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"p2"}}, std::move(filter2Node)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"test", {{}, {}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + { + auto env = VariableEnvironment::build(optimized); + ProjectionNameSet expSet = {"p1", "p2"}; + ASSERT_TRUE(expSet == env.topLevelProjections()); + } + ASSERT_EQ(5, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | p2\n" + "| RefBlock: \n" + "| Variable [p2]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [p2]\n" + "| PathGet [a]\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [p2]\n" + "| EvalPath []\n" + "| | Variable [p1]\n" + "| PathIdentity []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [p1]\n" + "| PathIdentity []\n" + "PhysicalScan [{'<root>': p1}, test]\n" + " BindBlock:\n" + " [p1]\n" + " Source []\n", + optimized); + + // Plan output with properties. + ASSERT_EXPLAIN_PROPS_V2( + "Properties [cost: 1.02, localCost: 0, adjustedCE: 10]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 10\n" + "| | projections: \n" + "| | p1\n" + "| | p2\n" + "| | indexingAvailability: \n" + "| | [groupId: 0, scanProjection: p1, scanDefName: test]\n" + "| | collectionAvailability: \n" + "| | test\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: Centralized\n" + "| Physical:\n" + "| distribution: \n" + "| type: Centralized\n" + "| indexingRequirement: \n" + "| Complete, dedupRID\n" + "Root []\n" + "| | projections: \n" + "| | p2\n" + "| RefBlock: \n" + "| Variable [p2]\n" + "Properties [cost: 1.02, localCost: 0.020001, adjustedCE: 10]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 10\n" + "| | requirementCEs: \n" + "| | refProjection: p2, path: 'PathGet [a] PathIdentity []', ce: 10\n" + "| | projections: \n" + "| | p1\n" + "| | p2\n" + "| | indexingAvailability: \n" + "| | [groupId: 0, scanProjection: p1, scanDefName: test]\n" + "| | collectionAvailability: \n" + "| | test\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: Centralized\n" + "| Physical:\n" + "| projections: \n" + "| p2\n" + "| distribution: \n" + "| type: Centralized\n" + "| indexingRequirement: \n" + "| Complete, dedupRID\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [p2]\n" + "| PathGet [a]\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Properties [cost: 1, localCost: 0.200001, adjustedCE: 100]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 100\n" + "| | projections: \n" + "| | p1\n" + "| | p2\n" + "| | indexingAvailability: \n" + "| | [groupId: 0, scanProjection: p1, scanDefName: test]\n" + "| | collectionAvailability: \n" + "| | test\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: Centralized\n" + "| Physical:\n" + "| projections: \n" + "| p2\n" + "| distribution: \n" + "| type: Centralized, disableExchanges\n" + "| indexingRequirement: \n" + "| Complete, dedupRID\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [p2]\n" + "| EvalPath []\n" + "| | Variable [p1]\n" + "| PathIdentity []\n" + "Properties [cost: 0.800002, localCost: 0.200001, adjustedCE: 100]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 100\n" + "| | projections: \n" + "| | p1\n" + "| | indexingAvailability: \n" + "| | [groupId: 0, scanProjection: p1, scanDefName: test]\n" + "| | collectionAvailability: \n" + "| | test\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: Centralized\n" + "| Physical:\n" + "| projections: \n" + "| p1\n" + "| distribution: \n" + "| type: Centralized, disableExchanges\n" + "| indexingRequirement: \n" + "| Complete, dedupRID\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [p1]\n" + "| PathIdentity []\n" + "Properties [cost: 0.600001, localCost: 0.600001, adjustedCE: 1000]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 1000\n" + "| | projections: \n" + "| | p1\n" + "| | indexingAvailability: \n" + "| | [groupId: 0, scanProjection: p1, scanDefName: test, possiblyEqPredsOnly]\n" + "| | collectionAvailability: \n" + "| | test\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: Centralized\n" + "| Physical:\n" + "| projections: \n" + "| p1\n" + "| distribution: \n" + "| type: Centralized, disableExchanges\n" + "| indexingRequirement: \n" + "| Complete, dedupRID\n" + "PhysicalScan [{'<root>': p1}, test]\n" + " BindBlock:\n" + " [p1]\n" + " Source []\n", + phaseManager); +} + +TEST(PhysRewriter, GroupBy) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("ptest", "test"); + + ABT projectionANode = make<EvaluationNode>( + "a", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode)); + ABT projectionBNode = + make<EvaluationNode>("b", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), + std::move(projectionANode)); + + ABT groupByNode = make<GroupByNode>(ProjectionNameVector{"a"}, + ProjectionNameVector{"c"}, + makeSeq(make<Variable>("b")), + std::move(projectionBNode)); + + ABT filterCNode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("c")), + std::move(groupByNode)); + + ABT filterANode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("a")), + std::move(filterCNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"c"}}, std::move(filterANode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"test", {{}, {}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(7, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | c\n" + "| RefBlock: \n" + "| Variable [c]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [a]\n" + "| PathIdentity []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [c]\n" + "| PathIdentity []\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [a]\n" + "| aggregations: \n" + "| [c]\n" + "| Variable [b]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [b]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [a]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "PhysicalScan [{'<root>': ptest}, test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, GroupBy1) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("ptest", "test"); + + ABT projectionANode = make<EvaluationNode>("pa", Constant::null(), std::move(scanNode)); + ABT projectionA1Node = + make<EvaluationNode>("pa1", Constant::null(), std::move(projectionANode)); + + ABT groupByNode = make<GroupByNode>(ProjectionNameVector{}, + ProjectionNameVector{"pb", "pb1"}, + makeSeq(make<Variable>("pa"), make<Variable>("pa1")), + std::move(projectionA1Node)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pb"}}, std::move(groupByNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"test", {{}, {}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(5, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Projection "pb1" is unused and we do not generate an aggregation expression for it. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pb\n" + "| RefBlock: \n" + "| Variable [pb]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| aggregations: \n" + "| [pb]\n" + "| Variable [pa]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [pa]\n" + "| Const [null]\n" + "PhysicalScan [{}, test]\n" + " BindBlock:\n", + optimized); +} + +TEST(PhysRewriter, Unwind) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("ptest", "test"); + + ABT projectionANode = make<EvaluationNode>( + "a", make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), std::move(scanNode)); + ABT projectionBNode = + make<EvaluationNode>("b", + make<EvalPath>(make<PathIdentity>(), make<Variable>("ptest")), + std::move(projectionANode)); + + ABT unwindNode = + make<UnwindNode>("a", "a_pid", false /*retainNonArrays*/, std::move(projectionBNode)); + + // This filter should stay above the unwind. + ABT filterANode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("a")), + std::move(unwindNode)); + + // This filter should be pushed down below the unwind. + ABT filterBNode = make<FilterNode>(make<EvalFilter>(make<PathIdentity>(), make<Variable>("b")), + std::move(filterANode)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"a", "b"}}, + std::move(filterBNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"test", {{}, {}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(7, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | a\n" + "| | b\n" + "| RefBlock: \n" + "| Variable [a]\n" + "| Variable [b]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [b]\n" + "| PathIdentity []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [a]\n" + "| PathIdentity []\n" + "Unwind []\n" + "| BindBlock:\n" + "| [a]\n" + "| Source []\n" + "| [a_pid]\n" + "| Source []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [b]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [a]\n" + "| EvalPath []\n" + "| | Variable [ptest]\n" + "| PathIdentity []\n" + "PhysicalScan [{'<root>': ptest}, test]\n" + " BindBlock:\n" + " [ptest]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, DuplicateFilter) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterNode1 = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(0)))), + make<Variable>("root")), + std::move(scanNode)); + + ABT filterNode2 = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(0)))), + make<Variable>("root")), + std::move(filterNode1)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", {{}, {}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(2, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Only one copy of the filter. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_0]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [0]\n" + "PhysicalScan [{'<root>': root, 'a': evalTemp_0}, c1]\n" + " BindBlock:\n" + " [evalTemp_0]\n" + " Source []\n" + " [root]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, FilterCollation) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalNode = make<EvaluationNode>( + "pb", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))), + make<Variable>("root")), + std::move(evalNode)); + + ABT collationNode = make<CollationNode>(CollationRequirement({{"pb", CollationOp::Ascending}}), + std::move(filterNode)); + + ABT limitSkipNode = make<LimitSkipNode>(LimitSkipRequirement{10, 0}, std::move(collationNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pb"}}, std::move(limitSkipNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", {{}, {}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(9, 11, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Limit-skip is attached to the collation node by virtue of physical props. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pb\n" + "| RefBlock: \n" + "| Variable [pb]\n" + "Collation []\n" + "| | collation: \n" + "| | pb: Ascending\n" + "| RefBlock: \n" + "| Variable [pb]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_0]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "PhysicalScan [{'a': evalTemp_0, 'b': pb}, c1]\n" + " BindBlock:\n" + " [evalTemp_0]\n" + " Source []\n" + " [pb]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, EvalCollation) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalNode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}), + std::move(evalNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(collationNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", {{}, {}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pa\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "Collation []\n" + "| | collation: \n" + "| | pa: Ascending\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "PhysicalScan [{'a': pa}, c1]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, FilterEvalCollation) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(10)))), + make<Variable>("root")), + std::move(scanNode)); + + ABT evalNode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(filterNode)); + + ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}), + std::move(evalNode)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, + std::move(collationNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", {{}, {}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Collation []\n" + "| | collation: \n" + "| | pa: Ascending\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [pa]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [10]\n" + "PhysicalScan [{'<root>': root, 'a': pa}, c1]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n" + " [root]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, FilterIndexing) { + using namespace properties; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))), + make<Variable>("root")), + std::move(scanNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); + + { + PrefixId prefixId; + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + // Demonstrate sargable node is rewritten from filter node. + // Note: SargableNodes cannot be lowered and by default are not created unless we have + // indexes. + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "RIDIntersect [root, hasLeftIntervals]\n" + "| Scan [c1]\n" + "| BindBlock:\n" + "| [root]\n" + "| Source []\n" + "Sargable [Index]\n" + "| | | | requirementsMap: \n" + "| | | | refProjection: root, path: 'PathGet [a] PathTraverse [] " + "PathIdentity []', intervals: {{{[Const [1], Const [1]]}}}\n" + "| | | candidateIndexes: \n" + "| | | candidateId: 1, index1, {}, {}, {{{[Const [1], Const [1]]}}}\n" + "| | BindBlock:\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Scan [c1]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + optimized); + } + + { + PrefixId prefixId; + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Test sargable filter is satisfied with an index scan. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: " + "{[Const [1], Const [1]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); + } + + { + PrefixId prefixId; + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", {{}, {}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(2, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Test we can optimize sargable filter nodes even without an index. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_0]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "PhysicalScan [{'<root>': root, 'a': evalTemp_0}, c1]\n" + " BindBlock:\n" + " [evalTemp_0]\n" + " Source []\n" + " [root]\n" + " Source []\n", + optimized); + } +} + +TEST(PhysRewriter, FilterIndexing1) { + using namespace properties; + + ABT scanNode = make<ScanNode>("root", "c1"); + + // This node will not be converted to Sargable. + ABT evalNode = make<EvaluationNode>( + "p1", + make<EvalPath>( + make<PathGet>( + "b", + make<PathLambda>(make<LambdaAbstraction>( + "t", + make<BinaryOp>(Operations::Add, make<Variable>("t"), Constant::int64(1))))), + make<Variable>("root")), + std::move(scanNode)); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))), + make<Variable>("p1")), + std::move(evalNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"p1"}}, std::move(filterNode)); + + PrefixId prefixId; + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(7, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | p1\n" + "| RefBlock: \n" + "| Variable [p1]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [p1]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [p1]\n" + "| EvalPath []\n" + "| | Variable [root]\n" + "| PathGet [b]\n" + "| PathLambda []\n" + "| LambdaAbstraction [t]\n" + "| BinaryOp [Add]\n" + "| | Const [1]\n" + "| Variable [t]\n" + "PhysicalScan [{'<root>': root}, c1]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, FilterIndexing2) { + using namespace properties; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>(make<PathGet>("a", + make<PathTraverse>(make<PathGet>( + "b", + make<PathTraverse>(make<PathCompare>( + Operations::Eq, Constant::int64(1)))))), + make<Variable>("root")), + std::move(scanNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); + + PrefixId prefixId; + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {{"index1", + {{{make<PathGet>("a", make<PathGet>("b", make<PathIdentity>())), + CollationOp::Ascending}}, + false /*isMultiKey*/}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const " + "[1], Const [1]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, FilterIndexing2NonSarg) { + using namespace properties; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalNode1 = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterNode1 = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1))), + make<Variable>("pa")), + std::move(evalNode1)); + + // Dependent eval node. + ABT evalNode2 = make<EvaluationNode>( + "pb", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("pa")), + std::move(filterNode1)); + + // Non-sargable filter. + ABT filterNode2 = make<FilterNode>( + make<EvalFilter>( + make<PathTraverse>(make<PathLambda>(make<LambdaAbstraction>( + "var", make<FunctionCall>("someFunction", makeSeq(make<Variable>("var")))))), + make<Variable>("pb")), + std::move(evalNode2)); + + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2)); + + PrefixId prefixId; + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(15, 20, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Demonstrate non-sargable evaluation and filter are moved under the NLJ+seek, + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [pb]\n" + "| PathTraverse []\n" + "| PathLambda []\n" + "| LambdaAbstraction [var]\n" + "| FunctionCall [someFunction]\n" + "| Variable [var]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [pb]\n" + "| EvalPath []\n" + "| | Variable [pa]\n" + "| PathGet [b]\n" + "| PathIdentity []\n" + "IndexScan [{'<indexKey> 0': pa, '<rid>': rid_0}, scanDefName: c1, indexDefName: index1, " + "interval: {[Const [1], Const [1]]}]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, FilterIndexing3) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalNode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1))), + make<Variable>("pa")), + std::move(evalNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("b"), CollationOp::Ascending}}, + false /*isMultiKey*/, + {DistributionType::Centralized}, + {}}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(5, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // We dont need a Seek if we dont have multi-key paths. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pa\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "IndexScan [{'<indexKey> 0': pa}, scanDefName: c1, indexDefName: index1, interval: {[Const " + "[1], Const [1]], (-inf, +inf)}]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, FilterIndexing3MultiKey) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalNode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1))), + make<Variable>("pa")), + std::move(evalNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {{"index1", + IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}, + {makeIndexPath("b"), CollationOp::Ascending}}, + true /*isMultiKey*/, + {DistributionType::Centralized}, + {}}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(7, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // We need a Seek to obtain value for "a". + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pa\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'a': pa}, c1]\n" + "| | BindBlock:\n" + "| | [pa]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "Unique []\n" + "| projections: \n" + "| rid_0\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const " + "[1], Const [1]], (-inf, +inf)}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, FilterIndexing4) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalNode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterANode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Lt, Constant::int64(1))), + make<Variable>("pa")), + std::move(evalNode)); + + ABT filterBNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "b", make<PathTraverse>(make<PathCompare>(Operations::Lt, Constant::int64(1)))), + make<Variable>("root")), + std::move(filterANode)); + + ABT filterCNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "c", make<PathTraverse>(make<PathCompare>(Operations::Lt, Constant::int64(1)))), + make<Variable>("root")), + std::move(filterBNode)); + + ABT filterDNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "d", make<PathTraverse>(make<PathCompare>(Operations::Lt, Constant::int64(1)))), + make<Variable>("root")), + std::move(filterCNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterDNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("b"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("c"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("d"), CollationOp::Ascending}}, + false /*isMultiKey*/, + {DistributionType::Centralized}, + {}}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + + // For now leave only GroupBy+Union RIDIntersect. + phaseManager.getHints()._disableHashJoinRIDIntersect = true; + + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(65, 80, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pa\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_8]\n" + "| PathCompare [Lt]\n" + "| Const [1]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_7]\n" + "| PathCompare [Lt]\n" + "| Const [1]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_6]\n" + "| PathCompare [Lt]\n" + "| Const [1]\n" + "IndexScan [{'<indexKey> 0': pa, '<indexKey> 1': evalTemp_6, '<indexKey> 2': evalTemp_7, " + "'<indexKey> 3': evalTemp_8}, scanDefName: c1, indexDefName: index1, interval: {(-inf, " + "Const [1]), (-inf, +inf), (-inf, +inf), (-inf, +inf)}]\n" + " BindBlock:\n" + " [evalTemp_6]\n" + " Source []\n" + " [evalTemp_7]\n" + " Source []\n" + " [evalTemp_8]\n" + " Source []\n" + " [pa]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, FilterIndexing5) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalANode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterANode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Gt, Constant::int64(0))), + make<Variable>("pa")), + std::move(evalANode)); + + ABT evalBNode = make<EvaluationNode>( + "pb", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")), + std::move(filterANode)); + + ABT filterBNode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Gt, Constant::int64(0))), + make<Variable>("pb")), + std::move(evalBNode)); + + ABT collationNode = make<CollationNode>(CollationRequirement({{"pb", CollationOp::Ascending}}), + std::move(filterBNode)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa", "pb"}}, + std::move(collationNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("b"), CollationOp::Ascending}}, + false /*isMultiKey*/, + {DistributionType::Centralized}, + {}}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(25, 55, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // We can cover both fields with the index, and need separate sort on "b". + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pa\n" + "| | pb\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "| Variable [pb]\n" + "Collation []\n" + "| | collation: \n" + "| | pb: Ascending\n" + "| RefBlock: \n" + "| Variable [pb]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [pb]\n" + "| PathCompare [Gt]\n" + "| Const [0]\n" + "IndexScan [{'<indexKey> 0': pa, '<indexKey> 1': pb}, scanDefName: c1, indexDefName: " + "index1, interval: {(Const [0], +inf), (-inf, +inf)}]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n" + " [pb]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, FilterIndexing6) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalANode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterANode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(0))), + make<Variable>("pa")), + std::move(evalANode)); + + ABT evalBNode = make<EvaluationNode>( + "pb", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")), + std::move(filterANode)); + + ABT filterBNode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Gt, Constant::int64(0))), + make<Variable>("pb")), + std::move(evalBNode)); + + ABT collationNode = make<CollationNode>(CollationRequirement({{"pb", CollationOp::Ascending}}), + std::move(filterBNode)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa", "pb"}}, + std::move(collationNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("b"), CollationOp::Ascending}}, + false /*isMultiKey*/, + {DistributionType::Centralized}, + {}}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(9, 15, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // We can cover both fields with the index, and do not need a separate sort on "b". + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pa\n" + "| | pb\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "| Variable [pb]\n" + "IndexScan [{'<indexKey> 0': pa, '<indexKey> 1': pb}, scanDefName: c1, indexDefName: " + "index1, interval: {[Const [0], Const [0]], (Const [0], +inf)}]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n" + " [pb]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, FilterIndexingStress) { + using namespace properties; + PrefixId prefixId; + + ABT result = make<ScanNode>("root", "c1"); + + static constexpr size_t kFilterCount = 15; + // A query with a large number of filters on different fields. + for (size_t index = 0; index < kFilterCount; index++) { + std::ostringstream os; + os << "field" << index; + + result = make<FilterNode>( + make<EvalFilter>(make<PathGet>(os.str(), + make<PathTraverse>(make<PathCompare>( + Operations::Eq, Constant::int64(0)))), + make<Variable>("root")), + std::move(result)); + } + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(result)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{makeNonMultikeyIndexPath("field0"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("field1"), CollationOp::Ascending}}, + false /*isMultiKey*/, + {DistributionType::Centralized}, + {}}}, + {"index2", + IndexDefinition{{{makeNonMultikeyIndexPath("field2"), CollationOp::Ascending}}, + false /*isMultiKey*/, + {DistributionType::Centralized}, + {}}}, + {"index3", + IndexDefinition{{{makeNonMultikeyIndexPath("field3"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("field4"), CollationOp::Ascending}}, + false /*isMultiKey*/, + {DistributionType::Centralized}, + {}}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + + // Without the changes to restrict SargableNode split to which this test is tied, we would + // be exploring 2^kFilterCount plans, one for each created group. + ASSERT_EQ(51, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [root]\n" + "| PathGet [field14]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [root]\n" + "| PathGet [field13]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [root]\n" + "| PathGet [field12]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [root]\n" + "| PathGet [field11]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [root]\n" + "| PathGet [field10]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [0]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| Filter []\n" + "| | EvalFilter []\n" + "| | | Variable [evalTemp_17]\n" + "| | PathTraverse []\n" + "| | PathCompare [Eq]\n" + "| | Const [0]\n" + "| Filter []\n" + "| | EvalFilter []\n" + "| | | Variable [evalTemp_16]\n" + "| | PathTraverse []\n" + "| | PathCompare [Eq]\n" + "| | Const [0]\n" + "| Filter []\n" + "| | EvalFilter []\n" + "| | | Variable [evalTemp_15]\n" + "| | PathTraverse []\n" + "| | PathCompare [Eq]\n" + "| | Const [0]\n" + "| Filter []\n" + "| | EvalFilter []\n" + "| | | Variable [evalTemp_14]\n" + "| | PathTraverse []\n" + "| | PathCompare [Eq]\n" + "| | Const [0]\n" + "| Filter []\n" + "| | EvalFilter []\n" + "| | | Variable [evalTemp_13]\n" + "| | PathTraverse []\n" + "| | PathCompare [Eq]\n" + "| | Const [0]\n" + "| Filter []\n" + "| | EvalFilter []\n" + "| | | Variable [evalTemp_12]\n" + "| | PathTraverse []\n" + "| | PathCompare [Eq]\n" + "| | Const [0]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root, 'field2': evalTemp_12, 'field5': " + "evalTemp_13, 'field6': evalTemp_14, 'field7': evalTemp_15, 'field8': evalTemp_16, " + "'field9': evalTemp_17}, c1]\n" + "| | BindBlock:\n" + "| | [evalTemp_12]\n" + "| | Source []\n" + "| | [evalTemp_13]\n" + "| | Source []\n" + "| | [evalTemp_14]\n" + "| | Source []\n" + "| | [evalTemp_15]\n" + "| | Source []\n" + "| | [evalTemp_16]\n" + "| | Source []\n" + "| | [evalTemp_17]\n" + "| | Source []\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "MergeJoin []\n" + "| | | Condition\n" + "| | | rid_0 = rid_1\n" + "| | Collation\n" + "| | Ascending\n" + "| Union []\n" + "| | BindBlock:\n" + "| | [rid_1]\n" + "| | Source []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [rid_1]\n" + "| | Variable [rid_0]\n" + "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index3, interval: {[Const " + "[0], Const [0]], [Const [0], Const [0]]}]\n" + "| BindBlock:\n" + "| [rid_0]\n" + "| Source []\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const " + "[0], Const [0]], [Const [0], Const [0]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, FilterIndexingVariable) { + using namespace properties; + PrefixId prefixId; + + // In the absence of full implementation of query parameterization, here we pretend we have a + // function "getQueryParam" which will return a query parameter by index. + const auto getQueryParamFn = [](const size_t index) { + return make<FunctionCall>("getQueryParam", makeSeq(Constant::int32(index))); + }; + + ABT scanNode = make<ScanNode>("root", "c1"); + + // Encode a condition using two query parameters (expressed as functions): + // "a" > param_0 AND "a" >= param_1 (observe param_1 comparison is inclusive). + ABT filterNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>("a", + make<PathTraverse>(make<PathComposeM>( + make<PathCompare>(Operations::Gt, getQueryParamFn(0)), + make<PathCompare>(Operations::Gte, getQueryParamFn(1))))), + make<Variable>("root")), + std::move(scanNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Observe unioning of two index scans with complex expressions for bounds. This encodes: + // (max(param_0, param_1), +inf) U [param_0 > param_1 ? MaxKey : param_1, max(param_0, param_1)] + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [rid_0]\n" + "| aggregations: \n" + "Union []\n" + "| | BindBlock:\n" + "| | [rid_0]\n" + "| | Source []\n" + "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {(If [] " + "BinaryOp [Gte] FunctionCall [getQueryParam] Const [0] FunctionCall [getQueryParam] Const " + "[1] FunctionCall [getQueryParam] Const [0] FunctionCall [getQueryParam] Const [1], " + "+inf)}]\n" + "| BindBlock:\n" + "| [rid_0]\n" + "| Source []\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[If [] " + "BinaryOp [Gte] FunctionCall [getQueryParam] Const [0] FunctionCall [getQueryParam] Const " + "[1] Const [maxKey] FunctionCall [getQueryParam] Const [1], If [] BinaryOp [Gte] " + "FunctionCall [getQueryParam] Const [0] FunctionCall [getQueryParam] Const [1] " + "FunctionCall [getQueryParam] Const [0] FunctionCall [getQueryParam] Const [1]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, FilterReorder) { + using namespace properties; + PrefixId prefixId; + + ABT result = make<ScanNode>("root", "c1"); + + PartialSchemaSelHints hints; + static constexpr size_t kFilterCount = 5; + for (size_t i = 0; i < kFilterCount; i++) { + ProjectionName projName = prefixId.getNextId("field"); + hints.emplace( + PartialSchemaKey{"root", + make<PathGet>(projName, make<PathTraverse>(make<PathIdentity>()))}, + 0.1 * (kFilterCount - i)); + result = make<FilterNode>( + make<EvalFilter>(make<PathGet>(std::move(projName), + make<PathTraverse>(make<PathCompare>( + Operations::Eq, Constant::int64(i)))), + make<Variable>("root")), + std::move(result)); + } + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(result)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + false /*requireRID*/, + {{{"c1", ScanDefinition{{}, {}}}}}, + std::make_unique<HintedCE>(std::move(hints)), + std::make_unique<DefaultCosting>(), + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(2, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Observe filters are ordered from most selective (lowest sel) to least selective (highest + // sel). + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_0]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_1]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_2]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_3]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [3]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_4]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [4]\n" + "PhysicalScan [{'<root>': root, 'field_0': evalTemp_0, 'field_1': evalTemp_1, " + "'field_2': " + "evalTemp_2, 'field_3': evalTemp_3, 'field_4': evalTemp_4}, c1]\n" + " BindBlock:\n" + " [evalTemp_0]\n" + " Source []\n" + " [evalTemp_1]\n" + " Source []\n" + " [evalTemp_2]\n" + " Source []\n" + " [evalTemp_3]\n" + " Source []\n" + " [evalTemp_4]\n" + " Source []\n" + " [root]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, CoveredScan) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalNode1 = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(evalNode1)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pa\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "IndexScan [{'<indexKey> 0': pa}, scanDefName: c1, indexDefName: index1, interval: " + "{(-inf, " + "+inf)}]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, EvalIndexing) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalNode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterNode = + make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(1)), + make<Variable>("pa")), + std::move(evalNode)); + + ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}), + std::move(filterNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(collationNode)); + + { + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(5, 10, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Should not need a collation node. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pa\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "IndexScan [{'<indexKey> 0': pa}, scanDefName: c1, indexDefName: index1, " + "interval: {(Const [1], +inf)}]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n", + optimized); + } + + { + // Index and collation node have incompatible ops. + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + makeIndexDefinition("a", CollationOp::Clustered, false /*isMultiKey*/)}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(10, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Index does not have the right collation and now we need a collation node. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pa\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "Collation []\n" + "| | collation: \n" + "| | pa: Ascending\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "IndexScan [{'<indexKey> 0': pa}, scanDefName: c1, indexDefName: index1, " + "interval: {(Const [1], +inf)}]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n", + optimized); + } +} + +TEST(PhysRewriter, EvalIndexing1) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalNode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterNode = + make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Eq, Constant::int64(1)), + make<Variable>("pa")), + std::move(evalNode)); + + ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}), + std::move(filterNode)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, + std::move(collationNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(8, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const " + "[1], Const [1]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, EvalIndexing2) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalNode1 = make<EvaluationNode>( + "pa1", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT evalNode2 = make<EvaluationNode>( + "pa2", + make<EvalPath>(make<PathField>("a", make<PathConstant>(make<Variable>("pa1"))), + Constant::int32(0)), + std::move(evalNode1)); + + ABT evalNode3 = make<EvaluationNode>( + "pa3", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("pa2")), + std::move(evalNode2)); + + ABT collationNode = make<CollationNode>(CollationRequirement({{"pa3", CollationOp::Ascending}}), + std::move(evalNode3)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa2"}}, + std::move(collationNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::ConstEvalPre, + OptPhaseManager::OptPhase::PathFuse, + OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(10, 20, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pa2\n" + "| RefBlock: \n" + "| Variable [pa2]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [pa3]\n" + "| Variable [pa1]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [pa2]\n" + "| EvalPath []\n" + "| | Const [0]\n" + "| PathField [a]\n" + "| PathConstant []\n" + "| Variable [pa1]\n" + "IndexScan [{'<indexKey> 0': pa1}, scanDefName: c1, indexDefName: index1, interval: " + "{(-inf, +inf)}]\n" + " BindBlock:\n" + " [pa1]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, MultiKeyIndex) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalANode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterANode = + make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Eq, Constant::int64(1)), + make<Variable>("pa")), + std::move(evalANode)); + + ABT evalBNode = make<EvaluationNode>( + "pb", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")), + std::move(filterANode)); + + ABT filterBNode = + make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(2)), + make<Variable>("pb")), + std::move(evalBNode)); + + ABT collationNode = make<CollationNode>( + CollationRequirement({{"pa", CollationOp::Ascending}, {"pb", CollationOp::Ascending}}), + std::move(filterBNode)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, + std::move(collationNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}, + {"index2", + makeIndexDefinition("b", CollationOp::Descending, false /*isMultiKey*/)}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + { + ABT optimized = rootNode; + + // Test RIDIntersect using only Group+Union. + phaseManager.getHints()._disableHashJoinRIDIntersect = true; + + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(15, 25, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // GroupBy+Union cannot propagate collation requirement, and we need a separate + // CollationNode. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Collation []\n" + "| | collation: \n" + "| | pa: Ascending\n" + "| | pb: Ascending\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "| Variable [pb]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | FunctionCall [getArraySize]\n" + "| | Variable [sides_0]\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [rid_0]\n" + "| aggregations: \n" + "| [pa]\n" + "| FunctionCall [$max]\n" + "| Variable [unionTemp_0]\n" + "| [pb]\n" + "| FunctionCall [$max]\n" + "| Variable [unionTemp_1]\n" + "| [sides_0]\n" + "| FunctionCall [$addToSet]\n" + "| Variable [sideId_0]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [rid_0]\n" + "| | Source []\n" + "| | [sideId_0]\n" + "| | Source []\n" + "| | [unionTemp_0]\n" + "| | Source []\n" + "| | [unionTemp_1]\n" + "| | Source []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [unionTemp_1]\n" + "| | Variable [pb]\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [unionTemp_0]\n" + "| | Const [Nothing]\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [sideId_0]\n" + "| | Const [1]\n" + "| IndexScan [{'<indexKey> 0': pb, '<rid>': rid_0}, scanDefName: c1, " + "indexDefName: " + "index2, interval: {(Const [2], +inf)}]\n" + "| BindBlock:\n" + "| [pb]\n" + "| Source []\n" + "| [rid_0]\n" + "| Source []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [unionTemp_1]\n" + "| Const [Nothing]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [unionTemp_0]\n" + "| Variable [pa]\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [sideId_0]\n" + "| Const [0]\n" + "IndexScan [{'<indexKey> 0': pa, '<rid>': rid_0}, scanDefName: c1, indexDefName: " + "index1, interval: {[Const [1], Const [1]]}]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n" + " [rid_0]\n" + " Source []\n", + optimized); + } + + { + ABT optimized = rootNode; + + phaseManager.getHints()._disableGroupByAndUnionRIDIntersect = false; + phaseManager.getHints()._disableHashJoinRIDIntersect = false; + + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(15, 25, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Index2 will be used in reverse direction. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "HashJoin [joinType: Inner]\n" + "| | Condition\n" + "| | rid_0 = rid_1\n" + "| Union []\n" + "| | BindBlock:\n" + "| | [rid_1]\n" + "| | Source []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [rid_1]\n" + "| | Variable [rid_0]\n" + "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index2, interval: " + "{(Const [2], +inf)}, reversed]\n" + "| BindBlock:\n" + "| [rid_0]\n" + "| Source []\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: " + "{[Const [1], Const [1]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); + } +} + +TEST(PhysRewriter, CompoundIndex1) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalANode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterANode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1))), + make<Variable>("pa")), + std::move(evalANode)); + + ABT evalBNode = make<EvaluationNode>( + "pb", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")), + std::move(filterANode)); + + ABT filterBNode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2))), + make<Variable>("pb")), + std::move(evalBNode)); + + ABT filterCNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "c", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))), + make<Variable>("root")), + std::move(filterBNode)); + + ABT filterDNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "d", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(4)))), + make<Variable>("root")), + std::move(filterCNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterDNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("c"), CollationOp::Descending}}, + false /*isMultiKey*/}}, + {"index2", + IndexDefinition{{{makeNonMultikeyIndexPath("b"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("d"), CollationOp::Ascending}}, + false /*isMultiKey*/}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(60, 110, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "MergeJoin []\n" + "| | | Condition\n" + "| | | rid_0 = rid_1\n" + "| | Collation\n" + "| | Ascending\n" + "| Union []\n" + "| | BindBlock:\n" + "| | [rid_1]\n" + "| | Source []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [rid_1]\n" + "| | Variable [rid_0]\n" + "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index2, interval: {[Const " + "[2], Const [2]], [Const [4], Const [4]]}]\n" + "| BindBlock:\n" + "| [rid_0]\n" + "| Source []\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const " + "[1], Const [1]], [Const [3], Const [3]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, CompoundIndex2) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalANode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterANode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1))), + make<Variable>("pa")), + std::move(evalANode)); + + ABT evalBNode = make<EvaluationNode>( + "pb", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")), + std::move(filterANode)); + + ABT filterBNode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2))), + make<Variable>("pb")), + std::move(evalBNode)); + + ABT filterCNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "c", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))), + make<Variable>("root")), + std::move(filterBNode)); + + ABT filterDNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "d", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(4)))), + make<Variable>("root")), + std::move(filterCNode)); + + ABT collationNode = make<CollationNode>( + CollationRequirement({{"pa", CollationOp::Ascending}, {"pb", CollationOp::Ascending}}), + std::move(filterDNode)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, + std::move(collationNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + { + {"index1", + IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("c"), CollationOp::Descending}}, + false /*isMultiKey*/}}, + {"index2", + IndexDefinition{{{makeNonMultikeyIndexPath("b"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("d"), CollationOp::Ascending}}, + false /*isMultiKey*/}}, + }}}}}, + {true /*debugMode*/, 3 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(100, 170, phaseManager.getMemo().getStats()._physPlanExplorationCount); + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "MergeJoin []\n" + "| | | Condition\n" + "| | | rid_0 = rid_1\n" + "| | Collation\n" + "| | Ascending\n" + "| Union []\n" + "| | BindBlock:\n" + "| | [rid_1]\n" + "| | Source []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [rid_1]\n" + "| | Variable [rid_0]\n" + "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index2, interval: " + "{[Const " + "[2], Const [2]], [Const [4], Const [4]]}]\n" + "| BindBlock:\n" + "| [rid_0]\n" + "| Source []\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const " + "[1], Const [1]], [Const [3], Const [3]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, CompoundIndex3) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalANode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterANode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1))), + make<Variable>("pa")), + std::move(evalANode)); + + ABT evalBNode = make<EvaluationNode>( + "pb", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")), + std::move(filterANode)); + + ABT filterBNode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2))), + make<Variable>("pb")), + std::move(evalBNode)); + + ABT filterCNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "c", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))), + make<Variable>("root")), + std::move(filterBNode)); + + ABT filterDNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "d", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(4)))), + make<Variable>("root")), + std::move(filterCNode)); + + ABT collationNode = make<CollationNode>( + CollationRequirement({{"pa", CollationOp::Ascending}, {"pb", CollationOp::Ascending}}), + std::move(filterDNode)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, + std::move(collationNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {{"index1", + IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}, + {makeIndexPath("c"), CollationOp::Descending}}, + true /*isMultiKey*/}}, + {"index2", + IndexDefinition{{{makeIndexPath("b"), CollationOp::Ascending}, + {makeIndexPath("d"), CollationOp::Ascending}}, + true /*isMultiKey*/}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(70, 110, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Collation []\n" + "| | collation: \n" + "| | pa: Ascending\n" + "| | pb: Ascending\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "| Variable [pb]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root, 'a': pa, 'b': pb}, c1]\n" + "| | BindBlock:\n" + "| | [pa]\n" + "| | Source []\n" + "| | [pb]\n" + "| | Source []\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "MergeJoin []\n" + "| | | Condition\n" + "| | | rid_0 = rid_1\n" + "| | Collation\n" + "| | Ascending\n" + "| Union []\n" + "| | BindBlock:\n" + "| | [rid_1]\n" + "| | Source []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [rid_1]\n" + "| | Variable [rid_0]\n" + "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index2, interval: {[Const " + "[2], Const [2]], [Const [4], Const [4]]}]\n" + "| BindBlock:\n" + "| [rid_0]\n" + "| Source []\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const " + "[1], Const [1]], [Const [3], Const [3]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, IndexBoundsIntersect) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterNode1 = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "b", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))), + make<Variable>("root")), + std::move(scanNode)); + + ABT filterNode2 = make<FilterNode>( + make<EvalFilter>( + make<PathComposeA>( + make<PathComposeM>(make<PathGet>("a", + make<PathTraverse>(make<PathCompare>( + Operations::Gt, Constant::int64(70)))), + make<PathGet>("a", + make<PathTraverse>(make<PathCompare>( + Operations::Lt, Constant::int64(90))))), + make<PathGet>( + "a", + make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(100))))), + make<Variable>("root")), + std::move(filterNode1)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{}}, std::move(filterNode2)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {{"index1", + IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}, + {makeIndexPath("b"), CollationOp::Ascending}}, + true /*isMultiKey*/}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(10, 15, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| RefBlock: \n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_1]\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [rid_0]\n" + "| aggregations: \n" + "| [evalTemp_1]\n" + "| FunctionCall [$first]\n" + "| Variable [disjunction_0]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [disjunction_0]\n" + "| | Source []\n" + "| | [rid_0]\n" + "| | Source []\n" + "| IndexScan [{'<indexKey> 1': disjunction_0, '<rid>': rid_0}, scanDefName: c1, " + "indexDefName: index1, interval: {[Const [100], Const [100]], (-inf, +inf)}]\n" + "| BindBlock:\n" + "| [disjunction_0]\n" + "| Source []\n" + "| [rid_0]\n" + "| Source []\n" + "Filter []\n" + "| EvalFilter []\n" + "| | FunctionCall [getArraySize]\n" + "| | Variable [sides_0]\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [rid_0]\n" + "| aggregations: \n" + "| [disjunction_0]\n" + "| FunctionCall [$first]\n" + "| Variable [conjunction_0]\n" + "| [sides_0]\n" + "| FunctionCall [$addToSet]\n" + "| Variable [sideId_0]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [conjunction_0]\n" + "| | Source []\n" + "| | [rid_0]\n" + "| | Source []\n" + "| | [sideId_0]\n" + "| | Source []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [sideId_0]\n" + "| | Const [1]\n" + "| IndexScan [{'<indexKey> 1': conjunction_0, '<rid>': rid_0}, scanDefName: c1, " + "indexDefName: index1, interval: {(-inf, Const [90]), (-inf, +inf)}]\n" + "| BindBlock:\n" + "| [conjunction_0]\n" + "| Source []\n" + "| [rid_0]\n" + "| Source []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [sideId_0]\n" + "| Const [0]\n" + "IndexScan [{'<indexKey> 1': conjunction_0, '<rid>': rid_0}, scanDefName: c1, " + "indexDefName: index1, interval: {(Const [70], +inf), (-inf, +inf)}]\n" + " BindBlock:\n" + " [conjunction_0]\n" + " Source []\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, IndexBoundsIntersect1) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalNode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>( + make<PathComposeM>( + make<PathTraverse>(make<PathCompare>(Operations::Gt, Constant::int64(70))), + make<PathTraverse>(make<PathCompare>(Operations::Lt, Constant::int64(90)))), + make<Variable>("pa")), + std::move(evalNode)); + + ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}), + std::move(filterNode)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, + std::move(collationNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}}, + false /*isMultiKey*/}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(15, 20, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {(Const " + "[70], Const [90])}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, IndexBoundsIntersect2) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalNode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>(make<PathTraverse>(make<PathComposeM>( + make<PathCompare>(Operations::Gt, Constant::int64(70)), + make<PathCompare>(Operations::Lt, Constant::int64(90)))), + make<Variable>("pa")), + std::move(evalNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {{"index1", + IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}}, + true /*isMultiKey*/}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(6, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Demonstrate we can intersect the bounds here because composition does not contain + // traverse. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "Unique []\n" + "| projections: \n" + "| rid_0\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {(Const " + "[70], Const [90])}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, IndexBoundsIntersect3) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>("a", + make<PathTraverse>(make<PathComposeM>( + make<PathGet>("b", + make<PathTraverse>(make<PathCompare>( + Operations::Gt, Constant::int64(70)))), + make<PathGet>("b", + make<PathTraverse>(make<PathCompare>( + Operations::Lt, Constant::int64(90))))))), + make<Variable>("root")), + std::move(scanNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{makeIndexPath(FieldPathType{"a", "b"}, true /*isMultiKey*/), + CollationOp::Ascending}}, + true /*isMultiKey*/}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // We intersect indexes because the outer composition is over the same field ("b"). + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | FunctionCall [getArraySize]\n" + "| | Variable [sides_0]\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [rid_0]\n" + "| aggregations: \n" + "| [sides_0]\n" + "| FunctionCall [$addToSet]\n" + "| Variable [sideId_0]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [rid_0]\n" + "| | Source []\n" + "| | [sideId_0]\n" + "| | Source []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [sideId_0]\n" + "| | Const [1]\n" + "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: " + "{(-inf, " + "Const [90])}]\n" + "| BindBlock:\n" + "| [rid_0]\n" + "| Source []\n" + "Evaluation []\n" + "| BindBlock:\n" + "| [sideId_0]\n" + "| Const [0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {(Const " + "[70], +inf)}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, IndexBoundsIntersect4) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>("a", + make<PathTraverse>(make<PathComposeM>( + make<PathGet>("b", + make<PathTraverse>(make<PathCompare>( + Operations::Gt, Constant::int64(70)))), + make<PathGet>("c", + make<PathTraverse>(make<PathCompare>( + Operations::Lt, Constant::int64(90))))))), + make<Variable>("root")), + std::move(scanNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{makeIndexPath(FieldPathType{"a", "b"}, true /*isMultiKey*/), + CollationOp::Ascending}}, + true /*isMultiKey*/}}, + {"index2", + IndexDefinition{{{makeIndexPath(FieldPathType{"a", "c"}, true /*isMultiKey*/), + CollationOp::Ascending}}, + true /*isMultiKey*/}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(3, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // We do not intersect indexes because the outer composition is over the different fields. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [root]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathComposeM []\n" + "| | PathGet [c]\n" + "| | PathTraverse []\n" + "| | PathCompare [Lt]\n" + "| | Const [90]\n" + "| PathGet [b]\n" + "| PathTraverse []\n" + "| PathCompare [Gt]\n" + "| Const [70]\n" + "PhysicalScan [{'<root>': root}, c1]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, IndexResidualReq) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalANode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterANode = + make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(0)), + make<Variable>("pa")), + std::move(evalANode)); + + ABT filterBNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "b", make<PathGet>("c", make<PathCompare>(Operations::Gt, Constant::int64(0)))), + make<Variable>("root")), + std::move(filterANode)); + + ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}), + std::move(filterBNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(collationNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{{{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}, + {makeNonMultikeyIndexPath("b"), CollationOp::Ascending}}, + false /*isMultiKey*/}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(10, 26, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Make sure we can use the index to cover "b" while testing "b.c" with a separate filter. + ASSERT_EXPLAIN_PROPS_V2( + "Properties [cost: 0.070002, localCost: 0, adjustedCE: 10]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 10\n" + "| | projections: \n" + "| | pa\n" + "| | root\n" + "| | indexingAvailability: \n" + "| | [groupId: 0, scanProjection: root, scanDefName: c1]\n" + "| | collectionAvailability: \n" + "| | c1\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: Centralized\n" + "| Physical:\n" + "| distribution: \n" + "| type: Centralized\n" + "| indexingRequirement: \n" + "| Complete, dedupRID\n" + "Root []\n" + "| | projections: \n" + "| | pa\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "Properties [cost: 0.070002, localCost: 0.070002, adjustedCE: 10]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 10\n" + "| | requirementCEs: \n" + "| | refProjection: root, path: 'PathGet [a] PathIdentity []', ce: 100\n" + "| | refProjection: root, path: 'PathGet [b] PathGet [c] PathIdentity []', " + "ce: 100\n" + "| | projections: \n" + "| | pa\n" + "| | root\n" + "| | indexingAvailability: \n" + "| | [groupId: 0, scanProjection: root, scanDefName: c1]\n" + "| | collectionAvailability: \n" + "| | c1\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: Centralized\n" + "| Physical:\n" + "| collation: \n" + "| pa: Ascending\n" + "| projections: \n" + "| pa\n" + "| distribution: \n" + "| type: Centralized\n" + "| indexingRequirement: \n" + "| Index, dedupRID\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_1]\n" + "| PathGet [c]\n" + "| PathCompare [Gt]\n" + "| Const [0]\n" + "IndexScan [{'<indexKey> 0': pa, '<indexKey> 1': evalTemp_1}, scanDefName: c1, " + "indexDefName: index1, interval: {(Const [0], +inf), (-inf, +inf)}]\n" + " BindBlock:\n" + " [evalTemp_1]\n" + " Source []\n" + " [pa]\n" + " Source []\n", + phaseManager); +} + +TEST(PhysRewriter, IndexResidualReq1) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterANode = make<FilterNode>( + make<EvalFilter>(make<PathGet>("a", make<PathCompare>(Operations::Eq, Constant::int64(0))), + make<Variable>("root")), + std::move(scanNode)); + + ABT filterBNode = make<FilterNode>( + make<EvalFilter>(make<PathGet>("b", make<PathCompare>(Operations::Eq, Constant::int64(0))), + make<Variable>("root")), + std::move(filterANode)); + + ABT filterCNode = make<FilterNode>( + make<EvalFilter>(make<PathGet>("c", make<PathCompare>(Operations::Eq, Constant::int64(0))), + make<Variable>("root")), + std::move(filterBNode)); + + ABT evalDNode = make<EvaluationNode>( + "pd", + make<EvalPath>(make<PathGet>("d", make<PathIdentity>()), make<Variable>("root")), + std::move(filterCNode)); + + ABT collationNode = make<CollationNode>(CollationRequirement({{"pd", CollationOp::Ascending}}), + std::move(evalDNode)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, + std::move(collationNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + makeCompositeIndexDefinition({{"a", CollationOp::Ascending, false /*isMultiKey*/}, + {"b", CollationOp::Ascending, false /*isMultiKey*/}, + {"c", CollationOp::Ascending, false /*isMultiKey*/}, + {"d", CollationOp::Ascending, false /*isMultiKey*/}}, + false /*isMultiKey*/)}, + {"index2", + makeCompositeIndexDefinition({{"a", CollationOp::Ascending, false /*isMultiKey*/}, + {"b", CollationOp::Ascending, false /*isMultiKey*/}, + {"d", CollationOp::Ascending, false /*isMultiKey*/}}, + false /*isMultiKey*/)}, + {"index3", + makeCompositeIndexDefinition({{"a", CollationOp::Ascending, false /*isMultiKey*/}, + {"d", CollationOp::Ascending, false /*isMultiKey*/}}, + false /*isMultiKey*/)}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(25, 30, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Prefer index1 over index2 and index3 in order to cover all fields. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const " + "[0], Const [0]], [Const [0], Const [0]], [Const [0], Const [0]], (-inf, +inf)}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, IndexResidualReq2) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterANode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(0)))), + make<Variable>("root")), + std::move(scanNode)); + + ABT filterBNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "b", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(0)))), + make<Variable>("root")), + std::move(filterANode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterBNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {{"index1", + makeCompositeIndexDefinition( + {{"a", CollationOp::Ascending, true /*isMultiKey*/}, + {"c", CollationOp::Ascending, true /*isMultiKey*/}, + {"b", CollationOp::Ascending, true /*isMultiKey*/}})}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(7, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // We can cover "b" with the index and filter before we Seek. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "Unique []\n" + "| projections: \n" + "| rid_0\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_4]\n" + "| PathCompare [Eq]\n" + "| Const [0]\n" + "IndexScan [{'<indexKey> 2': evalTemp_4, '<rid>': rid_0}, scanDefName: c1, " + "indexDefName: " + "index1, interval: {[Const [0], Const [0]], (-inf, +inf), (-inf, +inf)}]\n" + " BindBlock:\n" + " [evalTemp_4]\n" + " Source []\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, ElemMatchIndex) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + // This encodes an elemMatch with a conjunction >70 and <90. + ABT filterNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", + make<PathComposeM>(make<PathArr>(), + make<PathTraverse>(make<PathComposeM>( + make<PathCompare>(Operations::Gt, Constant::int64(70)), + make<PathCompare>(Operations::Lt, Constant::int64(90)))))), + make<Variable>("root")), + std::move(scanNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(5, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [root]\n" + "| PathGet [a]\n" + "| PathArr []\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "Unique []\n" + "| projections: \n" + "| rid_0\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {(Const " + "[70], Const [90])}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, ElemMatchIndex1) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterNode1 = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "b", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))), + make<Variable>("root")), + std::move(scanNode)); + + // This encodes an elemMatch with a conjunction >70 and <90. + ABT filterNode2 = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", + make<PathComposeM>(make<PathArr>(), + make<PathTraverse>(make<PathComposeM>( + make<PathCompare>(Operations::Gt, Constant::int64(70)), + make<PathCompare>(Operations::Lt, Constant::int64(90)))))), + make<Variable>("root")), + std::move(filterNode1)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {{"index1", + makeCompositeIndexDefinition( + {{"b", CollationOp::Ascending, true /*isMultiKey*/}, + {"a", CollationOp::Ascending, true /*isMultiKey*/}})}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(5, 10, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Demonstrate we can cover both the filter and the extracted elemMatch predicate with the + // index. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [root]\n" + "| PathGet [a]\n" + "| PathArr []\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "Unique []\n" + "| projections: \n" + "| rid_0\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const " + "[1], Const [1]], (Const [70], Const [90])}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, ObjectElemMatch) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", + make<PathComposeM>(make<PathArr>(), + make<PathTraverse>(make<PathComposeM>( + make<PathGet>("b", + make<PathTraverse>(make<PathCompare>( + Operations::Eq, Constant::int64(1)))), + make<PathGet>("c", + make<PathTraverse>(make<PathCompare>( + Operations::Eq, Constant::int64(2)))))))), + make<Variable>("root")), + std::move(scanNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {{"index1", + makeCompositeIndexDefinition( + {{"b", CollationOp::Ascending, true /*isMultiKey*/}, + {"a", CollationOp::Ascending, true /*isMultiKey*/}})}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // We currently cannot use indexes with ObjectElemMatch. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [root]\n" + "| PathGet [a]\n" + "| PathTraverse []\n" + "| PathComposeM []\n" + "| | PathGet [c]\n" + "| | PathTraverse []\n" + "| | PathCompare [Eq]\n" + "| | Const [2]\n" + "| PathGet [b]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [root]\n" + "| PathGet [a]\n" + "| PathArr []\n" + "PhysicalScan [{'<root>': root}, c1]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, ParallelScan) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(1)))), + make<Variable>("root")), + std::move(scanNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", ScanDefinition{{}, {}, {DistributionType::UnknownPartitioning}}}}, + 5 /*numberOfPartitions*/}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: Centralized\n" + "| RefBlock: \n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_0]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [1]\n" + "PhysicalScan [{'<root>': root, 'a': evalTemp_0}, c1, parallel]\n" + " BindBlock:\n" + " [evalTemp_0]\n" + " Source []\n" + " [root]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, HashPartitioning) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT projectionANode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + ABT projectionBNode = make<EvaluationNode>( + "pb", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")), + std::move(projectionANode)); + + ABT groupByNode = make<GroupByNode>(ProjectionNameVector{"pa"}, + ProjectionNameVector{"pc"}, + makeSeq(make<Variable>("pb")), + std::move(projectionBNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {}, + {DistributionType::HashPartitioning, + makeSeq(make<PathGet>("a", make<PathIdentity>()))}}}}, + 5 /*numberOfPartitions*/}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(5, 10, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pc\n" + "| RefBlock: \n" + "| Variable [pc]\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: Centralized\n" + "| RefBlock: \n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [pa]\n" + "| aggregations: \n" + "| [pc]\n" + "| Variable [pb]\n" + "PhysicalScan [{'a': pa, 'b': pb}, c1]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n" + " [pb]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, IndexPartitioning) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT projectionANode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterANode = + make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(0)), + make<Variable>("pa")), + std::move(projectionANode)); + + ABT projectionBNode = make<EvaluationNode>( + "pb", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")), + std::move(filterANode)); + + ABT filterBNode = + make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(1)), + make<Variable>("pb")), + std::move(projectionBNode)); + + ABT groupByNode = make<GroupByNode>(ProjectionNameVector{"pa"}, + ProjectionNameVector{"pc"}, + makeSeq(make<Variable>("pb")), + std::move(filterBNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{ + {{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}}, + false /*isMultiKey*/, + {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("a"))}, + {}}}}, + {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("b"))}}}}, + 5 /*numberOfPartitions*/}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(75, 150, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pc\n" + "| RefBlock: \n" + "| Variable [pc]\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: Centralized\n" + "| RefBlock: \n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [pa]\n" + "| aggregations: \n" + "| [pc]\n" + "| Variable [pb]\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: HashPartitioning\n" + "| | projections: \n" + "| | pa\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| Filter []\n" + "| | EvalFilter []\n" + "| | | Variable [pb]\n" + "| | PathCompare [Gt]\n" + "| | Const [1]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'b': pb}, c1]\n" + "| | BindBlock:\n" + "| | [pb]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: RoundRobin\n" + "| RefBlock: \n" + "IndexScan [{'<indexKey> 0': pa, '<rid>': rid_0}, scanDefName: c1, indexDefName: index1, " + "interval: {(Const [0], +inf)}]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, IndexPartitioning1) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT projectionANode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT filterANode = + make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(0)), + make<Variable>("pa")), + std::move(projectionANode)); + + ABT projectionBNode = make<EvaluationNode>( + "pb", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")), + std::move(filterANode)); + + ABT filterBNode = + make<FilterNode>(make<EvalFilter>(make<PathCompare>(Operations::Gt, Constant::int64(1)), + make<Variable>("pb")), + std::move(projectionBNode)); + + ABT groupByNode = make<GroupByNode>(ProjectionNameVector{"pa"}, + ProjectionNameVector{"pc"}, + makeSeq(make<Variable>("pb")), + std::move(filterBNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{ + {}, + {{"index1", + IndexDefinition{ + {{makeNonMultikeyIndexPath("a"), CollationOp::Ascending}}, + false /*isMultiKey*/, + {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("a"))}, + {}}}, + {"index2", + IndexDefinition{ + {{makeNonMultikeyIndexPath("b"), CollationOp::Ascending}}, + false /*isMultiKey*/, + {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("b"))}, + {}}}}, + {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("c"))}}}}, + 5 /*numberOfPartitions*/}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(150, 300, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pc\n" + "| RefBlock: \n" + "| Variable [pc]\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: Centralized\n" + "| RefBlock: \n" + "GroupBy []\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [pa]\n" + "| aggregations: \n" + "| [pc]\n" + "| Variable [pb]\n" + "HashJoin [joinType: Inner]\n" + "| | Condition\n" + "| | rid_0 = rid_3\n" + "| Union []\n" + "| | BindBlock:\n" + "| | [pa]\n" + "| | Source []\n" + "| | [rid_3]\n" + "| | Source []\n" + "| Evaluation []\n" + "| | BindBlock:\n" + "| | [rid_3]\n" + "| | Variable [rid_0]\n" + "| IndexScan [{'<indexKey> 0': pa, '<rid>': rid_0}, scanDefName: c1, indexDefName: " + "index1, interval: {(Const [0], +inf)}]\n" + "| BindBlock:\n" + "| [pa]\n" + "| Source []\n" + "| [rid_0]\n" + "| Source []\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: HashPartitioning\n" + "| | projections: \n" + "| | pa\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "IndexScan [{'<indexKey> 0': pb, '<rid>': rid_0}, scanDefName: c1, indexDefName: index2, " + "interval: {(Const [1], +inf)}]\n" + " BindBlock:\n" + " [pb]\n" + " Source []\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, LocalGlobalAgg) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalANode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + ABT evalBNode = make<EvaluationNode>( + "pb", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")), + std::move(evalANode)); + + ABT groupByNode = + make<GroupByNode>(ProjectionNameVector{"pa"}, + ProjectionNameVector{"pc"}, + makeSeq(make<FunctionCall>("$sum", makeSeq(make<Variable>("pb")))), + std::move(evalBNode)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa", "pc"}}, + std::move(groupByNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", ScanDefinition{{}, {}, {DistributionType::UnknownPartitioning}}}}, + 5 /*numberOfPartitions*/}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(15, 25, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pa\n" + "| | pc\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "| Variable [pc]\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: Centralized\n" + "| RefBlock: \n" + "GroupBy [Global]\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [pa]\n" + "| aggregations: \n" + "| [pc]\n" + "| FunctionCall [$sum]\n" + "| Variable [preagg_0]\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: HashPartitioning\n" + "| | projections: \n" + "| | pa\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "GroupBy [Local]\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| | Variable [pa]\n" + "| aggregations: \n" + "| [preagg_0]\n" + "| FunctionCall [$sum]\n" + "| Variable [pb]\n" + "PhysicalScan [{'a': pa, 'b': pb}, c1, parallel]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n" + " [pb]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, LocalGlobalAgg1) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalBNode = make<EvaluationNode>( + "pb", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT groupByNode = + make<GroupByNode>(ProjectionNameVector{}, + ProjectionNameVector{"pc"}, + makeSeq(make<FunctionCall>("$sum", makeSeq(make<Variable>("pb")))), + std::move(evalBNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", ScanDefinition{{}, {}, {DistributionType::UnknownPartitioning}}}}, + 5 /*numberOfPartitions*/}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(5, 15, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pc\n" + "| RefBlock: \n" + "| Variable [pc]\n" + "GroupBy [Global]\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| aggregations: \n" + "| [pc]\n" + "| FunctionCall [$sum]\n" + "| Variable [preagg_0]\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: Centralized\n" + "| RefBlock: \n" + "GroupBy [Local]\n" + "| | groupings: \n" + "| | RefBlock: \n" + "| aggregations: \n" + "| [preagg_0]\n" + "| FunctionCall [$sum]\n" + "| Variable [pb]\n" + "PhysicalScan [{'b': pb}, c1, parallel]\n" + " BindBlock:\n" + " [pb]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, LocalLimitSkip) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT limitSkipNode = make<LimitSkipNode>(LimitSkipRequirement{20, 10}, std::move(scanNode)); + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, + std::move(limitSkipNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", ScanDefinition{{}, {}, {DistributionType::UnknownPartitioning}}}}, + 5 /*numberOfPartitions*/}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(5, 15, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_PROPS_V2( + "Properties [cost: 0.0066022, localCost: 0, adjustedCE: 20]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 20\n" + "| | projections: \n" + "| | root\n" + "| | collectionAvailability: \n" + "| | c1\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: Centralized\n" + "| | distribution: \n" + "| | type: UnknownPartitioning\n" + "| Physical:\n" + "| distribution: \n" + "| type: Centralized\n" + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Properties [cost: 0.0066022, localCost: 1e-06, adjustedCE: 30]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 1000\n" + "| | projections: \n" + "| | root\n" + "| | indexingAvailability: \n" + "| | [groupId: 0, scanProjection: root, scanDefName: c1, possiblyEqPredsOnly]\n" + "| | collectionAvailability: \n" + "| | c1\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: UnknownPartitioning\n" + "| Physical:\n" + "| limitSkip:\n" + "| limit: 20\n" + "| skip: 10\n" + "| projections: \n" + "| root\n" + "| distribution: \n" + "| type: Centralized\n" + "| indexingRequirement: \n" + "| Complete, dedupRID\n" + "LimitSkip []\n" + "| limitSkip:\n" + "| limit: 20\n" + "| skip: 10\n" + "Properties [cost: 0.0066012, localCost: 0.003001, adjustedCE: 30]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 1000\n" + "| | projections: \n" + "| | root\n" + "| | indexingAvailability: \n" + "| | [groupId: 0, scanProjection: root, scanDefName: c1, possiblyEqPredsOnly]\n" + "| | collectionAvailability: \n" + "| | c1\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: UnknownPartitioning\n" + "| Physical:\n" + "| projections: \n" + "| root\n" + "| distribution: \n" + "| type: Centralized\n" + "| indexingRequirement: \n" + "| Complete, dedupRID\n" + "| limitEstimate: 30\n" + "Exchange []\n" + "| | distribution: \n" + "| | type: Centralized\n" + "| RefBlock: \n" + "Properties [cost: 0.0036002, localCost: 0.0036002, adjustedCE: 30]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 1000\n" + "| | projections: \n" + "| | root\n" + "| | indexingAvailability: \n" + "| | [groupId: 0, scanProjection: root, scanDefName: c1, possiblyEqPredsOnly]\n" + "| | collectionAvailability: \n" + "| | c1\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: UnknownPartitioning\n" + "| Physical:\n" + "| projections: \n" + "| root\n" + "| distribution: \n" + "| type: UnknownPartitioning, disableExchanges\n" + "| indexingRequirement: \n" + "| Complete, dedupRID\n" + "| limitEstimate: 30\n" + "PhysicalScan [{'<root>': root}, c1, parallel]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + phaseManager); +} + +TEST(PhysRewriter, CollationLimit) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT evalNode = make<EvaluationNode>( + "pa", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + ABT collationNode = make<CollationNode>(CollationRequirement({{"pa", CollationOp::Ascending}}), + std::move(evalNode)); + ABT limitSkipNode = make<LimitSkipNode>(LimitSkipRequirement{20, 0}, std::move(collationNode)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, + std::move(limitSkipNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", {{}, {}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_BETWEEN(9, 11, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // We have a collation node with limit-skip physical properties. It will be lowered to a + // sort node with limit. + ASSERT_EXPLAIN_PROPS_V2( + "Properties [cost: 4.92193, localCost: 0, adjustedCE: 20]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 20\n" + "| | projections: \n" + "| | pa\n" + "| | root\n" + "| | collectionAvailability: \n" + "| | c1\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: Centralized\n" + "| Physical:\n" + "| distribution: \n" + "| type: Centralized\n" + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Properties [cost: 4.92193, localCost: 4.32193, adjustedCE: 20]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 1000\n" + "| | requirementCEs: \n" + "| | refProjection: root, path: 'PathGet [a] PathIdentity []', ce: " + "1000\n" + "| | projections: \n" + "| | pa\n" + "| | root\n" + "| | indexingAvailability: \n" + "| | [groupId: 0, scanProjection: root, scanDefName: c1]\n" + "| | collectionAvailability: \n" + "| | c1\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: Centralized\n" + "| Physical:\n" + "| collation: \n" + "| pa: Ascending\n" + "| limitSkip:\n" + "| limit: 20\n" + "| skip: 0\n" + "| projections: \n" + "| root\n" + "| distribution: \n" + "| type: Centralized\n" + "| indexingRequirement: \n" + "| Complete, dedupRID\n" + "Collation []\n" + "| | collation: \n" + "| | pa: Ascending\n" + "| RefBlock: \n" + "| Variable [pa]\n" + "Properties [cost: 0.600001, localCost: 0.600001, adjustedCE: 1000]\n" + "| | Logical:\n" + "| | cardinalityEstimate: \n" + "| | ce: 1000\n" + "| | requirementCEs: \n" + "| | refProjection: root, path: 'PathGet [a] PathIdentity []', ce: " + "1000\n" + "| | projections: \n" + "| | pa\n" + "| | root\n" + "| | indexingAvailability: \n" + "| | [groupId: 0, scanProjection: root, scanDefName: c1]\n" + "| | collectionAvailability: \n" + "| | c1\n" + "| | distributionAvailability: \n" + "| | distribution: \n" + "| | type: Centralized\n" + "| Physical:\n" + "| projections: \n" + "| root\n" + "| pa\n" + "| distribution: \n" + "| type: Centralized\n" + "| indexingRequirement: \n" + "| Complete, dedupRID\n" + "PhysicalScan [{'<root>': root, 'a': pa}, c1]\n" + " BindBlock:\n" + " [pa]\n" + " Source []\n" + " [root]\n" + " Source []\n", + phaseManager); +} + +TEST(PhysRewriter, PartialIndex1) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterANode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))), + make<Variable>("root")), + std::move(scanNode)); + ABT filterBNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "b", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2)))), + make<Variable>("root")), + std::move(filterANode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterBNode)); + + // TODO: Test cases where partial filter bound is a range which subsumes the query + // requirement + // TODO: (e.g. half open interval) + auto conversionResult = convertExprToPartialSchemaReq(make<EvalFilter>( + make<PathGet>("b", + make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2)))), + make<Variable>("root"))); + ASSERT_TRUE(conversionResult._success); + ASSERT_FALSE(conversionResult._hasEmptyInterval); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {{"index1", + IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}}, + true /*isMultiKey*/, + {DistributionType::Centralized}, + std::move(conversionResult._reqMap)}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Partial schema requirement is not on an index field. We get a seek on this field. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| Filter []\n" + "| | EvalFilter []\n" + "| | | Variable [evalTemp_2]\n" + "| | PathTraverse []\n" + "| | PathCompare [Eq]\n" + "| | Const [2]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root, 'b': evalTemp_2}, c1]\n" + "| | BindBlock:\n" + "| | [evalTemp_2]\n" + "| | Source []\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const " + "[3], Const [3]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, PartialIndex2) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterANode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))), + make<Variable>("root")), + std::move(scanNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterANode)); + + auto conversionResult = convertExprToPartialSchemaReq(make<EvalFilter>( + make<PathGet>("a", + make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))), + make<Variable>("root"))); + ASSERT_TRUE(conversionResult._success); + ASSERT_FALSE(conversionResult._hasEmptyInterval); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {{"index1", + IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}}, + true /*isMultiKey*/, + {DistributionType::Centralized}, + std::move(conversionResult._reqMap)}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Partial schema requirement on an index field. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "BinaryJoin [joinType: Inner, {rid_0}]\n" + "| | Const [true]\n" + "| LimitSkip []\n" + "| | limitSkip:\n" + "| | limit: 1\n" + "| | skip: 0\n" + "| Seek [ridProjection: rid_0, {'<root>': root}, c1]\n" + "| | BindBlock:\n" + "| | [root]\n" + "| | Source []\n" + "| RefBlock: \n" + "| Variable [rid_0]\n" + "IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index1, interval: {[Const " + "[3], Const [3]]}]\n" + " BindBlock:\n" + " [rid_0]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, PartialIndexReject) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterANode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))), + make<Variable>("root")), + std::move(scanNode)); + ABT filterBNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "b", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2)))), + make<Variable>("root")), + std::move(filterANode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterBNode)); + + auto conversionResult = convertExprToPartialSchemaReq(make<EvalFilter>( + make<PathGet>("b", + make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(4)))), + make<Variable>("root"))); + ASSERT_TRUE(conversionResult._success); + ASSERT_FALSE(conversionResult._hasEmptyInterval); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"c1", + ScanDefinition{{}, + {{"index1", + IndexDefinition{{{makeIndexPath("a"), CollationOp::Ascending}}, + true /*isMultiKey*/, + {DistributionType::Centralized}, + std::move(conversionResult._reqMap)}}}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(3, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Incompatible partial filter. Use scan. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_1]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [2]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_0]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [3]\n" + "PhysicalScan [{'<root>': root, 'a': evalTemp_0, 'b': evalTemp_1}, c1]\n" + " BindBlock:\n" + " [evalTemp_0]\n" + " Source []\n" + " [evalTemp_1]\n" + " Source []\n" + " [root]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, RequireRID) { + using namespace properties; + PrefixId prefixId; + + ABT scanNode = make<ScanNode>("root", "c1"); + + ABT filterNode = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(3)))), + make<Variable>("root")), + std::move(scanNode)); + + ABT rootNode = + make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); + + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + true /*requireRID*/, + {{{"c1", ScanDefinition{{}, {}}}}}, + std::make_unique<HeuristicCE>(), + std::make_unique<DefaultCosting>(), + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = rootNode; + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(2, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + // Make sure the Scan node returns rid. + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | root\n" + "| RefBlock: \n" + "| Variable [root]\n" + "Filter []\n" + "| EvalFilter []\n" + "| | Variable [evalTemp_0]\n" + "| PathTraverse []\n" + "| PathCompare [Eq]\n" + "| Const [3]\n" + "PhysicalScan [{'<rid>': rid_0, '<root>': root, 'a': evalTemp_0}, c1]\n" + " BindBlock:\n" + " [evalTemp_0]\n" + " Source []\n" + " [rid_0]\n" + " Source []\n" + " [root]\n" + " Source []\n", + optimized); +} + +TEST(PhysRewriter, UnionRewrite) { + using namespace properties; + + ABT scanNode1 = make<ScanNode>("ptest1", "test1"); + ABT scanNode2 = make<ScanNode>("ptest2", "test2"); + + // Each branch produces two projections, pUnion1 and pUnion2. + ABT evalNode1 = make<EvaluationNode>( + "pUnion1", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("ptest1")), + std::move(scanNode1)); + ABT evalNode2 = make<EvaluationNode>( + "pUnion2", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("ptest1")), + std::move(evalNode1)); + + ABT evalNode3 = make<EvaluationNode>( + "pUnion1", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("ptest2")), + std::move(scanNode2)); + ABT evalNode4 = make<EvaluationNode>( + "pUnion2", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("ptest2")), + std::move(evalNode3)); + + ABT unionNode = + make<UnionNode>(ProjectionNameVector{"pUnion1", "pUnion2"}, makeSeq(evalNode2, evalNode4)); + + ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pUnion1"}}, + std::move(unionNode)); + + PrefixId prefixId; + OptPhaseManager phaseManager( + {OptPhaseManager::OptPhase::MemoSubstitutionPhase, + OptPhaseManager::OptPhase::MemoExplorationPhase, + OptPhaseManager::OptPhase::MemoImplementationPhase}, + prefixId, + {{{"test1", {{}, {}}}, {"test2", {{}, {}}}}}, + {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + + ABT optimized = std::move(rootNode); + ASSERT_TRUE(phaseManager.optimize(optimized)); + ASSERT_EQ(4, phaseManager.getMemo().getStats()._physPlanExplorationCount); + + ASSERT_EXPLAIN_V2( + "Root []\n" + "| | projections: \n" + "| | pUnion1\n" + "| RefBlock: \n" + "| Variable [pUnion1]\n" + "Union []\n" + "| | BindBlock:\n" + "| | [pUnion1]\n" + "| | Source []\n" + "| PhysicalScan [{'a': pUnion1}, test2]\n" + "| BindBlock:\n" + "| [pUnion1]\n" + "| Source []\n" + "PhysicalScan [{'a': pUnion1}, test1]\n" + " BindBlock:\n" + " [pUnion1]\n" + " Source []\n", + optimized); +} + +} // namespace +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/props.cpp b/src/mongo/db/query/optimizer/props.cpp new file mode 100644 index 00000000000..c03825c4101 --- /dev/null +++ b/src/mongo/db/query/optimizer/props.cpp @@ -0,0 +1,373 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/props.h" +#include "mongo/db/query/optimizer/utils/utils.h" +#include "mongo/util/assert_util.h" + +namespace mongo::optimizer::properties { + +CollationRequirement::CollationRequirement(ProjectionCollationSpec spec) : _spec(std::move(spec)) { + ProjectionNameSet projections; + for (const auto& entry : _spec) { + uassert(6624021, "Repeated projection name", projections.insert(entry.first).second); + } +} + +CollationRequirement CollationRequirement::Empty = CollationRequirement(); + +bool CollationRequirement::operator==(const CollationRequirement& other) const { + return _spec == other._spec; +} + +const ProjectionCollationSpec& CollationRequirement::getCollationSpec() const { + return _spec; +} + +ProjectionCollationSpec& CollationRequirement::getCollationSpec() { + return _spec; +} + +bool CollationRequirement::hasClusteredOp() const { + for (const auto& [projName, op] : _spec) { + if (op == CollationOp::Clustered) { + return true; + } + } + return false; +} + +ProjectionNameSet CollationRequirement::getAffectedProjectionNames() const { + ProjectionNameSet result; + for (const auto& entry : _spec) { + result.insert(entry.first); + } + return result; +} + +LimitSkipRequirement::LimitSkipRequirement(const int64_t limit, const int64_t skip) + : _limit((limit < 0) ? kMaxVal : limit), _skip(skip) {} + +bool LimitSkipRequirement::operator==(const LimitSkipRequirement& other) const { + return _skip == other._skip && _limit == other._limit; +} + +int64_t LimitSkipRequirement::getLimit() const { + return _limit; +} + +int64_t LimitSkipRequirement::getSkip() const { + return _skip; +} + +int64_t LimitSkipRequirement::getAbsoluteLimit() const { + return hasLimit() ? (_skip + _limit) : kMaxVal; +} + +ProjectionNameSet LimitSkipRequirement::getAffectedProjectionNames() const { + return {}; +} + +bool LimitSkipRequirement::hasLimit() const { + return _limit != kMaxVal; +} + +ProjectionRequirement::ProjectionRequirement(ProjectionNameOrderPreservingSet projections) + : _projections(std::move(projections)) {} + +bool ProjectionRequirement::operator==(const ProjectionRequirement& other) const { + return _projections.isEqualIgnoreOrder(other.getProjections()); +} + +ProjectionNameSet ProjectionRequirement::getAffectedProjectionNames() const { + ProjectionNameSet result; + for (const ProjectionName& projection : _projections.getVector()) { + result.insert(projection); + } + return result; +} + +const ProjectionNameOrderPreservingSet& ProjectionRequirement::getProjections() const { + return _projections; +} + +ProjectionNameOrderPreservingSet& ProjectionRequirement::getProjections() { + return _projections; +} + +DistributionAndProjections::DistributionAndProjections(DistributionType type) + : DistributionAndProjections(type, {}) {} + +DistributionAndProjections::DistributionAndProjections(DistributionType type, + ProjectionNameVector projectionNames) + : _type(type), _projectionNames(std::move(projectionNames)) { + uassert(6624096, + "Must have projection names when distributed under hash or range partitioning", + (_type != DistributionType::HashPartitioning && + _type != DistributionType::RangePartitioning) || + !_projectionNames.empty()); +} + +bool DistributionAndProjections::operator==(const DistributionAndProjections& other) const { + return _type == other._type && _projectionNames == other._projectionNames; +} + +DistributionRequirement::DistributionRequirement( + DistributionAndProjections distributionAndProjections) + : _distributionAndProjections(std::move(distributionAndProjections)), + _disableExchanges(false) {} + +bool DistributionRequirement::operator==(const DistributionRequirement& other) const { + return _distributionAndProjections == other._distributionAndProjections && + _disableExchanges == other._disableExchanges; +} + +const DistributionAndProjections& DistributionRequirement::getDistributionAndProjections() const { + return _distributionAndProjections; +} + +DistributionAndProjections& DistributionRequirement::getDistributionAndProjections() { + return _distributionAndProjections; +} + +ProjectionNameSet DistributionRequirement::getAffectedProjectionNames() const { + ProjectionNameSet result; + for (const ProjectionName& projectionName : _distributionAndProjections._projectionNames) { + result.insert(projectionName); + } + return result; +} + +bool DistributionRequirement::getDisableExchanges() const { + return _disableExchanges; +} + +void DistributionRequirement::setDisableExchanges(const bool disableExchanges) { + _disableExchanges = disableExchanges; +} + +IndexingRequirement::IndexingRequirement() + : IndexingRequirement(IndexReqTarget::Complete, "", true /*dedupRID*/, {}) {} + +IndexingRequirement::IndexingRequirement(IndexReqTarget indexReqTarget, + bool needsRID, + bool dedupRID, + GroupIdType satisfiedPartialIndexesGroupId) + : _indexReqTarget(indexReqTarget), + _needsRID(needsRID), + _dedupRID(dedupRID), + _satisfiedPartialIndexesGroupId(std::move(satisfiedPartialIndexesGroupId)) { + uassert(6624097, + "Avoiding dedup is only allowed for Index target", + _dedupRID || _indexReqTarget == IndexReqTarget::Index); +} + +bool IndexingRequirement::operator==(const IndexingRequirement& other) const { + return _indexReqTarget == other._indexReqTarget && _needsRID == other._needsRID && + _dedupRID == other._dedupRID && + _satisfiedPartialIndexesGroupId == other._satisfiedPartialIndexesGroupId; +} + +ProjectionNameSet IndexingRequirement::getAffectedProjectionNames() const { + // Specifically not returning ridProjectionName (even if present). + return {}; +} + +IndexReqTarget IndexingRequirement::getIndexReqTarget() const { + return _indexReqTarget; +} + +bool IndexingRequirement::getNeedsRID() const { + return _needsRID; +} + +bool IndexingRequirement::getDedupRID() const { + return _dedupRID; +} + +void IndexingRequirement::setDedupRID(const bool value) { + _dedupRID = value; +} + +const GroupIdType IndexingRequirement::getSatisfiedPartialIndexesGroupId() const { + return _satisfiedPartialIndexesGroupId; +} + +RepetitionEstimate::RepetitionEstimate(const CEType estimate) : _estimate(estimate) {} + +bool RepetitionEstimate::operator==(const RepetitionEstimate& other) const { + return _estimate == other._estimate; +} + +ProjectionNameSet RepetitionEstimate::getAffectedProjectionNames() const { + return {}; +} + +CEType RepetitionEstimate::getEstimate() const { + return _estimate; +} + +LimitEstimate::LimitEstimate(const CEType estimate) : _estimate(estimate) {} + +bool LimitEstimate::operator==(const LimitEstimate& other) const { + return _estimate == other._estimate; +} + +ProjectionNameSet LimitEstimate::getAffectedProjectionNames() const { + return {}; +} + +bool LimitEstimate::hasLimit() const { + return _estimate >= 0.0; +} + +CEType LimitEstimate::getEstimate() const { + return _estimate; +} + +ProjectionAvailability::ProjectionAvailability(ProjectionNameSet projections) + : _projections(std::move(projections)) {} + +bool ProjectionAvailability::operator==(const ProjectionAvailability& other) const { + return _projections == other._projections; +} + +const ProjectionNameSet& ProjectionAvailability::getProjections() const { + return _projections; +} + +CardinalityEstimate::CardinalityEstimate(const CEType estimate) + : _estimate(estimate), _partialSchemaKeyCEMap() {} + +bool CardinalityEstimate::operator==(const CardinalityEstimate& other) const { + return _estimate == other._estimate && _partialSchemaKeyCEMap == other._partialSchemaKeyCEMap; +} + +CEType CardinalityEstimate::getEstimate() const { + return _estimate; +} + +CEType& CardinalityEstimate::getEstimate() { + return _estimate; +} + +const PartialSchemaKeyCE& CardinalityEstimate::getPartialSchemaKeyCEMap() const { + return _partialSchemaKeyCEMap; +} + +PartialSchemaKeyCE& CardinalityEstimate::getPartialSchemaKeyCEMap() { + return _partialSchemaKeyCEMap; +} + +IndexingAvailability::IndexingAvailability(GroupIdType scanGroupId, + ProjectionName scanProjection, + std::string scanDefName, + const bool possiblyEqPredsOnly, + opt::unordered_set<std::string> satisfiedPartialIndexes) + : _scanGroupId(scanGroupId), + _scanProjection(std::move(scanProjection)), + _scanDefName(std::move(scanDefName)), + _possiblyEqPredsOnly(possiblyEqPredsOnly), + _satisfiedPartialIndexes(std::move(satisfiedPartialIndexes)) {} + +bool IndexingAvailability::operator==(const IndexingAvailability& other) const { + return _scanGroupId == other._scanGroupId && _scanProjection == other._scanProjection && + _scanDefName == other._scanDefName && _possiblyEqPredsOnly == other._possiblyEqPredsOnly && + _satisfiedPartialIndexes == other._satisfiedPartialIndexes; +} + +GroupIdType IndexingAvailability::getScanGroupId() const { + return _scanGroupId; +} + +const ProjectionName& IndexingAvailability::getScanProjection() const { + return _scanProjection; +} + +const std::string& IndexingAvailability::getScanDefName() const { + return _scanDefName; +} + +const opt::unordered_set<std::string>& IndexingAvailability::getSatisfiedPartialIndexes() const { + return _satisfiedPartialIndexes; +} + +opt::unordered_set<std::string>& IndexingAvailability::getSatisfiedPartialIndexes() { + return _satisfiedPartialIndexes; +} + +bool IndexingAvailability::getPossiblyEqPredsOnly() const { + return _possiblyEqPredsOnly; +} + +void IndexingAvailability::setPossiblyEqPredsOnly(const bool value) { + _possiblyEqPredsOnly = value; +} + +CollectionAvailability::CollectionAvailability(opt::unordered_set<std::string> scanDefSet) + : _scanDefSet(std::move(scanDefSet)) {} + +bool CollectionAvailability::operator==(const CollectionAvailability& other) const { + return _scanDefSet == other._scanDefSet; +} + +const opt::unordered_set<std::string>& CollectionAvailability::getScanDefSet() const { + return _scanDefSet; +} + +opt::unordered_set<std::string>& CollectionAvailability::getScanDefSet() { + return _scanDefSet; +} + +size_t DistributionHash::operator()( + const DistributionAndProjections& distributionAndProjections) const { + size_t result = 0; + updateHash(result, std::hash<DistributionType>()(distributionAndProjections._type)); + for (const ProjectionName& projectionName : distributionAndProjections._projectionNames) { + updateHash(result, std::hash<ProjectionName>()(projectionName)); + } + return result; +} + +DistributionAvailability::DistributionAvailability(DistributionSet distributionSet) + : _distributionSet(std::move(distributionSet)) {} + +bool DistributionAvailability::operator==(const DistributionAvailability& other) const { + return _distributionSet == other._distributionSet; +} + +const DistributionSet& DistributionAvailability::getDistributionSet() const { + return _distributionSet; +} + +DistributionSet& DistributionAvailability::getDistributionSet() { + return _distributionSet; +} + +} // namespace mongo::optimizer::properties diff --git a/src/mongo/db/query/optimizer/props.h b/src/mongo/db/query/optimizer/props.h new file mode 100644 index 00000000000..63c49f51774 --- /dev/null +++ b/src/mongo/db/query/optimizer/props.h @@ -0,0 +1,481 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <map> +#include <string> +#include <vector> + +#include "mongo/db/query/optimizer/algebra/operator.h" +#include "mongo/db/query/optimizer/algebra/polyvalue.h" +#include "mongo/db/query/optimizer/defs.h" +#include "mongo/db/query/optimizer/metadata.h" +#include "mongo/util/assert_util.h" + +namespace mongo::optimizer::properties { + +/** + * Tag for logical property types. + */ +class LogicalPropertyTag {}; + +/** + * Tag for physical property types. + */ +class PhysPropertyTag {}; + +/** + * Logical properties. + */ +class CardinalityEstimate; + +class ProjectionAvailability; +class IndexingAvailability; +class CollectionAvailability; +class DistributionAvailability; + +/** + * Physical properties. + */ +class CollationRequirement; +class LimitSkipRequirement; +class ProjectionRequirement; +class DistributionRequirement; +class IndexingRequirement; +class RepetitionEstimate; +class LimitEstimate; + +using LogicalProperty = algebra::PolyValue<CardinalityEstimate, + ProjectionAvailability, + IndexingAvailability, + CollectionAvailability, + DistributionAvailability>; + +using PhysProperty = algebra::PolyValue<CollationRequirement, + LimitSkipRequirement, + ProjectionRequirement, + DistributionRequirement, + IndexingRequirement, + RepetitionEstimate, + LimitEstimate>; + +using LogicalProps = opt::unordered_map<LogicalProperty::key_type, LogicalProperty>; +using PhysProps = opt::unordered_map<PhysProperty::key_type, PhysProperty>; + +template <typename T, typename... Args> +inline auto makeProperty(Args&&... args) { + if constexpr (std::is_base_of_v<LogicalPropertyTag, T>) { + return LogicalProperty::make<T>(std::forward<Args>(args)...); + } else if constexpr (std::is_base_of_v<PhysPropertyTag, T>) { + return PhysProperty::make<T>(std::forward<Args>(args)...); + } else { + static_assert("Unknown property type"); + } +} + +template <class P> +static constexpr auto getPropertyKey() { + if constexpr (std::is_base_of_v<LogicalPropertyTag, P>) { + return LogicalProperty::template tagOf<P>(); + } else if constexpr (std::is_base_of_v<PhysPropertyTag, P>) { + return PhysProperty::template tagOf<P>(); + } else { + static_assert("Unknown property type"); + } +} + +template <class P, class C> +bool hasProperty(const C& props) { + return props.find(getPropertyKey<P>()) != props.cend(); +} + +template <class P, class C> +P& getProperty(C& props) { + if (!hasProperty<P>(props)) { + uasserted(6624022, "Property type does not exist."); + } + return *props.at(getPropertyKey<P>()).template cast<P>(); +} + +template <class P, class C> +const P& getPropertyConst(const C& props) { + if (!hasProperty<P>(props)) { + uasserted(6624023, "Property type does not exist."); + } + return *props.at(getPropertyKey<P>()).template cast<P>(); +} + +template <class P, class C> +void removeProperty(C& props) { + props.erase(getPropertyKey<P>()); +} + +template <class P, class C> +bool setProperty(C& props, P property) { + return props.emplace(getPropertyKey<P>(), makeProperty<P>(std::move(property))).second; +} + +template <class P, class C> +void setPropertyOverwrite(C& props, P property) { + props.insert_or_assign(getPropertyKey<P>(), makeProperty<P>(std::move(property))); +} + +template <class C, typename... Args> +inline auto makeProps(Args&&... args) { + C props; + (setProperty(props, args), ...); + return props; +} + +template <typename... Args> +inline auto makeLogicalProps(Args&&... args) { + return makeProps<LogicalProps>(std::forward<Args>(args)...); +} + +template <typename... Args> +inline auto makePhysProps(Args&&... args) { + return makeProps<PhysProps>(std::forward<Args>(args)...); +} + +/** + * A physical property which specifies how the collection (or intermediate result) is required to be + * collated (sorted). + */ +class CollationRequirement final : public PhysPropertyTag { +public: + static CollationRequirement Empty; + + CollationRequirement() = default; + CollationRequirement(ProjectionCollationSpec spec); + + bool operator==(const CollationRequirement& other) const; + + const ProjectionCollationSpec& getCollationSpec() const; + ProjectionCollationSpec& getCollationSpec(); + + bool hasClusteredOp() const; + + ProjectionNameSet getAffectedProjectionNames() const; + +private: + ProjectionCollationSpec _spec; +}; + +/** + * A physical property which specifies what portion of the result in terms of window defined by the + * limit and skip is to be returned. + */ +class LimitSkipRequirement final : public PhysPropertyTag { +public: + static constexpr int64_t kMaxVal = std::numeric_limits<int64_t>::max(); + + LimitSkipRequirement(int64_t limit, int64_t skip); + + bool operator==(const LimitSkipRequirement& other) const; + + bool hasLimit() const; + + int64_t getLimit() const; + int64_t getSkip() const; + int64_t getAbsoluteLimit() const; + + ProjectionNameSet getAffectedProjectionNames() const; + +private: + // Max number of documents to return. Maximum integer value means unlimited. + int64_t _limit; + // Documents to skip before start returning in result. + int64_t _skip; +}; + +/** + * A physical property which specifies required projections to be returned as part of the result. + */ +class ProjectionRequirement final : public PhysPropertyTag { +public: + ProjectionRequirement(ProjectionNameOrderPreservingSet projections); + + bool operator==(const ProjectionRequirement& other) const; + + const ProjectionNameOrderPreservingSet& getProjections() const; + ProjectionNameOrderPreservingSet& getProjections(); + + ProjectionNameSet getAffectedProjectionNames() const; + +private: + ProjectionNameOrderPreservingSet _projections; +}; + +struct DistributionAndProjections { + DistributionAndProjections(DistributionType type); + DistributionAndProjections(DistributionType type, ProjectionNameVector projectionNames); + + bool operator==(const DistributionAndProjections& other) const; + + const DistributionType _type; + + /** + * Defined for hash and range-based partitioning. + */ + ProjectionNameVector _projectionNames; +}; + +/** + * A physical property which specifies how the result is to be distributed (or partitioned) amongst + * the computing partitions/nodes. + */ +class DistributionRequirement final : public PhysPropertyTag { +public: + DistributionRequirement(DistributionAndProjections distributionAndProjections); + + bool operator==(const DistributionRequirement& other) const; + + const DistributionAndProjections& getDistributionAndProjections() const; + DistributionAndProjections& getDistributionAndProjections(); + + ProjectionNameSet getAffectedProjectionNames() const; + + bool getDisableExchanges() const; + void setDisableExchanges(bool disableExchanges); + +private: + DistributionAndProjections _distributionAndProjections; + + // Heuristic used to disable exchanges right after Filter, Eval, and local GroupBy nodes. + bool _disableExchanges; +}; + +/** + * A physical property which describes if we intend to satisfy sargable predicates using an index. + * With indexing requirement "Complete", we are requiring a regular physical + * scan (both rid and row). With "Seek" (where we must have a non-empty RID projection name), we are + * targeting a physical Seek. With "Index" (with or without RID projection name), we + * are targeting a physical IndexScan. If in this case we have set RID projection, then we have + * either gone for a Seek, or we have performed intersection. With empty RID we are targeting a + * covered index scan. + */ +class IndexingRequirement final : public PhysPropertyTag { +public: + IndexingRequirement(); + IndexingRequirement(IndexReqTarget indexReqTarget, + bool needsRID, + bool dedupRIDs, + GroupIdType satisfiedPartialIndexesGroupId); + + bool operator==(const IndexingRequirement& other) const; + + ProjectionNameSet getAffectedProjectionNames() const; + + IndexReqTarget getIndexReqTarget() const; + bool getNeedsRID() const; + + bool getDedupRID() const; + void setDedupRID(bool value); + + const GroupIdType getSatisfiedPartialIndexesGroupId() const; + +private: + const IndexReqTarget _indexReqTarget; + + // Do we need to return an RID projection. + const bool _needsRID; + + // If target == Index, specifies if we need to dedup RIDs. + // Prior RID intersection removes the need to dedup. + bool _dedupRID; + + // Set of indexes with partial indexes whose partial filters are satisfied considering the whole + // query. Points to a group where can interrogate IndexingAvailability to find the satisfied + // indexes. + const GroupIdType _satisfiedPartialIndexesGroupId; +}; + +/** + * A physical property that specifies how many times do we expect to execute the current subtree. + * Typically generated via a NLJ where it is set on the inner side to reflect the outer side's + * cardinality. This property affects costing of stateful physical operators such as sort and hash + * groupby. + */ +class RepetitionEstimate final : public PhysPropertyTag { +public: + RepetitionEstimate(CEType estimate); + + bool operator==(const RepetitionEstimate& other) const; + + ProjectionNameSet getAffectedProjectionNames() const; + + CEType getEstimate() const; + +private: + CEType _estimate; +}; + +/** + * A physical property that specifies that the we will consider only some approximate number of + * documents. Typically generated after enforcing a LimitSkipRequirement. This property affects + * costing of stateful physical operators such as sort and hash groupby. + */ +class LimitEstimate final : public PhysPropertyTag { +public: + LimitEstimate(CEType estimate); + + bool operator==(const LimitEstimate& other) const; + + ProjectionNameSet getAffectedProjectionNames() const; + + bool hasLimit() const; + CEType getEstimate() const; + +private: + CEType _estimate; +}; + +/** + * A logical property which specifies available projections for a given ABT tree. + */ +class ProjectionAvailability final : public LogicalPropertyTag { +public: + ProjectionAvailability(ProjectionNameSet projections); + + bool operator==(const ProjectionAvailability& other) const; + + const ProjectionNameSet& getProjections() const; + +private: + ProjectionNameSet _projections; +}; + +/** + * A logical property which provides an estimated row count for a given ABT tree. + */ +class CardinalityEstimate final : public LogicalPropertyTag { +public: + CardinalityEstimate(CEType estimate); + + bool operator==(const CardinalityEstimate& other) const; + + CEType getEstimate() const; + CEType& getEstimate(); + + const PartialSchemaKeyCE& getPartialSchemaKeyCEMap() const; + PartialSchemaKeyCE& getPartialSchemaKeyCEMap(); + +private: + CEType _estimate; + + // Used for SargableNodes. Provide additional per partial schema key CE. + PartialSchemaKeyCE _partialSchemaKeyCEMap; +}; + +/** + * A logical property which specifies availability to index predicates in the ABT subtree and + * contains the scan projection. The projection and definition name are here for convenience: it can + * be retrieved using the scan group from the memo. + */ +class IndexingAvailability final : public LogicalPropertyTag { +public: + IndexingAvailability(GroupIdType scanGroupId, + ProjectionName scanProjection, + std::string scanDefName, + bool possiblyEqPredsOnly, + opt::unordered_set<std::string> satisfiedPartialIndexes); + + bool operator==(const IndexingAvailability& other) const; + + GroupIdType getScanGroupId() const; + const ProjectionName& getScanProjection() const; + const std::string& getScanDefName() const; + + const opt::unordered_set<std::string>& getSatisfiedPartialIndexes() const; + opt::unordered_set<std::string>& getSatisfiedPartialIndexes(); + + bool getPossiblyEqPredsOnly() const; + void setPossiblyEqPredsOnly(bool value); + +private: + const GroupIdType _scanGroupId; + const ProjectionName _scanProjection; + const std::string _scanDefName; + + // Specifies if all predicates in the current group and child group are "possibly" equalities. + // This is determined based on SargableNode exclusively containing equality intervals. + // The "possibly" part is due to 'Get "a" Id' being equivalent 'Get "a" Traverse Id' with a + // multi-key index. + bool _possiblyEqPredsOnly; + + // Set of indexes with partial indexes whose partial filters are satisfied for the current + // group. + opt::unordered_set<std::string> _satisfiedPartialIndexes; +}; + + +/** + * Logical property which specifies which collections (scanDefs) are available for a particular + * group. For example if the group contains a join of two tables, we would have (at least) two + * collections in the set. + */ +class CollectionAvailability final : public LogicalPropertyTag { +public: + CollectionAvailability(opt::unordered_set<std::string> scanDefSet); + + bool operator==(const CollectionAvailability& other) const; + + const opt::unordered_set<std::string>& getScanDefSet() const; + opt::unordered_set<std::string>& getScanDefSet(); + +private: + opt::unordered_set<std::string> _scanDefSet; +}; + +struct DistributionHash { + size_t operator()(const DistributionAndProjections& distributionAndProjections) const; +}; + +using DistributionSet = opt::unordered_set<DistributionAndProjections, DistributionHash>; + +/** + * Logical property which specifies promising projections and distributions to attempt to enforce + * during physical optimization. For example, a group containing a GroupByNode would add hash + * partitioning on the group-by projections. + */ +class DistributionAvailability final : public LogicalPropertyTag { +public: + DistributionAvailability(DistributionSet distributionSet); + + bool operator==(const DistributionAvailability& other) const; + + const DistributionSet& getDistributionSet() const; + DistributionSet& getDistributionSet(); + +private: + DistributionSet _distributionSet; +}; + +} // namespace mongo::optimizer::properties diff --git a/src/mongo/db/query/optimizer/reference_tracker.cpp b/src/mongo/db/query/optimizer/reference_tracker.cpp new file mode 100644 index 00000000000..c6afaba7024 --- /dev/null +++ b/src/mongo/db/query/optimizer/reference_tracker.cpp @@ -0,0 +1,689 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <iostream> + +#include "mongo/db/query/optimizer/cascades/memo.h" +#include "mongo/db/query/optimizer/reference_tracker.h" + +namespace mongo::optimizer { + +struct CollectedInfo { + using VarRefsMap = opt::unordered_map<std::string, opt::unordered_map<const Variable*, bool>>; + + /** + * All resolved variables so far + */ + opt::unordered_map<const Variable*, Definition> useMap; + + /** + * Definitions available for use in ancestor nodes (projections) + */ + DefinitionsMap defs; + + /** + * Free variables (i.e. so far not resolved) + */ + opt::unordered_map<std::string, std::vector<const Variable*>> freeVars; + + /** + * Per relational node projections. + */ + opt::unordered_map<const Node*, DefinitionsMap> nodeDefs; + + /** + * Support for tracking of the last local variable reference. + */ + VarRefsMap varLastRefs; + opt::unordered_set<const Variable*> lastRefs; + + /** + * This is a destructive merge, the 'other' will be siphoned out. + */ + void merge(CollectedInfo&& other) { + // Incoming (other) info has some definitions. So let's try to resolve our free variables. + if (!other.defs.empty() && !freeVars.empty()) { + for (auto&& [name, def] : other.defs) { + resolveFreeVars(name, def); + } + } + + // We have some definitions so let try to resolve other's free variables. + if (!defs.empty() && !other.freeVars.empty()) { + for (auto&& [name, def] : defs) { + other.resolveFreeVars(name, def); + } + } + + useMap.merge(other.useMap); + + // It should be impossible to have duplicate Variable pointer so everything should be + // copied. + uassert(6624024, "use map is not empty", other.useMap.empty()); + + defs.merge(other.defs); + + // Projection names are globally unique so everything should be copied. + uassert(6624025, "duplicate projections", other.defs.empty()); + + for (auto&& [name, vars] : other.freeVars) { + auto& v = freeVars[name]; + v.insert(v.end(), vars.begin(), vars.end()); + } + other.freeVars.clear(); + + nodeDefs.merge(other.nodeDefs); + + // It should be impossible to have duplicate Node pointer so everything should be + // copied. + uassert(6624026, "duplicate nodes", other.nodeDefs.empty()); + + // Merge last references. + mergeLastRefs(std::move(other.varLastRefs)); + lastRefs.merge(other.lastRefs); + uassert(6624027, "duplicate lastRefs", other.lastRefs.empty()); + } + + /** + * Merges variable references from 'other' and adjust last references as needed. + */ + void mergeLastRefs(VarRefsMap&& other) { + mergeLastRefsImpl(std::move(other), false, true); + } + + /** + * Merges variable references from 'other' but keeps the last references from 'this'; i.e. it + * resets the 'other' side. + */ + void mergeKeepLastRefs(VarRefsMap&& other) { + mergeLastRefsImpl(std::move(other), true, false); + } + + /** + * Merges variable references from 'other' and keeps the last references from both sides. + */ + void unionLastRefs(VarRefsMap&& other) { + mergeLastRefsImpl(std::move(other), false, false); + } + + void mergeLastRefsImpl(VarRefsMap&& other, bool resetOther, bool resetBoth) { + for (auto otherIt = other.begin(), end = other.end(); otherIt != end;) { + if (auto localIt = varLastRefs.find(otherIt->first); localIt != varLastRefs.end()) { + // This variable is referenced in both sets and we resetOther when adjust it + // accordingly. + if (resetOther) { + for (auto& [k, isLastRef] : otherIt->second) { + isLastRef = false; + } + } + + // Merge the maps. + localIt->second.merge(otherIt->second); + + // This variable is referenced in both sets so it will be marked as NOT last if + // resetBoth is set. + if (resetBoth) { + for (auto& [k, isLastRef] : localIt->second) { + isLastRef = false; + } + } + other.erase(otherIt++); + } else { + ++otherIt; + } + } + varLastRefs.merge(other); + uassert(6624098, "varLastRefs must be empty", other.empty()); + } + + /** + * Records collected last variable references for a specific variable. + */ + void finalizeLastRefs(const std::string& name) { + if (auto it = varLastRefs.find(name); it != varLastRefs.end()) { + for (auto& [var, isLastRef] : it->second) { + if (isLastRef) { + lastRefs.emplace(var); + } + } + + // After the finalization the map is not needed anymore. + varLastRefs.erase(it); + } + } + + /** + * This is a destructive merge, the 'others' will be siphoned out. + */ + void merge(std::vector<CollectedInfo>&& others) { + for (auto& other : others) { + merge(std::move(other)); + } + } + + /** + * A special merge asserting that the 'other' has no defined projections. Expressions do not + * project anything, only Nodes do. + * + * We still have to track free variables though. + */ + void mergeNoDefs(CollectedInfo&& other) { + other.assertEmptyDefs(); + merge(std::move(other)); + } + + static ProjectionNameSet getProjections(const DefinitionsMap& defs) { + ProjectionNameSet result; + + for (auto&& [k, v] : defs) { + result.emplace(k); + } + return result; + } + + ProjectionNameSet getProjections() const { + return getProjections(defs); + } + + void resolveFreeVars(const ProjectionName& name, const Definition& def) { + if (auto it = freeVars.find(name); it != freeVars.end()) { + for (const auto var : it->second) { + useMap.emplace(var, def); + } + freeVars.erase(it); + } + } + + void assertEmptyDefs() { + uassert(6624028, "Definitions must be empty", defs.empty()); + } +}; + +/** + * Collect all Variables into a set. + */ +class VariableCollector { +public: + template <typename T, typename... Ts> + void transport(const T& /*op*/, Ts&&... /*ts*/) {} + + void transport(const Variable& op) { + _result._variables.emplace(&op); + } + + void transport(const LambdaAbstraction& op, const ABT& /*bind*/) { + _result._definedVars.insert(op.varName()); + } + + void transport(const Let& op, const ABT& /*bind*/, const ABT& /*expr*/) { + _result._definedVars.insert(op.varName()); + } + + static VariableCollectorResult collect(const ABT& n) { + VariableCollector collector; + collector.collectInternal(n); + return std::move(collector._result); + } + +private: + void collectInternal(const ABT& n) { + algebra::transport<false>(n, *this); + } + + VariableCollectorResult _result; +}; + +struct Collector { + explicit Collector(const cascades::Memo* memo) : _memo(memo) {} + + template <typename T, typename... Ts> + CollectedInfo transport(const ABT&, const T& op, Ts&&... ts) { + CollectedInfo result{}; + (result.merge(std::forward<Ts>(ts)), ...); + + if constexpr (std::is_base_of_v<Node, T>) { + result.nodeDefs[&op] = result.defs; + } + + return result; + } + + CollectedInfo transport(const ABT& n, const Variable& variable) { + CollectedInfo result{}; + + // Every variable starts as a free variable until it is resolved. + result.freeVars[variable.name()].push_back(&variable); + + // Similarly, every variable starts as the last referencee until proven otherwise. + result.varLastRefs[variable.name()].emplace(&variable, true); + + return result; + } + + CollectedInfo transport(const ABT& n, + const Let& let, + CollectedInfo bindResult, + CollectedInfo inResult) { + CollectedInfo result{}; + + inResult.mergeKeepLastRefs(std::move(bindResult.varLastRefs)); + inResult.finalizeLastRefs(let.varName()); + + result.merge(std::move(bindResult)); + + // Local variables are not part of projections (i.e. we do not track them in defs) so + // resolve any free variables manually. + inResult.resolveFreeVars(let.varName(), Definition{n.ref(), let.bind().ref()}); + result.merge(std::move(inResult)); + + return result; + } + + CollectedInfo transport(const ABT& n, const LambdaAbstraction& lam, CollectedInfo inResult) { + CollectedInfo result{}; + + inResult.finalizeLastRefs(lam.varName()); + // Local variables are not part of projections (i.e. we do not track them in defs) so + // resolve any free variables manually. + inResult.resolveFreeVars(lam.varName(), Definition{n.ref(), ABT::reference_type{}}); + result.merge(std::move(inResult)); + + return result; + } + + CollectedInfo transport(const ABT& n, + const If&, + CollectedInfo condResult, + CollectedInfo thenResult, + CollectedInfo elseResult) { + + CollectedInfo result{}; + + + result.unionLastRefs(std::move(thenResult.varLastRefs)); + result.unionLastRefs(std::move(elseResult.varLastRefs)); + result.mergeKeepLastRefs(std::move(condResult.varLastRefs)); + + result.merge(std::move(condResult)); + result.merge(std::move(thenResult)); + result.merge(std::move(elseResult)); + + return result; + } + + static CollectedInfo collectForScan(const ABT& n, + const Node& node, + const ExpressionBinder& binder, + CollectedInfo refs) { + CollectedInfo result{}; + + result.mergeNoDefs(std::move(refs)); + + for (size_t i = 0; i < binder.names().size(); i++) { + result.defs[binder.names()[i]] = Definition{n.ref(), binder.exprs()[i].ref()}; + } + result.nodeDefs[&node] = result.defs; + + return result; + } + + CollectedInfo transport(const ABT& n, const ScanNode& node, CollectedInfo /*bindResult*/) { + return collectForScan(n, node, node.binder(), {}); + } + + CollectedInfo transport(const ABT& n, const ValueScanNode& node, CollectedInfo /*bindResult*/) { + return collectForScan(n, node, node.binder(), {}); + } + + CollectedInfo transport(const ABT& n, + const PhysicalScanNode& node, + CollectedInfo /*bindResult*/) { + return collectForScan(n, node, node.binder(), {}); + } + + CollectedInfo transport(const ABT& n, const IndexScanNode& node, CollectedInfo /*bindResult*/) { + return collectForScan(n, node, node.binder(), {}); + } + + CollectedInfo transport(const ABT& n, + const SeekNode& node, + CollectedInfo /*bindResult*/, + CollectedInfo refResult) { + return collectForScan(n, node, node.binder(), std::move(refResult)); + } + + CollectedInfo transport(const ABT& n, + const MemoLogicalDelegatorNode& memoLogicalDelegatorNode) { + CollectedInfo result{}; + + uassert(6624029, "Uninitialized memo", _memo); + + auto& group = _memo->getGroup(memoLogicalDelegatorNode.getGroupId()); + + auto& projectionNames = group.binder().names(); + auto& projections = group.binder().exprs(); + for (size_t i = 0; i < projectionNames.size(); i++) { + result.defs[projectionNames.at(i)] = Definition{n.ref(), projections[i].ref()}; + } + + result.nodeDefs[&memoLogicalDelegatorNode] = result.defs; + + return result; + } + + CollectedInfo transport(const ABT& n, + const EvaluationNode& evaluationNode, + CollectedInfo childResult, + CollectedInfo exprResult) { + CollectedInfo result{}; + + // Make the definition available upstream. + uassert(6624030, + str::stream() << "Cannot overwrite project " << evaluationNode.getProjectionName(), + childResult.defs.count(evaluationNode.getProjectionName()) == 0); + + result.merge(std::move(childResult)); + result.mergeNoDefs(std::move(exprResult)); + + result.defs[evaluationNode.getProjectionName()] = + Definition{n.ref(), evaluationNode.getProjection().ref()}; + + result.nodeDefs[&evaluationNode] = result.defs; + + return result; + } + + CollectedInfo transport(const ABT& n, + const SargableNode& node, + CollectedInfo childResult, + CollectedInfo bindResult, + CollectedInfo /*refResult*/) { + CollectedInfo result{}; + + result.merge(std::move(childResult)); + result.mergeNoDefs(std::move(bindResult)); + + const auto& projectionNames = node.binder().names(); + const auto& projections = node.binder().exprs(); + for (size_t i = 0; i < projectionNames.size(); i++) { + result.defs[projectionNames.at(i)] = Definition{n.ref(), projections[i].ref()}; + } + + result.nodeDefs[&node] = result.defs; + + return result; + } + + CollectedInfo transport(const ABT& n, + const RIDIntersectNode& node, + CollectedInfo leftChildResult, + CollectedInfo rightChildResult) { + CollectedInfo result{}; + + rightChildResult.defs.erase(node.getScanProjectionName()); + + result.merge(std::move(leftChildResult)); + result.merge(std::move(rightChildResult)); + + result.nodeDefs[&node] = result.defs; + + return result; + } + + CollectedInfo transport(const ABT& n, + const BinaryJoinNode& binaryJoinNode, + CollectedInfo leftChildResult, + CollectedInfo rightChildResult, + CollectedInfo filterResult) { + CollectedInfo result{}; + + { + const ProjectionNameSet& leftProjections = leftChildResult.getProjections(); + for (const ProjectionName& boundProjectionName : + binaryJoinNode.getCorrelatedProjectionNames()) { + uassert(6624099, + "Correlated projections must exist in left child.", + leftProjections.find(boundProjectionName) != leftProjections.cend()); + } + } + + result.merge(std::move(leftChildResult)); + result.merge(std::move(rightChildResult)); + result.mergeNoDefs(std::move(filterResult)); + + result.nodeDefs[&binaryJoinNode] = result.defs; + + return result; + } + + CollectedInfo transport(const ABT& n, + const UnionNode& unionNode, + std::vector<CollectedInfo> childResults, + CollectedInfo bindResult, + CollectedInfo refsResult) { + CollectedInfo result{}; + + const auto& names = unionNode.binder().names(); + + refsResult.assertEmptyDefs(); + + // Merge children but disregard any defined projections. + // Note that refsResult follows the structure as built by buildUnionReferences. + size_t counter = 0; + for (auto& u : childResults) { + // Manually copy and resolve references of specific child. + for (const auto& name : names) { + uassert(6624031, "Union projection does not exist", u.defs.count(name) != 0); + u.useMap.emplace(refsResult.freeVars[name][counter], u.defs[name]); + } + u.defs.clear(); + result.merge(std::move(u)); + ++counter; + } + + result.mergeNoDefs(std::move(bindResult)); + + // Propagate union projections. + const auto& defs = unionNode.binder().exprs(); + for (size_t idx = 0; idx < names.size(); ++idx) { + result.defs[names[idx]] = Definition{n.ref(), defs[idx].ref()}; + } + + result.nodeDefs[&unionNode] = result.defs; + + return result; + } + + CollectedInfo transport(const ABT& n, + const GroupByNode& groupNode, + CollectedInfo childResult, + CollectedInfo bindAggResult, + CollectedInfo refsAggResult, + CollectedInfo bindGbResult, + CollectedInfo refsGbResult) { + CollectedInfo result{}; + + // First resolve all variables from the inside point of view; i.e. agg expressions and group + // by expressions reference variables from the input child. + result.merge(std::move(refsAggResult)); + result.merge(std::move(refsGbResult)); + // Make a copy of 'childResult' as we need it later and 'merge' is destructive. + result.merge(CollectedInfo{childResult}); + + // GroupBy completely masks projected variables; i.e. outside expressions cannot reach + // inside the groupby. We will create a brand new set of projections from aggs and gbs here. + result.defs.clear(); + + const auto& aggs = groupNode.getAggregationProjectionNames(); + const auto& gbs = groupNode.getGroupByProjectionNames(); + for (size_t idx = 0; idx < aggs.size(); ++idx) { + uassert(6624032, + "Aggregation overwrites a child projection", + childResult.defs.count(aggs[idx]) == 0); + result.defs[aggs[idx]] = + Definition{n.ref(), groupNode.getAggregationProjections()[idx].ref()}; + } + + for (size_t idx = 0; idx < gbs.size(); ++idx) { + uassert(6624033, + "Group-by projection does not exist", + childResult.defs.count(gbs[idx]) != 0); + result.defs[gbs[idx]] = + Definition{n.ref(), groupNode.getGroupByProjections()[idx].ref()}; + } + + result.mergeNoDefs(std::move(bindAggResult)); + result.mergeNoDefs(std::move(bindGbResult)); + + result.nodeDefs[&groupNode] = result.defs; + + return result; + } + + CollectedInfo transport(const ABT& n, + const UnwindNode& unwindNode, + CollectedInfo childResult, + CollectedInfo bindResult, + CollectedInfo refsResult) { + CollectedInfo result{}; + + // First resolve all variables from the inside point of view. + result.merge(std::move(refsResult)); + result.merge(std::move(childResult)); + + const auto& name = unwindNode.getProjectionName(); + uassert(6624034, "Unwind projection does not exist", result.defs.count(name) != 0); + + // Redefine unwind projection. + result.defs[name] = Definition{n.ref(), unwindNode.getProjection().ref()}; + // Define unwind PID. + result.defs[unwindNode.getPIDProjectionName()] = + Definition{n.ref(), unwindNode.getPIDProjection().ref()}; + + result.mergeNoDefs(std::move(bindResult)); + + result.nodeDefs[&unwindNode] = result.defs; + + return result; + } + + CollectedInfo collect(const ABT& n) { + return algebra::transport<true>(n, *this); + } + +private: + const cascades::Memo* _memo; +}; + +VariableEnvironment VariableEnvironment::build(const ABT& root, const cascades::Memo* memo) { + + Collector c(memo); + auto info = std::make_unique<CollectedInfo>(c.collect(root)); + + // std::cout << "useMap size " << info.useMap.size() << "\n"; + // std::cout << "defs size " << info.defs.size() << "\n"; + // std::cout << "freeVars size " << info.freeVars.size() << "\n"; + + return VariableEnvironment{std::move(info), memo}; +} + +void VariableEnvironment::rebuild(const ABT& root) { + _info = std::make_unique<CollectedInfo>(Collector{_memo}.collect(root)); +} + +VariableEnvironment::VariableEnvironment(std::unique_ptr<CollectedInfo> info, + const cascades::Memo* memo) + : _info(std::move(info)), _memo(memo) {} + +VariableEnvironment::~VariableEnvironment() {} + +Definition VariableEnvironment::getDefinition(const Variable* var) const { + auto it = _info->useMap.find(var); + if (it == _info->useMap.end()) { + return Definition(); + } + + return it->second; +} + +const DefinitionsMap& VariableEnvironment::getDefinitions(const Node* node) const { + auto it = _info->nodeDefs.find(node); + uassert(6624035, "node does not exist", it != _info->nodeDefs.end()); + + return it->second; +} + +bool VariableEnvironment::hasDefinitions(const Node* node) const { + return _info->nodeDefs.find(node) != _info->nodeDefs.cend(); +} + +ProjectionNameSet VariableEnvironment::getProjections(const Node* node) const { + return CollectedInfo::getProjections(getDefinitions(node)); +} + +const DefinitionsMap& VariableEnvironment::getDefinitions(ABT::reference_type node) const { + uassert(6624036, "Invalid node type", node.is<Node>()); + return getDefinitions(node.cast<Node>()); +} + +bool VariableEnvironment::hasDefinitions(ABT::reference_type node) const { + uassert(6624037, "Invalid node type", node.is<Node>()); + return hasDefinitions(node.cast<Node>()); +} + +ProjectionNameSet VariableEnvironment::topLevelProjections() const { + return _info->getProjections(); +} + +bool VariableEnvironment::hasFreeVariables() const { + return !_info->freeVars.empty(); +} + +size_t VariableEnvironment::freeOccurences(const std::string& variable) const { + auto it = _info->freeVars.find(variable); + if (it == _info->freeVars.end()) { + return 0; + } + + return it->second.size(); +} + +bool VariableEnvironment::isLastRef(const Variable* var) const { + if (_info->lastRefs.count(var)) { + return true; + } + + return false; +} + +VariableCollectorResult VariableEnvironment::getVariables(const ABT& n) { + return VariableCollector::collect(n); +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/reference_tracker.h b/src/mongo/db/query/optimizer/reference_tracker.h new file mode 100644 index 00000000000..8d0e96e11aa --- /dev/null +++ b/src/mongo/db/query/optimizer/reference_tracker.h @@ -0,0 +1,113 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/node.h" + +namespace mongo::optimizer { +/** + * Every Variable ABT conceptually references to a point in the ABT tree. The pointed tree is the + * definition of the variable. + */ +struct Definition { + /** + * Pointer to ABT that defines the variable. It can be any Node (e.g. ScanNode, EvaluationNode, + * etc.) or Expr (e.g. let expression, lambda expression). + */ + ABT::reference_type definedBy; + + /** + * Pointer to actual definition of variable. + */ + ABT::reference_type definition; +}; + +namespace cascades { +class Memo; +} + +struct CollectedInfo; +using DefinitionsMap = opt::unordered_map<ProjectionName, Definition>; + +struct VariableCollectorResult { + opt::unordered_set<const Variable*> _variables; + // TODO: consider using a variable environment instance for this, but does not seem to be always + // viable, especially with rewrites. + opt::unordered_set<std::string> _definedVars; +}; + +class VariableEnvironment { + VariableEnvironment(std::unique_ptr<CollectedInfo> info, const cascades::Memo* memo); + +public: + /** + * Build the environment for the given ABT tree. The environment is valid as long as the tree + * does not change. More specifically, if a variable defining node is removed from the tree then + * the environment becomes stale and has to be rebuild. + */ + static VariableEnvironment build(const ABT& root, const cascades::Memo* memo = nullptr); + void rebuild(const ABT& root); + + ~VariableEnvironment(); + + /** + * + */ + Definition getDefinition(const Variable* var) const; + + /** + * We may revisit what we return from here. + */ + ProjectionNameSet topLevelProjections() const; + + const DefinitionsMap& getDefinitions(const Node* node) const; + + /** + * Per node projection names + */ + ProjectionNameSet getProjections(const Node* node) const; + bool hasDefinitions(const Node* node) const; + + const DefinitionsMap& getDefinitions(ABT::reference_type node) const; + bool hasDefinitions(ABT::reference_type node) const; + + bool hasFreeVariables() const; + size_t freeOccurences(const std::string& variable) const; + + bool isLastRef(const Variable* var) const; + + static VariableCollectorResult getVariables(const ABT& n); + +private: + std::unique_ptr<CollectedInfo> _info; + const cascades::Memo* _memo{nullptr}; +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/rewrites/const_eval.cpp b/src/mongo/db/query/optimizer/rewrites/const_eval.cpp new file mode 100644 index 00000000000..723bd29400d --- /dev/null +++ b/src/mongo/db/query/optimizer/rewrites/const_eval.cpp @@ -0,0 +1,565 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/rewrites/const_eval.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer { +bool ConstEval::optimize(ABT& n) { + invariant(_letRefs.empty()); + invariant(_projectRefs.empty()); + invariant(_singleRef.empty()); + invariant(_noRefProj.empty()); + invariant(!_inRefBlock); + invariant(_inCostlyCtx == 0); + invariant(_staleDefs.empty()); + invariant(_staleABTs.empty()); + invariant(_seenProjects.empty()); + invariant(_inlinedDefs.empty()); + + _changed = false; + + // We run the transport<true> that will pass the reference to ABT to specific transport + // functions. The reference serves as a conceptual 'this' pointer allowing the transport + // function to change the node itself. + algebra::transport<true>(n, *this); + + // Test if there are any projections with no references. If so remove them from the tree + removeUnusedEvalNodes(); + + invariant(_letRefs.empty()); + invariant(_projectRefs.empty()); + + while (_changed) { + _env.rebuild(n); + + if (_singleRef.empty() && _noRefProj.empty()) { + break; + } + _changed = false; + algebra::transport<true>(n, *this); + removeUnusedEvalNodes(); + } + + // TODO: should we be clearing here? + _singleRef.clear(); + + _staleDefs.clear(); + _staleABTs.clear(); + return _changed; +} + +void ConstEval::removeUnusedEvalNodes() { + for (auto&& [k, v] : _projectRefs) { + if (v.size() == 0) { + // Schedule node replacement as it has not references. + _noRefProj.emplace(k); + _changed = true; + } else if (v.size() == 1) { + // Do not inline nodes which can become Sargable. + // TODO: consider caching. + // TODO: consider deriving IndexingAvailability. + if (!_disableSargableInlining || + !convertExprToPartialSchemaReq(k->getProjection())._success) { + // Schedule node inlining as there is exactly one reference. + _singleRef.emplace(v.front()); + _changed = true; + } + } + } + + _projectRefs.clear(); + _seenProjects.clear(); + _inlinedDefs.clear(); +} + +void ConstEval::transport(ABT& n, const Variable& var) { + auto def = _env.getDefinition(&var); + + if (!def.definition.empty()) { + // See if we have already manipulated this definition and if so then use the newer version. + if (auto it = _staleDefs.find(def.definition); it != _staleDefs.end()) { + def.definition = it->second; + } + if (auto it = _staleDefs.find(def.definedBy); it != _staleDefs.end()) { + def.definedBy = it->second; + } + + if (auto constant = def.definition.cast<Constant>(); constant && !_inRefBlock) { + // If we find the definition and it is a simple constant then substitute the variable. + swapAndUpdate(n, def.definition); + } else if (auto variable = def.definition.cast<Variable>(); variable && !_inRefBlock) { + // This is a indirection to another variable. So we can skip, but first remember that we + // inlined this variable so that we won't try to replace it with a common expression and + // revert the inlining. + _inlinedDefs.emplace(def.definition); + swapAndUpdate(n, def.definition); + } else if (_singleRef.erase(&var)) { + // If this is the only reference to some expression then substitute the variable, but + // first remember that we inlined this expression so that we won't try to replace it + // with a common expression and revert the inlining. + _inlinedDefs.emplace(def.definition); + swapAndUpdate(n, def.definition); + } else if (auto let = def.definedBy.cast<Let>(); let) { + invariant(_letRefs.count(let)); + _letRefs[let].emplace_back(&var); + } else if (auto project = def.definedBy.cast<EvaluationNode>(); project) { + invariant(_projectRefs.count(project)); + _projectRefs[project].emplace_back(&var); + + // If we are in the ref block we do not want to inline even if there is only a single + // reference. Similarly, we do not want to inline any variable under traverse. + if (_inRefBlock || _inCostlyCtx > 0) { + _projectRefs[project].emplace_back(&var); + } + } + } +} + +void ConstEval::prepare(ABT&, const Let& let) { + _letRefs[&let] = {}; +} + +void ConstEval::transport(ABT& n, const Let& let, ABT& bind, ABT& in) { + auto& letRefs = _letRefs[&let]; + if (letRefs.size() == 0) { + // The bind expressions has not been referenced so it is dead code and the whole let + // expression can be removed; i.e. we implement a following rewrite: + // + // n == let var=<bind expr> in <in expr> + // + // v + // + // n == <in expr> + + // We don't want to make a copy of 'in' as it may be arbitrarily large. Also, we cannot + // move it out as it is part of the Let object and we do not want to invalidate any + // assumptions the Let may have about its structure. Hence we swap it for the "special" + // Blackhole object. The Blackhole does nothing, it just plugs the hole left in the 'in' + // place. + auto result = std::exchange(in, make<Blackhole>()); + + // Swap the current node (n) for the result. + swapAndUpdate(n, std::move(result)); + } else if (letRefs.size() == 1) { + // The bind expression has been referenced exactly once so schedule it for inlining. + _singleRef.emplace(letRefs.front()); + _changed = true; + } + _letRefs.erase(&let); +} + +void ConstEval::transport(ABT& n, const LambdaApplication& app, ABT& lam, ABT& arg) { + // If the 'lam' expression is LambdaAbstraction then we can do the inplace beta reduction. + // TODO - missing alpha conversion so for now assume globally unique names. + if (auto lambda = lam.cast<LambdaAbstraction>(); lambda) { + auto result = make<Let>(lambda->varName(), + std::exchange(arg, make<Blackhole>()), + std::exchange(lambda->getBody(), make<Blackhole>())); + + swapAndUpdate(n, std::move(result)); + } +} + +namespace fold_helpers { +using namespace sbe::value; + +template <class T> +sbe::value::Value constFoldNumberHelper(const sbe::value::TypeTags lhsTag, + const sbe::value::Value lhsValue, + const TypeTags rhsTag, + const sbe::value::Value rhsValue) { + const auto result = numericCast<T>(lhsTag, lhsValue) + numericCast<T>(rhsTag, rhsValue); + return bitcastFrom<T>(result); +} + +template <> +sbe::value::Value constFoldNumberHelper<Decimal128>(const TypeTags lhsTag, + const sbe::value::Value lhsValue, + const TypeTags rhsTag, + const sbe::value::Value rhsValue) { + const auto result = + numericCast<Decimal128>(lhsTag, lhsValue).add(numericCast<Decimal128>(rhsTag, rhsValue)); + return makeCopyDecimal(result).second; +} + +} // namespace fold_helpers + +// Specific transport for binary operation +// The const correctness is probably wrong (as const ABT& lhs, const ABT& rhs does not work for +// some reason but we can fix it later). +void ConstEval::transport(ABT& n, const BinaryOp& op, ABT& lhs, ABT& rhs) { + using namespace fold_helpers; + + switch (op.op()) { + case Operations::Add: { + // Let say we want to recognize ConstLhs + ConstRhs and replace it with the result of + // addition. + auto lhsConst = lhs.cast<Constant>(); + auto rhsConst = rhs.cast<Constant>(); + + if (lhsConst && rhsConst) { + auto [lhsTag, lhsValue] = lhsConst->get(); + auto [rhsTag, rhsValue] = rhsConst->get(); + + if (isNumber(lhsTag) && isNumber(rhsTag)) { + // So this is the addition operation and both arguments are number constants, + // hence we can compute the result. + + const TypeTags resultType = getWidestNumericalType(lhsTag, rhsTag); + sbe::value::Value resultValue; + + switch (resultType) { + case TypeTags::NumberInt32: { + resultValue = + constFoldNumberHelper<int32_t>(lhsTag, lhsValue, rhsTag, rhsValue); + break; + } + + case TypeTags::NumberInt64: { + resultValue = + constFoldNumberHelper<int64_t>(lhsTag, lhsValue, rhsTag, rhsValue); + break; + } + + case TypeTags::NumberDouble: { + resultValue = + constFoldNumberHelper<double>(lhsTag, lhsValue, rhsTag, rhsValue); + break; + } + + case TypeTags::NumberDecimal: { + resultValue = constFoldNumberHelper<Decimal128>( + lhsTag, lhsValue, rhsTag, rhsValue); + break; + } + + default: + MONGO_UNREACHABLE; + } + + // And this is the crucial step - we swap the current node (n) for the result. + swapAndUpdate(n, make<Constant>(resultType, resultValue)); + } + } + break; + } + + case Operations::Or: { + // Nothing and short-circuiting semantics of the 'or' operation in SBE allow us to + // interrogate 'lhs' only. + if (auto lhsConst = lhs.cast<Constant>(); lhsConst) { + auto [lhsTag, lhsValue] = lhsConst->get(); + if (lhsTag == sbe::value::TypeTags::Boolean && + !sbe::value::bitcastTo<bool>(lhsValue)) { + // false || rhs -> rhs + swapAndUpdate(n, std::exchange(rhs, make<Blackhole>())); + } else if (lhsTag == sbe::value::TypeTags::Boolean && + sbe::value::bitcastTo<bool>(lhsValue)) { + // true || rhs -> true + swapAndUpdate(n, Constant::boolean(true)); + } + } + break; + } + + case Operations::And: { + // Nothing and short-circuiting semantics of the 'and' operation in SBE allow us to + // interrogate 'lhs' only. + if (auto lhsConst = lhs.cast<Constant>(); lhsConst) { + auto [lhsTag, lhsValue] = lhsConst->get(); + if (lhsTag == sbe::value::TypeTags::Boolean && + !sbe::value::bitcastTo<bool>(lhsValue)) { + // false && rhs -> false + swapAndUpdate(n, Constant::boolean(false)); + } else if (lhsTag == sbe::value::TypeTags::Boolean && + sbe::value::bitcastTo<bool>(lhsValue)) { + // true && rhs -> rhs + swapAndUpdate(n, std::exchange(rhs, make<Blackhole>())); + } + } + break; + } + + case Operations::Eq: { + if (lhs == rhs) { + // If the subtrees are equal, we can conclude that their result is equal because we + // have only pure functions. + swapAndUpdate(n, Constant::boolean(true)); + } else if (lhs.is<Constant>() && rhs.is<Constant>()) { + // We have two constants which are not equal. + swapAndUpdate(n, Constant::boolean(false)); + } + break; + } + + case Operations::Lt: + case Operations::Lte: + case Operations::Gt: + case Operations::Gte: + case Operations::Cmp3w: { + const auto lhsConst = lhs.cast<Constant>(); + const auto rhsConst = rhs.cast<Constant>(); + + if (lhsConst) { + const auto [lhsTag, lhsVal] = lhsConst->get(); + + if (rhsConst) { + const auto [rhsTag, rhsVal] = rhsConst->get(); + + const auto [compareTag, compareVal] = + sbe::value::compareValue(lhsTag, lhsVal, rhsTag, rhsVal); + const auto cmpVal = sbe::value::bitcastTo<int32_t>(compareVal); + + switch (op.op()) { + case Operations::Lt: + swapAndUpdate(n, Constant::boolean(cmpVal < 0)); + break; + case Operations::Lte: + swapAndUpdate(n, Constant::boolean(cmpVal <= 0)); + break; + case Operations::Gt: + swapAndUpdate(n, Constant::boolean(cmpVal > 0)); + break; + case Operations::Gte: + swapAndUpdate(n, Constant::boolean(cmpVal >= 0)); + break; + case Operations::Cmp3w: + swapAndUpdate(n, Constant::int32(cmpVal)); + break; + + default: + MONGO_UNREACHABLE; + } + } else { + if (lhsTag == sbe::value::TypeTags::MinKey) { + switch (op.op()) { + case Operations::Lte: + swapAndUpdate(n, Constant::boolean(true)); + break; + case Operations::Gt: + swapAndUpdate(n, Constant::boolean(false)); + break; + + default: + break; + } + } else if (lhsTag == sbe::value::TypeTags::MaxKey) { + switch (op.op()) { + case Operations::Lt: + swapAndUpdate(n, Constant::boolean(false)); + break; + case Operations::Gte: + swapAndUpdate(n, Constant::boolean(true)); + break; + + default: + break; + } + } + } + } else if (rhsConst) { + const auto [rhsTag, rhsVal] = rhsConst->get(); + + if (rhsTag == sbe::value::TypeTags::MinKey) { + switch (op.op()) { + case Operations::Lt: + swapAndUpdate(n, Constant::boolean(false)); + break; + + case Operations::Gte: + swapAndUpdate(n, Constant::boolean(true)); + break; + + default: + break; + } + } else if (rhsTag == sbe::value::TypeTags::MaxKey) { + switch (op.op()) { + case Operations::Lte: + swapAndUpdate(n, Constant::boolean(true)); + break; + + case Operations::Gt: + swapAndUpdate(n, Constant::boolean(false)); + break; + + default: + break; + } + } + } + } + + default: + // Not implemented. + break; + } +} + +void ConstEval::transport(ABT& n, const FunctionCall& op, std::vector<ABT>& args) { + // We can simplify exists(constant) to true if the said constant is not Nothing. + if (op.name() == "exists" && args.size() == 1 && args[0].is<Constant>()) { + auto [tag, val] = args[0].cast<Constant>()->get(); + if (tag != sbe::value::TypeTags::Nothing) { + swapAndUpdate(n, Constant::boolean(true)); + } + } + + if (op.name() == "newArray") { + bool allConstants = true; + for (const ABT& arg : op.nodes()) { + if (!arg.is<Constant>()) { + allConstants = false; + break; + } + } + + if (allConstants) { + // All arguments are constants. Replace with an array constant. + + sbe::value::Array array; + for (const ABT& arg : op.nodes()) { + auto [tag, val] = arg.cast<Constant>()->get(); + // Copy the value before inserting into the array. + auto [tagCopy, valCopy] = sbe::value::copyValue(tag, val); + array.push_back(tagCopy, valCopy); + } + + auto [tag, val] = sbe::value::makeCopyArray(array); + swapAndUpdate(n, make<Constant>(tag, val)); + } + } +} + +void ConstEval::transport(ABT& n, const If& op, ABT& cond, ABT& thenBranch, ABT& elseBranch) { + // If the condition is a boolean constant we can simplify. + if (auto condConst = cond.cast<Constant>(); condConst) { + auto [condTag, condValue] = condConst->get(); + if (condTag == sbe::value::TypeTags::Boolean && sbe::value::bitcastTo<bool>(condValue)) { + // if true -> thenBranch + swapAndUpdate(n, std::exchange(thenBranch, make<Blackhole>())); + } else if (condTag == sbe::value::TypeTags::Boolean && + !sbe::value::bitcastTo<bool>(condValue)) { + // if false -> elseBranch + swapAndUpdate(n, std::exchange(elseBranch, make<Blackhole>())); + } + } +} + +void ConstEval::prepare(ABT&, const PathTraverse&) { + ++_inCostlyCtx; +} + +void ConstEval::transport(ABT&, const PathTraverse&, ABT&) { + --_inCostlyCtx; +} + +void ConstEval::prepare(ABT&, const LambdaAbstraction&) { + ++_inCostlyCtx; +} + +void ConstEval::transport(ABT&, const LambdaAbstraction&, ABT&) { + --_inCostlyCtx; +} + +void ConstEval::transport(ABT& n, const EvaluationNode& op, ABT& child, ABT& expr) { + if (_noRefProj.erase(&op)) { + // The evaluation node is unused so replace it with its own child. + if (_erasedProjNames != nullptr) { + _erasedProjNames->insert(op.getProjectionName()); + } + + // First, pull out the child and put in a blackhole. + auto result = std::exchange(child, make<Blackhole>()); + + // Replace the evaluation node itself with the extracted child. + swapAndUpdate(n, std::move(result)); + } else { + if (!_projectRefs.count(&op)) { + _projectRefs[&op] = {}; + } + + // Do not consider simple constants or variable references for elimination. + if (!op.getProjection().is<Constant>() && !op.getProjection().is<Variable>()) { + // Try to find a projection with the same expression as the current 'op' node and + // substitute it with a variable pointing to that source projection. + if (auto source = _seenProjects.find(&op); source != _seenProjects.end() && + // Make sure that the matched projection is visible to the current 'op'. + _env.getProjections(&op).count((*source)->getProjectionName()) && + // If we already inlined the matched projection, we don't want to use it as a source + // for common expression as it will negate the inlining. + !_inlinedDefs.count((*source)->getProjection().ref())) { + invariant(_projectRefs.count(*source)); + + auto var = make<Variable>((*source)->getProjectionName()); + // Source now will have an extra reference from the newly constructed projection. + _projectRefs[*source].emplace_back(var.cast<Variable>()); + + auto newN = make<EvaluationNode>(op.getProjectionName(), + std::move(var), + std::exchange(child, make<Blackhole>())); + // The new projection node should inherit the references from the old node. + _projectRefs[newN.cast<EvaluationNode>()] = std::move(_projectRefs[&op]); + _projectRefs.erase(&op); + + swapAndUpdate(n, std::move(newN)); + } else { + _seenProjects.emplace(&op); + } + } + } +} + +void ConstEval::prepare(ABT&, const References& refs) { + // It is structurally impossible to nest References nodes. + invariant(!_inRefBlock); + _inRefBlock = true; +} +void ConstEval::transport(ABT& n, const References& op, std::vector<ABT>&) { + invariant(_inRefBlock); + _inRefBlock = false; +} + +void ConstEval::swapAndUpdate(ABT& n, ABT newN) { + // Record the mapping from the old to the new. + invariant(_staleDefs.count(n.ref()) == 0); + invariant(_staleDefs.count(newN.ref()) == 0); + + _staleDefs[n.ref()] = newN.ref(); + + // Do the swap. + std::swap(n, newN); + + // newN now contains the old ABT + _staleABTs.emplace_back(std::move(newN)); + + _changed = true; +} +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/rewrites/const_eval.h b/src/mongo/db/query/optimizer/rewrites/const_eval.h new file mode 100644 index 00000000000..42a1fdde77a --- /dev/null +++ b/src/mongo/db/query/optimizer/rewrites/const_eval.h @@ -0,0 +1,121 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/reference_tracker.h" +#include "mongo/db/query/optimizer/utils/abt_hash.h" + +namespace mongo::optimizer { +/** + * This is an example rewriter that does constant evaluation in-place. + */ +class ConstEval { +public: + ConstEval(VariableEnvironment& env, + const bool disableSargableInlining = false, + ProjectionNameSet* erasedProjNames = nullptr) + : _disableSargableInlining(disableSargableInlining), + _env(env), + _erasedProjNames(erasedProjNames) {} + + // The default noop transport. Note the first ABT& parameter. + template <typename T, typename... Ts> + void transport(ABT&, const T&, Ts&&...) {} + + void transport(ABT& n, const Variable& var); + + void prepare(ABT&, const Let& let); + void transport(ABT& n, const Let& let, ABT&, ABT& in); + void transport(ABT& n, const LambdaApplication& app, ABT& lam, ABT& arg); + void prepare(ABT&, const LambdaAbstraction&); + void transport(ABT&, const LambdaAbstraction&, ABT&); + + // Specific transport for binary operation + // The const correctness is probably wrong (as const ABT& lhs, const ABT& rhs does not work for + // some reason but we can fix it later). + void transport(ABT& n, const BinaryOp& op, ABT& lhs, ABT& rhs); + void transport(ABT& n, const FunctionCall& op, std::vector<ABT>& args); + void transport(ABT& n, const If& op, ABT& cond, ABT& thenBranch, ABT& elseBranch); + void transport(ABT& n, const EvaluationNode& op, ABT& child, ABT& expr); + + void prepare(ABT&, const PathTraverse&); + void transport(ABT&, const PathTraverse&, ABT&); + + void prepare(ABT&, const References& refs); + void transport(ABT& n, const References& op, std::vector<ABT>&); + + // The tree is passed in as NON-const reference as we will be updating it. + bool optimize(ABT& n); + +private: + struct EvalNodeHash { + size_t operator()(const EvaluationNode* node) const { + return ABTHashGenerator::generate(node->getProjection()); + } + }; + + struct EvalNodeCompare { + size_t operator()(const EvaluationNode* lhs, const EvaluationNode* rhs) const { + return lhs->getProjection() == rhs->getProjection(); + } + }; + + struct RefHash { + size_t operator()(const ABT::reference_type& nodeRef) const { + return nodeRef.hash(); + } + }; + + void swapAndUpdate(ABT& n, ABT newN); + void removeUnusedEvalNodes(); + + // Controls if we can inline certain EvaluationNodes. + const bool _disableSargableInlining; + + VariableEnvironment& _env; + opt::unordered_set<const Variable*> _singleRef; + opt::unordered_set<const EvaluationNode*> _noRefProj; + opt::unordered_map<const Let*, std::vector<const Variable*>> _letRefs; + opt::unordered_map<const EvaluationNode*, std::vector<const Variable*>> _projectRefs; + opt::unordered_set<const EvaluationNode*, EvalNodeHash, EvalNodeCompare> _seenProjects; + opt::unordered_set<ABT::reference_type, RefHash> _inlinedDefs; + opt::unordered_map<ABT::reference_type, ABT::reference_type, RefHash> _staleDefs; + // We collect old ABTs in order to avoid the ABA problem. + std::vector<ABT> _staleABTs; + + bool _inRefBlock{false}; + size_t _inCostlyCtx{0}; + bool _changed{false}; + + // Optionally collect projection names from erased Eval nodes. + ProjectionNameSet* _erasedProjNames; +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/rewrites/path.cpp b/src/mongo/db/query/optimizer/rewrites/path.cpp new file mode 100644 index 00000000000..1378dadb6ae --- /dev/null +++ b/src/mongo/db/query/optimizer/rewrites/path.cpp @@ -0,0 +1,346 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/rewrites/path.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer { +ABT::reference_type PathFusion::follow(ABT::reference_type n) { + if (auto var = n.cast<Variable>(); var) { + auto def = _env.getDefinition(var); + if (!def.definition.empty() && !def.definition.is<Source>()) { + return follow(def.definition); + } + } + + return n; +} + +bool PathFusion::fuse(ABT& lhs, const ABT& rhs) { + if (auto rhsComposeM = rhs.cast<PathComposeM>(); rhsComposeM != nullptr) { + if (_info[rhsComposeM->getPath1().cast<PathSyntaxSort>()]._isConst) { + if (fuse(lhs, rhsComposeM->getPath1())) { + return true; + } + if (_info[rhsComposeM->getPath2().cast<PathSyntaxSort>()]._isConst && + fuse(lhs, rhsComposeM->getPath2())) { + return true; + } + } + } + + if (auto lhsGet = lhs.cast<PathGet>(); lhsGet != nullptr) { + if (auto rhsField = rhs.cast<PathField>(); + rhsField != nullptr && lhsGet->name() == rhsField->name()) { + return fuse(lhsGet->getPath(), rhsField->getPath()); + } + + if (auto rhsKeep = rhs.cast<PathKeep>(); rhsKeep != nullptr) { + if (rhsKeep->getNames().count(lhsGet->name()) > 0) { + return true; + } + } + } + + if (auto lhsTraverse = lhs.cast<PathTraverse>(); lhsTraverse != nullptr) { + if (auto rhsTraverse = rhs.cast<PathTraverse>(); rhsTraverse != nullptr) { + return fuse(lhsTraverse->getPath(), rhsTraverse->getPath()); + } + + auto rhsType = _info[rhs.cast<PathSyntaxSort>()]._type; + if (rhsType != Type::unknown && rhsType != Type::array) { + auto result = std::exchange(lhsTraverse->getPath(), make<Blackhole>()); + std::swap(lhs, result); + return fuse(lhs, rhs); + } + } + + if (lhs.is<PathIdentity>()) { + lhs = rhs; + return true; + } + + if (rhs.is<PathLambda>()) { + lhs = make<PathComposeM>(std::move(lhs), rhs); + return true; + } + + if (auto rhsConst = rhs.cast<PathConstant>(); rhsConst != nullptr) { + if (auto lhsCmp = lhs.cast<PathCompare>(); lhsCmp != nullptr) { + auto result = make<PathConstant>(make<BinaryOp>( + lhsCmp->op(), + make<BinaryOp>(Operations::Cmp3w, rhsConst->getConstant(), lhsCmp->getVal()), + Constant::int64(0))); + + std::swap(lhs, result); + return true; + } + + lhs = make<PathComposeM>(rhs, std::move(lhs)); + return true; + } + + return false; +} + +bool PathFusion::optimize(ABT& root) { + _changed = false; + algebra::transport<true>(root, *this); + + if (_changed) { + _env.rebuild(root); + } + + return _changed; +} + +void PathFusion::transport(ABT& n, const PathConstant& path, ABT& c) { + CollectedInfo ci; + if (auto exprC = path.getConstant().cast<Constant>(); exprC) { + // Let's see what we can determine from the constant expression + auto [tag, val] = exprC->get(); + if (sbe::value::isObject(tag)) { + ci._type = Type::object; + } else if (sbe::value::isArray(tag)) { + ci._type = Type::array; + } else if (tag == sbe::value::TypeTags::Boolean) { + ci._type = Type::boolean; + } else if (tag == sbe::value::TypeTags::Nothing) { + ci._type = Type::nothing; + } else { + ci._type = Type::any; + } + } + + ci._isConst = true; + _info[&path] = ci; +} + +void PathFusion::transport(ABT& n, const PathCompare& path, ABT& c) { + CollectedInfo ci; + + // TODO - follow up on Nothing and 3 value logic. Assume plain boolean for now. + ci._type = Type::boolean; + ci._isConst = _info[path.getVal().cast<PathSyntaxSort>()]._isConst; + _info[&path] = ci; +} + +void PathFusion::transport(ABT& n, const PathGet& get, ABT& path) { + // Get "a" Const <c> -> Const <c> + if (auto constPath = path.cast<PathConstant>(); constPath) { + // Pull out the constant path + auto result = std::exchange(path, make<Blackhole>()); + + // And swap it for the current node + std::swap(n, result); + + _changed = true; + } else { + auto it = _info.find(path.cast<PathSyntaxSort>()); + uassert(6624129, "expected to find path", it != _info.end()); + + // Simply move the info from the child. + _info[&get] = it->second; + } +} + +void PathFusion::transport(ABT& n, const PathField& field, ABT& path) { + auto it = _info.find(path.cast<PathSyntaxSort>()); + uassert(6624130, "expected to find path", it != _info.end()); + + CollectedInfo ci; + if (it->second._type == Type::unknown) { + // We don't know anything about the child. + ci._type = Type::unknown; + } else if (it->second._type == Type::nothing) { + // We are setting a field to nothing (aka drop) hence we do not know what the result could + // be (i.e. it all depends on the input). + ci._type = Type::unknown; + } else { + // The child produces bona fide value hence this will become an object. + ci._type = Type::object; + } + + ci._isConst = it->second._isConst; + _info[&field] = ci; +} + +void PathFusion::transport(ABT& n, const PathTraverse& path, ABT& inner) { + // Traverse is completely dependent on its input and we cannot say anything about it. + CollectedInfo ci; + ci._type = Type::unknown; + ci._isConst = false; + _info[&path] = ci; +} + +void PathFusion::transport(ABT& n, const PathComposeM& path, ABT& p1, ABT& p2) { + if (auto p1Const = p1.cast<PathConstant>(); p1Const != nullptr) { + switch (_kindCtx.back()) { + case Kind::filter: + n = make<PathConstant>(make<EvalFilter>(p2, p1Const->getConstant())); + _changed = true; + return; + + case Kind::project: + n = make<PathConstant>(make<EvalPath>(p2, p1Const->getConstant())); + _changed = true; + return; + + default: + MONGO_UNREACHABLE; + } + } + + if (auto p1Get = p1.cast<PathGet>(); p1Get != nullptr && p1Get->getPath().is<PathIdentity>()) { + // TODO: handle chain of Gets. + n = make<PathGet>(p1Get->name(), std::move(p2)); + _changed = true; + return; + } + + if (p1.is<PathIdentity>()) { + // Id * p2 -> p2 + auto result = std::exchange(p2, make<Blackhole>()); + std::swap(n, result); + _changed = true; + return; + } else if (p2.is<PathIdentity>()) { + // p1 * Id -> p1 + auto result = std::exchange(p1, make<Blackhole>()); + std::swap(n, result); + _changed = true; + return; + } else if (_redundant.erase(p1.cast<PathSyntaxSort>())) { + auto result = std::exchange(p2, make<Blackhole>()); + std::swap(n, result); + _changed = true; + return; + } else if (_redundant.erase(p2.cast<PathSyntaxSort>())) { + auto result = std::exchange(p1, make<Blackhole>()); + std::swap(n, result); + _changed = true; + return; + } + + auto p1InfoIt = _info.find(p1.cast<PathSyntaxSort>()); + auto p2InfoIt = _info.find(p2.cast<PathSyntaxSort>()); + + uassert(6624131, "info must be defined", p1InfoIt != _info.end() && p2InfoIt != _info.end()); + + if (p1.is<PathDefault>() && p2InfoIt->second.isNotNothing()) { + // Default * Const e -> e (provided we can prove e is not Nothing and we can do that only + // when e is Constant expression) + auto result = std::exchange(p2, make<Blackhole>()); + std::swap(n, result); + _changed = true; + } else if (p2.is<PathDefault>() && p1InfoIt->second.isNotNothing()) { + // Const e * Default -> e (provided we can prove e is not Nothing and we can do that only + // when e is Constant expression) + auto result = std::exchange(p1, make<Blackhole>()); + std::swap(n, result); + _changed = true; + } else if (p2InfoIt->second._type == Type::object) { + auto left = collectComposed(p1); + for (auto l : left) { + if (l.is<PathObj>()) { + _redundant.emplace(l.cast<PathSyntaxSort>()); + _changed = true; + } + } + _info[&path] = p2InfoIt->second; + } else { + _info[&path] = p2InfoIt->second; + } +} + +void PathFusion::transport(ABT& n, const EvalPath& eval, ABT& path, ABT& input) { + auto realInput = follow(input); + // If we are evaluating const path then we can simply replace the whole expression with the + // result. + if (auto constPath = path.cast<PathConstant>(); constPath) { + // Pull out the constant out of the path + auto result = std::exchange(constPath->getConstant(), make<Blackhole>()); + + // And swap it for the current node + std::swap(n, result); + + _changed = true; + } else if (auto evalInput = realInput.cast<EvalPath>(); evalInput) { + // An input to 'this' EvalPath expression is another EvalPath so we may try to fuse the + // paths. + if (fuse(n.cast<EvalPath>()->getPath(), evalInput->getPath())) { + // We have fused paths so replace the input (by making a copy). + input = evalInput->getInput(); + + _changed = true; + } else if (auto evalImmediateInput = input.cast<EvalPath>(); + evalImmediateInput != nullptr) { + // Compose immediate EvalPath input. + n = make<EvalPath>( + make<PathComposeM>(std::move(evalImmediateInput->getPath()), std::move(path)), + std::move(evalImmediateInput->getInput())); + + _changed = true; + } + } + _kindCtx.pop_back(); +} + +void PathFusion::transport(ABT& n, const EvalFilter& eval, ABT& path, ABT& input) { + auto realInput = follow(input); + // If we are evaluating const path then we can simply replace the whole expression with the + // result. + if (auto constPath = path.cast<PathConstant>(); constPath) { + // Pull out the constant out of the path + auto result = std::exchange(constPath->getConstant(), make<Blackhole>()); + + // And swap it for the current node + std::swap(n, result); + + _changed = true; + } else if (auto evalImmediateInput = input.cast<EvalFilter>(); evalImmediateInput != nullptr) { + // Compose immediate EvalFilter input. + n = make<EvalFilter>( + make<PathComposeM>(std::move(evalImmediateInput->getPath()), std::move(path)), + std::move(evalImmediateInput->getInput())); + + _changed = true; + } else if (auto evalInput = realInput.cast<EvalPath>(); evalInput) { + // An input to 'this' EvalFilter expression is another EvalPath so we may try to fuse the + // paths. + if (fuse(n.cast<EvalFilter>()->getPath(), evalInput->getPath())) { + // We have fused paths so replace the input (by making a copy). + input = evalInput->getInput(); + + _changed = true; + } + } + _kindCtx.pop_back(); +} +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/rewrites/path.h b/src/mongo/db/query/optimizer/rewrites/path.h new file mode 100644 index 00000000000..6e53fe5a7b5 --- /dev/null +++ b/src/mongo/db/query/optimizer/rewrites/path.h @@ -0,0 +1,94 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/reference_tracker.h" + +namespace mongo::optimizer { +class PathFusion { + enum class Type { unknown, nothing, object, array, boolean, any }; + enum class Kind { project, filter }; + + struct CollectedInfo { + bool isNotNothing() const { + return _type != Type::unknown && _type != Type::nothing; + } + + Type _type{Type::unknown}; + + // Is the result of the path independent of its input (e.g. can be if the paths terminates + // with PathConst, but not necessarily with PathIdentity). + bool _isConst{false}; + }; + +public: + PathFusion(VariableEnvironment& env) : _env(env) {} + + template <typename T, typename... Ts> + void transport(ABT&, const T& op, Ts&&...) { + if constexpr (std::is_base_of_v<PathSyntaxSort, T>) { + _info[&op] = CollectedInfo{}; + } + } + + void transport(ABT& n, const PathConstant&, ABT& c); + void transport(ABT& n, const PathCompare&, ABT& c); + void transport(ABT& n, const PathGet&, ABT& path); + void transport(ABT& n, const PathField&, ABT& path); + void transport(ABT& n, const PathTraverse&, ABT& inner); + void transport(ABT& n, const PathComposeM&, ABT& p1, ABT& p2); + + void prepare(ABT& n, const EvalPath& eval) { + _kindCtx.push_back(Kind::project); + } + void transport(ABT& n, const EvalPath& eval, ABT& path, ABT& input); + void prepare(ABT& n, const EvalFilter& eval) { + _kindCtx.push_back(Kind::filter); + } + void transport(ABT& n, const EvalFilter& eval, ABT& path, ABT& input); + + bool optimize(ABT& root); + +private: + ABT::reference_type follow(ABT::reference_type n); + ABT::reference_type follow(const ABT& n) { + return follow(n.ref()); + } + bool fuse(ABT& lhs, const ABT& rhs); + + VariableEnvironment& _env; + opt::unordered_map<const PathSyntaxSort*, CollectedInfo> _info; + opt::unordered_set<const PathSyntaxSort*> _redundant; + + // A stack of context (either project or filter path) + std::vector<Kind> _kindCtx; + bool _changed{false}; +}; +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/rewrites/path_lower.cpp b/src/mongo/db/query/optimizer/rewrites/path_lower.cpp new file mode 100644 index 00000000000..c5eb294c6c8 --- /dev/null +++ b/src/mongo/db/query/optimizer/rewrites/path_lower.cpp @@ -0,0 +1,400 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/rewrites/path_lower.h" + + +namespace mongo::optimizer { +bool EvalPathLowering::optimize(ABT& n) { + _changed = false; + + algebra::transport<true>(n, *this); + + if (_changed) { + _env.rebuild(n); + } + + return _changed; +} + +void EvalPathLowering::transport(ABT& n, const PathConstant&, ABT& c) { + n = make<LambdaAbstraction>(_prefixId.getNextId("_"), std::exchange(c, make<Blackhole>())); + _changed = true; +} + +void EvalPathLowering::transport(ABT& n, const PathIdentity&) { + const std::string& name = _prefixId.getNextId("x"); + + n = make<LambdaAbstraction>(name, make<Variable>(name)); + _changed = true; +} + +void EvalPathLowering::transport(ABT& n, const PathLambda&, ABT& lam) { + n = std::exchange(lam, make<Blackhole>()); + _changed = true; +} + +void EvalPathLowering::transport(ABT& n, const PathDefault&, ABT& c) { + // if (exists(x), x, c) + const std::string& name = _prefixId.getNextId("valDefault"); + + n = make<LambdaAbstraction>( + name, + make<If>(make<FunctionCall>("exists", makeSeq(make<Variable>(name))), + make<Variable>(name), + std::exchange(c, make<Blackhole>()))); + _changed = true; +} + +void EvalPathLowering::transport(ABT& n, const PathCompare&, ABT& c) { + uasserted(6624132, "cannot lower compare in projection"); +} + +void EvalPathLowering::transport(ABT& n, const PathGet& p, ABT& inner) { + const std::string& name = _prefixId.getNextId("inputGet"); + + n = make<LambdaAbstraction>( + name, + make<LambdaApplication>( + std::exchange(inner, make<Blackhole>()), + make<FunctionCall>("getField", + makeSeq(make<Variable>(name), Constant::str(p.name()))))); + _changed = true; +} + +void EvalPathLowering::transport(ABT& n, const PathDrop& drop) { + // if (isObject(x), dropFields(x,...) , x) + // Alternatively, we can implement a special builtin function that does the comparison and drop. + const std::string& name = _prefixId.getNextId("valDrop"); + + std::vector<ABT> params; + params.emplace_back(make<Variable>(name)); + for (const auto& fieldName : drop.getNames()) { + params.emplace_back(Constant::str(fieldName)); + } + + n = make<LambdaAbstraction>( + name, + make<If>(make<FunctionCall>("isObject", makeSeq(make<Variable>(name))), + make<FunctionCall>("dropFields", std::move(params)), + make<Variable>(name))); + _changed = true; +} + +void EvalPathLowering::transport(ABT& n, const PathKeep& keep) { + // if (isObject(x), keepFields(x,...) , x) + // Alternatively, we can implement a special builtin function that does the comparison and drop. + const std::string& name = _prefixId.getNextId("valKeep"); + + std::vector<ABT> params; + params.emplace_back(make<Variable>(name)); + for (const auto& fieldName : keep.getNames()) { + params.emplace_back(Constant::str(fieldName)); + } + + n = make<LambdaAbstraction>( + name, + make<If>(make<FunctionCall>("isObject", makeSeq(make<Variable>(name))), + make<FunctionCall>("keepFields", std::move(params)), + make<Variable>(name))); + _changed = true; +} + +void EvalPathLowering::transport(ABT& n, const PathObj&) { + // if (isObject(x), x, Nothing) + const std::string& name = _prefixId.getNextId("valObj"); + + n = make<LambdaAbstraction>( + name, + make<If>(make<FunctionCall>("isObject", makeSeq(make<Variable>(name))), + make<Variable>(name), + Constant::nothing())); + _changed = true; +} + +void EvalPathLowering::transport(ABT& n, const PathArr&) { + // if (isArray(x), x, Nothing) + const std::string& name = _prefixId.getNextId("valArr"); + + n = make<LambdaAbstraction>( + name, + make<If>(make<FunctionCall>("isArray", makeSeq(make<Variable>(name))), + make<Variable>(name), + Constant::nothing())); + + _changed = true; +} + +void EvalPathLowering::transport(ABT& n, const PathTraverse&, ABT& inner) { + const std::string& name = _prefixId.getNextId("valTraverse"); + + n = make<LambdaAbstraction>( + name, + make<FunctionCall>("traverseP", + makeSeq(make<Variable>(name), std::exchange(inner, make<Blackhole>())))); + _changed = true; +} + +void EvalPathLowering::transport(ABT& n, const PathField& p, ABT& inner) { + const std::string& name = _prefixId.getNextId("inputField"); + const std::string& val = _prefixId.getNextId("valField"); + + n = make<LambdaAbstraction>( + name, + make<Let>( + val, + make<LambdaApplication>( + std::exchange(inner, make<Blackhole>()), + make<FunctionCall>("getField", + makeSeq(make<Variable>(name), Constant::str(p.name())))), + make<If>( + make<BinaryOp>(Operations::Or, + make<FunctionCall>("exists", makeSeq(make<Variable>(val))), + make<FunctionCall>("isObject", makeSeq(make<Variable>(name)))), + make<FunctionCall>( + "setField", + makeSeq(make<Variable>(name), Constant::str(p.name()), make<Variable>(val))), + make<Variable>(name)))); + + _changed = true; +} + +void EvalPathLowering::transport(ABT& n, const PathComposeM&, ABT& p1, ABT& p2) { + // p1 * p2 -> (p2 (p1 input)) + const std::string& name = _prefixId.getNextId("inputComposeM"); + + n = make<LambdaAbstraction>( + name, + make<LambdaApplication>( + std::exchange(p2, make<Blackhole>()), + make<LambdaApplication>(std::exchange(p1, make<Blackhole>()), make<Variable>(name)))); + + _changed = true; +} + +void EvalPathLowering::transport(ABT& n, const PathComposeA&, ABT& p1, ABT& p2) { + uasserted(6624133, "cannot lower additive composite in projection"); +} + +void EvalPathLowering::transport(ABT& n, const EvalPath&, ABT& path, ABT& input) { + // In order to completely dissolve EvalPath the incoming path must be lowered to an expression + // (lambda). + uassert(6624134, "Incomplete evalpath lowering", path.is<LambdaAbstraction>()); + + n = make<LambdaApplication>(std::exchange(path, make<Blackhole>()), + std::exchange(input, make<Blackhole>())); + + _changed = true; +} + +bool EvalFilterLowering::optimize(ABT& n) { + _changed = false; + + algebra::transport<true>(n, *this); + + if (_changed) { + _env.rebuild(n); + } + + return _changed; +} + +void EvalFilterLowering::transport(ABT& n, const PathConstant&, ABT& c) { + n = make<LambdaAbstraction>(_prefixId.getNextId("_"), std::exchange(c, make<Blackhole>())); + _changed = true; +} + +void EvalFilterLowering::transport(ABT& n, const PathIdentity&) { + n = make<LambdaAbstraction>(_prefixId.getNextId("_"), Constant::boolean(true)); + _changed = true; + + // TODO - do we need an identity element for the additive composition? i.e. false constant + // Or should Identity be left undefined and removed by the PathFuse? +} + +void EvalFilterLowering::transport(ABT& n, const PathLambda&, ABT& lam) { + n = std::exchange(lam, make<Blackhole>()); + _changed = true; +} + +void EvalFilterLowering::transport(ABT& n, const PathDefault&, ABT& c) { + uasserted(6624135, "cannot lower default in filter"); +} + +void EvalFilterLowering::transport(ABT& n, const PathCompare& cmp, ABT& c) { + const std::string& name = _prefixId.getNextId("valCmp"); + + if (cmp.op() == Operations::Eq) { + n = make<LambdaAbstraction>( + name, + make<BinaryOp>(cmp.op(), make<Variable>(name), std::exchange(c, make<Blackhole>()))); + } else { + n = make<LambdaAbstraction>( + name, + make<BinaryOp>(cmp.op(), + make<BinaryOp>(Operations::Cmp3w, + make<Variable>(name), + std::exchange(c, make<Blackhole>())), + Constant::int64(0))); + } + + _changed = true; +} + +void EvalFilterLowering::transport(ABT& n, const PathGet& p, ABT& inner) { + const std::string& name = _prefixId.getNextId("inputGet"); + + int idx; + bool isNumber = NumberParser{}(p.name(), &idx).isOK(); + n = make<LambdaAbstraction>( + name, + make<LambdaApplication>( + std::exchange(inner, make<Blackhole>()), + make<FunctionCall>(isNumber ? "getFieldOrElement" : "getField", + makeSeq(make<Variable>(name), Constant::str(p.name()))))); + _changed = true; +} + +void EvalFilterLowering::transport(ABT& n, const PathDrop& drop) { + uasserted(6624136, "cannot lower drop in filter"); +} + +void EvalFilterLowering::transport(ABT& n, const PathKeep& keep) { + uasserted(6624137, "cannot lower keep in filter"); +} + +void EvalFilterLowering::transport(ABT& n, const PathObj&) { + const std::string& name = _prefixId.getNextId("valObj"); + n = make<LambdaAbstraction>(name, + make<FunctionCall>("isObject", makeSeq(make<Variable>(name)))); + _changed = true; +} + +void EvalFilterLowering::transport(ABT& n, const PathArr&) { + const std::string& name = _prefixId.getNextId("valArr"); + n = make<LambdaAbstraction>(name, make<FunctionCall>("isArray", makeSeq(make<Variable>(name)))); + _changed = true; +} + +void EvalFilterLowering::prepare(ABT& n, const PathTraverse& t) { + int idx; + // This is a bad hack that detect if a child is number path element + if (auto child = t.getPath().cast<PathGet>(); + child && NumberParser{}(child->name(), &idx).isOK()) { + _traverseStack.emplace_back(n.ref()); + } +} + +void EvalFilterLowering::transport(ABT& n, const PathTraverse&, ABT& inner) { + const std::string& name = _prefixId.getNextId("valTraverse"); + + ABT numberPath = Constant::boolean(false); + if (!_traverseStack.empty() && _traverseStack.back() == n.ref()) { + numberPath = Constant::boolean(true); + _traverseStack.pop_back(); + } + n = make<LambdaAbstraction>(name, + make<FunctionCall>("traverseF", + makeSeq(make<Variable>(name), + std::exchange(inner, make<Blackhole>()), + std::move(numberPath)))); + + _changed = true; +} + +void EvalFilterLowering::transport(ABT& n, const PathField& p, ABT& inner) { + uasserted(6624140, "cannot lower arr in filter"); +} + +void EvalFilterLowering::transport(ABT& n, const PathComposeM&, ABT& p1, ABT& p2) { + const std::string& name = _prefixId.getNextId("inputComposeM"); + + n = make<LambdaAbstraction>( + name, + make<If>( + make<FunctionCall>("fillEmpty", + makeSeq(make<LambdaApplication>(std::exchange(p1, make<Blackhole>()), + make<Variable>(name)), + Constant::boolean(false))), + make<LambdaApplication>(std::exchange(p2, make<Blackhole>()), make<Variable>(name)), + Constant::boolean(false))); + + _changed = true; +} + +void EvalFilterLowering::transport(ABT& n, const PathComposeA&, ABT& p1, ABT& p2) { + const std::string& name = _prefixId.getNextId("inputComposeA"); + + n = make<LambdaAbstraction>( + name, + make<If>( + make<FunctionCall>("fillEmpty", + makeSeq(make<LambdaApplication>(std::exchange(p1, make<Blackhole>()), + make<Variable>(name)), + Constant::boolean(false))), + Constant::boolean(true), + make<LambdaApplication>(std::exchange(p2, make<Blackhole>()), make<Variable>(name)))); + + _changed = true; +} + +void EvalFilterLowering::transport(ABT& n, const EvalFilter&, ABT& path, ABT& input) { + // In order to completely dissolve EvalFilter the incoming path must be lowered to an expression + // (lambda). + uassert(6624141, "Incomplete evalfilter lowering", path.is<LambdaAbstraction>()); + + n = make<LambdaApplication>(std::exchange(path, make<Blackhole>()), + std::exchange(input, make<Blackhole>())); + + _changed = true; +} + +void PathLowering::transport(ABT& n, const EvalPath&, ABT&, ABT&) { + _changed = _changed || _project.optimize(n); +} + +void PathLowering::transport(ABT& n, const EvalFilter&, ABT&, ABT&) { + _changed = _changed || _filter.optimize(n); +} + +bool PathLowering::optimize(ABT& n) { + _changed = false; + + algebra::transport<true>(n, *this); + + // TODO investigate why we crash when this is removed. It should not be needed here. + if (_changed) { + _env.rebuild(n); + } + + return _changed; +} + + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/rewrites/path_lower.h b/src/mongo/db/query/optimizer/rewrites/path_lower.h new file mode 100644 index 00000000000..59df0f50050 --- /dev/null +++ b/src/mongo/db/query/optimizer/rewrites/path_lower.h @@ -0,0 +1,158 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/reference_tracker.h" +#include "mongo/db/query/optimizer/utils/utils.h" + + +namespace mongo::optimizer { +/** + * This class lowers projection paths (aka EvalPath) to simple expressions. + */ +class EvalPathLowering { +public: + EvalPathLowering(PrefixId& prefixId, VariableEnvironment& env) + : _prefixId(prefixId), _env(env) {} + + // The default noop transport. + template <typename T, typename... Ts> + void transport(ABT&, const T&, Ts&&...) { + static_assert(!std::is_base_of_v<PathSyntaxSort, T>, + "Path elements must define their transport"); + } + + void transport(ABT& n, const PathConstant&, ABT& c); + void transport(ABT& n, const PathIdentity&); + void transport(ABT& n, const PathLambda&, ABT& lam); + void transport(ABT& n, const PathDefault&, ABT& c); + void transport(ABT& n, const PathCompare&, ABT& c); + void transport(ABT& n, const PathDrop&); + void transport(ABT& n, const PathKeep&); + void transport(ABT& n, const PathObj&); + void transport(ABT& n, const PathArr&); + + void transport(ABT& n, const PathTraverse&, ABT& inner); + + void transport(ABT& n, const PathGet&, ABT& inner); + void transport(ABT& n, const PathField&, ABT& inner); + + void transport(ABT& n, const PathComposeM&, ABT& p1, ABT& p2); + void transport(ABT& n, const PathComposeA&, ABT& p1, ABT& p2); + + void transport(ABT& n, const EvalPath&, ABT& path, ABT& input); + + // The tree is passed in as NON-const reference as we will be updating it. + // Returns true if the tree changed. + bool optimize(ABT& n); + +private: + // We don't own these. + PrefixId& _prefixId; + VariableEnvironment& _env; + + bool _changed{false}; +}; + +/** + * This class lowers match/filter paths (aka EvalFilter) to simple expressions. + */ +class EvalFilterLowering { +public: + EvalFilterLowering(PrefixId& prefixId, VariableEnvironment& env) + : _prefixId(prefixId), _env(env) {} + + // The default noop transport. + template <typename T, typename... Ts> + void transport(ABT&, const T&, Ts&&...) { + static_assert(!std::is_base_of_v<PathSyntaxSort, T>, + "Path elements must define their transport"); + } + + void transport(ABT& n, const PathConstant&, ABT& c); + void transport(ABT& n, const PathIdentity&); + void transport(ABT& n, const PathLambda&, ABT& lam); + void transport(ABT& n, const PathDefault&, ABT& c); + void transport(ABT& n, const PathCompare&, ABT& c); + void transport(ABT& n, const PathDrop&); + void transport(ABT& n, const PathKeep&); + void transport(ABT& n, const PathObj&); + void transport(ABT& n, const PathArr&); + + void prepare(ABT& n, const PathTraverse& t); + void transport(ABT& n, const PathTraverse&, ABT& inner); + + void transport(ABT& n, const PathGet&, ABT& inner); + void transport(ABT& n, const PathField&, ABT& inner); + + void transport(ABT& n, const PathComposeM&, ABT& p1, ABT& p2); + void transport(ABT& n, const PathComposeA&, ABT& p1, ABT& p2); + + void transport(ABT& n, const EvalFilter&, ABT& path, ABT& input); + + // The tree is passed in as NON-const reference as we will be updating it. + // Returns true if the tree changed. + bool optimize(ABT& n); + +private: + // We don't own these. + PrefixId& _prefixId; + VariableEnvironment& _env; + + std::vector<ABT::reference_type> _traverseStack; + + bool _changed{false}; +}; + +class PathLowering { +public: + PathLowering(PrefixId& prefixId, VariableEnvironment& env) + : _prefixId(prefixId), _env(env), _project(_prefixId, _env), _filter(_prefixId, _env) {} + + // The default noop transport. + template <typename T, typename... Ts> + void transport(ABT&, const T&, Ts&&...) {} + + void transport(ABT& n, const EvalPath&, ABT&, ABT&); + void transport(ABT& n, const EvalFilter&, ABT&, ABT&); + + bool optimize(ABT& n); + +private: + // We don't own these. + PrefixId& _prefixId; + VariableEnvironment& _env; + + EvalPathLowering _project; + EvalFilterLowering _filter; + + bool _changed{false}; +}; +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/rewrites/path_optimizer_test.cpp b/src/mongo/db/query/optimizer/rewrites/path_optimizer_test.cpp new file mode 100644 index 00000000000..23f2c1e8899 --- /dev/null +++ b/src/mongo/db/query/optimizer/rewrites/path_optimizer_test.cpp @@ -0,0 +1,881 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/explain.h" +#include "mongo/db/query/optimizer/rewrites/const_eval.h" +#include "mongo/db/query/optimizer/rewrites/path.h" +#include "mongo/db/query/optimizer/rewrites/path_lower.h" +#include "mongo/db/query/optimizer/utils/unit_test_utils.h" +#include "mongo/unittest/unittest.h" + +namespace mongo::optimizer { +namespace { +TEST(Path, Const) { + auto tree = make<EvalPath>(make<PathConstant>(Constant::int64(2)), make<Variable>("ptest")); + auto env = VariableEnvironment::build(tree); + + auto fusor = PathFusion(env); + fusor.optimize(tree); + + // The result must be Constant. + auto result = tree.cast<Constant>(); + ASSERT(result != nullptr); + + // And the value must be 2 + ASSERT_EQ(result->getValueInt64(), 2); +} + +TEST(Path, GetConst) { + // Get "any" Const 2 + auto tree = make<EvalPath>(make<PathGet>("any", make<PathConstant>(Constant::int64(2))), + make<Variable>("ptest")); + auto env = VariableEnvironment::build(tree); + + auto fusor = PathFusion(env); + fusor.optimize(tree); + + // The result must be Constant. + auto result = tree.cast<Constant>(); + ASSERT(result != nullptr); + + // And the value must be 2 + ASSERT_EQ(result->getValueInt64(), 2); +} + +TEST(Path, Fuse1) { + // Field "a" Const 2 + auto field = make<EvalPath>(make<PathField>("a", make<PathConstant>(Constant::int64(2))), + make<Variable>("root")); + + // Get "a" Id + auto get = make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("x")); + + // let x = (Field "a" Const 2 | root) + // in (Get "a" Id | x) + auto tree = make<Let>("x", std::move(field), std::move(get)); + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathFusion{env}.optimize(tree)) { + changed = true; + } + + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + + } while (changed); + + // The result must be Constant. + auto result = tree.cast<Constant>(); + ASSERT(result != nullptr); + + // And the value must be 2 + ASSERT_EQ(result->getValueInt64(), 2); +} + +TEST(Path, Fuse2) { + auto scanNode = make<ScanNode>("root", "test"); + + // Field "a" Const 2 + auto field = make<EvalPath>(make<PathField>("a", make<PathConstant>(Constant::int64(2))), + make<Variable>("root")); + + auto project1 = make<EvaluationNode>("x", std::move(field), std::move(scanNode)); + + // Get "a" Id + auto get = make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("x")); + auto project2 = make<EvaluationNode>("y", std::move(get), std::move(project1)); + + auto tree = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"y"}}, + std::move(project2)); + + auto env = VariableEnvironment::build(tree); + { + ProjectionNameSet projSet = env.topLevelProjections(); + ProjectionNameSet expSet = {"x", "y", "root"}; + ASSERT(expSet == projSet); + ASSERT(!env.hasFreeVariables()); + } + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathFusion{env}.optimize(tree)) { + changed = true; + } + + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + + } while (changed); + + // After rewrites for x projection disappear from the tree. + { + ProjectionNameSet projSet = env.topLevelProjections(); + ProjectionNameSet expSet = {"root", "y"}; + ASSERT(expSet == projSet); + ASSERT(!env.hasFreeVariables()); + } +} + +TEST(Path, Fuse3) { + auto scanNode = make<ScanNode>("root", "test"); + + auto project0 = make<EvaluationNode>( + "z", + make<EvalPath>(make<PathGet>("z", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + // Field "a" Const Var "z" + auto field = make<EvalPath>(make<PathField>("a", make<PathConstant>(make<Variable>("z"))), + make<Variable>("root")); + auto project1 = make<EvaluationNode>("x", std::move(field), std::move(project0)); + + // Get "a" Traverse Const 2 + auto get = make<EvalPath>( + make<PathGet>("a", make<PathTraverse>(make<PathConstant>(Constant::int64(2)))), + make<Variable>("x")); + auto project2 = make<EvaluationNode>("y", std::move(get), std::move(project1)); + + auto tree = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"y"}}, + std::move(project2)); + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " y\n" + " RefBlock: \n" + " Variable [y]\n" + " Evaluation []\n" + " BindBlock:\n" + " [y]\n" + " EvalPath []\n" + " PathGet [a]\n" + " PathTraverse []\n" + " PathConstant []\n" + " Const [2]\n" + " Variable [x]\n" + " Evaluation []\n" + " BindBlock:\n" + " [x]\n" + " EvalPath []\n" + " PathField [a]\n" + " PathConstant []\n" + " Variable [z]\n" + " Variable [root]\n" + " Evaluation []\n" + " BindBlock:\n" + " [z]\n" + " EvalPath []\n" + " PathGet [z]\n" + " PathIdentity []\n" + " Variable [root]\n" + " Scan [test]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + tree); + + auto env = VariableEnvironment::build(tree); + bool changed = false; + do { + changed = false; + if (PathFusion{env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " y\n" + " RefBlock: \n" + " Variable [y]\n" + " Evaluation []\n" + " BindBlock:\n" + " [y]\n" + " EvalPath []\n" + " PathGet [z]\n" + " PathTraverse []\n" + " PathConstant []\n" + " Const [2]\n" + " Variable [root]\n" + " Scan [test]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + tree); +} + +TEST(Path, Fuse4) { + auto scanNode = make<ScanNode>("root", "test"); + + auto project0 = make<EvaluationNode>( + "z", + make<EvalPath>(make<PathGet>("z", make<PathIdentity>()), make<Variable>("root")), + std::move(scanNode)); + + auto project01 = make<EvaluationNode>( + "z1", + make<EvalPath>(make<PathGet>("z1", make<PathIdentity>()), make<Variable>("root")), + std::move(project0)); + + auto project02 = make<EvaluationNode>( + "z2", + make<EvalPath>(make<PathGet>("z2", make<PathIdentity>()), make<Variable>("root")), + std::move(project01)); + + // Field "a" Const Var "z" * Field "b" Const Var "z1" + auto field = make<EvalPath>( + make<PathComposeM>( + make<PathField>("c", make<PathConstant>(make<Variable>("z2"))), + make<PathComposeM>(make<PathField>("a", make<PathConstant>(make<Variable>("z"))), + make<PathField>("b", make<PathConstant>(make<Variable>("z1"))))), + make<Variable>("root")); + auto project1 = make<EvaluationNode>("x", std::move(field), std::move(project02)); + + // Get "a" Traverse Const 2 + auto get = make<EvalPath>( + make<PathGet>("a", make<PathTraverse>(make<PathConstant>(Constant::int64(2)))), + make<Variable>("x")); + auto project2 = make<EvaluationNode>("y", std::move(get), std::move(project1)); + + auto tree = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"x", "y"}}, + std::move(project2)); + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " x\n" + " y\n" + " RefBlock: \n" + " Variable [x]\n" + " Variable [y]\n" + " Evaluation []\n" + " BindBlock:\n" + " [y]\n" + " EvalPath []\n" + " PathGet [a]\n" + " PathTraverse []\n" + " PathConstant []\n" + " Const [2]\n" + " Variable [x]\n" + " Evaluation []\n" + " BindBlock:\n" + " [x]\n" + " EvalPath []\n" + " PathComposeM []\n" + " PathField [c]\n" + " PathConstant []\n" + " Variable [z2]\n" + " PathComposeM []\n" + " PathField [a]\n" + " PathConstant []\n" + " Variable [z]\n" + " PathField [b]\n" + " PathConstant []\n" + " Variable [z1]\n" + " Variable [root]\n" + " Evaluation []\n" + " BindBlock:\n" + " [z2]\n" + " EvalPath []\n" + " PathGet [z2]\n" + " PathIdentity []\n" + " Variable [root]\n" + " Evaluation []\n" + " BindBlock:\n" + " [z1]\n" + " EvalPath []\n" + " PathGet [z1]\n" + " PathIdentity []\n" + " Variable [root]\n" + " Evaluation []\n" + " BindBlock:\n" + " [z]\n" + " EvalPath []\n" + " PathGet [z]\n" + " PathIdentity []\n" + " Variable [root]\n" + " Scan [test]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + tree); + + auto env = VariableEnvironment::build(tree); + bool changed = false; + do { + changed = false; + if (PathFusion{env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " x\n" + " y\n" + " RefBlock: \n" + " Variable [x]\n" + " Variable [y]\n" + " Evaluation []\n" + " BindBlock:\n" + " [y]\n" + " EvalPath []\n" + " PathTraverse []\n" + " PathConstant []\n" + " Const [2]\n" + " Variable [z]\n" + " Evaluation []\n" + " BindBlock:\n" + " [x]\n" + " EvalPath []\n" + " PathComposeM []\n" + " PathField [c]\n" + " PathConstant []\n" + " EvalPath []\n" + " PathGet [z2]\n" + " PathIdentity []\n" + " Variable [root]\n" + " PathComposeM []\n" + " PathField [a]\n" + " PathConstant []\n" + " Variable [z]\n" + " PathField [b]\n" + " PathConstant []\n" + " EvalPath []\n" + " PathGet [z1]\n" + " PathIdentity []\n" + " Variable [root]\n" + " Variable [root]\n" + " Evaluation []\n" + " BindBlock:\n" + " [z]\n" + " EvalPath []\n" + " PathGet [z]\n" + " PathIdentity []\n" + " Variable [root]\n" + " Scan [test]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + tree); +} + +TEST(Path, Fuse5) { + auto scanNode = make<ScanNode>("root", "test"); + + auto project = make<EvaluationNode>( + "x", + make<EvalPath>(make<PathKeep>(PathKeep::NameSet{"a", "b", "c"}), make<Variable>("root")), + std::move(scanNode)); + + // Get "a" Traverse Compare= 2 + auto filter = make<FilterNode>( + make<EvalFilter>( + make<PathGet>( + "a", make<PathTraverse>(make<PathCompare>(Operations::Eq, Constant::int64(2)))), + make<Variable>("x")), + std::move(project)); + + auto tree = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"x"}}, + std::move(filter)); + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " x\n" + " RefBlock: \n" + " Variable [x]\n" + " Filter []\n" + " EvalFilter []\n" + " PathGet [a]\n" + " PathTraverse []\n" + " PathCompare [Eq]\n" + " Const [2]\n" + " Variable [x]\n" + " Evaluation []\n" + " BindBlock:\n" + " [x]\n" + " EvalPath []\n" + " PathKeep [a, b, c]\n" + " Variable [root]\n" + " Scan [test]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + tree); + + auto env = VariableEnvironment::build(tree); + bool changed = false; + do { + changed = false; + if (PathFusion{env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + // The filter now refers directly to the root projection. + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " x\n" + " RefBlock: \n" + " Variable [x]\n" + " Filter []\n" + " EvalFilter []\n" + " PathGet [a]\n" + " PathTraverse []\n" + " PathCompare [Eq]\n" + " Const [2]\n" + " Variable [root]\n" + " Evaluation []\n" + " BindBlock:\n" + " [x]\n" + " EvalPath []\n" + " PathKeep [a, b, c]\n" + " Variable [root]\n" + " Scan [test]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + tree); +} + +TEST(Path, Lower1) { + PrefixId prefixId; + + auto tree = make<EvalPath>(make<PathIdentity>(), make<Variable>("foo")); + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + ASSERT(tree.is<Variable>()); + ASSERT_EQ(tree.cast<Variable>()->name(), "foo"); +} + +TEST(Path, Lower2) { + PrefixId prefixId; + + auto tree = make<EvalPath>(make<PathConstant>(Constant::int64(10)), make<Variable>("foo")); + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + ASSERT(tree.is<Constant>()); + ASSERT_EQ(tree.cast<Constant>()->getValueInt64(), 10); +} + +TEST(Path, Lower3) { + PrefixId prefixId; + + auto tree = make<EvalPath>( + make<PathLambda>(make<LambdaAbstraction>( + "x", make<BinaryOp>(Operations::Add, make<Variable>("x"), Constant::int64(1)))), + Constant::int64(9)); + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + ASSERT(tree.is<Constant>()); + ASSERT_EQ(tree.cast<Constant>()->getValueInt64(), 10); +} + +TEST(Path, Lower4) { + PrefixId prefixId; + + auto tree = make<EvalPath>( + make<PathGet>( + "fieldA", + make<PathGet>("fieldB", + /*make<PathIdentity>()*/ make<PathConstant>(Constant::int64(100)))), + make<Variable>("rootObj")); + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + ASSERT(tree.is<Constant>()); + ASSERT_EQ(tree.cast<Constant>()->getValueInt64(), 100); +} + +TEST(Path, Lower5) { + PrefixId prefixId; + + auto tree = make<EvalPath>( + make<PathGet>( + "fieldA", + make<PathGet>( + "fieldB", + make<PathLambda>(make<LambdaAbstraction>( + "x", + make<BinaryOp>(Operations::Add, make<Variable>("x"), Constant::int64(1)))))), + make<Variable>("rootObj")); + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + ASSERT_EXPLAIN( + "BinaryOp [Add]\n" + " FunctionCall [getField]\n" + " FunctionCall [getField]\n" + " Variable [rootObj]\n" + " Const [\"fieldA\"]\n" + " Const [\"fieldB\"]\n" + " Const [1]\n", + tree); +} + +TEST(Path, ProjElim1) { + PrefixId prefixId; + + auto scanNode = make<ScanNode>("root", "test"); + + auto expr1 = make<FunctionCall>("anyFunctionWillDo", makeSeq(make<Variable>("root"))); + auto project1 = make<EvaluationNode>("x", std::move(expr1), std::move(scanNode)); + + auto expr2 = make<Variable>("x"); + auto project2 = make<EvaluationNode>("y", std::move(expr2), std::move(project1)); + + auto tree = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"y"}}, + std::move(project2)); + + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " y\n" + " RefBlock: \n" + " Variable [y]\n" + " Evaluation []\n" + " BindBlock:\n" + " [y]\n" + " FunctionCall [anyFunctionWillDo]\n" + " Variable [root]\n" + " Scan [test]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + tree); +} + +TEST(Path, ProjElim2) { + PrefixId prefixId; + + auto scanNode = make<ScanNode>("root", "test"); + + auto expr1 = make<FunctionCall>("anyFunctionWillDo", makeSeq(make<Variable>("root"))); + auto project1 = make<EvaluationNode>("x", std::move(expr1), std::move(scanNode)); + + auto expr2 = make<Variable>("x"); + auto project2 = make<EvaluationNode>("y", std::move(expr2), std::move(project1)); + + auto tree = make<RootNode>(properties::ProjectionRequirement{{}}, std::move(project2)); + + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " RefBlock: \n" + " Scan [test]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + tree); +} + +TEST(Path, ProjElim3) { + auto node = make<ScanNode>("root", "test"); + std::string var = "root"; + for (int i = 0; i < 100; ++i) { + std::string newVar = "p" + std::to_string(i); + node = make<EvaluationNode>( + newVar, + // make<FunctionCall>("anyFunctionWillDo", makeSeq(make<Variable>(var))), + make<Variable>(var), + std::move(node)); + var = newVar; + } + + auto tree = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{var}}, + std::move(node)); + + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + ASSERT_EXPLAIN( + "Root []\n" + " projections: \n" + " p99\n" + " RefBlock: \n" + " Variable [p99]\n" + " Evaluation []\n" + " BindBlock:\n" + " [p99]\n" + " Variable [root]\n" + " Scan [test]\n" + " BindBlock:\n" + " [root]\n" + " Source []\n", + tree); +} + +TEST(Path, Lower6) { + PrefixId prefixId; + + auto tree = make<EvalPath>( + make<PathGet>("fieldA", make<PathGet>("fieldB", make<PathDefault>(Constant::int64(0)))), + make<Variable>("rootObj")); + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + ASSERT_EXPLAIN( + "Let [valDefault_0]\n" + " FunctionCall [getField]\n" + " FunctionCall [getField]\n" + " Variable [rootObj]\n" + " Const [\"fieldA\"]\n" + " Const [\"fieldB\"]\n" + " If []\n" + " FunctionCall [exists]\n" + " Variable [valDefault_0]\n" + " Variable [valDefault_0]\n" + " Const [0]\n", + tree); +} + +TEST(Path, Lower7) { + PrefixId prefixId; + + auto tree = make<EvalPath>(make<PathGet>("fieldA", + make<PathTraverse>(make<PathGet>( + "fieldB", make<PathDefault>(Constant::int64(0))))), + make<Variable>("rootObj")); + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + // Add some asserts on the shape of the tree or something. +} + +TEST(Path, Lower8) { + PrefixId prefixId; + + auto tree = make<EvalPath>( + make<PathComposeM>(make<PathIdentity>(), make<PathConstant>(Constant::int64(100))), + make<Variable>("rootObj")); + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + ASSERT(tree.is<Constant>()); + ASSERT_EQ(tree.cast<Constant>()->getValueInt64(), 100); +} + +TEST(Path, Lower9) { + PrefixId prefixId; + + auto tree = make<EvalPath>(make<PathComposeM>(make<PathGet>("fieldA", make<PathIdentity>()), + make<PathConstant>(Constant::int64(100))), + make<Variable>("rootObj")); + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + ASSERT(tree.is<Constant>()); + ASSERT_EQ(tree.cast<Constant>()->getValueInt64(), 100); +} + +TEST(Path, Lower10) { + PrefixId prefixId; + + auto tree = make<EvalPath>( + make<PathField>( + "fieldA", + make<PathTraverse>(make<PathField>("fieldB", make<PathDefault>(Constant::int64(0))))), + make<Variable>("rootObj")); + auto env = VariableEnvironment::build(tree); + + // Run rewriters while things change + bool changed = false; + do { + changed = false; + if (PathLowering{prefixId, env}.optimize(tree)) { + changed = true; + } + if (ConstEval{env}.optimize(tree)) { + changed = true; + } + } while (changed); + + // Add some asserts on the shape of the tree or something. +} + +} // namespace +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/syntax/expr.cpp b/src/mongo/db/query/optimizer/syntax/expr.cpp new file mode 100644 index 00000000000..1dcd85de61a --- /dev/null +++ b/src/mongo/db/query/optimizer/syntax/expr.cpp @@ -0,0 +1,128 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/syntax/expr.h" +#include "mongo/db/query/optimizer/node.h" + +namespace mongo::optimizer { + +using namespace sbe::value; + +Constant::Constant(TypeTags tag, sbe::value::Value val) : _tag(tag), _val(val) {} + +Constant::Constant(const Constant& other) { + auto [tag, val] = copyValue(other._tag, other._val); + _tag = tag; + _val = val; +} + +Constant::Constant(Constant&& other) noexcept { + _tag = other._tag; + _val = other._val; + + other._tag = TypeTags::Nothing; + other._val = 0; +} + +ABT Constant::str(std::string str) { + // Views are non-owning so we have to make a copy. + auto [tag, val] = makeNewString(str); + return make<Constant>(tag, val); +} + +ABT Constant::int32(int32_t valueInt32) { + return make<Constant>(TypeTags::NumberInt32, bitcastFrom<int32_t>(valueInt32)); +} + +ABT Constant::int64(int64_t valueInt64) { + return make<Constant>(TypeTags::NumberInt64, bitcastFrom<int64_t>(valueInt64)); +} + +ABT Constant::fromDouble(double value) { + return make<Constant>(TypeTags::NumberDouble, bitcastFrom<double>(value)); +} + +ABT Constant::emptyObject() { + auto [tag, val] = makeNewObject(); + return make<Constant>(tag, val); +} + +ABT Constant::emptyArray() { + auto [tag, val] = makeNewArray(); + return make<Constant>(tag, val); +} + +ABT Constant::nothing() { + return make<Constant>(TypeTags::Nothing, 0); +} + +ABT Constant::null() { + return make<Constant>(TypeTags::Null, 0); +} + +ABT Constant::boolean(bool b) { + return make<Constant>(TypeTags::Boolean, bitcastFrom<bool>(b)); +} + +ABT Constant::minKey() { + return make<Constant>(TypeTags::MinKey, 0); +} + +ABT Constant::maxKey() { + return make<Constant>(TypeTags::MaxKey, 0); +} + +Constant::~Constant() { + releaseValue(_tag, _val); +} + +bool Constant::operator==(const Constant& other) const { + const auto [compareTag, compareVal] = compareValue(_tag, _val, other._tag, other._val); + return sbe::value::bitcastTo<int32_t>(compareVal) == 0; +} + +bool Constant::isValueInt64() const { + return _tag == TypeTags::NumberInt64; +} + +int64_t Constant::getValueInt64() const { + uassert(6624057, "Constant value type is not int64_t", isValueInt64()); + return bitcastTo<int64_t>(_val); +} + +bool Constant::isValueInt32() const { + return _tag == TypeTags::NumberInt32; +} + +int32_t Constant::getValueInt32() const { + uassert(6624354, "Constant value type is not int32_t", isValueInt32()); + return bitcastTo<int32_t>(_val); +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/syntax/expr.h b/src/mongo/db/query/optimizer/syntax/expr.h new file mode 100644 index 00000000000..a795625d59c --- /dev/null +++ b/src/mongo/db/query/optimizer/syntax/expr.h @@ -0,0 +1,350 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <ostream> + +#include "mongo/db/exec/sbe/values/value.h" +#include "mongo/db/query/optimizer/syntax/syntax.h" + +namespace mongo::optimizer { + +class Constant final : public Operator<Constant, 0>, public ExpressionSyntaxSort { +public: + Constant(sbe::value::TypeTags tag, sbe::value::Value val); + + static ABT str(std::string str); + + static ABT int32(int32_t valueInt32); + static ABT int64(int64_t valueInt64); + + static ABT fromDouble(double value); + + static ABT emptyObject(); + static ABT emptyArray(); + + static ABT nothing(); + static ABT null(); + + static ABT boolean(bool b); + + static ABT minKey(); + static ABT maxKey(); + + ~Constant(); + + Constant(const Constant& other); + Constant(Constant&& other) noexcept; + + bool operator==(const Constant& other) const; + + auto get() const { + return std::pair{_tag, _val}; + } + + bool isValueInt64() const; + int64_t getValueInt64() const; + + bool isValueInt32() const; + int32_t getValueInt32() const; + + bool isNumber() const { + return sbe::value::isNumber(_tag); + } + + bool isNothing() const { + return _tag == sbe::value::TypeTags::Nothing; + } + +private: + sbe::value::TypeTags _tag; + sbe::value::Value _val; +}; + +class Variable final : public Operator<Variable, 0>, public ExpressionSyntaxSort { + std::string _name; + +public: + Variable(std::string inName) : _name(std::move(inName)) {} + + bool operator==(const Variable& other) const { + return _name == other._name; + } + + auto& name() const { + return _name; + } +}; + +class UnaryOp final : public Operator<UnaryOp, 1>, public ExpressionSyntaxSort { + using Base = Operator<UnaryOp, 1>; + Operations _op; + +public: + UnaryOp(Operations inOp, ABT inExpr) : Base(std::move(inExpr)), _op(inOp) { + assertExprSort(getChild()); + } + + bool operator==(const UnaryOp& other) const { + return _op == other._op && getChild() == other.getChild(); + } + + auto op() const { + return _op; + } + + const ABT& getChild() const { + return get<0>(); + } +}; + +class BinaryOp final : public Operator<BinaryOp, 2>, public ExpressionSyntaxSort { + using Base = Operator<BinaryOp, 2>; + Operations _op; + +public: + BinaryOp(Operations inOp, ABT inLhs, ABT inRhs) + : Base(std::move(inLhs), std::move(inRhs)), _op(inOp) { + assertExprSort(getLeftChild()); + assertExprSort(getRightChild()); + } + + bool operator==(const BinaryOp& other) const { + return _op == other._op && getLeftChild() == other.getLeftChild() && + getRightChild() == other.getRightChild(); + } + + auto op() const { + return _op; + } + + const ABT& getLeftChild() const { + return get<0>(); + } + + const ABT& getRightChild() const { + return get<1>(); + } +}; + +class If final : public Operator<If, 3>, public ExpressionSyntaxSort { + using Base = Operator<If, 3>; + +public: + If(ABT inCond, ABT inThen, ABT inElse) + : Base(std::move(inCond), std::move(inThen), std::move(inElse)) { + assertExprSort(getCondChild()); + assertExprSort(getThenChild()); + assertExprSort(getElseChild()); + } + + bool operator==(const If& other) const { + return getCondChild() == other.getCondChild() && getThenChild() == other.getThenChild() && + getElseChild() == other.getElseChild(); + } + + const ABT& getCondChild() const { + return get<0>(); + } + + const ABT& getThenChild() const { + return get<1>(); + } + + const ABT& getElseChild() const { + return get<2>(); + } +}; + +class Let final : public Operator<Let, 2>, public ExpressionSyntaxSort { + using Base = Operator<Let, 2>; + + std::string _varName; + +public: + Let(std::string var, ABT inBind, ABT inExpr) + : Base(std::move(inBind), std::move(inExpr)), _varName(std::move(var)) { + assertExprSort(bind()); + assertExprSort(in()); + } + + bool operator==(const Let& other) const { + return _varName == other._varName && bind() == other.bind() && in() == other.in(); + } + + auto& varName() const { + return _varName; + } + + const ABT& bind() const { + return get<0>(); + } + + const ABT& in() const { + return get<1>(); + } +}; + +class LambdaAbstraction final : public Operator<LambdaAbstraction, 1>, public ExpressionSyntaxSort { + using Base = Operator<LambdaAbstraction, 1>; + + std::string _varName; + +public: + LambdaAbstraction(std::string var, ABT inBody) + : Base(std::move(inBody)), _varName(std::move(var)) { + assertExprSort(getBody()); + } + + bool operator==(const LambdaAbstraction& other) const { + return _varName == other._varName && getBody() == other.getBody(); + } + + auto& varName() const { + return _varName; + } + + const ABT& getBody() const { + return get<0>(); + } + + ABT& getBody() { + return get<0>(); + } +}; + +class LambdaApplication final : public Operator<LambdaApplication, 2>, public ExpressionSyntaxSort { + using Base = Operator<LambdaApplication, 2>; + +public: + LambdaApplication(ABT inLambda, ABT inArgument) + : Base(std::move(inLambda), std::move(inArgument)) { + assertExprSort(getLambda()); + assertExprSort(getArgument()); + } + + bool operator==(const LambdaApplication& other) const { + return getLambda() == other.getLambda() && getArgument() == other.getArgument(); + } + + const ABT& getLambda() const { + return get<0>(); + } + + const ABT& getArgument() const { + return get<1>(); + } +}; + +class FunctionCall final : public OperatorDynamicHomogenous<FunctionCall>, + public ExpressionSyntaxSort { + using Base = OperatorDynamicHomogenous<FunctionCall>; + std::string _name; + +public: + FunctionCall(std::string inName, ABTVector inArgs) + : Base(std::move(inArgs)), _name(std::move(inName)) { + for (auto& a : nodes()) { + assertExprSort(a); + } + } + + bool operator==(const FunctionCall& other) const { + return _name == other._name && nodes() == other.nodes(); + } + + auto& name() const { + return _name; + } +}; + +class EvalPath final : public Operator<EvalPath, 2>, public ExpressionSyntaxSort { + using Base = Operator<EvalPath, 2>; + +public: + EvalPath(ABT inPath, ABT inInput) : Base(std::move(inPath), std::move(inInput)) { + assertPathSort(getPath()); + assertExprSort(getInput()); + } + + bool operator==(const EvalPath& other) const { + return getPath() == other.getPath() && getInput() == other.getInput(); + } + + const ABT& getPath() const { + return get<0>(); + } + + ABT& getPath() { + return get<0>(); + } + + const ABT& getInput() const { + return get<1>(); + } +}; + +class EvalFilter final : public Operator<EvalFilter, 2>, public ExpressionSyntaxSort { + using Base = Operator<EvalFilter, 2>; + +public: + EvalFilter(ABT inPath, ABT inInput) : Base(std::move(inPath), std::move(inInput)) { + assertPathSort(getPath()); + assertExprSort(getInput()); + } + + bool operator==(const EvalFilter& other) const { + return getPath() == other.getPath() && getInput() == other.getInput(); + } + + const ABT& getPath() const { + return get<0>(); + } + + ABT& getPath() { + return get<0>(); + } + + const ABT& getInput() const { + return get<1>(); + } +}; + +/** + * This class represents a source of values originating from relational nodes. + */ +class Source final : public Operator<Source, 0>, public ExpressionSyntaxSort { +public: + bool operator==(const Source& other) const { + return true; + } +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/syntax/path.h b/src/mongo/db/query/optimizer/syntax/path.h new file mode 100644 index 00000000000..4f6ec8775b4 --- /dev/null +++ b/src/mongo/db/query/optimizer/syntax/path.h @@ -0,0 +1,349 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <ostream> +#include <unordered_set> + +#include "mongo/db/query/optimizer/syntax/syntax.h" + + +namespace mongo::optimizer { + +/** + * A constant path element - any input value is disregarded and replaced by the result of (constant) + * expression. + * + * It could also be expressed as lambda that ignores its input: \ _ . c + */ +class PathConstant final : public Operator<PathConstant, 1>, public PathSyntaxSort { + using Base = Operator<PathConstant, 1>; + +public: + PathConstant(ABT inConstant) : Base(std::move(inConstant)) { + assertExprSort(getConstant()); + } + + bool operator==(const PathConstant& other) const { + return getConstant() == other.getConstant(); + } + + const ABT& getConstant() const { + return get<0>(); + } + + ABT& getConstant() { + return get<0>(); + } +}; + +/** + * A lambda path element - the expression must be a single argument lambda. The lambda is applied + * with the input value. + */ +class PathLambda final : public Operator<PathLambda, 1>, public PathSyntaxSort { + using Base = Operator<PathLambda, 1>; + +public: + PathLambda(ABT inLambda) : Base(std::move(inLambda)) { + assertExprSort(getLambda()); + } + + bool operator==(const PathLambda& other) const { + return getLambda() == other.getLambda(); + } + + const ABT& getLambda() const { + return get<0>(); + } +}; + +/** + * An identity path element - the input is not disturbed at all. + * + * It could also be expressed as lambda : \ x . x + */ +class PathIdentity final : public Operator<PathIdentity, 0>, public PathSyntaxSort { +public: + bool operator==(const PathIdentity& other) const { + return true; + } +}; + +/** + * A default path element - If input is Nothing then return the result of expression (assumed to + * return non-Nothing) otherwise return the input undisturbed. + */ +class PathDefault final : public Operator<PathDefault, 1>, public PathSyntaxSort { + using Base = Operator<PathDefault, 1>; + +public: + PathDefault(ABT inDefault) : Base(std::move(inDefault)) { + assertExprSort(getDefault()); + } + + bool operator==(const PathDefault& other) const { + return getDefault() == other.getDefault(); + } + + const ABT& getDefault() const { + return get<0>(); + } +}; + +/** + * A comparison path element - the input value is compared to the result of (constant) expression. + * The actual semantics (return value) depends on what component is evaluating the paths (i.e. + * filter or project). + */ +class PathCompare : public Operator<PathCompare, 1>, public PathSyntaxSort { + using Base = Operator<PathCompare, 1>; + + Operations _cmp; + +public: + PathCompare(Operations inCmp, ABT inVal) : Base(std::move(inVal)), _cmp(inCmp) { + assertExprSort(getVal()); + } + + bool operator==(const PathCompare& other) const { + return _cmp == other._cmp && getVal() == other.getVal(); + } + + auto op() const { + return _cmp; + } + + const ABT& getVal() const { + return get<0>(); + } + + ABT& getVal() { + return get<0>(); + } +}; + +/** + * A drop path element - drops fields from the input if it is an object otherwise returns it + * undisturbed. + */ +class PathDrop final : public Operator<PathDrop, 0>, public PathSyntaxSort { +public: + using NameSet = opt::unordered_set<std::string>; + + PathDrop(NameSet inNames) : _names(std::move(inNames)) {} + + bool operator==(const PathDrop& other) const { + return _names == other._names; + } + + const NameSet& getNames() const { + return _names; + } + +private: + const NameSet _names; +}; + +/** + * A keep path element - keeps fields from the input if it is an object otherwise returns it + * undisturbed. + */ +class PathKeep final : public Operator<PathKeep, 0>, public PathSyntaxSort { +public: + using NameSet = opt::unordered_set<std::string>; + + PathKeep(NameSet inNames) : _names(std::move(inNames)) {} + + bool operator==(const PathKeep other) const { + return _names == other._names; + } + + const NameSet& getNames() const { + return _names; + } + +private: + const NameSet _names; +}; + +/** + * Returns input undisturbed if it is an object otherwise return Nothing. + */ +class PathObj final : public Operator<PathObj, 0>, public PathSyntaxSort { +public: + bool operator==(const PathObj& other) const { + return true; + } +}; + +/** + * Returns input undisturbed if it is an array otherwise return Nothing. + */ +class PathArr final : public Operator<PathArr, 0>, public PathSyntaxSort { +public: + bool operator==(const PathArr& other) const { + return true; + } +}; + +/** + * A traverse path element - apply the inner path to every element of an array. + */ +class PathTraverse final : public Operator<PathTraverse, 1>, public PathSyntaxSort { + using Base = Operator<PathTraverse, 1>; + +public: + PathTraverse(ABT inPath) : Base(std::move(inPath)) { + assertPathSort(getPath()); + } + + bool operator==(const PathTraverse& other) const { + return getPath() == other.getPath(); + } + + const ABT& getPath() const { + return get<0>(); + } + + ABT& getPath() { + return get<0>(); + } +}; + +/** + * A field path element - apply the inner path to an object field. + */ +class PathField final : public Operator<PathField, 1>, public PathSyntaxSort { + using Base = Operator<PathField, 1>; + std::string _name; + +public: + PathField(std::string inName, ABT inPath) : Base(std::move(inPath)), _name(std::move(inName)) { + assertPathSort(getPath()); + } + + bool operator==(const PathField& other) const { + return _name == other._name && getPath() == other.getPath(); + } + + auto& name() const { + return _name; + } + + const ABT& getPath() const { + return get<0>(); + } + + ABT& getPath() { + return get<0>(); + } +}; + +/** + * A get path element - similar to the path element. + */ +class PathGet final : public Operator<PathGet, 1>, public PathSyntaxSort { + using Base = Operator<PathGet, 1>; + std::string _name; + +public: + PathGet(std::string inName, ABT inPath) : Base(std::move(inPath)), _name(std::move(inName)) { + assertPathSort(getPath()); + } + + bool operator==(const PathGet& other) const { + return _name == other._name && getPath() == other.getPath(); + } + + auto& name() const { + return _name; + } + + const ABT& getPath() const { + return get<0>(); + } + + ABT& getPath() { + return get<0>(); + } +}; + +/** + * A multiplicative composition path element. + */ +class PathComposeM final : public Operator<PathComposeM, 2>, public PathSyntaxSort { + using Base = Operator<PathComposeM, 2>; + +public: + PathComposeM(ABT inPath1, ABT inPath2) : Base(std::move(inPath1), std::move(inPath2)) { + assertPathSort(getPath1()); + assertPathSort(getPath2()); + } + + bool operator==(const PathComposeM& other) const { + return getPath1() == other.getPath1() && getPath2() == other.getPath2(); + } + + const ABT& getPath1() const { + return get<0>(); + } + + const ABT& getPath2() const { + return get<1>(); + } +}; + +/** + * An additive composition path element. + */ +class PathComposeA final : public Operator<PathComposeA, 2>, public PathSyntaxSort { + using Base = Operator<PathComposeA, 2>; + +public: + PathComposeA(ABT inPath1, ABT inPath2) : Base(std::move(inPath1), std::move(inPath2)) { + assertPathSort(getPath1()); + assertPathSort(getPath2()); + } + + bool operator==(const PathComposeA& other) const { + return getPath1() == other.getPath1() && getPath2() == other.getPath2(); + } + + const ABT& getPath1() const { + return get<0>(); + } + + const ABT& getPath2() const { + return get<1>(); + } +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/syntax/syntax.h b/src/mongo/db/query/optimizer/syntax/syntax.h new file mode 100644 index 00000000000..cfdd70ce4dc --- /dev/null +++ b/src/mongo/db/query/optimizer/syntax/syntax.h @@ -0,0 +1,263 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <string> + +#include "mongo/db/query/optimizer/algebra/operator.h" +#include "mongo/db/query/optimizer/algebra/polyvalue.h" +#include "mongo/db/query/optimizer/syntax/syntax_fwd_declare.h" +#include "mongo/db/query/optimizer/utils/printable_enum.h" +#include "mongo/util/assert_util.h" + +namespace mongo::optimizer { + +using ABT = algebra::PolyValue<Blackhole, + Constant, // expressions + Variable, + UnaryOp, + BinaryOp, + If, + Let, + LambdaAbstraction, + LambdaApplication, + FunctionCall, + EvalPath, + EvalFilter, + Source, + PathConstant, // path elements + PathLambda, + PathIdentity, + PathDefault, + PathCompare, + PathDrop, + PathKeep, + PathObj, + PathArr, + PathTraverse, + PathField, + PathGet, + PathComposeM, + PathComposeA, + ScanNode, // nodes + PhysicalScanNode, + ValueScanNode, + CoScanNode, + IndexScanNode, + SeekNode, + MemoLogicalDelegatorNode, + MemoPhysicalDelegatorNode, + FilterNode, + EvaluationNode, + SargableNode, + RIDIntersectNode, + BinaryJoinNode, + HashJoinNode, + MergeJoinNode, + UnionNode, + GroupByNode, + UnwindNode, + UniqueNode, + CollationNode, + LimitSkipNode, + ExchangeNode, + RootNode, + References, // utilities + ExpressionBinder>; + +template <typename Derived, size_t Arity> +using Operator = algebra::OpSpecificArity<ABT, Derived, Arity>; + +template <typename Derived, size_t Arity> +using OperatorDynamic = algebra::OpSpecificDynamicArity<ABT, Derived, Arity>; + +template <typename Derived> +using OperatorDynamicHomogenous = OperatorDynamic<Derived, 0>; + +using ABTVector = std::vector<ABT>; + +template <typename T, typename... Args> +inline auto make(Args&&... args) { + return ABT::make<T>(std::forward<Args>(args)...); +} + +template <typename... Args> +inline auto makeSeq(Args&&... args) { + ABTVector seq; + (seq.emplace_back(std::forward<Args>(args)), ...); + return seq; +} + +class ExpressionSyntaxSort {}; + +class PathSyntaxSort {}; + +inline void assertExprSort(const ABT& e) { + if (!e.is<ExpressionSyntaxSort>()) { + uasserted(6624058, "expression syntax sort expected"); + } +} + +inline void assertPathSort(const ABT& e) { + if (!e.is<PathSyntaxSort>()) { + uasserted(6624059, "path syntax sort expected"); + } +} + +inline bool operator!=(const ABT& left, const ABT& right) { + return !(left == right); +} + +#define PATHSYNTAX_OPNAMES(F) \ + /* comparison operations */ \ + F(Eq) \ + F(Neq) \ + F(Gt) \ + F(Gte) \ + F(Lt) \ + F(Lte) \ + F(Cmp3w) \ + \ + /* binary operations */ \ + F(Add) \ + F(Sub) \ + F(Mult) \ + F(Div) \ + \ + /* unary operations */ \ + F(Neg) \ + \ + /* logical operations */ \ + F(And) \ + F(Or) \ + F(Not) + +MAKE_PRINTABLE_ENUM(Operations, PATHSYNTAX_OPNAMES); +MAKE_PRINTABLE_ENUM_STRING_ARRAY(OperationsEnum, Operations, PATHSYNTAX_OPNAMES); +#undef PATHSYNTAX_OPNAMES + +inline constexpr bool isUnaryOp(Operations op) { + return op == Operations::Neg || op == Operations::Not; +} + +inline constexpr bool isBinaryOp(Operations op) { + return !isUnaryOp(op); +} + +/** + * This is a special inert ABT node. It is used by rewriters to preserve structural properties of + * nodes during in-place rewriting. + */ +class Blackhole final : public Operator<Blackhole, 0> { +public: + bool operator==(const Blackhole& other) const { + return true; + } +}; + +/** + * This is a helper structure that represents Node internal references. Some relational nodes + * implicitly reference named projections from its children. + * + * Canonical examples are: GROUP BY "a", ORDER BY "b", etc. + * + * We want to capture these references. The rule of ABTs says that the ONLY way to reference a named + * entity is through the Variable class. The uniformity of the approach makes life much easier for + * the optimizer developers. + * On the other hand using Variables everywhere makes writing code more verbose, hence this helper. + */ +class References final : public OperatorDynamicHomogenous<References> { + using Base = OperatorDynamicHomogenous<References>; + +public: + /* + * Construct Variable objects out of provided vector of strings. + */ + References(const std::vector<std::string>& names) : Base(ABTVector{}) { + // Construct actual Variable objects from names and make them the children of this object. + for (const auto& name : names) { + nodes().emplace_back(make<Variable>(name)); + } + } + + /* + * Alternatively, construct references out of provided ABTs. This may be useful when the + * internal references are more complex then a simple string. We may consider e.g. GROUP BY + * (a+b). + */ + References(ABTVector refs) : Base(std::move(refs)) { + for (const auto& node : nodes()) { + assertExprSort(node); + } + } + + bool operator==(const References& other) const { + return nodes() == other.nodes(); + } +}; + +/** + * This class represents a unified way of binding identifiers to expressions. Every ABT node that + * introduces a new identifier must use this binder (i.e. all relational nodes adding new + * projections and expression nodes adding new local variables). + */ +class ExpressionBinder : public OperatorDynamicHomogenous<ExpressionBinder> { + using Base = OperatorDynamicHomogenous<ExpressionBinder>; + std::vector<std::string> _names; + +public: + ExpressionBinder(std::string name, ABT expr) : Base(makeSeq(std::move(expr))) { + _names.emplace_back(std::move(name)); + for (const auto& node : nodes()) { + assertExprSort(node); + } + } + + ExpressionBinder(std::vector<std::string> names, ABTVector exprs) + : Base(std::move(exprs)), _names(std::move(names)) { + for (const auto& node : nodes()) { + assertExprSort(node); + } + } + + bool operator==(const ExpressionBinder& other) const { + return _names == other._names && exprs() == other.exprs(); + } + + const std::vector<std::string>& names() const { + return _names; + } + + const ABTVector& exprs() const { + return nodes(); + } +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/syntax/syntax_fwd_declare.h b/src/mongo/db/query/optimizer/syntax/syntax_fwd_declare.h new file mode 100644 index 00000000000..b21781959e4 --- /dev/null +++ b/src/mongo/db/query/optimizer/syntax/syntax_fwd_declare.h @@ -0,0 +1,101 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +namespace mongo::optimizer { + +/** + * Expressions + */ +class Blackhole; +class Constant; +class Variable; +class UnaryOp; +class BinaryOp; +class If; +class Let; +class LambdaAbstraction; +class LambdaApplication; +class FunctionCall; +class EvalPath; +class EvalFilter; +class Source; + +/** + * Path elements + */ +class PathConstant; +class PathLambda; +class PathIdentity; +class PathDefault; +class PathCompare; +class PathDrop; +class PathKeep; +class PathObj; +class PathArr; +class PathTraverse; +class PathField; +class PathGet; +class PathComposeM; +class PathComposeA; + +/** + * Nodes + */ +class ScanNode; +class PhysicalScanNode; +class ValueScanNode; +class CoScanNode; +class IndexScanNode; +class SeekNode; +class MemoLogicalDelegatorNode; +class MemoPhysicalDelegatorNode; +class FilterNode; +class EvaluationNode; +class SargableNode; +class RIDIntersectNode; +class BinaryJoinNode; +class HashJoinNode; +class MergeJoinNode; +class UnionNode; +class GroupByNode; +class UnwindNode; +class UniqueNode; +class CollationNode; +class LimitSkipNode; +class ExchangeNode; +class RootNode; + +/** + * Utility + */ +class References; +class ExpressionBinder; +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/abt_hash.cpp b/src/mongo/db/query/optimizer/utils/abt_hash.cpp new file mode 100644 index 00000000000..9edb830f893 --- /dev/null +++ b/src/mongo/db/query/optimizer/utils/abt_hash.cpp @@ -0,0 +1,473 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/utils/abt_hash.h" + +#include "mongo/db/query/optimizer/node.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer { + +static size_t computeCollationHash(const properties::CollationRequirement& prop) { + size_t collationHash = 17; + for (const auto& entry : prop.getCollationSpec()) { + updateHash(collationHash, std::hash<ProjectionName>()(entry.first)); + updateHash(collationHash, std::hash<CollationOp>()(entry.second)); + } + return collationHash; +} + +static size_t computeLimitSkipHash(const properties::LimitSkipRequirement& prop) { + size_t limitSkipHash = 17; + updateHash(limitSkipHash, std::hash<int64_t>()(prop.getLimit())); + updateHash(limitSkipHash, std::hash<int64_t>()(prop.getSkip())); + return limitSkipHash; +} + +static size_t computePropertyProjectionsHash(const ProjectionNameVector& projections) { + size_t resultHash = 17; + for (const ProjectionName& projection : projections) { + updateHashUnordered(resultHash, std::hash<ProjectionName>()(projection)); + } + return resultHash; +} + +static size_t computeProjectionRequirementHash(const properties::ProjectionRequirement& prop) { + return computePropertyProjectionsHash(prop.getProjections().getVector()); +} + +static size_t computeDistributionHash(const properties::DistributionRequirement& prop) { + size_t resultHash = 17; + const auto& distribAndProjections = prop.getDistributionAndProjections(); + updateHash(resultHash, std::hash<DistributionType>()(distribAndProjections._type)); + updateHash(resultHash, computePropertyProjectionsHash(distribAndProjections._projectionNames)); + return resultHash; +} + +static void updateBoundHash(size_t& result, const BoundRequirement& bound) { + updateHash(result, std::hash<bool>()(bound.isInclusive())); + if (!bound.isInfinite()) { + updateHash(result, ABTHashGenerator::generate(bound.getBound())); + } +}; + +template <class T> +class IntervalHasher { +public: + size_t computeHash(const IntervalRequirement& req) { + size_t result = 17; + updateBoundHash(result, req.getLowBound()); + updateBoundHash(result, req.getHighBound()); + return 17; + } + + size_t computeHash(const MultiKeyIntervalRequirement& req) { + size_t result = 19; + for (const auto& interval : req) { + updateHash(result, computeHash(interval)); + } + return result; + } + + size_t transport(const typename T::Atom& node) { + return computeHash(node.getExpr()); + } + + size_t transport(const typename T::Conjunction& node, const std::vector<size_t> childResults) { + size_t result = 31; + for (const size_t childResult : childResults) { + updateHash(result, childResult); + } + return result; + } + + size_t transport(const typename T::Disjunction& node, const std::vector<size_t> childResults) { + size_t result = 29; + for (const size_t childResult : childResults) { + updateHash(result, childResult); + } + return result; + } + + size_t compute(const typename T::Node& intervals) { + return algebra::transport<false>(intervals, *this); + } +}; + +static size_t computePartialSchemaReqHash(const PartialSchemaRequirements& reqMap) { + size_t result = 17; + + IntervalHasher<IntervalReqExpr> intervalHasher; + for (const auto& [key, req] : reqMap) { + updateHash(result, std::hash<ProjectionName>()(key._projectionName)); + updateHash(result, ABTHashGenerator::generate(key._path)); + updateHash(result, std::hash<ProjectionName>()(req.getBoundProjectionName())); + updateHash(result, intervalHasher.compute(req.getIntervals())); + } + return result; +} + +static size_t computeCandidateIndexMapHash(const CandidateIndexMap& map) { + size_t result = 17; + + IntervalHasher<MultiKeyIntervalReqExpr> intervalHasher; + for (const auto& [indexDefName, candidateIndexEntry] : map) { + updateHash(result, std::hash<std::string>()(indexDefName)); + + { + const auto& fieldProjectionMap = candidateIndexEntry._fieldProjectionMap; + updateHash(result, std::hash<ProjectionName>()(fieldProjectionMap._ridProjection)); + updateHash(result, std::hash<ProjectionName>()(fieldProjectionMap._rootProjection)); + for (const auto& fieldProjectionMapEntry : fieldProjectionMap._fieldProjections) { + updateHash(result, std::hash<FieldNameType>()(fieldProjectionMapEntry.first)); + updateHash(result, std::hash<ProjectionName>()(fieldProjectionMapEntry.second)); + } + } + updateHash(result, intervalHasher.compute(candidateIndexEntry._intervals)); + } + + return result; +} + +/** + * Hasher for ABT nodes. Used in conjunction with memo. + */ +class ABTHashTransporter { +public: + /** + * Nodes + */ + template <typename T, typename... Ts> + size_t transport(const T& /*node*/, Ts&&...) { + // Physical nodes do not currently need to implement hash. + static_assert(!canBeLogicalNode<T>(), "Logical node must implement its hash."); + uasserted(6624142, "must implement custom hash"); + } + + size_t transport(const References& references, std::vector<size_t> inResults) { + return computeHashSeq<1>(computeVectorHash(inResults)); + } + + size_t transport(const ExpressionBinder& binders, std::vector<size_t> inResults) { + return computeHashSeq<2>(computeVectorHash(binders.names()), computeVectorHash(inResults)); + } + + size_t transport(const ScanNode& node, size_t bindResult) { + return computeHashSeq<3>(std::hash<std::string>()(node.getScanDefName()), bindResult); + } + + size_t transport(const ValueScanNode& node, size_t bindResult) { + return computeHashSeq<46>(std::hash<size_t>()(node.getArraySize()), + ABTHashGenerator::generate(node.getValueArray()), + bindResult); + } + + size_t transport(const MemoLogicalDelegatorNode& node) { + return computeHashSeq<4>(std::hash<GroupIdType>()(node.getGroupId())); + } + + size_t transport(const FilterNode& node, size_t childResult, size_t filterResult) { + return computeHashSeq<5>(filterResult, childResult); + } + + size_t transport(const EvaluationNode& node, size_t childResult, size_t projectionResult) { + return computeHashSeq<6>(projectionResult, childResult); + } + + size_t transport(const SargableNode& node, + size_t childResult, + size_t /*bindResult*/, + size_t /*refResult*/) { + return computeHashSeq<44>(computePartialSchemaReqHash(node.getReqMap()), + computeCandidateIndexMapHash(node.getCandidateIndexMap()), + std::hash<IndexReqTarget>()(node.getTarget()), + childResult); + } + + size_t transport(const RIDIntersectNode& node, + size_t leftChildResult, + size_t rightChildResult) { + // Specifically always including children. + return computeHashSeq<45>(std::hash<ProjectionName>()(node.getScanProjectionName()), + std::hash<bool>()(node.hasLeftIntervals()), + std::hash<bool>()(node.hasRightIntervals()), + leftChildResult, + rightChildResult); + } + + size_t transport(const BinaryJoinNode& node, + size_t leftChildResult, + size_t rightChildResult, + size_t filterResult) { + // Specifically always including children. + return computeHashSeq<7>(filterResult, leftChildResult, rightChildResult); + } + + size_t transport(const UnionNode& node, + std::vector<size_t> childResults, + size_t bindResult, + size_t refsResult) { + // Specifically always including children. + return computeHashSeq<9>(bindResult, refsResult, computeVectorHash(childResults)); + } + + size_t transport(const GroupByNode& node, + size_t childResult, + size_t bindAggResult, + size_t refsAggResult, + size_t bindGbResult, + size_t refsGbResult) { + return computeHashSeq<10>(bindAggResult, + refsAggResult, + bindGbResult, + refsGbResult, + std::hash<GroupNodeType>()(node.getType()), + childResult); + } + + size_t transport(const UnwindNode& node, + size_t childResult, + size_t bindResult, + size_t refsResult) { + return computeHashSeq<11>( + std::hash<bool>()(node.getRetainNonArrays()), bindResult, refsResult, childResult); + } + + size_t transport(const CollationNode& node, size_t childResult, size_t /*refsResult*/) { + return computeHashSeq<13>(computeCollationHash(node.getProperty()), childResult); + } + + size_t transport(const LimitSkipNode& node, size_t childResult) { + return computeHashSeq<14>(computeLimitSkipHash(node.getProperty()), childResult); + } + + size_t transport(const ExchangeNode& node, size_t childResult, size_t /*refsResult*/) { + return computeHashSeq<43>(computeDistributionHash(node.getProperty()), childResult); + } + + size_t transport(const RootNode& node, size_t childResult, size_t /*refsResult*/) { + return computeHashSeq<15>(computeProjectionRequirementHash(node.getProperty()), + childResult); + } + + /** + * Expressions + */ + size_t transport(const Blackhole& expr) { + return computeHashSeq<16>(); + } + + size_t transport(const Constant& expr) { + auto [tag, val] = expr.get(); + return computeHashSeq<17>(sbe::value::hashValue(tag, val)); + } + + size_t transport(const Variable& expr) { + return computeHashSeq<18>(std::hash<std::string>()(expr.name())); + } + + size_t transport(const UnaryOp& expr, size_t inResult) { + return computeHashSeq<19>(std::hash<Operations>()(expr.op()), inResult); + } + + size_t transport(const BinaryOp& expr, size_t leftResult, size_t rightResult) { + return computeHashSeq<20>(std::hash<Operations>()(expr.op()), leftResult, rightResult); + } + + size_t transport(const If& expr, size_t condResult, size_t thenResult, size_t elseResult) { + return computeHashSeq<21>(condResult, thenResult, elseResult); + } + + size_t transport(const Let& expr, size_t bindResult, size_t exprResult) { + return computeHashSeq<22>(std::hash<std::string>()(expr.varName()), bindResult, exprResult); + } + + size_t transport(const LambdaAbstraction& expr, size_t inResult) { + return computeHashSeq<23>(std::hash<std::string>()(expr.varName()), inResult); + } + + size_t transport(const LambdaApplication& expr, size_t lambdaResult, size_t argumentResult) { + return computeHashSeq<24>(lambdaResult, argumentResult); + } + + size_t transport(const FunctionCall& expr, std::vector<size_t> argResults) { + return computeHashSeq<25>(std::hash<std::string>()(expr.name()), + computeVectorHash(argResults)); + } + + size_t transport(const EvalPath& expr, size_t pathResult, size_t inputResult) { + return computeHashSeq<26>(pathResult, inputResult); + } + + size_t transport(const EvalFilter& expr, size_t pathResult, size_t inputResult) { + return computeHashSeq<27>(pathResult, inputResult); + } + + size_t transport(const Source& expr) { + return computeHashSeq<28>(); + } + + /** + * Paths + */ + size_t transport(const PathConstant& path, size_t inResult) { + return computeHashSeq<29>(inResult); + } + + size_t transport(const PathLambda& path, size_t inResult) { + return computeHashSeq<30>(inResult); + } + + size_t transport(const PathIdentity& path) { + return computeHashSeq<31>(); + } + + size_t transport(const PathDefault& path, size_t inResult) { + return computeHashSeq<32>(inResult); + } + + size_t transport(const PathCompare& path, size_t valueResult) { + return computeHashSeq<33>(std::hash<Operations>()(path.op()), valueResult); + } + + size_t transport(const PathDrop& path) { + size_t namesHash = 17; + for (const std::string& name : path.getNames()) { + updateHash(namesHash, std::hash<std::string>()(name)); + } + return computeHashSeq<34>(namesHash); + } + + size_t transport(const PathKeep& path) { + size_t namesHash = 17; + for (const std::string& name : path.getNames()) { + updateHash(namesHash, std::hash<std::string>()(name)); + } + return computeHashSeq<35>(namesHash); + } + + size_t transport(const PathObj& path) { + return computeHashSeq<36>(); + } + + size_t transport(const PathArr& path) { + return computeHashSeq<37>(); + } + + size_t transport(const PathTraverse& path, size_t inResult) { + return computeHashSeq<38>(inResult); + } + + size_t transport(const PathField& path, size_t inResult) { + return computeHashSeq<39>(std::hash<std::string>()(path.name()), inResult); + } + + size_t transport(const PathGet& path, size_t inResult) { + return computeHashSeq<40>(std::hash<std::string>()(path.name()), inResult); + } + + size_t transport(const PathComposeM& path, size_t leftResult, size_t rightResult) { + return computeHashSeq<41>(leftResult, rightResult); + } + + size_t transport(const PathComposeA& path, size_t leftResult, size_t rightResult) { + return computeHashSeq<42>(leftResult, rightResult); + } + + size_t generate(const ABT& node) { + return algebra::transport<false>(node, *this); + } + + size_t generate(const ABT::reference_type& nodeRef) { + return algebra::transport<false>(nodeRef, *this); + } +}; + +size_t ABTHashGenerator::generate(const ABT& node) { + ABTHashTransporter gen; + return gen.generate(node); +} + +size_t ABTHashGenerator::generate(const ABT::reference_type& nodeRef) { + ABTHashTransporter gen; + return gen.generate(nodeRef); +} + +class PhysPropsHasher { +public: + size_t operator()(const properties::PhysProperty&, + const properties::CollationRequirement& prop) { + return computeHashSeq<1>(computeCollationHash(prop)); + } + + size_t operator()(const properties::PhysProperty&, + const properties::LimitSkipRequirement& prop) { + return computeHashSeq<2>(computeLimitSkipHash(prop)); + } + + size_t operator()(const properties::PhysProperty&, + const properties::ProjectionRequirement& prop) { + return computeHashSeq<3>(computeProjectionRequirementHash(prop)); + } + + size_t operator()(const properties::PhysProperty&, + const properties::DistributionRequirement& prop) { + return computeHashSeq<4>(computeDistributionHash(prop)); + } + + size_t operator()(const properties::PhysProperty&, + const properties::IndexingRequirement& prop) { + return computeHashSeq<5>(std::hash<IndexReqTarget>()(prop.getIndexReqTarget()), + std::hash<bool>()(prop.getNeedsRID()), + std::hash<bool>()(prop.getDedupRID())); + } + + size_t operator()(const properties::PhysProperty&, const properties::RepetitionEstimate& prop) { + return computeHashSeq<6>(std::hash<CEType>()(prop.getEstimate())); + } + + size_t operator()(const properties::PhysProperty&, const properties::LimitEstimate& prop) { + return computeHashSeq<7>(std::hash<CEType>()(prop.getEstimate())); + } + + static size_t computeHash(const properties::PhysProps& props) { + PhysPropsHasher visitor; + size_t result = 17; + for (const auto& prop : props) { + updateHashUnordered(result, prop.second.visit(visitor)); + } + return result; + } +}; + +size_t ABTHashGenerator::generateForPhysProps(const properties::PhysProps& props) { + return PhysPropsHasher::computeHash(props); +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/abt_hash.h b/src/mongo/db/query/optimizer/utils/abt_hash.h new file mode 100644 index 00000000000..1f097bed3a1 --- /dev/null +++ b/src/mongo/db/query/optimizer/utils/abt_hash.h @@ -0,0 +1,45 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/props.h" +#include "mongo/db/query/optimizer/syntax/syntax.h" + +namespace mongo::optimizer { + +class ABTHashGenerator { +public: + static size_t generate(const ABT& node); + static size_t generate(const ABT::reference_type& nodeRef); + + static size_t generateForPhysProps(const properties::PhysProps& props); +}; + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/interval_utils.cpp b/src/mongo/db/query/optimizer/utils/interval_utils.cpp new file mode 100644 index 00000000000..5a0f11e3299 --- /dev/null +++ b/src/mongo/db/query/optimizer/utils/interval_utils.cpp @@ -0,0 +1,330 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/utils/interval_utils.h" + +#include "mongo/db/exec/sbe/values/value.h" +#include "mongo/db/query/optimizer/node.h" +#include "mongo/db/query/optimizer/rewrites/const_eval.h" + + +namespace mongo::optimizer { + +void combineIntervalsDNF(const bool intersect, + IntervalReqExpr::Node& target, + const IntervalReqExpr::Node& source) { + if (target == source) { + // Intervals are the same. Leave target unchanged. + return; + } + + if (isIntervalReqFullyOpenDNF(target)) { + // Intersecting with fully open interval is redundant. + // Unioning with fully open interval results in a fully-open interval. + if (intersect) { + target = source; + } + return; + } + + if (isIntervalReqFullyOpenDNF(source)) { + // Intersecting with fully open interval is redundant. + // Unioning with fully open interval results in a fully-open interval. + if (!intersect) { + target = source; + } + return; + } + + IntervalReqExpr::NodeVector newDisjunction; + // Integrate both compound bounds. + if (intersect) { + // Intersection is analogous to polynomial multiplication. Using '.' to denote intersection + // and '+' to denote union. (a.b + c.d) . (e+f) = a.b.e + c.d.e + a.b.f + c.d.f + // TODO: in certain cases we can simplify further. For example if we only have scalars, we + // can simplify (-inf, 10) ^ (5, +inf) to (5, 10), but this does not work with arrays. + + for (const auto& sourceConjunction : source.cast<IntervalReqExpr::Disjunction>()->nodes()) { + const auto& sourceConjunctionIntervals = + sourceConjunction.cast<IntervalReqExpr::Conjunction>()->nodes(); + for (const auto& targetConjunction : + target.cast<IntervalReqExpr::Disjunction>()->nodes()) { + // TODO: handle case with targetConjunct fully open + // TODO: handle case with targetConjunct half-open and sourceConjuct equality. + // TODO: handle case with both targetConjunct and sourceConjuct equalities + // (different consts). + + auto newConjunctionIntervals = + targetConjunction.cast<IntervalReqExpr::Conjunction>()->nodes(); + std::copy(sourceConjunctionIntervals.cbegin(), + sourceConjunctionIntervals.cend(), + std::back_inserter(newConjunctionIntervals)); + newDisjunction.emplace_back(IntervalReqExpr::make<IntervalReqExpr::Conjunction>( + std::move(newConjunctionIntervals))); + } + } + } else { + // Unioning is analogous to polynomial addition. + // (a.b + c.d) + (e+f) = a.b + c.d + e + f + newDisjunction = target.cast<IntervalReqExpr::Disjunction>()->nodes(); + for (const auto& sourceConjunction : source.cast<IntervalReqExpr::Disjunction>()->nodes()) { + newDisjunction.push_back(sourceConjunction); + } + } + target = IntervalReqExpr::make<IntervalReqExpr::Disjunction>(std::move(newDisjunction)); +} + +std::vector<IntervalRequirement> intersectIntervals(const IntervalRequirement& i1, + const IntervalRequirement& i2) { + // Handle trivial cases of intersection. + if (i1.isFullyOpen()) { + return {i2}; + } + if (i2.isFullyOpen()) { + return {i1}; + } + + const ABT low1 = + i1.getLowBound().isInfinite() ? Constant::minKey() : i1.getLowBound().getBound(); + const ABT high1 = + i1.getHighBound().isInfinite() ? Constant::maxKey() : i1.getHighBound().getBound(); + const ABT low2 = + i2.getLowBound().isInfinite() ? Constant::minKey() : i2.getLowBound().getBound(); + const ABT high2 = + i2.getHighBound().isInfinite() ? Constant::maxKey() : i2.getHighBound().getBound(); + + const auto foldFn = [](ABT expr) { + // Performs constant folding. + VariableEnvironment env = VariableEnvironment::build(expr); + ConstEval instance(env); + instance.optimize(expr); + return expr; + }; + const auto minMaxFn = [](const Operations op, const ABT& v1, const ABT& v2) { + // Encodes max(v1, v2). + return make<If>(make<BinaryOp>(op, v1, v2), v1, v2); + }; + const auto minMaxFn1 = [](const Operations op, const ABT& v1, const ABT& v2, const ABT& v3) { + // Encodes v1 op v2 ? v3 : v2 + return make<If>(make<BinaryOp>(op, v1, v2), v3, v2); + }; + + // In the simplest case our bound is (max(low1, low2), min(high1, high2)) if none of the bounds + // are inclusive. + const ABT maxLow = foldFn(minMaxFn(Operations::Gte, low1, low2)); + const ABT minHigh = foldFn(minMaxFn(Operations::Lte, high1, high2)); + if (foldFn(make<BinaryOp>(Operations::Gt, maxLow, minHigh)) == Constant::boolean(true)) { + // Low bound is greater than high bound. + return {}; + } + + const bool low1Inc = i1.getLowBound().isInclusive(); + const bool high1Inc = i1.getHighBound().isInclusive(); + const bool low2Inc = i2.getLowBound().isInclusive(); + const bool high2Inc = i2.getHighBound().isInclusive(); + + // We form a "main" result interval which is closed on any side with "agreement" between the two + // intervals. For example [low1, high1] ^ [low2, high2) -> [max(low1, low2), min(high1, high2)) + BoundRequirement lowBoundMain = (maxLow == Constant::minKey()) + ? BoundRequirement::makeInfinite() + : BoundRequirement{low1Inc && low2Inc, maxLow}; + BoundRequirement highBoundMain = (minHigh == Constant::maxKey()) + ? BoundRequirement::makeInfinite() + : BoundRequirement{high1Inc && high2Inc, minHigh}; + + const bool boundsEqual = + foldFn(make<BinaryOp>(Operations::Eq, maxLow, minHigh)) == Constant::boolean(true); + if (boundsEqual) { + if (low1Inc && high1Inc && low2Inc && high2Inc) { + // Point interval. + return {{std::move(lowBoundMain), std::move(highBoundMain)}}; + } + if ((!low1Inc && !low2Inc) || (!high1Inc && !high2Inc)) { + // Fully open on both sides. + return {}; + } + } + if (low1Inc == low2Inc && high1Inc == high2Inc) { + // Inclusion matches on both sides. + return {{std::move(lowBoundMain), std::move(highBoundMain)}}; + } + + // At this point we have intervals without inclusion agreement, for example + // [low1, high1) ^ (low2, high2]. We have the main result which in this case is the open + // (max(low1, low2), min(high1, high2)). Then we add an extra closed interval for each side with + // disagreement. For example for the lower sides we add: [low2 >= low1 ? MaxKey : low1, + // min(max(low1, low2), min(high1, high2)] This is a closed interval which would reduce to + // [max(low1, low2), max(low1, low2)] if low1 < low2. If low2 >= low1 the interval reduces to an + // empty one [MaxKey, min(max(low1, low2), min(high1, high2)] which will return no results from + // an index scan. We do not know that in general if we do not have constants (we cannot fold). + // + // If we can fold the extra interval, we exploit the fact that (max(low1, low2), + // min(high1, high2)) U [max(low1, low2), max(low1, low2)] is [max(low1, low2), min(high1, + // high2)) (observe left side is now closed). Then we create a similar auxiliary interval for + // the right side if there is disagreement on the inclusion. Finally, we attempt to fold both + // intervals. Should we conclude definitively that they are point intervals, we update the + // inclusion of the main interval for the respective side. + + std::vector<IntervalRequirement> result; + const auto addAuxInterval = [&](ABT low, ABT high, BoundRequirement& bound) { + IntervalRequirement interval{{true, low}, {true, high}}; + + const ABT comparison = foldFn(make<BinaryOp>(Operations::Lte, low, high)); + if (comparison == Constant::boolean(true)) { + if (interval.isEquality()) { + // We can determine the two bounds are equal. + bound.setInclusive(true); + } else { + result.push_back(std::move(interval)); + } + } else if (!comparison.is<Constant>()) { + // We cannot determine statically how the two bounds compare. + result.push_back(std::move(interval)); + } + }; + + if (low1Inc != low2Inc) { + const ABT low = foldFn(minMaxFn1( + Operations::Gte, low1Inc ? low2 : low1, low1Inc ? low1 : low2, Constant::maxKey())); + const ABT high = foldFn(minMaxFn(Operations::Lte, maxLow, minHigh)); + addAuxInterval(std::move(low), std::move(high), lowBoundMain); + } + + if (high1Inc != high2Inc) { + const ABT low = foldFn(minMaxFn(Operations::Gte, maxLow, minHigh)); + const ABT high = foldFn(minMaxFn1(Operations::Lte, + high1Inc ? high2 : high1, + high1Inc ? high1 : high2, + Constant::minKey())); + addAuxInterval(std::move(low), std::move(high), highBoundMain); + } + + if (!boundsEqual || (lowBoundMain.isInclusive() && highBoundMain.isInclusive())) { + // We add the main interval to the result as long as it is a valid point interval, or the + // bounds are not equal. + result.emplace_back(std::move(lowBoundMain), std::move(highBoundMain)); + } + return result; +} + +boost::optional<IntervalReqExpr::Node> intersectDNFIntervals( + const IntervalReqExpr::Node& intervalDNF) { + IntervalReqExpr::NodeVector disjuncts; + + for (const auto& disjunct : intervalDNF.cast<IntervalReqExpr::Disjunction>()->nodes()) { + const auto& conjuncts = disjunct.cast<IntervalReqExpr::Conjunction>()->nodes(); + uassert(6624149, "Empty disjunct in interval DNF.", !conjuncts.empty()); + + std::vector<IntervalRequirement> intersectedIntervalDisjunction; + bool isEmpty = false; + bool isFirst = true; + + for (const auto& conjunct : conjuncts) { + const auto& interval = conjunct.cast<IntervalReqExpr::Atom>()->getExpr(); + if (isFirst) { + isFirst = false; + intersectedIntervalDisjunction = {interval}; + } else { + std::vector<IntervalRequirement> newResult; + for (const auto& intersectedInterval : intersectedIntervalDisjunction) { + auto intersectionResult = intersectIntervals(intersectedInterval, interval); + newResult.insert( + newResult.end(), intersectionResult.cbegin(), intersectionResult.cend()); + } + if (newResult.empty()) { + // The intersection is empty, there is no need to process the remaining + // conjuncts + isEmpty = true; + break; + } + std::swap(intersectedIntervalDisjunction, newResult); + } + } + if (isEmpty) { + continue; // The whole conjunct is false (empty interval), skip it. + } + + for (const auto& interval : intersectedIntervalDisjunction) { + auto conjunction = + IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::makeSeq( + IntervalReqExpr::make<IntervalReqExpr::Atom>(std::move(interval)))); + disjuncts.emplace_back(conjunction); + } + } + + if (disjuncts.empty()) { + return {}; + } + return IntervalReqExpr::make<IntervalReqExpr::Disjunction>(std::move(disjuncts)); +} + +bool combineMultiKeyIntervalsDNF(MultiKeyIntervalReqExpr::Node& targetIntervals, + const IntervalReqExpr::Node& sourceIntervals) { + MultiKeyIntervalReqExpr::NodeVector newDisjunction; + + for (const auto& sourceConjunction : + sourceIntervals.cast<IntervalReqExpr::Disjunction>()->nodes()) { + for (const auto& targetConjunction : + targetIntervals.cast<MultiKeyIntervalReqExpr::Disjunction>()->nodes()) { + MultiKeyIntervalReqExpr::NodeVector newConjunction; + + for (const auto& sourceConjunct : + sourceConjunction.cast<IntervalReqExpr::Conjunction>()->nodes()) { + const auto& sourceInterval = + sourceConjunct.cast<IntervalReqExpr::Atom>()->getExpr(); + for (const auto& targetConjunct : + targetConjunction.cast<MultiKeyIntervalReqExpr::Conjunction>()->nodes()) { + const auto& targetInterval = + targetConjunct.cast<MultiKeyIntervalReqExpr::Atom>()->getExpr(); + if (!targetInterval.empty() && !targetInterval.back().isEquality() && + !sourceInterval.isFullyOpen()) { + // We do not have an equality prefix. Reject. + return {}; + } + + auto newInterval = targetInterval; + newInterval.push_back(sourceInterval); + newConjunction.emplace_back( + MultiKeyIntervalReqExpr::make<MultiKeyIntervalReqExpr::Atom>( + std::move(newInterval))); + } + } + + newDisjunction.emplace_back( + MultiKeyIntervalReqExpr::make<MultiKeyIntervalReqExpr::Conjunction>( + std::move(newConjunction))); + } + } + + targetIntervals = MultiKeyIntervalReqExpr::make<MultiKeyIntervalReqExpr::Disjunction>( + std::move(newDisjunction)); + return true; +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/interval_utils.h b/src/mongo/db/query/optimizer/utils/interval_utils.h new file mode 100644 index 00000000000..4186e0aa276 --- /dev/null +++ b/src/mongo/db/query/optimizer/utils/interval_utils.h @@ -0,0 +1,68 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/index_bounds.h" + +namespace mongo::optimizer { + +/** + * Intersects or unions two intervals without simplification which might depend on multi-keyness. + * Currently assumes intervals are in DNF. + * TODO: handle generic interval expressions (not necessarily DNF). + */ +void combineIntervalsDNF(bool intersect, + IntervalReqExpr::Node& target, + const IntervalReqExpr::Node& source); + +/** + * Intersect all intervals within each conjunction of intervals in a disjunction of intervals. + * Notice that all intervals reference the same path (which is an index field). + * Return a DNF of the intersected intervals, where there is at most one interval inside each + * conjunct. If the resulting interval is empty, return boost::none. + * The intervals themselves can contain Constants, Variables, or arbitrary arithmetic expressions. + * TODO: handle generic interval expressions (not necessarily DNF). + */ +boost::optional<IntervalReqExpr::Node> intersectDNFIntervals( + const IntervalReqExpr::Node& intervalDNF); + +/** + * Combines a source interval over a single path with a target multi-component interval. The + * multi-component interval is extended to contain an extra field. The resulting multi-component + * interval defined the boundaries over the index component used by the index access execution + * operator. If we fail to combine, the target multi-key interval is left unchanged. + * Currently we only support a single "equality prefix": 0+ equalities followed by at most + * inequality, and trailing open intervals. + * TODO: support Recursive Index Navigation. + */ +bool combineMultiKeyIntervalsDNF(MultiKeyIntervalReqExpr::Node& targetIntervals, + const IntervalReqExpr::Node& sourceIntervals); + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/memo_utils.cpp b/src/mongo/db/query/optimizer/utils/memo_utils.cpp new file mode 100644 index 00000000000..9dfa3e595a2 --- /dev/null +++ b/src/mongo/db/query/optimizer/utils/memo_utils.cpp @@ -0,0 +1,176 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/utils/memo_utils.h" + +#include "mongo/db/query/optimizer/cascades/memo.h" + + +namespace mongo::optimizer { + +ABT wrapConstFilter(ABT node) { + return make<FilterNode>(Constant::boolean(true), std::move(node)); +} + +ABT unwrapConstFilter(ABT node) { + if (auto nodePtr = node.cast<FilterNode>(); + nodePtr != nullptr && nodePtr->getFilter() == Constant::boolean(true)) { + return nodePtr->getChild(); + } + return node; +} + +class MemoLatestPlanExtractor { +public: + explicit MemoLatestPlanExtractor(const cascades::Memo& memo) : _memo(memo) {} + + /** + * Logical delegator node. + */ + void transport(ABT& n, + const MemoLogicalDelegatorNode& node, + opt::unordered_set<GroupIdType>& visitedGroups) { + n = extractLatest(node.getGroupId(), visitedGroups); + } + + /** + * Other ABT types. + */ + template <typename T, typename... Ts> + void transport(ABT& /*n*/, + const T& /*node*/, + opt::unordered_set<GroupIdType>& visitedGroups, + Ts&&...) { + // noop + } + + ABT extractLatest(const GroupIdType groupId, opt::unordered_set<GroupIdType>& visitedGroups) { + const cascades::Group& group = _memo.getGroup(groupId); + if (!visitedGroups.insert(groupId).second) { + const GroupIdType scanGroupId = + properties::getPropertyConst<properties::IndexingAvailability>( + group._logicalProperties) + .getScanGroupId(); + uassert( + 6624357, "Visited the same non-scan group more than once", groupId == scanGroupId); + } + + ABT rootNode = group._logicalNodes.getVector().back(); + algebra::transport<true>(rootNode, *this, visitedGroups); + return rootNode; + } + +private: + const cascades::Memo& _memo; +}; + +ABT extractLatestPlan(const cascades::Memo& memo, const GroupIdType rootGroupId) { + MemoLatestPlanExtractor extractor(memo); + opt::unordered_set<GroupIdType> visitedGroups; + return extractor.extractLatest(rootGroupId, visitedGroups); +} + +class MemoPhysicalPlanExtractor { +public: + explicit MemoPhysicalPlanExtractor(const cascades::Memo& memo, + const Metadata& metadata, + NodeToGroupPropsMap& nodeToGroupPropsMap) + : _memo(memo), + _metadata(metadata), + _nodeToGroupPropsMap(nodeToGroupPropsMap), + _planNodeId(0) {} + + /** + * Physical delegator node. + */ + void transport(ABT& n, const MemoPhysicalDelegatorNode& node, const MemoPhysicalNodeId /*id*/) { + n = extract(node.getNodeId()); + } + + /** + * Other ABT types. + */ + template <typename T, typename... Ts> + void transport(ABT& /*n*/, const T& node, MemoPhysicalNodeId id, Ts&&...) { + if constexpr (std::is_base_of_v<Node, T>) { + using namespace properties; + + const cascades::Group& group = _memo.getGroup(id._groupId); + const auto& physicalResult = group._physicalNodes.at(id._index); + const auto& nodeInfo = *physicalResult._nodeInfo; + + LogicalProps logicalProps = group._logicalProperties; + PhysProps physProps = physicalResult._physProps; + if (!_metadata.isParallelExecution()) { + // Do not display availability and requirement if under centralized setting. + removeProperty<DistributionAvailability>(logicalProps); + removeProperty<DistributionRequirement>(physProps); + } + + _nodeToGroupPropsMap.emplace(&node, + NodeProps{_planNodeId++, + id, + std::move(logicalProps), + std::move(physProps), + nodeInfo._cost, + nodeInfo._localCost, + nodeInfo._adjustedCE}); + } + } + + ABT extract(const MemoPhysicalNodeId nodeId) { + const auto& result = _memo.getGroup(nodeId._groupId)._physicalNodes.at(nodeId._index); + uassert(6624143, + "Physical delegator must be pointing to an optimized result.", + result._nodeInfo.has_value()); + ABT node = result._nodeInfo->_node; + + algebra::transport<true>(node, *this, nodeId); + return node; + } + +private: + // We don't own this. + const cascades::Memo& _memo; + const Metadata& _metadata; + NodeToGroupPropsMap& _nodeToGroupPropsMap; + + int32_t _planNodeId; +}; + +std::pair<ABT, NodeToGroupPropsMap> extractPhysicalPlan(const MemoPhysicalNodeId id, + const Metadata& metadata, + const cascades::Memo& memo) { + NodeToGroupPropsMap resultMap; + MemoPhysicalPlanExtractor extractor(memo, metadata, resultMap); + ABT resultNode = extractor.extract(id); + return {std::move(resultNode), std::move(resultMap)}; +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/memo_utils.h b/src/mongo/db/query/optimizer/utils/memo_utils.h new file mode 100644 index 00000000000..af28a2c3ef8 --- /dev/null +++ b/src/mongo/db/query/optimizer/utils/memo_utils.h @@ -0,0 +1,75 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/cascades/memo.h" +#include "mongo/db/query/optimizer/metadata.h" +#include "mongo/db/query/optimizer/node_defs.h" +#include "mongo/db/query/optimizer/utils/utils.h" + + +namespace mongo::optimizer { + +ABT wrapConstFilter(ABT node); +ABT unwrapConstFilter(ABT node); + +template <class ToAddType, class ToRemoveType> +static void addRemoveProjectionsToProperties(properties::PhysProps& properties, + const ToAddType& toAdd, + const ToRemoveType& toRemove) { + ProjectionNameOrderPreservingSet& projections = + properties::getProperty<properties::ProjectionRequirement>(properties).getProjections(); + for (const auto& varName : toRemove) { + projections.erase(varName); + } + for (const auto& varName : toAdd) { + projections.emplace_back(varName); + } +} + +template <class ToAddType> +static void addProjectionsToProperties(properties::PhysProps& properties, const ToAddType& toAdd) { + addRemoveProjectionsToProperties(properties, toAdd, ToAddType{}); +} + +/** + * Extracts the "latest" logical plan. Starting from the root group, we follow the last logical + * nodes. + */ +ABT extractLatestPlan(const cascades::Memo& memo, GroupIdType rootGroupId); + +/** + * Extracts a complete physical plan by inlining references to MemoPhysicalPlanNode. + */ +std::pair<ABT, NodeToGroupPropsMap> extractPhysicalPlan(MemoPhysicalNodeId id, + const Metadata& metadata, + const cascades::Memo& memo); + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/printable_enum.h b/src/mongo/db/query/optimizer/utils/printable_enum.h new file mode 100644 index 00000000000..873187f5a25 --- /dev/null +++ b/src/mongo/db/query/optimizer/utils/printable_enum.h @@ -0,0 +1,48 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +namespace mongo::optimizer { + +// Declares a helper macro to print an enmumerator in an enum. +#define PRINTABLE_ENUMERATOR(value) value, +// Declares a helper macro to print an entry in a string array for an enum. +#define PRINTABLE_ENUM_STRING(value) #value, + +// Makes an enum class with a name based on a LIST macro of enumerators. +#define MAKE_PRINTABLE_ENUM(name, LIST) enum class name { LIST(PRINTABLE_ENUMERATOR) }; + +// Makes an array of enum names with a name based on a LIST macro of enumerators. +#define MAKE_PRINTABLE_ENUM_STRING_ARRAY(nspace, name, LIST) \ + namespace nspace { \ + constexpr const char* toString[] = {LIST(PRINTABLE_ENUM_STRING)}; \ + } // namespace nspace + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp b/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp new file mode 100644 index 00000000000..2f5252e573f --- /dev/null +++ b/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp @@ -0,0 +1,157 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/utils/unit_test_utils.h" + +#include "mongo/db/pipeline/abt/abt_document_source_visitor.h" +#include "mongo/db/pipeline/expression_context_for_test.h" +#include "mongo/db/query/optimizer/explain.h" +#include "mongo/db/query/optimizer/metadata.h" +#include "mongo/db/query/optimizer/node.h" +#include "mongo/unittest/temp_dir.h" +#include "mongo/unittest/unittest.h" + +namespace mongo::optimizer { + +static constexpr bool kDebugAsserts = false; + +void maybePrintABT(const ABT& abt) { + // Always print using the supported versions to make sure we don't crash. + const std::string strV1 = ExplainGenerator::explain(abt); + const std::string strV2 = ExplainGenerator::explainV2(abt); + auto [tag, val] = ExplainGenerator::explainBSON(abt); + sbe::value::ValueGuard vg(tag, val); + + if constexpr (kDebugAsserts) { + std::cout << "V1: " << strV1 << "\n"; + std::cout << "V2: " << strV2 << "\n"; + std::cout << "BSON: " << ExplainGenerator::printBSON(tag, val) << "\n"; + } +} + +ABT makeIndexPath(FieldPathType fieldPath, bool isMultiKey) { + ABT result = make<PathIdentity>(); + + for (size_t i = fieldPath.size(); i-- > 0;) { + if (isMultiKey) { + result = make<PathTraverse>(std::move(result)); + } + result = make<PathGet>(std::move(fieldPath.at(i)), std::move(result)); + } + + return result; +} + +ABT makeIndexPath(FieldNameType fieldName) { + return makeIndexPath(FieldPathType{std::move(fieldName)}); +} + +ABT makeNonMultikeyIndexPath(FieldNameType fieldName) { + return makeIndexPath(FieldPathType{std::move(fieldName)}, false /*isMultiKey*/); +} + +IndexDefinition makeIndexDefinition(FieldNameType fieldName, CollationOp op, bool isMultiKey) { + IndexCollationSpec idxCollSpec{ + IndexCollationEntry((isMultiKey ? makeIndexPath(std::move(fieldName)) + : makeNonMultikeyIndexPath(std::move(fieldName))), + op)}; + return IndexDefinition{std::move(idxCollSpec), isMultiKey}; +} + +IndexDefinition makeCompositeIndexDefinition(std::vector<TestIndexField> indexFields, + bool isMultiKey) { + IndexCollationSpec idxCollSpec; + for (auto& idxField : indexFields) { + idxCollSpec.emplace_back((idxField.isMultiKey + ? makeIndexPath(std::move(idxField.fieldName)) + : makeNonMultikeyIndexPath(std::move(idxField.fieldName))), + idxField.op); + } + return IndexDefinition{std::move(idxCollSpec), isMultiKey}; +} + +std::unique_ptr<mongo::Pipeline, mongo::PipelineDeleter> parsePipeline( + const NamespaceString& nss, + const std::string& inputPipeline, + OperationContextNoop& opCtx, + const std::vector<ExpressionContext::ResolvedNamespace>& involvedNss) { + const BSONObj inputBson = fromjson("{pipeline: " + inputPipeline + "}"); + + std::vector<BSONObj> rawPipeline; + for (auto&& stageElem : inputBson["pipeline"].Array()) { + ASSERT_EQUALS(stageElem.type(), BSONType::Object); + rawPipeline.push_back(stageElem.embeddedObject()); + } + + AggregateCommandRequest request(nss, rawPipeline); + boost::intrusive_ptr<ExpressionContextForTest> ctx( + new ExpressionContextForTest(&opCtx, request)); + + // Setup the resolved namespaces for other involved collections. + for (const auto& resolvedNss : involvedNss) { + ctx->setResolvedNamespace(resolvedNss.ns, resolvedNss); + } + + unittest::TempDir tempDir("ABTPipelineTest"); + ctx->tempDir = tempDir.path(); + + return Pipeline::parse(request.getPipeline(), ctx); +} + +ABT translatePipeline(const Metadata& metadata, + const std::string& pipelineStr, + ProjectionName scanProjName, + std::string scanDefName, + PrefixId& prefixId, + const std::vector<ExpressionContext::ResolvedNamespace>& involvedNss) { + OperationContextNoop opCtx; + auto pipeline = + parsePipeline(NamespaceString("a." + scanDefName), pipelineStr, opCtx, involvedNss); + return translatePipelineToABT(metadata, + *pipeline.get(), + scanProjName, + make<ScanNode>(scanProjName, std::move(scanDefName)), + prefixId); +} + +ABT translatePipeline(Metadata& metadata, + const std::string& pipelineStr, + std::string scanDefName, + PrefixId& prefixId) { + return translatePipeline( + metadata, pipelineStr, prefixId.getNextId("scan"), scanDefName, prefixId); +} + +ABT translatePipeline(const std::string& pipelineStr, std::string scanDefName) { + PrefixId prefixId; + Metadata metadata{{}}; + return translatePipeline(metadata, pipelineStr, std::move(scanDefName), prefixId); +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/unit_test_utils.h b/src/mongo/db/query/optimizer/utils/unit_test_utils.h new file mode 100644 index 00000000000..402a81618d8 --- /dev/null +++ b/src/mongo/db/query/optimizer/utils/unit_test_utils.h @@ -0,0 +1,99 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/operation_context_noop.h" +#include "mongo/db/pipeline/expression_context_for_test.h" +#include "mongo/db/pipeline/pipeline.h" +#include "mongo/db/query/optimizer/defs.h" +#include "mongo/db/query/optimizer/metadata.h" +#include "mongo/db/query/optimizer/utils/utils.h" + +namespace mongo::optimizer { + +void maybePrintABT(const ABT& abt); + +#define ASSERT_EXPLAIN(expected, abt) \ + maybePrintABT(abt); \ + ASSERT_EQ(expected, ExplainGenerator::explain(abt)) + +#define ASSERT_EXPLAIN_V2(expected, abt) \ + maybePrintABT(abt); \ + ASSERT_EQ(expected, ExplainGenerator::explainV2(abt)) + +#define ASSERT_EXPLAIN_BSON(expected, abt) \ + maybePrintABT(abt); \ + ASSERT_EQ(expected, ExplainGenerator::explainBSON(abt)) + +#define ASSERT_EXPLAIN_PROPS_V2(expected, phaseManager) \ + ASSERT_EQ(expected, \ + ExplainGenerator::explainV2( \ + make<MemoPhysicalDelegatorNode>(phaseManager.getPhysicalNodeId()), \ + true /*displayPhysicalProperties*/, \ + &phaseManager.getMemo())) + +#define ASSERT_EXPLAIN_MEMO(expected, memo) ASSERT_EQ(expected, ExplainGenerator::explainMemo(memo)) + +#define ASSERT_BETWEEN(a, b, value) \ + ASSERT_LTE(a, value); \ + ASSERT_GTE(b, value); + +struct TestIndexField { + FieldNameType fieldName; + CollationOp op; + bool isMultiKey; +}; + +ABT makeIndexPath(FieldPathType fieldPath, bool isMultiKey = true); + +ABT makeIndexPath(FieldNameType fieldName); +ABT makeNonMultikeyIndexPath(FieldNameType fieldName); + +IndexDefinition makeIndexDefinition(FieldNameType fieldName, + CollationOp op, + bool isMultiKey = true); +IndexDefinition makeCompositeIndexDefinition(std::vector<TestIndexField> indexFields, + bool isMultiKey = true); + +ABT translatePipeline(const Metadata& metadata, + const std::string& pipelineStr, + ProjectionName scanProjName, + std::string scanDefName, + PrefixId& prefixId, + const std::vector<ExpressionContext::ResolvedNamespace>& involvedNss = {}); + +ABT translatePipeline(Metadata& metadata, + const std::string& pipelineStr, + std::string scanDefName, + PrefixId& prefixId); + +ABT translatePipeline(const std::string& pipelineStr, std::string scanDefName = "collection"); + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/utils.cpp b/src/mongo/db/query/optimizer/utils/utils.cpp new file mode 100644 index 00000000000..5bf7d93db29 --- /dev/null +++ b/src/mongo/db/query/optimizer/utils/utils.cpp @@ -0,0 +1,1583 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/utils/utils.h" + +#include "boost/none.hpp" +#include "mongo/db/query/optimizer/metadata.h" +#include "mongo/db/query/optimizer/reference_tracker.h" +#include "mongo/db/query/optimizer/syntax/path.h" +#include "mongo/db/query/optimizer/utils/interval_utils.h" +#include "mongo/db/storage/storage_parameters_gen.h" +#include "mongo/util/assert_util.h" + +namespace mongo::optimizer { + +size_t roundUpToNextPow2(const size_t v, const size_t maxPower) { + if (v == 0) { + return 0; + } + + size_t result = 1; + for (size_t power = 0; result < v && power < maxPower; result <<= 1, power++) { + } + return result; +} + +std::vector<ABT::reference_type> collectComposed(const ABT& n) { + if (auto comp = n.cast<PathComposeM>(); comp) { + auto lhs = collectComposed(comp->getPath1()); + auto rhs = collectComposed(comp->getPath2()); + lhs.insert(lhs.end(), rhs.begin(), rhs.end()); + + return lhs; + } + + return {n.ref()}; +} + +FieldNameType getSimpleField(const ABT& node) { + const PathGet* pathGet = node.cast<PathGet>(); + return pathGet != nullptr ? pathGet->name() : ""; +} + +std::string PrefixId::getNextId(const std::string& key) { + std::ostringstream os; + os << key << "_" << _idCounterPerKey[key]++; + return os.str(); +} + +ProjectionNameOrderedSet convertToOrderedSet(ProjectionNameSet unordered) { + ProjectionNameOrderedSet ordered; + for (const ProjectionName& projection : unordered) { + ordered.emplace(projection); + } + return ordered; +} + +opt::unordered_set<FieldNameType> toUnorderedFieldNameSet(std::set<FieldNameType> set) { + opt::unordered_set<FieldNameType> result; + for (auto&& fieldName : set) { + result.emplace(std::move(fieldName)); + } + return result; +} + +void combineLimitSkipProperties(properties::LimitSkipRequirement& aboveProp, + const properties::LimitSkipRequirement& belowProp) { + using namespace properties; + + const int64_t newAbsLimit = std::min<int64_t>( + aboveProp.hasLimit() ? (belowProp.getSkip() + aboveProp.getAbsoluteLimit()) + : LimitSkipRequirement::kMaxVal, + std::max<int64_t>(0, + belowProp.hasLimit() + ? (belowProp.getAbsoluteLimit() - aboveProp.getSkip()) + : LimitSkipRequirement::kMaxVal)); + + const int64_t newLimit = (newAbsLimit == LimitSkipRequirement::kMaxVal) + ? LimitSkipRequirement::kMaxVal + : (newAbsLimit - belowProp.getSkip()); + const int64_t newSkip = (newLimit == 0) ? 0 : belowProp.getSkip(); + aboveProp = {newLimit, newSkip}; +} + +/** + * Used to track references originating from a set of properties. + */ +class PropertiesAffectedColumnsExtractor { +public: + template <class T> + void operator()(const properties::PhysProperty&, const T& prop) { + for (const ProjectionName& projection : prop.getAffectedProjectionNames()) { + _projections.insert(projection); + } + } + + static ProjectionNameSet extract(const properties::PhysProps& properties) { + PropertiesAffectedColumnsExtractor extractor; + for (const auto& entry : properties) { + entry.second.visit(extractor); + } + return extractor._projections; + } + +private: + ProjectionNameSet _projections; +}; + +ProjectionNameSet extractReferencedColumns(const properties::PhysProps& properties) { + return PropertiesAffectedColumnsExtractor::extract(properties); +} + +bool areMultiKeyIntervalsEqualities(const MultiKeyIntervalRequirement& intervals) { + for (const auto& interval : intervals) { + if (!interval.isEquality()) { + return false; + } + } + return true; +} + +CollationSplitResult splitCollationSpec(const ProjectionCollationSpec& collationSpec, + const ProjectionNameSet& leftProjections, + const ProjectionNameSet& rightProjections) { + bool leftSide = true; + ProjectionCollationSpec leftCollationSpec; + ProjectionCollationSpec rightCollationSpec; + + for (const auto& collationEntry : collationSpec) { + const ProjectionName& projectionName = collationEntry.first; + + if (leftProjections.count(projectionName) > 0) { + if (!leftSide) { + // Left and right projections must complement and form prefix and suffix. + return {}; + } + leftCollationSpec.push_back(collationEntry); + } else if (rightProjections.count(projectionName) > 0) { + if (leftSide) { + leftSide = false; + } + rightCollationSpec.push_back(collationEntry); + } else { + uasserted(6624146, + "Collation projection must appear in either the left or the right " + "child projections"); + return {}; + } + } + + return {true /*validSplit*/, std::move(leftCollationSpec), std::move(rightCollationSpec)}; +} +/** + * Helper class used to extract variable references from a node. + */ +class NodeVariableTracker { +public: + template <typename T, typename... Ts> + VariableNameSetType walk(const T&, Ts&&...) { + static_assert(!std::is_base_of_v<Node, T>, "Nodes must implement variable tracking"); + + // Default case: no variables. + return {}; + } + + VariableNameSetType walk(const ScanNode& /*node*/, const ABT& /*binds*/) { + return {}; + } + + VariableNameSetType walk(const ValueScanNode& /*node*/, const ABT& /*binds*/) { + return {}; + } + + VariableNameSetType walk(const PhysicalScanNode& /*node*/, const ABT& /*binds*/) { + return {}; + } + + VariableNameSetType walk(const CoScanNode& /*node*/) { + return {}; + } + + VariableNameSetType walk(const IndexScanNode& /*node*/, const ABT& /*binds*/) { + return {}; + } + + VariableNameSetType walk(const SeekNode& /*node*/, const ABT& /*binds*/, const ABT& refs) { + return extractFromABT(refs); + } + + VariableNameSetType walk(const MemoLogicalDelegatorNode& /*node*/) { + return {}; + } + + VariableNameSetType walk(const MemoPhysicalDelegatorNode& /*node*/) { + return {}; + } + + VariableNameSetType walk(const FilterNode& /*node*/, const ABT& /*child*/, const ABT& expr) { + return extractFromABT(expr); + } + + VariableNameSetType walk(const EvaluationNode& /*node*/, + const ABT& /*child*/, + const ABT& expr) { + return extractFromABT(expr); + } + + VariableNameSetType walk(const SargableNode& /*node*/, + const ABT& /*child*/, + const ABT& /*binds*/, + const ABT& refs) { + return extractFromABT(refs); + } + + VariableNameSetType walk(const RIDIntersectNode& /*node*/, + const ABT& /*leftChild*/, + const ABT& /*rightChild*/) { + return {}; + } + + VariableNameSetType walk(const BinaryJoinNode& /*node*/, + const ABT& /*leftChild*/, + const ABT& /*rightChild*/, + const ABT& expr) { + return extractFromABT(expr); + } + + VariableNameSetType walk(const HashJoinNode& /*node*/, + const ABT& /*leftChild*/, + const ABT& /*rightChild*/, + const ABT& refs) { + return extractFromABT(refs); + } + + VariableNameSetType walk(const MergeJoinNode& /*node*/, + const ABT& /*leftChild*/, + const ABT& /*rightChild*/, + const ABT& refs) { + return extractFromABT(refs); + } + + VariableNameSetType walk(const UnionNode& /*node*/, + const ABTVector& /*children*/, + const ABT& /*binder*/, + const ABT& refs) { + return extractFromABT(refs); + } + + VariableNameSetType walk(const GroupByNode& /*node*/, + const ABT& /*child*/, + const ABT& /*aggBinder*/, + const ABT& aggRefs, + const ABT& /*groupbyBinder*/, + const ABT& groupbyRefs) { + VariableNameSetType result; + extractFromABT(result, aggRefs); + extractFromABT(result, groupbyRefs); + return result; + } + + VariableNameSetType walk(const UnwindNode& /*node*/, + const ABT& /*child*/, + const ABT& /*binds*/, + const ABT& refs) { + return extractFromABT(refs); + } + + VariableNameSetType walk(const UniqueNode& /*node*/, const ABT& /*child*/, const ABT& refs) { + return extractFromABT(refs); + } + + VariableNameSetType walk(const CollationNode& /*node*/, const ABT& /*child*/, const ABT& refs) { + return extractFromABT(refs); + } + + VariableNameSetType walk(const LimitSkipNode& /*node*/, const ABT& /*child*/) { + return {}; + } + + VariableNameSetType walk(const ExchangeNode& /*node*/, const ABT& /*child*/, const ABT& refs) { + return extractFromABT(refs); + } + + VariableNameSetType walk(const RootNode& /*node*/, const ABT& /*child*/, const ABT& refs) { + return extractFromABT(refs); + } + + static VariableNameSetType collect(const ABT& n) { + NodeVariableTracker tracker; + return algebra::walk<false>(n, tracker); + } + +private: + void extractFromABT(VariableNameSetType& vars, const ABT& v) { + const auto& result = VariableEnvironment::getVariables(v); + for (const Variable* var : result._variables) { + if (result._definedVars.count(var->name()) == 0) { + // We are interested in either free variables, or variables defined on other nodes. + vars.insert(var->name()); + } + } + } + + VariableNameSetType extractFromABT(const ABT& v) { + VariableNameSetType result; + extractFromABT(result, v); + return result; + } +}; + +VariableNameSetType collectVariableReferences(const ABT& n) { + return NodeVariableTracker::collect(n); +} + +PartialSchemaReqConversion::PartialSchemaReqConversion() + : _success(false), + _bound(), + _reqMap(), + _hasIntersected(false), + _hasTraversed(false), + _hasEmptyInterval(false) {} + +PartialSchemaReqConversion::PartialSchemaReqConversion(PartialSchemaRequirements reqMap) + : _success(true), + _bound(), + _reqMap(std::move(reqMap)), + _hasIntersected(false), + _hasTraversed(false), + _hasEmptyInterval(false) {} + +PartialSchemaReqConversion::PartialSchemaReqConversion(ABT bound) + : _success(true), + _bound(std::move(bound)), + _reqMap(), + _hasIntersected(false), + _hasTraversed(false), + _hasEmptyInterval(false) {} + +/** + * Helper class that builds PartialSchemaRequirements property from an EvalFilter or an EvalPath. + */ +class PartialSchemaReqConverter { +public: + PartialSchemaReqConverter() = default; + + PartialSchemaReqConversion handleEvalPathAndEvalFilter(PartialSchemaReqConversion pathResult, + PartialSchemaReqConversion inputResult) { + if (!pathResult._success || !inputResult._success) { + return {}; + } + if (pathResult._bound.has_value() || !inputResult._bound.has_value() || + !inputResult._reqMap.empty()) { + return {}; + } + + if (auto boundPtr = inputResult._bound->cast<Variable>(); boundPtr != nullptr) { + const ProjectionName& boundVarName = boundPtr->name(); + PartialSchemaRequirements newMap; + + for (auto& [key, req] : pathResult._reqMap) { + if (!key._projectionName.empty()) { + return {}; + } + newMap.emplace(PartialSchemaKey{boundVarName, key._path}, std::move(req)); + } + + PartialSchemaReqConversion result{std::move(newMap)}; + result._hasEmptyInterval = pathResult._hasEmptyInterval; + return result; + } + + return {}; + } + + PartialSchemaReqConversion transport(const ABT& n, + const EvalPath& evalPath, + PartialSchemaReqConversion pathResult, + PartialSchemaReqConversion inputResult) { + return handleEvalPathAndEvalFilter(std::move(pathResult), std::move(inputResult)); + } + + PartialSchemaReqConversion transport(const ABT& n, + const EvalFilter& evalFilter, + PartialSchemaReqConversion pathResult, + PartialSchemaReqConversion inputResult) { + return handleEvalPathAndEvalFilter(std::move(pathResult), std::move(inputResult)); + } + + static PartialSchemaReqConversion handleComposition(const bool isMultiplicative, + PartialSchemaReqConversion leftResult, + PartialSchemaReqConversion rightResult) { + if (!leftResult._success || !rightResult._success) { + return {}; + } + if (leftResult._bound.has_value() || rightResult._bound.has_value()) { + return {}; + } + + auto& leftReq = leftResult._reqMap; + auto& rightReq = rightResult._reqMap; + if (isMultiplicative) { + { + ProjectionRenames projectionRenames; + if (!intersectPartialSchemaReq(leftReq, rightReq, projectionRenames)) { + return {}; + } + if (!projectionRenames.empty()) { + return {}; + } + } + + if (!leftResult._hasTraversed && !rightResult._hasTraversed) { + // Intersect intervals only if we have not traversed. E.g. (-inf, 90) ^ (70, +inf) + // becomes (70, 90). + for (auto& [key, req] : leftReq) { + auto newIntervals = intersectDNFIntervals(req.getIntervals()); + if (newIntervals) { + req.getIntervals() = std::move(newIntervals.get()); + } else { + leftResult._hasEmptyInterval = true; + break; + } + } + } else if (leftReq.size() > 1) { + // Reject if we have traversed, and composed requirements are more than one. + return {}; + } + + leftResult._hasIntersected = true; + return leftResult; + } + + // We can only perform additive composition (OR) if we have a single matching key on both + // sides. + if (leftReq.size() != 1 || rightReq.size() != 1) { + return {}; + } + + auto leftEntry = leftReq.begin(); + auto rightEntry = rightReq.begin(); + if (!(leftEntry->first == rightEntry->first)) { + return {}; + } + + combineIntervalsDNF(false /*intersect*/, + leftEntry->second.getIntervals(), + rightEntry->second.getIntervals()); + return leftResult; + } + + PartialSchemaReqConversion transport(const ABT& n, + const PathComposeM& pathComposeM, + PartialSchemaReqConversion leftResult, + PartialSchemaReqConversion rightResult) { + return handleComposition( + true /*isMultiplicative*/, std::move(leftResult), std::move(rightResult)); + } + + PartialSchemaReqConversion transport(const ABT& n, + const PathComposeA& pathComposeA, + PartialSchemaReqConversion leftResult, + PartialSchemaReqConversion rightResult) { + return handleComposition( + false /*isMultiplicative*/, std::move(leftResult), std::move(rightResult)); + } + + template <class T> + static PartialSchemaReqConversion handleGetAndTraverse(const ABT& n, + PartialSchemaReqConversion inputResult) { + if (!inputResult._success) { + return {}; + } + if (inputResult._bound.has_value()) { + return {}; + } + + // New map has keys with appended paths. + PartialSchemaRequirements newMap; + + for (auto& entry : inputResult._reqMap) { + const ProjectionName& projectionName = entry.first._projectionName; + if (!projectionName.empty()) { + return {}; + } + + ABT path = entry.first._path; + + // Updated key path to be now rooted at n, with existing key path as child. + ABT appendedPath = n; + std::swap(appendedPath.cast<T>()->getPath(), path); + std::swap(path, appendedPath); + + newMap.emplace(PartialSchemaKey{projectionName, std::move(path)}, + std::move(entry.second)); + } + + inputResult._reqMap = std::move(newMap); + return inputResult; + } + + PartialSchemaReqConversion transport(const ABT& n, + const PathGet& pathGet, + PartialSchemaReqConversion inputResult) { + return handleGetAndTraverse<PathGet>(n, std::move(inputResult)); + } + + PartialSchemaReqConversion transport(const ABT& n, + const PathTraverse& pathTraverse, + PartialSchemaReqConversion inputResult) { + if (inputResult._reqMap.size() > 1) { + // Cannot append traverse if we have more than one requirement. + return {}; + } + + PartialSchemaReqConversion result = + handleGetAndTraverse<PathTraverse>(n, std::move(inputResult)); + result._hasTraversed = true; + return result; + } + + PartialSchemaReqConversion transport(const ABT& n, + const PathCompare& pathCompare, + PartialSchemaReqConversion inputResult) { + if (!inputResult._success) { + return {}; + } + if (!inputResult._bound.has_value() || !inputResult._reqMap.empty()) { + return {}; + } + + const auto& bound = inputResult._bound; + bool lowBoundInclusive = false; + boost::optional<ABT> lowBound; + bool highBoundInclusive = false; + boost::optional<ABT> highBound; + + const Operations op = pathCompare.op(); + switch (op) { + case Operations::Eq: + lowBoundInclusive = true; + lowBound = bound; + highBoundInclusive = true; + highBound = bound; + break; + + case Operations::Lt: + case Operations::Lte: + lowBoundInclusive = false; + highBoundInclusive = op == Operations::Lte; + highBound = bound; + break; + + case Operations::Gt: + case Operations::Gte: + lowBoundInclusive = op == Operations::Gte; + lowBound = bound; + highBoundInclusive = false; + break; + + default: + // TODO handle other comparisons? + return {}; + } + + auto intervalExpr = IntervalReqExpr::makeSingularDNF(IntervalRequirement{ + {lowBoundInclusive, std::move(lowBound)}, {highBoundInclusive, std::move(highBound)}}); + return {PartialSchemaRequirements{ + {PartialSchemaKey{}, + PartialSchemaRequirement{"" /*boundProjectionName*/, std::move(intervalExpr)}}}}; + } + + PartialSchemaReqConversion transport(const ABT& n, const PathIdentity& pathIdentity) { + return {PartialSchemaRequirements{{{}, {}}}}; + } + + template <typename T, typename... Ts> + PartialSchemaReqConversion transport(const ABT& n, const T& node, Ts&&...) { + if constexpr (std::is_base_of_v<ExpressionSyntaxSort, T>) { + // We allow expressions to participate in bounds. + return {n}; + } + // General case. Reject conversion. + return {}; + } + + PartialSchemaReqConversion convert(const ABT& input) { + return algebra::transport<true>(input, *this); + } +}; + +PartialSchemaReqConversion convertExprToPartialSchemaReq(const ABT& expr) { + PartialSchemaReqConverter converter; + PartialSchemaReqConversion result = converter.convert(expr); + if (result._reqMap.empty()) { + result._success = false; + return result; + } + + for (const auto& entry : result._reqMap) { + if (entry.first.emptyPath() && isIntervalReqFullyOpenDNF(entry.second.getIntervals())) { + // We need to determine either path or interval (or both). + result._success = false; + return result; + } + } + return result; +} + +static bool intersectPartialSchemaReq(PartialSchemaRequirements& reqMap, + PartialSchemaKey key, + PartialSchemaRequirement req, + ProjectionRenames& projectionRenames) { + for (;;) { + bool merged = false; + PartialSchemaKey newKey; + PartialSchemaRequirement newReq; + + const bool reqHasBoundProj = req.hasBoundProjectionName(); + { + // Look for exact match on the path, and if found combine intervals. + auto it = reqMap.find(key); + if (it != reqMap.cend()) { + PartialSchemaRequirement& existingReq = it->second; + if (reqHasBoundProj) { + if (existingReq.hasBoundProjectionName()) { + projectionRenames.emplace(req.getBoundProjectionName(), + existingReq.getBoundProjectionName()); + } else { + existingReq.setBoundProjectionName(req.getBoundProjectionName()); + } + } + combineIntervalsDNF( + true /*intersect*/, existingReq.getIntervals(), req.getIntervals()); + return true; + } + } + + for (auto it = reqMap.begin(); it != reqMap.cend();) { + const auto& [existingKey, existingReq] = *it; + uassert(6624150, + "Existing key referring to new requirement.", + !reqHasBoundProj || + existingKey._projectionName != req.getBoundProjectionName()); + + if (existingReq.hasBoundProjectionName() && + key._projectionName == existingReq.getBoundProjectionName()) { + // The new key is referring to a projection the existing requirement binds. + if (reqHasBoundProj) { + return false; + } + + newKey = existingKey; + newReq = req; + + PathAppender appender(key._path); + appender.append(newKey._path); + + combineIntervalsDNF( + true /*intersect*/, newReq.getIntervals(), existingReq.getIntervals()); + + if (key._path.is<PathIdentity>()) { + newReq.setBoundProjectionName(existingReq.getBoundProjectionName()); + reqMap.erase(it++); + } else if (!isIntervalReqFullyOpenDNF(existingReq.getIntervals())) { + return false; + } + + merged = true; + break; + } else { + it++; + continue; + } + } + + if (merged) { + key = std::move(newKey); + req = std::move(newReq); + } else { + reqMap[key] = req; + break; + } + }; + + return true; +} + +bool intersectPartialSchemaReq(PartialSchemaRequirements& target, + const PartialSchemaRequirements& source, + ProjectionRenames& projectionRenames) { + for (const auto& [key, req] : source) { + if (!intersectPartialSchemaReq(target, key, req, projectionRenames)) { + return false; + } + } + + return true; +} + +std::string encodeIndexKeyName(const size_t indexField) { + std::ostringstream os; + os << kIndexKeyPrefix << " " << indexField; + return os.str(); +} + +size_t decodeIndexKeyName(const std::string& fieldName) { + std::istringstream is(fieldName); + + std::string prefix; + is >> prefix; + uassert(6624151, "Invalid index key prefix", prefix == kIndexKeyPrefix); + + int key; + is >> key; + return key; +} + +/** + * Checks if one index path is a prefix of another. Considers only Get, Traverse, and Id. + * Return the suffix that doesn't match. + */ +class PathSuffixExtactor { +public: + using ResultType = boost::optional<ABT::reference_type>; + + /** + * 'n' - The complete index path being compared to, can be modified if needed. + * 'node' - Same as 'n' but cast to a specific type by the caller in order to invoke the + * correct operator. + * 'other' - The query that is checked if it is a prefix of the index. + */ + ResultType operator()(const ABT& n, const PathGet& node, const ABT& other) { + if (auto otherGet = other.cast<PathGet>(); + otherGet != nullptr && otherGet->name() == node.name()) { + if (auto otherChildTraverse = otherGet->getPath().cast<PathTraverse>(); + otherChildTraverse != nullptr && !node.getPath().is<PathTraverse>()) { + // If a query path has a Traverse, but the index path doesn't, the query can + // still be evaluated by this index. Skip the Traverse node, and continue matching. + return node.getPath().visit(*this, otherChildTraverse->getPath()); + } else { + return node.getPath().visit(*this, otherGet->getPath()); + } + } + return {}; + } + + ResultType operator()(const ABT& n, const PathTraverse& node, const ABT& other) { + if (auto otherTraverse = other.cast<PathTraverse>(); otherTraverse != nullptr) { + return node.getPath().visit(*this, otherTraverse->getPath()); + } + return {}; + } + + ResultType operator()(const ABT& n, const PathIdentity& node, const ABT& other) { + return {other.ref()}; + } + + template <typename T, typename... Ts> + ResultType operator()(const ABT& /*n*/, const T& /*node*/, Ts&&...) { + uasserted(6624152, "Unexpected node type"); + } + + static ResultType check(const ABT& node, const ABT& candidatePrefix) { + PathSuffixExtactor instance; + return candidatePrefix.visit(instance, node); + } +}; + +/** + * Check if a path contains a Traverse node over the last path component. + */ +class PathTraverseChecker { +public: + PathTraverseChecker() {} + + /** + * PathIdentity is always the last node in a path. + * Return true if either 'node' is a PathTraverse just before the last node. + * If the input Traverse node is somewhere in the middle of the path, return + * the result of its child. + */ + bool transport(const ABT& /*n*/, const PathTraverse& node, bool childResult) { + return node.getPath().is<PathIdentity>() || childResult; + } + + bool transport(const ABT& /*n*/, const PathGet& node, bool childResult) { + return childResult; + } + + bool transport(const ABT& /*n*/, const PathIdentity& node) { + return false; + } + + template <typename T, typename... Ts> + bool transport(const ABT& /*n*/, const T& /*node*/, Ts&&...) { + uasserted(6624153, "Index paths only consist of Get, Traverse, and Id nodes."); + return false; + } + + /** + * Return true if the node before the last one in 'path' is a PathTraverse. + * Return false otherwise. + */ + bool check(const ABT& path) { + return algebra::transport<true>(path, *this); + } +}; + +void findMatchingSchemaRequirement(const PartialSchemaKey& indexKey, + const PartialSchemaRequirements& reqMap, + PartialSchemaKeySet& keySet, + PartialSchemaRequirement& req, + const bool setIntervalsAndBoundProj) { + for (const auto& [queryKey, queryReq] : reqMap) { + const auto pathSuffixResult = PathSuffixExtactor::check(queryKey._path, indexKey._path); + if (pathSuffixResult.has_value() && pathSuffixResult->is<PathIdentity>()) { + keySet.insert(queryKey); + + if (setIntervalsAndBoundProj) { + // Combine all matching requirements into a single entry. + if (queryReq.hasBoundProjectionName()) { + uassert(6624154, "Unexpected bound projection", !req.hasBoundProjectionName()); + req.setBoundProjectionName(queryReq.getBoundProjectionName()); + } + if (!isIntervalReqFullyOpenDNF(queryReq.getIntervals())) { + uassert(6624350, + "Unexpected intervals", + isIntervalReqFullyOpenDNF(req.getIntervals())); + req.getIntervals() = queryReq.getIntervals(); + } + } + } + } +} + +CandidateIndexMap computeCandidateIndexMap(PrefixId& prefixId, + const ProjectionName& scanProjectionName, + const PartialSchemaRequirements& reqMap, + const ScanDefinition& scanDef, + bool& hasEmptyInterval) { + CandidateIndexMap result; + hasEmptyInterval = false; + + for (const auto& [indexDefName, indexDef] : scanDef.getIndexDefs()) { + FieldProjectionMap indexProjectionMap; + auto intervals = MultiKeyIntervalReqExpr::makeSingularDNF(); // Singular empty interval. + PartialSchemaRequirements residualRequirements; + ProjectionNameSet residualRequirementsTempProjections; + ResidualKeyMap residualKeyMap; + opt::unordered_set<size_t> fieldsToCollate; + + PartialSchemaKeySet unsatisfiedKeys; + for (const auto& [key, req] : reqMap) { + unsatisfiedKeys.insert(key); + } + + // True if the paths from partial schema requirements form a strict prefix of the index + // collation. + bool isPrefix = true; + // If we formed bounds using at least one requirement (as opposed to having only residual + // requirements). + bool hasExactMatch = false; + bool indexSuitable = true; + + const IndexCollationSpec& indexCollationSpec = indexDef.getCollationSpec(); + for (size_t indexField = 0; indexField < indexCollationSpec.size(); indexField++) { + const auto& indexCollationEntry = indexCollationSpec.at(indexField); + PartialSchemaKey indexKey{scanProjectionName, indexCollationEntry._path}; + + PartialSchemaKeySet keySet; + PartialSchemaRequirement req; + findMatchingSchemaRequirement(indexKey, reqMap, keySet, req); + if (!keySet.empty() && isPrefix) { + hasExactMatch = true; + const PartialSchemaKey& queryKey = *keySet.begin(); + for (const auto& key : keySet) { + unsatisfiedKeys.erase(key); + } + + // The interval derived from the query (not from the index). + const auto& requiredInterval = req.getIntervals(); + + PathTraverseChecker pathChecker; + const bool indexPathContainsArrays = pathChecker.check(indexCollationEntry._path); + bool combineSuccess = false; + if (indexPathContainsArrays) { + combineSuccess = combineMultiKeyIntervalsDNF(intervals, requiredInterval); + } else { + auto intersectedIntervals = intersectDNFIntervals(requiredInterval); + if (intersectedIntervals.has_value()) { + combineSuccess = + combineMultiKeyIntervalsDNF(intervals, intersectedIntervals.get()); + } else { + if (indexDef.getPartialReqMap().empty()) { + hasEmptyInterval = true; + return CandidateIndexMap(); + } else { + // This is a partial index, so skip the empty interval, but consider the + // remaining indexes. + indexSuitable = false; + break; + } + } + } + + if (!combineSuccess) { + if (!combineMultiKeyIntervalsDNF(intervals, + IntervalReqExpr::makeSingularDNF())) { + uasserted(6624155, "Cannot combine with an open interval"); + } + + // Move interval into residual requirements. + if (req.hasBoundProjectionName()) { + PartialSchemaKey residualKey{req.getBoundProjectionName(), + make<PathIdentity>()}; + residualRequirements.emplace( + residualKey, PartialSchemaRequirement{"", req.getIntervals()}); + residualKeyMap.emplace(std::move(residualKey), queryKey); + residualRequirementsTempProjections.insert(req.getBoundProjectionName()); + } else { + ProjectionName tempProj = prefixId.getNextId("evalTemp"); + PartialSchemaKey residualKey{tempProj, make<PathIdentity>()}; + residualRequirements.emplace(residualKey, req); + residualKeyMap.emplace(std::move(residualKey), queryKey); + residualRequirementsTempProjections.insert(tempProj); + + // Include bounds projection into index spec. + if (!indexProjectionMap._fieldProjections + .emplace(encodeIndexKeyName(indexField), std::move(tempProj)) + .second) { + uasserted(6624156, "Duplicate field name"); + } + } + } + + // Include bounds projection into index spec. + if (req.hasBoundProjectionName() && + !indexProjectionMap._fieldProjections + .emplace(encodeIndexKeyName(indexField), req.getBoundProjectionName()) + .second) { + uasserted(6624157, "Duplicate field name"); + } + + if (auto singularInterval = IntervalReqExpr::getSingularDNF(requiredInterval); + !singularInterval || !singularInterval->isEquality()) { + // We only care about collation of for non-equality intervals. + // Equivalently, it is sufficient for singular intervals to be clustered. + fieldsToCollate.insert(indexField); + } + } else { + bool foundPathPrefix = false; + for (const auto& queryKey : unsatisfiedKeys) { + const auto pathPrefixResult = + PathSuffixExtactor::check(queryKey._path, indexCollationEntry._path); + if (pathPrefixResult.has_value()) { + ProjectionName tempProj = prefixId.getNextId("evalTemp"); + PartialSchemaKey residualKey{tempProj, pathPrefixResult.get()}; + residualRequirements.emplace(residualKey, reqMap.at(queryKey)); + residualKeyMap.emplace(std::move(residualKey), queryKey); + residualRequirementsTempProjections.insert(tempProj); + + // Include bounds projection into index spec. + if (!indexProjectionMap._fieldProjections + .emplace(encodeIndexKeyName(indexField), std::move(tempProj)) + .second) { + uasserted(6624158, "Duplicate field name"); + } + + if (!combineMultiKeyIntervalsDNF(intervals, + IntervalReqExpr::makeSingularDNF())) { + uasserted(6624159, "Cannot combine with an open interval"); + } + + foundPathPrefix = true; + unsatisfiedKeys.erase(queryKey); + break; + } + } + + if (!foundPathPrefix) { + isPrefix = false; + if (!combineMultiKeyIntervalsDNF(intervals, + IntervalReqExpr::makeSingularDNF())) { + uasserted(6624160, "Cannot combine with an open interval"); + } + } + } + } + if (!indexSuitable) { + continue; + } + if (!hasExactMatch) { + continue; + } + if (!unsatisfiedKeys.empty()) { + continue; + } + + uassert(6624161, "Invalid map sizes", residualRequirements.size() == residualKeyMap.size()); + result.emplace(indexDefName, + CandidateIndexEntry{std::move(indexProjectionMap), + std::move(intervals), + std::move(residualRequirements), + std::move(residualRequirementsTempProjections), + std::move(residualKeyMap), + std::move(fieldsToCollate)}); + } + + return result; +} + +class PartialSchemaReqLowerTransport { +public: + ABT transport(const IntervalReqExpr::Atom& node) { + const auto& interval = node.getExpr(); + const auto& lowBound = interval.getLowBound(); + const auto& highBound = interval.getHighBound(); + + if (interval.isEquality()) { + return make<PathCompare>(Operations::Eq, lowBound.getBound()); + } + + ABT result = make<PathIdentity>(); + if (!lowBound.isInfinite()) { + maybeComposePath( + result, + make<PathCompare>(lowBound.isInclusive() ? Operations::Gte : Operations::Gt, + lowBound.getBound())); + } + if (!highBound.isInfinite()) { + maybeComposePath( + result, + make<PathCompare>(highBound.isInclusive() ? Operations::Lte : Operations::Lt, + highBound.getBound())); + } + return result; + } + + template <class Element> + ABT composeChildren(ABTVector childResults) { + ABT result = make<PathIdentity>(); + for (ABT& n : childResults) { + maybeComposePath<Element>(result, std::move(n)); + } + return result; + } + + ABT transport(const IntervalReqExpr::Conjunction& node, ABTVector childResults) { + return composeChildren<PathComposeM>(std::move(childResults)); + } + + ABT transport(const IntervalReqExpr::Disjunction& node, ABTVector childResults) { + return composeChildren<PathComposeA>(std::move(childResults)); + } + + ABT lower(const IntervalReqExpr::Node& intervals) { + return algebra::transport<false>(intervals, *this); + } +}; + +void lowerPartialSchemaRequirement(const PartialSchemaKey& key, + const PartialSchemaRequirement& req, + ABT& node, + const std::function<void(const ABT& node)>& visitor) { + PartialSchemaReqLowerTransport transport; + ABT path = transport.lower(req.getIntervals()); + const bool pathIsId = path.is<PathIdentity>(); + + if (req.hasBoundProjectionName()) { + node = make<EvaluationNode>(req.getBoundProjectionName(), + make<EvalPath>(key._path, make<Variable>(key._projectionName)), + std::move(node)); + visitor(node); + + if (!pathIsId) { + node = make<FilterNode>( + make<EvalFilter>(std::move(path), make<Variable>(req.getBoundProjectionName())), + std::move(node)); + visitor(node); + } + } else { + uassert( + 6624162, "If we do not have a bound projection, then we have a proper path", !pathIsId); + + PathAppender appender(std::move(path)); + path = key._path; + appender.append(path); + + node = + make<FilterNode>(make<EvalFilter>(std::move(path), make<Variable>(key._projectionName)), + std::move(node)); + visitor(node); + } +} + +void lowerPartialSchemaRequirements(const CEType baseCE, + const CEType scanGroupCE, + ResidualRequirements& requirements, + ABT& physNode, + NodeCEMap& nodeCEMap) { + sortResidualRequirements(requirements); + + CEType residualCE = baseCE; + for (const auto& [residualKey, residualReq, ce] : requirements) { + if (scanGroupCE > 0.0) { + residualCE *= ce / scanGroupCE; + } + lowerPartialSchemaRequirement(residualKey, residualReq, physNode, [&](const ABT& node) { + nodeCEMap.emplace(node.cast<Node>(), residualCE); + }); + } +} + +void computePhysicalScanParams(PrefixId& prefixId, + const PartialSchemaRequirements& reqMap, + const PartialSchemaKeyCE& partialSchemaKeyCEMap, + const ProjectionNameOrderPreservingSet& requiredProjections, + ResidualRequirements& residualRequirements, + ProjectionRenames& projectionRenames, + FieldProjectionMap& fieldProjectionMap, + bool& requiresRootProjection) { + for (const auto& [key, req] : reqMap) { + bool hasBoundProjection = req.hasBoundProjectionName(); + if (hasBoundProjection && !requiredProjections.find(req.getBoundProjectionName()).second) { + // Bound projection is not required, pretend we don't bind. + hasBoundProjection = false; + } + if (!hasBoundProjection && isIntervalReqFullyOpenDNF(req.getIntervals())) { + // Redundant requirement. + continue; + } + + const CEType keyCE = partialSchemaKeyCEMap.at(key); + if (auto pathGet = key._path.cast<PathGet>(); pathGet != nullptr) { + // Extract a new requirements path with removed simple paths. + // For example if we have a key Get "a" Traverse Compare = 0 we leave only + // Traverse Compare 0. + if (pathGet->getPath().is<PathIdentity>() && hasBoundProjection) { + const auto [it, inserted] = fieldProjectionMap._fieldProjections.emplace( + pathGet->name(), req.getBoundProjectionName()); + if (!inserted) { + projectionRenames.emplace(req.getBoundProjectionName(), it->second); + } + + if (!isIntervalReqFullyOpenDNF(req.getIntervals())) { + residualRequirements.emplace_back( + PartialSchemaKey{req.getBoundProjectionName(), make<PathIdentity>()}, + PartialSchemaRequirement{"", req.getIntervals()}, + keyCE); + } + } else { + ProjectionName tempProjName; + auto it = fieldProjectionMap._fieldProjections.find(pathGet->name()); + if (it == fieldProjectionMap._fieldProjections.cend()) { + tempProjName = prefixId.getNextId("evalTemp"); + fieldProjectionMap._fieldProjections.emplace(pathGet->name(), tempProjName); + } else { + tempProjName = it->second; + } + + residualRequirements.emplace_back( + PartialSchemaKey{std::move(tempProjName), pathGet->getPath()}, req, keyCE); + } + } else { + // Move other conditions into the residual map. + requiresRootProjection = true; + residualRequirements.emplace_back(key, req, keyCE); + } + } +} + +void sortResidualRequirements(ResidualRequirements& residualReq) { + // Sort residual requirements by estimated CE. + std::sort( + residualReq.begin(), + residualReq.end(), + [](const ResidualRequirement& x, const ResidualRequirement& y) { return x._ce < y._ce; }); +} + +void applyProjectionRenames(ProjectionRenames projectionRenames, + ABT& node, + const std::function<void(const ABT& node)>& visitor) { + for (auto&& [targetProjName, sourceProjName] : projectionRenames) { + node = make<EvaluationNode>( + std::move(targetProjName), make<Variable>(std::move(sourceProjName)), std::move(node)); + visitor(node); + } +} + +ABT lowerRIDIntersectGroupBy(PrefixId& prefixId, + const ProjectionName& ridProjName, + const CEType intersectedCE, + const CEType leftCE, + const CEType rightCE, + const properties::PhysProps& physProps, + const properties::PhysProps& leftPhysProps, + const properties::PhysProps& rightPhysProps, + ABT leftChild, + ABT rightChild, + NodeCEMap& nodeCEMap, + ChildPropsType& childProps) { + using namespace properties; + + const auto& leftProjections = + getPropertyConst<ProjectionRequirement>(leftPhysProps).getProjections(); + + ABTVector aggExpressions; + ProjectionNameVector aggProjectionNames; + + const ProjectionName sideIdProjectionName = prefixId.getNextId("sideId"); + const ProjectionName sideSetProjectionName = prefixId.getNextId("sides"); + + aggExpressions.emplace_back( + make<FunctionCall>("$addToSet", makeSeq(make<Variable>(sideIdProjectionName)))); + aggProjectionNames.push_back(sideSetProjectionName); + + leftChild = + make<EvaluationNode>(sideIdProjectionName, Constant::int64(0), std::move(leftChild)); + childProps.emplace_back(&leftChild.cast<EvaluationNode>()->getChild(), leftPhysProps); + nodeCEMap.emplace(leftChild.cast<Node>(), leftCE); + + rightChild = + make<EvaluationNode>(sideIdProjectionName, Constant::int64(1), std::move(rightChild)); + childProps.emplace_back(&rightChild.cast<EvaluationNode>()->getChild(), rightPhysProps); + nodeCEMap.emplace(rightChild.cast<Node>(), rightCE); + + ProjectionNameVector sortedProjections = + getPropertyConst<ProjectionRequirement>(physProps).getProjections().getVector(); + std::sort(sortedProjections.begin(), sortedProjections.end()); + + ProjectionNameVector unionProjections{ridProjName, sideIdProjectionName}; + for (const ProjectionName& projectionName : sortedProjections) { + if (projectionName == ridProjName) { + continue; + } + + ProjectionName tempProjectionName = prefixId.getNextId("unionTemp"); + unionProjections.push_back(tempProjectionName); + + if (leftProjections.find(projectionName).second) { + leftChild = make<EvaluationNode>( + tempProjectionName, make<Variable>(projectionName), std::move(leftChild)); + nodeCEMap.emplace(leftChild.cast<Node>(), leftCE); + + rightChild = make<EvaluationNode>( + tempProjectionName, Constant::nothing(), std::move(rightChild)); + nodeCEMap.emplace(rightChild.cast<Node>(), rightCE); + } else { + leftChild = + make<EvaluationNode>(tempProjectionName, Constant::nothing(), std::move(leftChild)); + nodeCEMap.emplace(leftChild.cast<Node>(), leftCE); + + rightChild = make<EvaluationNode>( + tempProjectionName, make<Variable>(projectionName), std::move(rightChild)); + nodeCEMap.emplace(rightChild.cast<Node>(), rightCE); + } + + aggExpressions.emplace_back( + make<FunctionCall>("$max", makeSeq(make<Variable>(tempProjectionName)))); + aggProjectionNames.push_back(projectionName); + } + + ABT result = make<UnionNode>(std::move(unionProjections), + makeSeq(std::move(leftChild), std::move(rightChild))); + nodeCEMap.emplace(result.cast<Node>(), leftCE + rightCE); + + result = make<GroupByNode>(ProjectionNameVector{ridProjName}, + std::move(aggProjectionNames), + std::move(aggExpressions), + std::move(result)); + nodeCEMap.emplace(result.cast<Node>(), intersectedCE); + + result = make<FilterNode>( + make<EvalFilter>( + make<PathCompare>(Operations::Eq, Constant::int64(2)), + make<FunctionCall>("getArraySize", makeSeq(make<Variable>(sideSetProjectionName)))), + std::move(result)); + nodeCEMap.emplace(result.cast<Node>(), intersectedCE); + + return result; +} + +ABT lowerRIDIntersectHashJoin(PrefixId& prefixId, + const ProjectionName& ridProjName, + const CEType intersectedCE, + const CEType leftCE, + const CEType rightCE, + const properties::PhysProps& leftPhysProps, + const properties::PhysProps& rightPhysProps, + ABT leftChild, + ABT rightChild, + NodeCEMap& nodeCEMap, + ChildPropsType& childProps) { + using namespace properties; + + ProjectionName rightRIDProjName = prefixId.getNextId("rid"); + rightChild = + make<EvaluationNode>(rightRIDProjName, make<Variable>(ridProjName), std::move(rightChild)); + ABT* rightChildPtr = &rightChild.cast<EvaluationNode>()->getChild(); + nodeCEMap.emplace(rightChild.cast<Node>(), rightCE); + + auto rightProjections = + getPropertyConst<ProjectionRequirement>(rightPhysProps).getProjections(); + rightProjections.erase(ridProjName); + rightProjections.emplace_back(rightRIDProjName); + ProjectionNameVector sortedProjections = rightProjections.getVector(); + std::sort(sortedProjections.begin(), sortedProjections.end()); + + // Use a union node to restrict the rid projection name coming from the right child in order + // to ensure we do not have the same rid from both children. This node is optimized away + // during lowering. + rightChild = make<UnionNode>(std::move(sortedProjections), makeSeq(std::move(rightChild))); + nodeCEMap.emplace(rightChild.cast<Node>(), rightCE); + + ABT result = make<HashJoinNode>(JoinType::Inner, + ProjectionNameVector{ridProjName}, + ProjectionNameVector{std::move(rightRIDProjName)}, + std::move(leftChild), + std::move(rightChild)); + nodeCEMap.emplace(result.cast<Node>(), intersectedCE); + + childProps.emplace_back(&result.cast<HashJoinNode>()->getLeftChild(), leftPhysProps); + childProps.emplace_back(rightChildPtr, rightPhysProps); + + return result; +} + +ABT lowerRIDIntersectMergeJoin(PrefixId& prefixId, + const ProjectionName& ridProjName, + const CEType intersectedCE, + const CEType leftCE, + const CEType rightCE, + const properties::PhysProps& leftPhysProps, + const properties::PhysProps& rightPhysProps, + ABT leftChild, + ABT rightChild, + NodeCEMap& nodeCEMap, + ChildPropsType& childProps) { + using namespace properties; + + ProjectionName rightRIDProjName = prefixId.getNextId("rid"); + rightChild = + make<EvaluationNode>(rightRIDProjName, make<Variable>(ridProjName), std::move(rightChild)); + ABT* rightChildPtr = &rightChild.cast<EvaluationNode>()->getChild(); + nodeCEMap.emplace(rightChild.cast<Node>(), rightCE); + + auto rightProjections = + getPropertyConst<ProjectionRequirement>(rightPhysProps).getProjections(); + rightProjections.erase(ridProjName); + rightProjections.emplace_back(rightRIDProjName); + ProjectionNameVector sortedProjections = rightProjections.getVector(); + std::sort(sortedProjections.begin(), sortedProjections.end()); + + // Use a union node to restrict the rid projection name coming from the right child in order + // to ensure we do not have the same rid from both children. This node is optimized away + // during lowering. + rightChild = make<UnionNode>(std::move(sortedProjections), makeSeq(std::move(rightChild))); + nodeCEMap.emplace(rightChild.cast<Node>(), rightCE); + + ABT result = make<MergeJoinNode>(ProjectionNameVector{ridProjName}, + ProjectionNameVector{std::move(rightRIDProjName)}, + std::vector<CollationOp>{CollationOp::Ascending}, + std::move(leftChild), + std::move(rightChild)); + nodeCEMap.emplace(result.cast<Node>(), intersectedCE); + + childProps.emplace_back(&result.cast<MergeJoinNode>()->getLeftChild(), leftPhysProps); + childProps.emplace_back(rightChildPtr, rightPhysProps); + + return result; +} + +class IntervalLowerTransport { +public: + IntervalLowerTransport(PrefixId& prefixId, + const ProjectionName& ridProjName, + FieldProjectionMap indexProjectionMap, + const std::string& scanDefName, + const std::string& indexDefName, + const bool reverseOrder, + const CEType indexCE, + const CEType scanGroupCE, + NodeCEMap& nodeCEMap) + : _prefixId(prefixId), + _ridProjName(ridProjName), + _scanDefName(scanDefName), + _indexDefName(indexDefName), + _reverseOrder(reverseOrder), + _scanGroupCE(scanGroupCE), + _nodeCEMap(nodeCEMap) { + const SelectivityType indexSel = (scanGroupCE == 0.0) ? 0.0 : (indexCE / _scanGroupCE); + _estimateStack.push_back(indexSel); + _fpmStack.push_back(std::move(indexProjectionMap)); + }; + + ABT transport(const MultiKeyIntervalReqExpr::Atom& node) { + ABT physicalIndexScan = make<IndexScanNode>( + _fpmStack.back(), + IndexSpecification{_scanDefName, _indexDefName, node.getExpr(), _reverseOrder}); + _nodeCEMap.emplace(physicalIndexScan.cast<Node>(), _scanGroupCE * _estimateStack.back()); + return physicalIndexScan; + } + + template <bool isConjunction> + void prepare(const size_t childCount) { + // Here we are assuming each conjunction and disjunction contribute uniformly to the total + // selectivity. + // TODO: consider estimates per individual interval. + + const SelectivityType parentSel = _estimateStack.back(); + SelectivityType childSel = 0.0; + if constexpr (isConjunction) { + childSel = (parentSel == 0.0) ? 0.0 : std::pow(parentSel, 1.0 / childCount); + } else { + childSel = _estimateStack.back() / childCount; + } + _estimateStack.push_back(childSel); + + FieldProjectionMap childMap = _fpmStack.back(); + if (childMap._ridProjection.empty()) { + childMap._ridProjection = _ridProjName; + } + if (childCount > 1) { + for (auto& [fieldName, projectionName] : childMap._fieldProjections) { + projectionName = _prefixId.getNextId(isConjunction ? "conjunction" : "disjunction"); + } + } + _fpmStack.push_back(std::move(childMap)); + } + + void prepare(const MultiKeyIntervalReqExpr::Conjunction& node) { + prepare<true /*isConjunction*/>(node.nodes().size()); + } + + template <bool isIntersect> + ABT implement(ABTVector inputs) { + _estimateStack.pop_back(); + const CEType ce = _scanGroupCE * _estimateStack.back(); + + auto innerMap = std::move(_fpmStack.back()); + _fpmStack.pop_back(); + auto outerMap = _fpmStack.back(); + + const size_t inputSize = inputs.size(); + if (inputSize == 1) { + return std::move(inputs.front()); + } + + ProjectionNameVector unionProjectionNames; + unionProjectionNames.push_back(innerMap._ridProjection); + for (const auto& [fieldName, projectionName] : innerMap._fieldProjections) { + unionProjectionNames.push_back(projectionName); + } + + ProjectionNameVector aggProjectionNames; + for (const auto& [fieldName, projectionName] : outerMap._fieldProjections) { + aggProjectionNames.push_back(projectionName); + } + + ABTVector aggExpressions; + for (const auto& [fieldName, projectionName] : innerMap._fieldProjections) { + aggExpressions.emplace_back( + make<FunctionCall>("$first", makeSeq(make<Variable>(projectionName)))); + } + + ProjectionName sideSetProjectionName; + if constexpr (isIntersect) { + const ProjectionName sideIdProjectionName = _prefixId.getNextId("sideId"); + unionProjectionNames.push_back(sideIdProjectionName); + sideSetProjectionName = _prefixId.getNextId("sides"); + + for (size_t index = 0; index < inputSize; index++) { + ABT& input = inputs.at(index); + input = make<EvaluationNode>( + sideIdProjectionName, Constant::int64(index), std::move(input)); + // Not relevant for cost. + _nodeCEMap.emplace(input.cast<Node>(), 0.0); + } + + aggExpressions.emplace_back( + make<FunctionCall>("$addToSet", makeSeq(make<Variable>(sideIdProjectionName)))); + aggProjectionNames.push_back(sideSetProjectionName); + } + + ABT result = make<UnionNode>(std::move(unionProjectionNames), std::move(inputs)); + _nodeCEMap.emplace(result.cast<Node>(), ce); + + result = make<GroupByNode>(ProjectionNameVector{innerMap._ridProjection}, + std::move(aggProjectionNames), + std::move(aggExpressions), + std::move(result)); + _nodeCEMap.emplace(result.cast<Node>(), ce); + + if constexpr (isIntersect) { + result = make<FilterNode>( + make<EvalFilter>( + make<PathCompare>(Operations::Eq, Constant::int64(inputSize)), + make<FunctionCall>("getArraySize", + makeSeq(make<Variable>(sideSetProjectionName)))), + std::move(result)); + _nodeCEMap.emplace(result.cast<Node>(), ce); + } + return result; + } + + ABT transport(const MultiKeyIntervalReqExpr::Conjunction& node, ABTVector childResults) { + return implement<true /*isIntersect*/>(std::move(childResults)); + } + + void prepare(const MultiKeyIntervalReqExpr::Disjunction& node) { + prepare<false /*isConjunction*/>(node.nodes().size()); + } + + ABT transport(const MultiKeyIntervalReqExpr::Disjunction& node, ABTVector childResults) { + return implement<false /*isIntersect*/>(std::move(childResults)); + } + + ABT lower(const MultiKeyIntervalReqExpr::Node& intervals) { + return algebra::transport<false>(intervals, *this); + } + +private: + PrefixId& _prefixId; + const ProjectionName& _ridProjName; + const std::string& _scanDefName; + const std::string& _indexDefName; + const bool _reverseOrder; + const CEType _scanGroupCE; + NodeCEMap& _nodeCEMap; + + std::vector<SelectivityType> _estimateStack; + std::vector<FieldProjectionMap> _fpmStack; +}; + +ABT lowerIntervals(PrefixId& prefixId, + const ProjectionName& ridProjName, + FieldProjectionMap indexProjectionMap, + const std::string& scanDefName, + const std::string& indexDefName, + const MultiKeyIntervalReqExpr::Node& intervals, + const bool reverseOrder, + const CEType indexCE, + const CEType scanGroupCE, + NodeCEMap& nodeCEMap) { + IntervalLowerTransport lowerTransport(prefixId, + ridProjName, + std::move(indexProjectionMap), + scanDefName, + indexDefName, + reverseOrder, + indexCE, + scanGroupCE, + nodeCEMap); + return lowerTransport.lower(intervals); +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/utils.h b/src/mongo/db/query/optimizer/utils/utils.h new file mode 100644 index 00000000000..371dac2a48d --- /dev/null +++ b/src/mongo/db/query/optimizer/utils/utils.h @@ -0,0 +1,317 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/query/optimizer/defs.h" +#include "mongo/db/query/optimizer/node.h" +#include "mongo/db/query/optimizer/node_defs.h" +#include "mongo/db/query/optimizer/props.h" + +namespace mongo::optimizer { + +inline void updateHash(size_t& result, const size_t hash) { + result = 31 * result + hash; +} + +inline void updateHashUnordered(size_t& result, const size_t hash) { + result ^= hash; +} + +template <class T, class T1 = std::conditional_t<std::is_arithmetic_v<T>, const T, const T&>> +inline size_t computeVectorHash(const std::vector<T>& v) { + size_t result = 17; + for (T1 e : v) { + updateHash(result, std::hash<T>()(e)); + } + return result; +} + +template <int typeCode, typename... Args> +inline size_t computeHashSeq(const Args&... seq) { + size_t result = 17 + typeCode; + (updateHash(result, seq), ...); + return result; +} + +size_t roundUpToNextPow2(size_t v, size_t maxPower); + +std::vector<ABT::reference_type> collectComposed(const ABT& n); + +/** + * Returns the path represented by 'node' as a simple dotted string. Returns an empty string if + * 'node' is not a path. + */ +FieldNameType getSimpleField(const ABT& node); + +template <class Element = PathComposeM> +inline void maybeComposePath(ABT& composition, ABT child) { + if (child.is<PathIdentity>()) { + return; + } + if (composition.is<PathIdentity>()) { + composition = std::move(child); + return; + } + + composition = make<Element>(std::move(composition), std::move(child)); +} + +/** + * Used to vend out fresh ids for projection names. + */ +class PrefixId { +public: + std::string getNextId(const std::string& key); + +private: + opt::unordered_map<std::string, int> _idCounterPerKey; +}; + +ProjectionNameOrderedSet convertToOrderedSet(ProjectionNameSet unordered); + +opt::unordered_set<FieldNameType> toUnorderedFieldNameSet(std::set<FieldNameType> set); + + +void combineLimitSkipProperties(properties::LimitSkipRequirement& aboveProp, + const properties::LimitSkipRequirement& belowProp); + +/** + * Used to track references originating from a set of physical properties. + */ +ProjectionNameSet extractReferencedColumns(const properties::PhysProps& properties); + +/** + * Returns true if all components of the compound interval are equalities. + */ +bool areMultiKeyIntervalsEqualities(const MultiKeyIntervalRequirement& intervals); + +struct CollationSplitResult { + bool _validSplit = false; + ProjectionCollationSpec _leftCollation; + ProjectionCollationSpec _rightCollation; +}; + +/** + * Split a collation requirement between an outer (left) and inner (right) side. The outer side must + * be a prefix in the collation spec, and the right side a suffix. + */ +CollationSplitResult splitCollationSpec(const ProjectionCollationSpec& collationSpec, + const ProjectionNameSet& leftProjections, + const ProjectionNameSet& rightProjections); + +/** + * Used to extract variable references from a node. + */ +using VariableNameSetType = opt::unordered_set<std::string>; +VariableNameSetType collectVariableReferences(const ABT& n); + +/** + * Appends a path to another path. Performs the append at PathIdentity elements. + */ +class PathAppender { +public: + PathAppender(ABT toAppend) : _toAppend(std::move(toAppend)) {} + + void transport(ABT& n, const PathIdentity& node) { + n = _toAppend; + } + + template <typename T, typename... Ts> + void transport(ABT& /*n*/, const T& /*node*/, Ts&&...) { + // noop + } + + void append(ABT& path) { + return algebra::transport<true>(path, *this); + } + +private: + ABT _toAppend; +}; + +struct PartialSchemaReqConversion { + PartialSchemaReqConversion(); + PartialSchemaReqConversion(PartialSchemaRequirements reqMap); + PartialSchemaReqConversion(ABT bound); + + // Is our current bottom-up conversion successful. If not shortcut to top. + bool _success; + + // If set, contains a Constant or Variable bound of an (yet unknown) interval. + boost::optional<ABT> _bound; + + // Requirements we have built so far. + PartialSchemaRequirements _reqMap; + + // Have we added a PathComposeM. + bool _hasIntersected; + + // Have we added a PathTraverse. + bool _hasTraversed; + + // If we have determined that we have a contradiction. + bool _hasEmptyInterval; +}; + +/** + * Takes an expression that comes from an Filter or Evaluation node, and attempt to convert + * to a PartialSchemaReqConversion. This is done independent of the availability of indexes. + * Essentially this means to extract intervals over paths whenever possible. + */ +PartialSchemaReqConversion convertExprToPartialSchemaReq(const ABT& expr); + +bool intersectPartialSchemaReq(PartialSchemaRequirements& target, + const PartialSchemaRequirements& source, + ProjectionRenames& projectionRenames); + + +/** + * Encode an index of an index field as a field name in order to use with a FieldProjectionMap. + */ +std::string encodeIndexKeyName(size_t indexField); + +/** + * Decode an field name as an index field. + */ +size_t decodeIndexKeyName(const std::string& fieldName); + +/** + * Given a partial schema key that specifies an index path, and a map of partial requirements + * created from sargable query conditions, return the partial requirement that matches the + * index path (and thus can be evaluated via this path). + */ +void findMatchingSchemaRequirement(const PartialSchemaKey& indexKey, + const PartialSchemaRequirements& reqMap, + PartialSchemaKeySet& keySet, + PartialSchemaRequirement& req, + bool setIntervalsAndBoundProj = true); + +/** + * Compute a mapping [indexName -> CandidateIndexEntry] that describes intervals that could be + * used for accessing each of the indexes in the map. The intervals themselves are derived from + * 'reqMap'. + * If the intersection of any of the interval requirements in 'reqMap' results in an empty + * interval, return an empty mappting and set 'hasEmptyInterval' to true. + * Otherwise return the computed mapping, and set 'hasEmptyInterval' to false. + */ +CandidateIndexMap computeCandidateIndexMap(PrefixId& prefixId, + const ProjectionName& scanProjectionName, + const PartialSchemaRequirements& reqMap, + const ScanDefinition& scanDef, + bool& hasEmptyInterval); + +/** + * Used to lower a Sargable node to a subtree consisting of functionally equivalent Filter and Eval + * nodes. + */ +void lowerPartialSchemaRequirement(const PartialSchemaKey& key, + const PartialSchemaRequirement& req, + ABT& node, + const std::function<void(const ABT& node)>& visitor = + [](const ABT&) {}); + +void lowerPartialSchemaRequirements(CEType baseCE, + CEType scanGroupCE, + ResidualRequirements& requirements, + ABT& physNode, + NodeCEMap& nodeCEMap); + +void computePhysicalScanParams(PrefixId& prefixId, + const PartialSchemaRequirements& reqMap, + const PartialSchemaKeyCE& partialSchemaKeyCEMap, + const ProjectionNameOrderPreservingSet& requiredProjections, + ResidualRequirements& residualRequirements, + ProjectionRenames& projectionRenames, + FieldProjectionMap& fieldProjectionMap, + bool& requiresRootProjection); + +void sortResidualRequirements(ResidualRequirements& residualReq); + +void applyProjectionRenames(ProjectionRenames projectionRenames, + ABT& node, + const std::function<void(const ABT& node)>& visitor = [](const ABT&) { + }); + +/** + * Implements an RID Intersect node using Union and GroupBy. + */ +ABT lowerRIDIntersectGroupBy(PrefixId& prefixId, + const ProjectionName& ridProjName, + CEType intersectedCE, + CEType leftCE, + CEType rightCE, + const properties::PhysProps& physProps, + const properties::PhysProps& leftPhysProps, + const properties::PhysProps& rightPhysProps, + ABT leftChild, + ABT rightChild, + NodeCEMap& nodeCEMap, + ChildPropsType& childProps); + +/** + * Implements an RID Intersect node using a HashJoin. + */ +ABT lowerRIDIntersectHashJoin(PrefixId& prefixId, + const ProjectionName& ridProjName, + CEType intersectedCE, + CEType leftCE, + CEType rightCE, + const properties::PhysProps& leftPhysProps, + const properties::PhysProps& rightPhysProps, + ABT leftChild, + ABT rightChild, + NodeCEMap& nodeCEMap, + ChildPropsType& childProps); + +ABT lowerRIDIntersectMergeJoin(PrefixId& prefixId, + const ProjectionName& ridProjName, + CEType intersectedCE, + CEType leftCE, + CEType rightCE, + const properties::PhysProps& leftPhysProps, + const properties::PhysProps& rightPhysProps, + ABT leftChild, + ABT rightChild, + NodeCEMap& nodeCEMap, + ChildPropsType& childProps); + +ABT lowerIntervals(PrefixId& prefixId, + const ProjectionName& ridProjName, + FieldProjectionMap indexProjectionMap, + const std::string& scanDefName, + const std::string& indexDefName, + const MultiKeyIntervalReqExpr::Node& intervals, + bool reverseOrder, + CEType indexCE, + CEType scanGroupCE, + NodeCEMap& nodeCEMap); + + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/plan_executor_factory.cpp b/src/mongo/db/query/plan_executor_factory.cpp index bd11182daa2..c72b768adb2 100644 --- a/src/mongo/db/query/plan_executor_factory.cpp +++ b/src/mongo/db/query/plan_executor_factory.cpp @@ -122,6 +122,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> make( std::unique_ptr<CanonicalQuery> cq, std::unique_ptr<QuerySolution> solution, std::pair<std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData> root, + std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData, const CollectionPtr* collection, size_t plannerOptions, NamespaceString nss, @@ -140,6 +141,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> make( return {{new PlanExecutorSBE( opCtx, std::move(cq), + std::move(optimizerData), {makeVector<sbe::plan_ranker::CandidatePlan>(sbe::plan_ranker::CandidatePlan{ std::move(solution), std::move(rootStage), std::move(data)}), 0}, @@ -169,6 +171,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> make( return {{new PlanExecutorSBE(opCtx, std::move(cq), + {}, std::move(candidates), *collection, plannerOptions & QueryPlannerParams::RETURN_OWNED_DATA, diff --git a/src/mongo/db/query/plan_executor_factory.h b/src/mongo/db/query/plan_executor_factory.h index 380e21cb1e4..17694489add 100644 --- a/src/mongo/db/query/plan_executor_factory.h +++ b/src/mongo/db/query/plan_executor_factory.h @@ -35,6 +35,7 @@ #include "mongo/db/exec/working_set.h" #include "mongo/db/pipeline/pipeline.h" #include "mongo/db/pipeline/plan_executor_pipeline.h" +#include "mongo/db/query/optimizer/explain_interface.h" #include "mongo/db/query/plan_executor.h" #include "mongo/db/query/plan_yield_policy_sbe.h" #include "mongo/db/query/query_solution.h" @@ -106,12 +107,14 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> make( /** * Constructs a PlanExecutor for the query 'cq' which will execute the SBE plan 'root'. A yield * policy can optionally be provided if the plan should automatically yield during execution. + * "optimizerData" is used to print optimizer ABT plans, and may be empty. */ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> make( OperationContext* opCtx, std::unique_ptr<CanonicalQuery> cq, std::unique_ptr<QuerySolution> solution, std::pair<std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData> root, + std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData, const CollectionPtr* collection, size_t plannerOptions, NamespaceString nss, diff --git a/src/mongo/db/query/plan_executor_sbe.cpp b/src/mongo/db/query/plan_executor_sbe.cpp index 3b7a2db857f..e1dded76398 100644 --- a/src/mongo/db/query/plan_executor_sbe.cpp +++ b/src/mongo/db/query/plan_executor_sbe.cpp @@ -48,6 +48,7 @@ extern FailPoint planExecutorHangBeforeShouldWaitForInserts; PlanExecutorSBE::PlanExecutorSBE(OperationContext* opCtx, std::unique_ptr<CanonicalQuery> cq, + std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData, sbe::CandidatePlans candidates, const CollectionPtr& collection, bool returnOwnedBson, @@ -103,8 +104,7 @@ PlanExecutorSBE::PlanExecutorSBE(OperationContext* opCtx, const auto isMultiPlan = candidates.plans.size() > 1; - uassert(5088500, "Query does not have a valid CanonicalQuery", _cq); - if (!_cq->getExpCtx()->explain) { + if (!_cq || !_cq->getExpCtx()->explain) { // If we're not in explain mode, there is no need to keep rejected candidate plans around. candidates.plans.clear(); } else { @@ -115,6 +115,7 @@ PlanExecutorSBE::PlanExecutorSBE(OperationContext* opCtx, _planExplainer = plan_explainer_factory::make(_root.get(), &_rootData, _solution.get(), + std::move(optimizerData), std::move(candidates.plans), isMultiPlan, std::move(_rootData.debugInfo)); diff --git a/src/mongo/db/query/plan_executor_sbe.h b/src/mongo/db/query/plan_executor_sbe.h index a0f7a62f57b..f20e399d36e 100644 --- a/src/mongo/db/query/plan_executor_sbe.h +++ b/src/mongo/db/query/plan_executor_sbe.h @@ -32,6 +32,7 @@ #include <queue> #include "mongo/db/exec/sbe/stages/stages.h" +#include "mongo/db/query/optimizer/explain_interface.h" #include "mongo/db/query/plan_executor.h" #include "mongo/db/query/plan_explainer_sbe.h" #include "mongo/db/query/plan_yield_policy_sbe.h" @@ -44,6 +45,7 @@ class PlanExecutorSBE final : public PlanExecutor { public: PlanExecutorSBE(OperationContext* opCtx, std::unique_ptr<CanonicalQuery> cq, + std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData, sbe::CandidatePlans candidates, const CollectionPtr& collection, bool returnOwnedBson, diff --git a/src/mongo/db/query/plan_explainer_factory.cpp b/src/mongo/db/query/plan_explainer_factory.cpp index ef4858e1f7c..eecbba3908a 100644 --- a/src/mongo/db/query/plan_explainer_factory.cpp +++ b/src/mongo/db/query/plan_explainer_factory.cpp @@ -47,25 +47,32 @@ std::unique_ptr<PlanExplainer> make(PlanStage* root, const PlanEnumeratorExplain std::unique_ptr<PlanExplainer> make(sbe::PlanStage* root, const stage_builder::PlanStageData* data, const QuerySolution* solution) { - return make(root, data, solution, {}, false); + return make(root, data, solution, {}, {}, false); } std::unique_ptr<PlanExplainer> make(sbe::PlanStage* root, const stage_builder::PlanStageData* data, const QuerySolution* solution, + std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData, std::vector<sbe::plan_ranker::CandidatePlan> rejectedCandidates, bool isMultiPlan) { // Pre-compute Debugging info for explain use. auto debugInfoSBE = std::make_unique<plan_cache_debug_info::DebugInfoSBE>( plan_cache_util::buildDebugInfo(solution)); - return std::make_unique<PlanExplainerSBE>( - root, data, solution, std::move(rejectedCandidates), isMultiPlan, std::move(debugInfoSBE)); + return std::make_unique<PlanExplainerSBE>(root, + data, + solution, + std::move(optimizerData), + std::move(rejectedCandidates), + isMultiPlan, + std::move(debugInfoSBE)); } std::unique_ptr<PlanExplainer> make( sbe::PlanStage* root, const stage_builder::PlanStageData* data, const QuerySolution* solution, + std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData, std::vector<sbe::plan_ranker::CandidatePlan> rejectedCandidates, bool isMultiPlan, std::unique_ptr<plan_cache_debug_info::DebugInfoSBE> debugInfoSBE) { @@ -77,7 +84,12 @@ std::unique_ptr<PlanExplainer> make( plan_cache_util::buildDebugInfo(solution)); } - return std::make_unique<PlanExplainerSBE>( - root, data, solution, std::move(rejectedCandidates), isMultiPlan, std::move(debugInfoSBE)); + return std::make_unique<PlanExplainerSBE>(root, + data, + solution, + std::move(optimizerData), + std::move(rejectedCandidates), + isMultiPlan, + std::move(debugInfoSBE)); } } // namespace mongo::plan_explainer_factory diff --git a/src/mongo/db/query/plan_explainer_factory.h b/src/mongo/db/query/plan_explainer_factory.h index f0165f2c1cb..ddeb3b94e91 100644 --- a/src/mongo/db/query/plan_explainer_factory.h +++ b/src/mongo/db/query/plan_explainer_factory.h @@ -31,6 +31,7 @@ #include "mongo/db/exec/plan_stage.h" #include "mongo/db/exec/sbe/stages/stages.h" +#include "mongo/db/query/optimizer/explain_interface.h" #include "mongo/db/query/plan_enumerator_explain_info.h" #include "mongo/db/query/plan_explainer.h" #include "mongo/db/query/query_solution.h" @@ -49,12 +50,14 @@ std::unique_ptr<PlanExplainer> make(sbe::PlanStage* root, std::unique_ptr<PlanExplainer> make(sbe::PlanStage* root, const stage_builder::PlanStageData* data, const QuerySolution* solution, + std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData, std::vector<sbe::plan_ranker::CandidatePlan> rejectedCandidates, bool isMultiPlan); std::unique_ptr<PlanExplainer> make(sbe::PlanStage* root, const stage_builder::PlanStageData* data, const QuerySolution* solution, + std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData, std::vector<sbe::plan_ranker::CandidatePlan> rejectedCandidates, bool isMultiPlan, std::unique_ptr<plan_cache_debug_info::DebugInfoSBE> debugInfo); diff --git a/src/mongo/db/query/plan_explainer_sbe.cpp b/src/mongo/db/query/plan_explainer_sbe.cpp index 8bde2311ab6..00035f473f2 100644 --- a/src/mongo/db/query/plan_explainer_sbe.cpp +++ b/src/mongo/db/query/plan_explainer_sbe.cpp @@ -36,6 +36,8 @@ #include "mongo/db/exec/plan_stats_walker.h" #include "mongo/db/fts/fts_query_impl.h" #include "mongo/db/keypattern.h" +#include "mongo/db/query/optimizer/explain_interface.h" +#include "mongo/db/query/optimizer/node.h" #include "mongo/db/query/plan_explainer_impl.h" #include "mongo/db/query/plan_summary_stats_visitor.h" #include "mongo/db/query/projection_ast_util.h" @@ -305,12 +307,13 @@ PlanExplainer::PlanStatsDetails buildPlanStatsDetails( const QuerySolution* solution, const sbe::PlanStageStats* stats, const boost::optional<BSONObj>& execPlanDebugInfo, + const boost::optional<BSONObj>& optimizerExplain, ExplainOptions::Verbosity verbosity) { BSONObjBuilder bob; if (verbosity >= ExplainOptions::Verbosity::kExecStats) { auto summary = collectExecutionStatsSummary(stats); - if (verbosity >= ExplainOptions::Verbosity::kExecAllPlans) { + if (solution != nullptr && verbosity >= ExplainOptions::Verbosity::kExecAllPlans) { summary.score = solution->score; } statsToBSON(stats, &bob, &bob); @@ -323,9 +326,18 @@ PlanExplainer::PlanStatsDetails buildPlanStatsDetails( return {bob.obj(), std::move(summary)}; } - statsToBSON(solution->root(), &bob, &bob); + if (solution != nullptr) { + statsToBSON(solution->root(), &bob, &bob); + } + invariant(execPlanDebugInfo); - return {BSON("queryPlan" << bob.obj() << "slotBasedPlan" << *execPlanDebugInfo), boost::none}; + if (optimizerExplain) { + return {BSON("optimizerPlan" << *optimizerExplain << "slotBasedPlan" << *execPlanDebugInfo), + boost::none}; + } else { + return {BSON("queryPlan" << bob.obj() << "slotBasedPlan" << *execPlanDebugInfo), + boost::none}; + } } } // namespace @@ -369,10 +381,12 @@ void PlanExplainerSBE::getSummaryStats(PlanSummaryStats* statsOut) const { PlanExplainer::PlanStatsDetails PlanExplainerSBE::getWinningPlanStats( ExplainOptions::Verbosity verbosity) const { invariant(_root); - invariant(_solution); auto stats = _root->getStats(true /* includeDebugInfo */); - return buildPlanStatsDetails( - _solution, stats.get(), buildExecPlanDebugInfo(_root, _rootData), verbosity); + return buildPlanStatsDetails(_solution, + stats.get(), + buildExecPlanDebugInfo(_root, _rootData), + buildCascadesPlan(), + verbosity); } PlanExplainer::PlanStatsDetails PlanExplainerSBE::getWinningPlanTrialStats() const { @@ -385,6 +399,7 @@ PlanExplainer::PlanStatsDetails PlanExplainerSBE::getWinningPlanTrialStats() con // This parameter is not used in `buildPlanStatsDetails` if the last parameter is // `ExplainOptions::Verbosity::kExecAllPlans`, as is the case here. boost::none, + boost::none, ExplainOptions::Verbosity::kExecAllPlans); } return getWinningPlanStats(ExplainOptions::Verbosity::kExecAllPlans); @@ -405,7 +420,7 @@ std::vector<PlanExplainer::PlanStatsDetails> PlanExplainerSBE::getRejectedPlansS auto stats = candidate.root->getStats(true /* includeDebugInfo */); auto execPlanDebugInfo = buildExecPlanDebugInfo(candidate.root.get(), &candidate.data); res.push_back(buildPlanStatsDetails( - candidate.solution.get(), stats.get(), execPlanDebugInfo, verbosity)); + candidate.solution.get(), stats.get(), execPlanDebugInfo, boost::none, verbosity)); } return res; } @@ -418,7 +433,8 @@ std::vector<PlanExplainer::PlanStatsDetails> PlanExplainerSBE::getCachedPlanStat auto&& stats = decision.getStats<mongo::sbe::PlanStageStats>(); if (verbosity >= ExplainOptions::Verbosity::kExecStats) { for (auto&& planStats : stats.candidatePlanStats) { - res.push_back(buildPlanStatsDetails(nullptr, planStats.get(), boost::none, verbosity)); + res.push_back(buildPlanStatsDetails( + nullptr, planStats.get(), boost::none, boost::none, verbosity)); } } else { // At the "queryPlanner" verbosity we only need to provide details about the winning plan @@ -429,4 +445,12 @@ std::vector<PlanExplainer::PlanStatsDetails> PlanExplainerSBE::getCachedPlanStat return res; } + +boost::optional<BSONObj> PlanExplainerSBE::buildCascadesPlan() const { + if (_optimizerData) { + return _optimizerData->explainBSON(); + } + return {}; +} + } // namespace mongo diff --git a/src/mongo/db/query/plan_explainer_sbe.h b/src/mongo/db/query/plan_explainer_sbe.h index 070ed479647..1edaf3209ed 100644 --- a/src/mongo/db/query/plan_explainer_sbe.h +++ b/src/mongo/db/query/plan_explainer_sbe.h @@ -30,6 +30,7 @@ #pragma once #include "mongo/db/exec/sbe/stages/stages.h" +#include "mongo/db/query/optimizer/explain_interface.h" #include "mongo/db/query/plan_cache_debug_info.h" #include "mongo/db/query/plan_explainer.h" #include "mongo/db/query/query_solution.h" @@ -44,6 +45,7 @@ public: PlanExplainerSBE(const sbe::PlanStage* root, const stage_builder::PlanStageData* data, const QuerySolution* solution, + std::unique_ptr<optimizer::AbstractABTPrinter> optimizerData, std::vector<sbe::plan_ranker::CandidatePlan> rejectedCandidates, bool isMultiPlan, std::unique_ptr<plan_cache_debug_info::DebugInfoSBE> debugInfo) @@ -51,6 +53,7 @@ public: _root{root}, _rootData{data}, _solution{solution}, + _optimizerData(std::move(optimizerData)), _rejectedCandidates{std::move(rejectedCandidates)}, _isMultiPlan{isMultiPlan}, _debugInfo{std::move(debugInfo)} { @@ -81,11 +84,15 @@ private: return boost::none; } + boost::optional<BSONObj> buildCascadesPlan() const; + // These fields are are owned elsewhere (e.g. the PlanExecutor or CandidatePlan). const sbe::PlanStage* _root{nullptr}; const stage_builder::PlanStageData* _rootData{nullptr}; const QuerySolution* _solution{nullptr}; + const std::unique_ptr<optimizer::AbstractABTPrinter> _optimizerData; + const std::vector<sbe::plan_ranker::CandidatePlan> _rejectedCandidates; const bool _isMultiPlan{false}; // Pre-computed debugging info so we don't necessarily have to collect them from QuerySolution. diff --git a/src/mongo/db/query/query_feature_flags.idl b/src/mongo/db/query/query_feature_flags.idl index a15f8fbd4c4..e1ba84203f6 100644 --- a/src/mongo/db/query/query_feature_flags.idl +++ b/src/mongo/db/query/query_feature_flags.idl @@ -110,11 +110,16 @@ feature_flags: cpp_varname: gFeatureFlagPerShardCursor default: false + featureFlagCommonQueryFramework: + description: "Feature flag for allowing use of Cascades-based query optimizer" + cpp_varname: gfeatureFlagCommonQueryFramework + default: false + featureFlagLastPointQuery: description : "Feature flag for optimizing Last Point queries on time-series collections" cpp_varname: gfeatureFlagLastPointQuery default: false - + featureFlagChangeStreamPreAndPostImagesTimeBasedRetentionPolicy: description: "Feature flag to enable time based retention policy of point-in-time pre- and post-images of documents in change streams" cpp_varname: gFeatureFlagChangeStreamPreAndPostImagesTimeBasedRetentionPolicy diff --git a/src/mongo/db/query/query_knobs.idl b/src/mongo/db/query/query_knobs.idl index 0b30f179764..52952861b6e 100644 --- a/src/mongo/db/query/query_knobs.idl +++ b/src/mongo/db/query/query_knobs.idl @@ -641,3 +641,66 @@ server_parameters: cpp_varname: "internalEnableMultipleAutoGetCollections" cpp_vartype: AtomicWord<bool> default: false + + internalQueryEnableSamplingCardinalityEstimator: + description: "Set to use the sampling-based method for estimating cardinality." + set_at: [ startup, runtime ] + cpp_varname: "internalQueryEnableSamplingCardinalityEstimator" + cpp_vartype: AtomicWord<bool> + default: true + + internalQueryEnableCascadesOptimizer: + description: "Set to use the new optimizer path, must be used in conjunction with the feature flag." + set_at: [ startup, runtime ] + cpp_varname: "internalQueryEnableCascadesOptimizer" + cpp_vartype: AtomicWord<bool> + default: false + + internalCascadesOptimizerDisableScan: + description: "Disable full collection scans." + set_at: [ startup, runtime ] + cpp_varname: "internalCascadesOptimizerDisableScan" + cpp_vartype: AtomicWord<bool> + default: false + + internalCascadesOptimizerDisableIndexes: + description: "Disable index scan plans." + set_at: [ startup, runtime ] + cpp_varname: "internalCascadesOptimizerDisableIndexes" + cpp_vartype: AtomicWord<bool> + default: false + + internalCascadesOptimizerDisableMergeJoinRIDIntersect: + description: "Disable index RID intersection via merge join." + set_at: [ startup, runtime ] + cpp_varname: "internalCascadesOptimizerDisableMergeJoinRIDIntersect" + cpp_vartype: AtomicWord<bool> + default: false + + internalCascadesOptimizerDisableHashJoinRIDIntersect: + description: "Disable index RID intersection via hash join." + set_at: [ startup, runtime ] + cpp_varname: "internalCascadesOptimizerDisableHashJoinRIDIntersect" + cpp_vartype: AtomicWord<bool> + default: false + + internalCascadesOptimizerDisableGroupByAndUnionRIDIntersect: + description: "Disable index RID intersection via group by and union." + set_at: [ startup, runtime ] + cpp_varname: "internalCascadesOptimizerDisableGroupByAndUnionRIDIntersect" + cpp_vartype: AtomicWord<bool> + default: false + + internalCascadesOptimizerKeepRejectedPlans: + description: "Keep track of rejected plans in the memo." + set_at: [ startup, runtime ] + cpp_varname: "internalCascadesOptimizerKeepRejectedPlans" + cpp_vartype: AtomicWord<bool> + default: false + + internalCascadesOptimizerDisableBranchAndBound: + description: "Disable cascades branch-and-bound strategy, and fully evaluate all plans." + set_at: [ startup, runtime ] + cpp_varname: "internalCascadesOptimizerDisableBranchAndBound" + cpp_vartype: AtomicWord<bool> + default: false diff --git a/src/mongo/db/query/sbe_cached_solution_planner.cpp b/src/mongo/db/query/sbe_cached_solution_planner.cpp index f9b7ba125ea..d0fd8db0b81 100644 --- a/src/mongo/db/query/sbe_cached_solution_planner.cpp +++ b/src/mongo/db/query/sbe_cached_solution_planner.cpp @@ -73,6 +73,7 @@ CandidatePlans CachedSolutionPlanner::plan( candidate.root.get(), &candidate.data, candidate.solution.get(), + {}, /* optimizedData */ {}, /* rejectedCandidates */ false, /* isMultiPlan */ candidate.data.debugInfo |