summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYoonsoo Kim <yoonsoo.kim@mongodb.com>2022-03-14 23:31:56 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-03-15 00:20:27 +0000
commita9b05b292e6d92ad08244578b7d48a6ea7506b90 (patch)
tree4a8038d208c3b49792a8e1770c7cedb2f6b0f6d3
parent8d7a536c1e60452028501e7871bace126aee1d65 (diff)
downloadmongo-a9b05b292e6d92ad08244578b7d48a6ea7506b90.tar.gz
SERVER-63260 Fix incorrect $avg result when merging partial results
-rw-r--r--jstests/noPassthrough/accumulator_bug_fix.js620
-rw-r--r--jstests/noPassthrough/sum_accumulator_bug_fix.js415
-rw-r--r--src/mongo/db/pipeline/accumulator.h3
-rw-r--r--src/mongo/db/pipeline/accumulator_avg.cpp87
-rw-r--r--src/mongo/db/pipeline/accumulator_sum.cpp123
-rw-r--r--src/mongo/db/query/sbe_stage_builder_accumulator.cpp19
6 files changed, 772 insertions, 495 deletions
diff --git a/jstests/noPassthrough/accumulator_bug_fix.js b/jstests/noPassthrough/accumulator_bug_fix.js
new file mode 100644
index 00000000000..ef6289282ac
--- /dev/null
+++ b/jstests/noPassthrough/accumulator_bug_fix.js
@@ -0,0 +1,620 @@
+/**
+ * Tests whether $sum accumulator incorrect result bug is fixed on both engines under FCV 6.0.
+ *
+ * @tags: [
+ * requires_fcv_60,
+ * requires_sharding,
+ * ]
+ */
+(function() {
+'use strict';
+
+(function testAccumulatorWhenSpillingOnClassicEngine() {
+ const verifyAccumulatorSpillingResult = (testDesc, accSpec) => {
+ const conn = MongoRunner.runMongod();
+
+ const db = conn.getDB(jsTestName());
+ const coll = db.spilling;
+
+ for (let i = 0; i < 100; ++i) {
+ assert.commandWorked(coll.insert([
+ {k: i, n: 1e+34},
+ {k: i, n: NumberDecimal("0.1")},
+ {k: i, n: NumberDecimal("0.01")},
+ {k: i, n: -1e+34}
+ ]));
+ }
+
+ // Turns on the classical engine.
+ assert.commandWorked(
+ db.adminCommand({setParameter: 1, internalQueryForceClassicEngine: true}));
+
+ const pipeline = [{$group: {_id: "$k", o: accSpec}}, {$group: {_id: "$o"}}];
+
+ // The results when not spilling is the expected results.
+ const expectedRes = coll.aggregate(pipeline).toArray();
+
+ // Has the document source group spill.
+ assert.commandWorked(
+ db.adminCommand({setParameter: 1, internalDocumentSourceGroupMaxMemoryBytes: 1000}),
+ testDesc);
+
+ // Makes sure that the document source group will spill.
+ assert.commandFailedWithCode(
+ coll.runCommand(
+ {aggregate: coll.getName(), pipeline: pipeline, cursor: {}, allowDiskUse: false}),
+ ErrorCodes.QueryExceededMemoryLimitNoDiskUseAllowed,
+ testDesc);
+
+ const classicSpillingRes = coll.aggregate(pipeline, {allowDiskUse: true}).toArray();
+ assert.eq(classicSpillingRes, expectedRes, testDesc);
+
+ MongoRunner.stopMongod(conn);
+ };
+
+ verifyAccumulatorSpillingResult("Verifying $sum spilling bug is fixed on the classic engine",
+ {$sum: "$n"});
+
+ verifyAccumulatorSpillingResult("Verifying $avg spilling bug is fixed on the classic engine",
+ {$avg: "$n"});
+}());
+
+(function testOverTheWireDataFormatOnBothEngines() {
+ const conn = MongoRunner.runMongod();
+
+ const db = conn.getDB(jsTestName());
+ const coll = db.spilling;
+
+ const verifyOverTheWireDataFormatOnBothEngines = (testDesc, pipeline, expectedRes) => {
+ const aggCmd = {
+ aggregate: coll.getName(),
+ pipeline: pipeline,
+ needsMerge: true,
+ fromMongos: true,
+ cursor: {}
+ };
+
+ // Turns on the classical engine.
+ assert.commandWorked(
+ db.adminCommand({setParameter: 1, internalQueryForceClassicEngine: true}));
+ const classicRes = assert.commandWorked(db.runCommand(aggCmd)).cursor.firstBatch;
+ assert.eq(classicRes, expectedRes, testDesc);
+
+ // Turns off the classical engine.
+ assert.commandWorked(
+ db.adminCommand({setParameter: 1, internalQueryForceClassicEngine: false}));
+ const sbeRes = assert.commandWorked(db.runCommand(aggCmd)).cursor.firstBatch;
+ assert.eq(sbeRes, expectedRes, testDesc);
+ };
+
+ (function testOverTheWireDataFormat() {
+ const pipelineWithSum = [{$group: {_id: null, o: {$sum: "$n"}}}];
+ const pipelineWithAvg = [{$group: {_id: null, o: {$avg: "$n"}}}];
+
+ assert.commandWorked(coll.insert({n: NumberInt(1)}));
+ let expectedPartialSum = [
+ NumberInt(16), // The type id for NumberInt
+ 1.0, // sum
+ 0.0 // addend
+ ];
+ verifyOverTheWireDataFormatOnBothEngines(
+ "Partial sum of an int", pipelineWithSum, [{_id: null, o: expectedPartialSum}]);
+ verifyOverTheWireDataFormatOnBothEngines(
+ "Partial avg of an int", pipelineWithAvg, [{
+ _id: null,
+ o: {subTotal: 1.0, count: NumberLong(1), subTotalError: 0.0, ps: expectedPartialSum}
+ }]);
+
+ assert.commandWorked(coll.insert({n: NumberLong(1)}));
+ expectedPartialSum = [
+ NumberInt(18), // The type id for NumberLong
+ 2.0, // sum
+ 0.0 // addend
+ ];
+ verifyOverTheWireDataFormatOnBothEngines("Partial sum of an int and a long",
+ pipelineWithSum,
+ [{_id: null, o: expectedPartialSum}]);
+ verifyOverTheWireDataFormatOnBothEngines(
+ "Partial avg of an int and a long", pipelineWithAvg, [{
+ _id: null,
+ o: {subTotal: 2.0, count: NumberLong(2), subTotalError: 0.0, ps: expectedPartialSum}
+ }]);
+
+ assert.commandWorked(coll.insert({n: NumberLong("9223372036854775807")}));
+ expectedPartialSum = [
+ NumberInt(18), // The type id for NumberLong
+ 9223372036854775808.0, // sum
+ 1.0 // addend
+ ];
+ verifyOverTheWireDataFormatOnBothEngines("Partial sum of an int/a long/the long max",
+ pipelineWithSum,
+ [{_id: null, o: expectedPartialSum}]);
+ verifyOverTheWireDataFormatOnBothEngines(
+ "Partial avg of an int/a long/the long max", pipelineWithAvg, [{
+ _id: null,
+ o: {
+ subTotal: 9223372036854775808.0,
+ count: NumberLong(3),
+ subTotalError: 1.0,
+ ps: expectedPartialSum
+ }
+ }]);
+
+ // A double can always expresses 15 digits precisely. So, 1.0 + 0.00000000000001 is
+ // precisely expressed by the 'addend' element.
+ assert.commandWorked(coll.insert({n: 0.00000000000001}));
+ expectedPartialSum = [
+ NumberInt(1), // The type id for NumberDouble
+ 9223372036854775808.0, // sum
+ 1.00000000000001 // addend
+ ];
+ verifyOverTheWireDataFormatOnBothEngines(
+ "Partial sum of mixed data leading to a number that a double can't express",
+ pipelineWithSum,
+ [{_id: null, o: expectedPartialSum}]);
+ verifyOverTheWireDataFormatOnBothEngines(
+ "Partial avg of mixed data leading to a number that a double can't express",
+ pipelineWithAvg,
+ [{
+ _id: null,
+ o: {
+ subTotal: 9223372036854775808.0,
+ count: NumberLong(4),
+ subTotalError: 1.00000000000001,
+ ps: expectedPartialSum
+ }
+ }]);
+
+ assert.commandWorked(coll.insert({n: NumberDecimal("1.0")}));
+ expectedPartialSum = [
+ NumberInt(1), // The type id for NumberDouble
+ 9223372036854775808.0, // sum
+ 1.00000000000001, // addend
+ NumberDecimal("1.0")
+ ];
+ verifyOverTheWireDataFormatOnBothEngines("Partial sum of mixed data which has a decimal",
+ pipelineWithSum,
+ [{_id: null, o: expectedPartialSum}]);
+ verifyOverTheWireDataFormatOnBothEngines(
+ "Partial avg of mixed data which has a decimal", pipelineWithAvg, [{
+ _id: null,
+ o: {
+ subTotal: NumberDecimal("9223372036854775810.000000000000010"),
+ count: NumberLong(5),
+ ps: expectedPartialSum
+ }
+ }]);
+
+ assert(coll.drop());
+
+ assert.commandWorked(coll.insert([{n: Number.MAX_VALUE}, {n: Number.MAX_VALUE}]));
+ expectedPartialSum = [
+ NumberInt(1), // The type id for NumberDouble
+ Infinity, // sum
+ NaN // addend
+ ];
+ verifyOverTheWireDataFormatOnBothEngines(
+ "Partial sum of two double max", pipelineWithSum, [{_id: null, o: expectedPartialSum}]);
+ verifyOverTheWireDataFormatOnBothEngines(
+ "Partial avg of two double max", pipelineWithAvg, [{
+ _id: null,
+ o: {
+ subTotal: Infinity,
+ count: NumberLong(2),
+ subTotalError: NaN,
+ ps: expectedPartialSum
+ }
+ }]);
+
+ assert(coll.drop());
+
+ assert.commandWorked(coll.insert([{n: NumberDecimal("1.0")}, {n: 1.0}]));
+ expectedPartialSum = [
+ NumberInt(1), // The type id for NumberDouble
+ 1.0, // sum
+ 0.0, // addend
+ NumberDecimal("1.0")
+ ];
+ verifyOverTheWireDataFormatOnBothEngines("Partial sum of a decimal and a double",
+ pipelineWithSum,
+ [{_id: null, o: expectedPartialSum}]);
+ verifyOverTheWireDataFormatOnBothEngines(
+ "Partial avg of a decimal and a double", pipelineWithAvg, [{
+ _id: null,
+ o: {subTotal: NumberDecimal("2.0"), count: NumberLong(2), ps: expectedPartialSum}
+ }]);
+ }());
+
+ MongoRunner.stopMongod(conn);
+}());
+
+(function testShardedAccumulatorOnBothEngines() {
+ const st = new ShardingTest({shards: 2});
+
+ const db = st.getDB(jsTestName());
+ const dbAtShard0 = st.shard0.getDB(jsTestName());
+ const dbAtShard1 = st.shard1.getDB(jsTestName());
+
+ // Makes sure that the test db is sharded.
+ assert.commandWorked(st.s0.adminCommand({enableSharding: db.getName()}));
+
+ let verifyShardedAccumulatorResultsOnBothEngine = (testDesc, coll, pipeline, expectedRes) => {
+ // Turns to the classic engine at the shards.
+ assert.commandWorked(
+ dbAtShard0.adminCommand({setParameter: 1, internalQueryForceClassicEngine: true}));
+ assert.commandWorked(
+ dbAtShard1.adminCommand({setParameter: 1, internalQueryForceClassicEngine: true}));
+
+ // Verifies that the classic engine's results are same as the expected results.
+ const classicRes = coll.aggregate(pipeline).toArray();
+ assert.eq(classicRes, expectedRes, testDesc);
+
+ // Turns to the SBE engine at the shards.
+ assert.commandWorked(
+ dbAtShard0.adminCommand({setParameter: 1, internalQueryForceClassicEngine: false}));
+ assert.commandWorked(
+ dbAtShard1.adminCommand({setParameter: 1, internalQueryForceClassicEngine: false}));
+
+ // Verifies that the SBE engine's results are same as the expected results.
+ const sbeRes = coll.aggregate(pipeline).toArray();
+ assert.eq(sbeRes, expectedRes, testDesc);
+ };
+
+ let shardCollectionByHashing = coll => {
+ coll.drop();
+
+ // Makes sure that the collection is sharded.
+ assert.commandWorked(
+ st.s0.adminCommand({shardCollection: coll.getFullName(), key: {_id: "hashed"}}));
+
+ return coll;
+ };
+
+ let hashShardedColl = shardCollectionByHashing(db.partial_sum);
+ let unshardedColl = db.partial_sum2;
+
+ for (let i = 0; i < 3; ++i) {
+ assert.commandWorked(hashShardedColl.insert([
+ {k: i, n: 1e+34},
+ {k: i, n: NumberDecimal("0.1")},
+ {k: i, n: NumberDecimal("0.01")},
+ {k: i, n: -1e+34}
+ ]));
+ assert.commandWorked(unshardedColl.insert([
+ {k: i, n: 1e+34},
+ {k: i, n: NumberDecimal("0.1")},
+ {k: i, n: NumberDecimal("0.01")},
+ {k: i, n: -1e+34}
+ ]));
+ }
+
+ let pipeline = [{$group: {_id: "$k", s: {$sum: "$n"}}}, {$group: {_id: "$s"}}];
+
+ // The results on an unsharded collection is the expected results.
+ let expectedRes = unshardedColl.aggregate(pipeline).toArray();
+ verifyShardedAccumulatorResultsOnBothEngine(
+ "Sharded sum for mixed data by which only decimal sum survive",
+ hashShardedColl,
+ pipeline,
+ expectedRes);
+
+ pipeline = [{$group: {_id: "$k", s: {$avg: "$n"}}}, {$group: {_id: "$s"}}];
+
+ // The results on an unsharded collection is the expected results.
+ expectedRes = unshardedColl.aggregate(pipeline).toArray();
+ verifyShardedAccumulatorResultsOnBothEngine(
+ "Sharded avg for mixed data by which only decimal sum survive",
+ hashShardedColl,
+ pipeline,
+ expectedRes);
+
+ const int32Max = 2147483647;
+ const numberIntMax = NumberInt(int32Max);
+ const numberLongMax = NumberLong("9223372036854775807");
+ const verySmallDecimal = NumberDecimal("1e-15");
+ const veryLargeDecimal = NumberDecimal("1e+33");
+
+ // This value is precisely representable by a double.
+ const doubleClosestToLongMax = 9223372036854775808.0;
+ [{
+ testDesc: "No documents evaluated",
+ inputs: [{}],
+ expectedRes: [{_id: null, o: NumberInt(0)}]
+ },
+ {
+ testDesc: "An int",
+ inputs: [{n: NumberInt(10)}],
+ expectedRes: [{_id: null, o: NumberInt(10)}]
+ },
+ {
+ testDesc: "a long",
+ inputs: [{n: NumberLong(10)}],
+ expectedRes: [{_id: null, o: NumberLong(10)}]
+ },
+ {testDesc: "A double", inputs: [{n: 10.0}], expectedRes: [{_id: null, o: 10.0}]},
+ {
+ testDesc: "A long that cannot be expressed as an int",
+ inputs: [{n: NumberLong("60000000000")}],
+ expectedRes: [{_id: null, o: NumberLong("60000000000")}]
+ },
+ {
+ testDesc: "A non integer valued double",
+ inputs: [{n: 7.5}],
+ expectedRes: [{_id: null, o: 7.5}]
+ },
+ {testDesc: "A nan double", inputs: [{n: NaN}], expectedRes: [{_id: null, o: NaN}]},
+ {testDesc: "A -nan double", inputs: [{n: -NaN}], expectedRes: [{_id: null, o: -NaN}]},
+ {
+ testDesc: "A infinity double",
+ inputs: [{n: Infinity}],
+ expectedRes: [{_id: null, o: Infinity}]
+ },
+ {
+ testDesc: "A -infinity double",
+ inputs: [{n: -Infinity}],
+ expectedRes: [{_id: null, o: -Infinity}]
+ },
+ {
+ testDesc: "Two ints are summed",
+ inputs: [{n: NumberInt(4)}, {n: NumberInt(5)}],
+ expectedRes: [{_id: null, o: NumberInt(9)}]
+ },
+ {
+ testDesc: "An int and a long",
+ inputs: [{n: NumberInt(4)}, {n: NumberLong(5)}],
+ expectedRes: [{_id: null, o: NumberLong(9)}]
+ },
+ {
+ testDesc: "Two longs",
+ inputs: [{n: NumberLong(4)}, {n: NumberLong(5)}],
+ expectedRes: [{_id: null, o: NumberLong(9)}]
+ },
+ {
+ testDesc: "An int and a double",
+ inputs: [{n: NumberInt(4)}, {n: 5.5}],
+ expectedRes: [{_id: null, o: 9.5}]
+ },
+ {
+ testDesc: "A long and a double",
+ inputs: [{n: NumberLong(4)}, {n: 5.5}],
+ expectedRes: [{_id: null, o: 9.5}]
+ },
+ {testDesc: "Two doubles", inputs: [{n: 2.5}, {n: 5.5}], expectedRes: [{_id: null, o: 8.0}]},
+ {
+ testDesc: "An int, a long, and a double",
+ inputs: [{n: NumberInt(5)}, {n: NumberLong(99)}, {n: 0.2}],
+ expectedRes: [{_id: null, o: 104.2}]
+ },
+ {
+ testDesc: "Two decimals",
+ inputs: [{n: NumberDecimal("-10.100")}, {n: NumberDecimal("20.200")}],
+ expectedRes: [{_id: null, o: NumberDecimal("10.100")}]
+ },
+ {
+ testDesc: "Two longs and a decimal",
+ inputs: [{n: NumberLong(10)}, {n: NumberLong(10)}, {n: NumberDecimal("10.000")}],
+ expectedRes: [{_id: null, o: NumberDecimal("30.000")}]
+ },
+ {
+ testDesc: "A double and a decimal",
+ inputs: [{n: 2.5}, {n: NumberDecimal("2.5")}],
+ expectedRes: [{_id: null, o: NumberDecimal("5.0")}]
+ },
+ {
+ testDesc: "An int, long, double and decimal",
+ inputs: [{n: NumberInt(10)}, {n: NumberLong(10)}, {n: 10.5}, {n: NumberDecimal("9.6")}],
+ expectedRes: [{_id: null, o: NumberDecimal("40.1")}]
+ },
+ {
+ testDesc: "A long max and a very small decimal resulting in 34 digits",
+ inputs: [{n: numberLongMax}, {n: verySmallDecimal}],
+ expectedRes: [{_id: null, o: NumberDecimal("9223372036854775807.000000000000001")}]
+ },
+ {
+ testDesc: "A long and a very large decimal resulting in 34 digits",
+ inputs: [{n: NumberLong(1)}, {n: veryLargeDecimal}],
+ expectedRes: [{_id: null, o: NumberDecimal("1000000000000000000000000000000001")}]
+ },
+ {
+ testDesc:
+ "The double closest to the long max and a very small decimal resulting in 34 digits",
+ inputs: [{n: doubleClosestToLongMax}, {n: verySmallDecimal}],
+ expectedRes: [{_id: null, o: NumberDecimal("9223372036854775808.000000000000001")}]
+ },
+ {
+ testDesc: "A double and a very large decimal resulting in 34 digits",
+ inputs: [{n: 1.0}, {n: veryLargeDecimal}],
+ expectedRes: [{_id: null, o: NumberDecimal("1000000000000000000000000000000001")}]
+ },
+ {
+ testDesc: "A negative value is summed",
+ inputs: [{n: NumberInt(5)}, {n: -8.5}],
+ expectedRes: [{_id: null, o: -3.5}]
+ },
+ {
+ testDesc: "A long and a negative int are summed",
+ inputs: [{n: NumberLong(5)}, {n: NumberInt(-6)}],
+ expectedRes: [{_id: null, o: NumberLong(-1)}]
+ },
+ {
+ testDesc: "Two ints do not overflow",
+ inputs: [{n: numberIntMax}, {n: NumberInt(10)}],
+ expectedRes: [{_id: null, o: NumberLong(int32Max + 10)}]
+ },
+ {
+ testDesc: "Two negative ints do not overflow",
+ inputs: [{n: NumberInt(-int32Max)}, {n: NumberInt(-10)}],
+ expectedRes: [{_id: null, o: NumberLong(-int32Max - 10)}]
+ },
+ {
+ testDesc: "An int and a long do not trigger an int overflow",
+ inputs: [{n: numberIntMax}, {n: NumberLong(1)}],
+ expectedRes: [{_id: null, o: NumberLong(int32Max + 1)}]
+ },
+ {
+ testDesc: "An int and a double do not trigger an int overflow",
+ inputs: [{n: numberIntMax}, {n: 1.0}],
+ expectedRes: [{_id: null, o: int32Max + 1.0}]
+ },
+ {
+ testDesc: "An int and a long overflow into a double",
+ inputs: [{n: NumberInt(1)}, {n: numberLongMax}],
+ expectedRes: [{_id: null, o: doubleClosestToLongMax}]
+ },
+ {
+ testDesc: "Two longs overflow into a double",
+ inputs: [{n: numberLongMax}, {n: numberLongMax}],
+ expectedRes: [{_id: null, o: doubleClosestToLongMax * 2}]
+ },
+ {
+ testDesc: "A long and a double do not trigger a long overflow",
+ inputs: [{n: numberLongMax}, {n: 1.0}],
+ expectedRes: [{_id: null, o: doubleClosestToLongMax}]
+ },
+ {
+ testDesc: "Two doubles overflow to infinity",
+ inputs: [{n: Number.MAX_VALUE}, {n: Number.MAX_VALUE}],
+ expectedRes: [{_id: null, o: Infinity}]
+ },
+ {
+ testDesc: "Two large integers do not overflow if a double is added later",
+ inputs: [{n: numberLongMax}, {n: numberLongMax}, {n: 1.0}],
+ expectedRes: [{_id: null, o: doubleClosestToLongMax * 2}]
+ },
+ {
+ testDesc: "An int and a NaN double",
+ inputs: [{n: NumberInt(4)}, {n: NaN}],
+ expectedRes: [{_id: null, o: NaN}]
+ },
+ {
+ testDesc: "Null values are ignored",
+ inputs: [{n: NumberInt(5)}, {n: null}],
+ expectedRes: [{_id: null, o: NumberInt(5)}]
+ },
+ {
+ testDesc: "Missing values are ignored",
+ inputs: [{n: NumberInt(9)}, {}],
+ expectedRes: [{_id: null, o: NumberInt(9)}]
+ }].forEach(({testDesc, inputs, expectedRes}) => {
+ hashShardedColl.drop();
+ assert.commandWorked(hashShardedColl.insert(inputs));
+
+ verifyShardedAccumulatorResultsOnBothEngine(
+ testDesc, hashShardedColl, [{$group: {_id: null, o: {$sum: "$n"}}}], expectedRes);
+ });
+
+ [{testDesc: "No documents evaluated", inputs: [{}], expectedRes: [{_id: null, o: null}]},
+ {
+ testDesc: "One int value is converted to double",
+ inputs: [{n: NumberInt(3)}],
+ expectedRes: [{_id: null, o: 3.0}]
+ },
+ {
+ testDesc: "One long value is converted to double",
+ inputs: [{n: NumberLong(-4)}],
+ expectedRes: [{_id: null, o: -4.0}]
+ },
+ {testDesc: "One double value", inputs: [{n: 22.6}], expectedRes: [{_id: null, o: 22.6}]},
+ {
+ testDesc: "Averaging two ints",
+ inputs: [{n: NumberInt(10)}, {n: NumberInt(11)}],
+ expectedRes: [{_id: null, o: 10.5}]
+ },
+ {
+ testDesc: "Averaging two longs",
+ inputs: [{n: NumberLong(10)}, {n: NumberLong(11)}],
+ expectedRes: [{_id: null, o: 10.5}]
+ },
+ {
+ testDesc: "Averaging two doubles",
+ inputs: [{n: 10.0}, {n: 11.0}],
+ expectedRes: [{_id: null, o: 10.5}]
+ },
+ {
+ testDesc: "The average of an int and a double is a double",
+ inputs: [{n: NumberInt(10)}, {n: 11.0}],
+ expectedRes: [{_id: null, o: 10.5}]
+ },
+ {
+ testDesc: "The average of a long and a double is a double",
+ inputs: [{n: NumberLong(10)}, {n: 11.0}],
+ expectedRes: [{_id: null, o: 10.5}]
+ },
+ {
+ testDesc: "The average of an int and a long is a double",
+ inputs: [{n: NumberInt(5)}, {n: NumberLong(3)}],
+ expectedRes: [{_id: null, o: 4.0}]
+ },
+ {
+ testDesc: "Averaging an int, long, and double",
+ inputs: [{n: NumberInt(1)}, {n: NumberLong(2)}, {n: 6.0}],
+ expectedRes: [{_id: null, o: 3.0}]
+ },
+ {
+ testDesc: "Unlike $sum, two ints do not overflow in the 'total' portion of the average",
+ inputs: [{n: numberIntMax}, {n: numberIntMax}],
+ expectedRes: [{_id: null, o: int32Max}]
+ },
+ {
+ testDesc: "Two longs do overflow in the 'total' portion of the average",
+ inputs: [{n: numberLongMax}, {n: numberLongMax}],
+ expectedRes: [{_id: null, o: doubleClosestToLongMax}]
+ },
+ {
+ testDesc: "Averaging an Infinity and a number",
+ inputs: [{n: Infinity}, {n: 1}],
+ expectedRes: [{_id: null, o: Infinity}]
+ },
+ {
+ testDesc: "Averaging two Infinities",
+ inputs: [{n: Infinity}, {n: Infinity}],
+ expectedRes: [{_id: null, o: Infinity}]
+ },
+ {
+ testDesc: "Averaging an Infinity and an NaN",
+ inputs: [{n: Infinity}, {n: NaN}],
+ expectedRes: [{_id: null, o: NaN}]
+ },
+ {
+ testDesc: "Averaging an NaN and a number",
+ inputs: [{n: NaN}, {n: 1}],
+ expectedRes: [{_id: null, o: NaN}]
+ },
+ {
+ testDesc: "Averaging two NaNs",
+ inputs: [{n: NaN}, {n: NaN}],
+ expectedRes: [{_id: null, o: NaN}]
+ },
+ {
+ testDesc: "Averaging two decimals",
+ inputs: [
+ {n: NumberDecimal("-1234567890.1234567889")},
+ {n: NumberDecimal("-1234567890.1234567891")}
+ ],
+ expectedRes: [{_id: null, o: NumberDecimal("-1234567890.1234567890")}]
+ },
+ {
+ testDesc: "Averaging two longs and a decimal results in an accurate decimal result",
+ inputs: [
+ {n: NumberLong("1234567890123456788")},
+ {n: NumberLong("1234567890123456789")},
+ {n: NumberDecimal("1234567890123456790.037037036703702")}
+ ],
+ expectedRes: [{_id: null, o: NumberDecimal("1234567890123456789.012345678901234")}]
+ },
+ {
+ testDesc: "Averaging a double and a decimal",
+ inputs: [{n: 1.0e22}, {n: NumberDecimal("9999999999999999999999.9999999999")}],
+ expectedRes: [{_id: null, o: NumberDecimal("9999999999999999999999.99999999995")}]
+ },
+ ].forEach(({testDesc, inputs, expectedRes}) => {
+ hashShardedColl.drop();
+ assert.commandWorked(hashShardedColl.insert(inputs));
+
+ verifyShardedAccumulatorResultsOnBothEngine(
+ testDesc, hashShardedColl, [{$group: {_id: null, o: {$avg: "$n"}}}], expectedRes);
+ });
+
+ st.stop();
+}());
+}());
diff --git a/jstests/noPassthrough/sum_accumulator_bug_fix.js b/jstests/noPassthrough/sum_accumulator_bug_fix.js
deleted file mode 100644
index 69a0bb07a01..00000000000
--- a/jstests/noPassthrough/sum_accumulator_bug_fix.js
+++ /dev/null
@@ -1,415 +0,0 @@
-/**
- * Tests whether $sum accumulator incorrect result bug is fixed on both engines under FCV 6.0.
- *
- * @tags: [
- * requires_fcv_60,
- * requires_sharding,
- * ]
- */
-(function() {
-'use strict';
-
-(function testSumWhenSpillingOnClassicEngine() {
- const conn = MongoRunner.runMongod();
-
- const db = conn.getDB(jsTestName());
- const coll = db.spilling;
-
- for (let i = 0; i < 100; ++i) {
- assert.commandWorked(coll.insert([
- {k: i, n: 1e+34},
- {k: i, n: NumberDecimal("0.1")},
- {k: i, n: NumberDecimal("0.01")},
- {k: i, n: -1e+34}
- ]));
- }
-
- // Turns on the classical engine.
- db.adminCommand({setParameter: 1, internalQueryForceClassicEngine: true});
-
- const pipeline = [{$group: {_id: "$k", o: {$sum: "$n"}}}, {$group: {_id: "$o"}}];
- // The results when not spilling is the expected results.
- const expectedRes = coll.aggregate(pipeline).toArray();
-
- // Has the document source group spill.
- assert.commandWorked(
- db.adminCommand({setParameter: 1, internalDocumentSourceGroupMaxMemoryBytes: 1000}));
-
- // Makes sure that the document source group will spill.
- assert.commandFailedWithCode(
- coll.runCommand(
- {aggregate: coll.getName(), pipeline: pipeline, cursor: {}, allowDiskUse: false}),
- ErrorCodes.QueryExceededMemoryLimitNoDiskUseAllowed);
-
- const classicSpillingRes = coll.aggregate(pipeline, {allowDiskUse: true}).toArray();
- assert.eq(classicSpillingRes, expectedRes);
-
- MongoRunner.stopMongod(conn);
-}());
-
-(function testOverTheWireDataFormatOnBothEngines() {
- const conn = MongoRunner.runMongod();
-
- const db = conn.getDB(jsTestName());
- assert.commandWorked(db.dropDatabase());
- const coll = db.spilling;
- const pipeline = [{$group: {_id: null, o: {$sum: "$n"}}}];
- const aggCmd = {
- aggregate: coll.getName(),
- pipeline: pipeline,
- needsMerge: true,
- fromMongos: true,
- cursor: {}
- };
-
- const verifyOverTheWireDataFormatOnBothEngines = (expectedRes) => {
- // Turns on the classical engine.
- assert.commandWorked(
- db.adminCommand({setParameter: 1, internalQueryForceClassicEngine: true}));
- const classicRes = assert.commandWorked(db.runCommand(aggCmd)).cursor.firstBatch;
- assert.eq(classicRes, expectedRes);
-
- // Turns off the classical engine.
- assert.commandWorked(
- db.adminCommand({setParameter: 1, internalQueryForceClassicEngine: false}));
- const sbeRes = assert.commandWorked(db.runCommand(aggCmd)).cursor.firstBatch;
- assert.eq(sbeRes, expectedRes);
- };
-
- assert.commandWorked(coll.insert({n: NumberInt(1)}));
- verifyOverTheWireDataFormatOnBothEngines([{
- _id: null,
- o: [
- NumberInt(16), // The type id for NumberInt
- 1.0, // sum
- 0.0 // addend
- ]
- }]);
-
- assert.commandWorked(coll.insert({n: NumberLong(1)}));
- verifyOverTheWireDataFormatOnBothEngines([{
- _id: null,
- o: [
- NumberInt(18), // The type id for NumberLong
- 2.0, // sum
- 0.0 // addend
- ]
- }]);
-
- assert.commandWorked(coll.insert({n: NumberLong("9223372036854775807")}));
- verifyOverTheWireDataFormatOnBothEngines([{
- _id: null,
- o: [
- NumberInt(18), // The type id for NumberLong
- 9223372036854775808.0, // sum
- 1.0 // addend
- ]
- }]);
-
- // A double can always expresses 15 digits precisely. So, 1.0 + 0.00000000000001 is precisely
- // expressed by the 'addend' element.
- assert.commandWorked(coll.insert({n: 0.00000000000001}));
- verifyOverTheWireDataFormatOnBothEngines([{
- _id: null,
- o: [
- NumberInt(1), // The type id for NumberDouble
- 9223372036854775808.0, // sum
- 1.00000000000001 // addend
- ]
- }]);
-
- assert.commandWorked(coll.insert({n: NumberDecimal("1.0")}));
- verifyOverTheWireDataFormatOnBothEngines([{
- _id: null,
- o: [
- NumberInt(1), // The type id for NumberDouble
- 9223372036854775808.0, // sum
- 1.00000000000001, // addend
- NumberDecimal("1.0")
- ]
- }]);
-
- assert(coll.drop());
-
- assert.commandWorked(coll.insert([{n: Number.MAX_VALUE}, {n: Number.MAX_VALUE}]));
- verifyOverTheWireDataFormatOnBothEngines([{
- _id: null,
- o: [
- NumberInt(1), // The type id for NumberDouble
- Infinity, // sum
- NaN // addend
- ]
- }]);
-
- MongoRunner.stopMongod(conn);
-}());
-
-(function testShardedSumOnBothEngines() {
- const st = new ShardingTest({shards: 2});
-
- const db = st.getDB(jsTestName());
- assert.commandWorked(db.dropDatabase());
- const dbAtShard0 = st.shard0.getDB(jsTestName());
- const dbAtShard1 = st.shard1.getDB(jsTestName());
-
- // Makes sure that the test db is sharded.
- assert.commandWorked(st.s0.adminCommand({enableSharding: db.getName()}));
-
- let verifyShardedSumResultsOnBothEngine = (testDesc, coll, pipeline, expectedRes) => {
- // Turns to the classic engine at the shards.
- assert.commandWorked(
- dbAtShard0.adminCommand({setParameter: 1, internalQueryForceClassicEngine: true}));
- assert.commandWorked(
- dbAtShard1.adminCommand({setParameter: 1, internalQueryForceClassicEngine: true}));
-
- // Verifies that the classic engine's results are same as the expected results.
- const classicRes = coll.aggregate(pipeline).toArray();
- assert.eq(classicRes, expectedRes, testDesc);
-
- // Turns to the SBE engine at the shards.
- assert.commandWorked(
- dbAtShard0.adminCommand({setParameter: 1, internalQueryForceClassicEngine: false}));
- assert.commandWorked(
- dbAtShard1.adminCommand({setParameter: 1, internalQueryForceClassicEngine: false}));
-
- // Verifies that the SBE engine's results are same as the expected results.
- const sbeRes = coll.aggregate(pipeline).toArray();
- assert.eq(sbeRes, expectedRes, testDesc);
- };
-
- let shardCollectionByHashing = coll => {
- coll.drop();
-
- // Makes sure that the collection is sharded.
- assert.commandWorked(
- st.s0.adminCommand({shardCollection: coll.getFullName(), key: {_id: "hashed"}}));
-
- return coll;
- };
-
- let hashShardedColl = shardCollectionByHashing(db.partial_sum);
- let unshardedColl = db.partial_sum2;
-
- for (let i = 0; i < 3; ++i) {
- assert.commandWorked(hashShardedColl.insert([
- {k: i, n: 1e+34},
- {k: i, n: NumberDecimal("0.1")},
- {k: i, n: NumberDecimal("0.01")},
- {k: i, n: -1e+34}
- ]));
- assert.commandWorked(unshardedColl.insert([
- {k: i, n: 1e+34},
- {k: i, n: NumberDecimal("0.1")},
- {k: i, n: NumberDecimal("0.01")},
- {k: i, n: -1e+34}
- ]));
- }
-
- const pipeline = [{$group: {_id: "$k", s: {$sum: "$n"}}}, {$group: {_id: "$s"}}];
-
- // The results on an unsharded collection is the expected results.
- const expectedRes = unshardedColl.aggregate(pipeline).toArray();
- verifyShardedSumResultsOnBothEngine(
- "Sharded sum for mixed data by which only decimal sum survive",
- hashShardedColl,
- pipeline,
- expectedRes);
-
- const int32Max = 2147483647;
- const numberIntMax = NumberInt(int32Max);
- const numberLongMax = NumberLong("9223372036854775807");
- const verySmallDecimal = NumberDecimal("1e-15");
- const veryLargeDecimal = NumberDecimal("1e+33");
-
- // This value is precisely representable by a double.
- const doubleClosestToLongMax = 9223372036854775808.0;
- [{
- testDesc: "No documents evaluated",
- inputs: [{}],
- expectedRes: [{_id: null, o: NumberInt(0)}]
- },
- {
- testDesc: "An int",
- inputs: [{n: NumberInt(10)}],
- expectedRes: [{_id: null, o: NumberInt(10)}]
- },
- {
- testDesc: "a long",
- inputs: [{n: NumberLong(10)}],
- expectedRes: [{_id: null, o: NumberLong(10)}]
- },
- {testDesc: "A double", inputs: [{n: 10.0}], expectedRes: [{_id: null, o: 10.0}]},
- {
- testDesc: "A long that cannot be expressed as an int",
- inputs: [{n: NumberLong("60000000000")}],
- expectedRes: [{_id: null, o: NumberLong("60000000000")}]
- },
- {
- testDesc: "A non integer valued double",
- inputs: [{n: 7.5}],
- expectedRes: [{_id: null, o: 7.5}]
- },
- {testDesc: "A nan double", inputs: [{n: NaN}], expectedRes: [{_id: null, o: NaN}]},
- {testDesc: "A -nan double", inputs: [{n: -NaN}], expectedRes: [{_id: null, o: -NaN}]},
- {
- testDesc: "A infinity double",
- inputs: [{n: Infinity}],
- expectedRes: [{_id: null, o: Infinity}]
- },
- {
- testDesc: "A -infinity double",
- inputs: [{n: -Infinity}],
- expectedRes: [{_id: null, o: -Infinity}]
- },
- {
- testDesc: "Two ints are summed",
- inputs: [{n: NumberInt(4)}, {n: NumberInt(5)}],
- expectedRes: [{_id: null, o: NumberInt(9)}]
- },
- {
- testDesc: "An int and a long",
- inputs: [{n: NumberInt(4)}, {n: NumberLong(5)}],
- expectedRes: [{_id: null, o: NumberLong(9)}]
- },
- {
- testDesc: "Two longs",
- inputs: [{n: NumberLong(4)}, {n: NumberLong(5)}],
- expectedRes: [{_id: null, o: NumberLong(9)}]
- },
- {
- testDesc: "An int and a double",
- inputs: [{n: NumberInt(4)}, {n: 5.5}],
- expectedRes: [{_id: null, o: 9.5}]
- },
- {
- testDesc: "A long and a double",
- inputs: [{n: NumberLong(4)}, {n: 5.5}],
- expectedRes: [{_id: null, o: 9.5}]
- },
- {testDesc: "Two doubles", inputs: [{n: 2.5}, {n: 5.5}], expectedRes: [{_id: null, o: 8.0}]},
- {
- testDesc: "An int, a long, and a double",
- inputs: [{n: NumberInt(5)}, {n: NumberLong(99)}, {n: 0.2}],
- expectedRes: [{_id: null, o: 104.2}]
- },
- {
- testDesc: "Two decimals",
- inputs: [{n: NumberDecimal("-10.100")}, {n: NumberDecimal("20.200")}],
- expectedRes: [{_id: null, o: NumberDecimal("10.100")}]
- },
- {
- testDesc: "Two longs and a decimal",
- inputs: [{n: NumberLong(10)}, {n: NumberLong(10)}, {n: NumberDecimal("10.000")}],
- expectedRes: [{_id: null, o: NumberDecimal("30.000")}]
- },
- {
- testDesc: "A double and a decimal",
- inputs: [{n: 2.5}, {n: NumberDecimal("2.5")}],
- expectedRes: [{_id: null, o: NumberDecimal("5.0")}]
- },
- {
- testDesc: "An int, long, double and decimal",
- inputs: [{n: NumberInt(10)}, {n: NumberLong(10)}, {n: 10.5}, {n: NumberDecimal("9.6")}],
- expectedRes: [{_id: null, o: NumberDecimal("40.1")}]
- },
- {
- testDesc: "A long max and a very small decimal resulting in 34 digits",
- inputs: [{n: numberLongMax}, {n: verySmallDecimal}],
- expectedRes: [{_id: null, o: NumberDecimal("9223372036854775807.000000000000001")}]
- },
- {
- testDesc: "A long and a very large decimal resulting in 34 digits",
- inputs: [{n: NumberLong(1)}, {n: veryLargeDecimal}],
- expectedRes: [{_id: null, o: NumberDecimal("1000000000000000000000000000000001")}]
- },
- {
- testDesc:
- "The double closest to the long max and a very small decimal resulting in 34 digits",
- inputs: [{n: doubleClosestToLongMax}, {n: verySmallDecimal}],
- expectedRes: [{_id: null, o: NumberDecimal("9223372036854775808.000000000000001")}]
- },
- {
- testDesc: "A double and a very large decimal resulting in 34 digits",
- inputs: [{n: 1.0}, {n: veryLargeDecimal}],
- expectedRes: [{_id: null, o: NumberDecimal("1000000000000000000000000000000001")}]
- },
- {
- testDesc: "A negative value is summed",
- inputs: [{n: NumberInt(5)}, {n: -8.5}],
- expectedRes: [{_id: null, o: -3.5}]
- },
- {
- testDesc: "A long and a negative int are summed",
- inputs: [{n: NumberLong(5)}, {n: NumberInt(-6)}],
- expectedRes: [{_id: null, o: NumberLong(-1)}]
- },
- {
- testDesc: "Two ints do not overflow",
- inputs: [{n: numberIntMax}, {n: NumberInt(10)}],
- expectedRes: [{_id: null, o: NumberLong(int32Max + 10)}]
- },
- {
- testDesc: "Two negative ints do not overflow",
- inputs: [{n: NumberInt(-int32Max)}, {n: NumberInt(-10)}],
- expectedRes: [{_id: null, o: NumberLong(-int32Max - 10)}]
- },
- {
- testDesc: "An int and a long do not trigger an int overflow",
- inputs: [{n: numberIntMax}, {n: NumberLong(1)}],
- expectedRes: [{_id: null, o: NumberLong(int32Max + 1)}]
- },
- {
- testDesc: "An int and a double do not trigger an int overflow",
- inputs: [{n: numberIntMax}, {n: 1.0}],
- expectedRes: [{_id: null, o: int32Max + 1.0}]
- },
- {
- testDesc: "An int and a long overflow into a double",
- inputs: [{n: NumberInt(1)}, {n: numberLongMax}],
- expectedRes: [{_id: null, o: doubleClosestToLongMax}]
- },
- {
- testDesc: "Two longs overflow into a double",
- inputs: [{n: numberLongMax}, {n: numberLongMax}],
- expectedRes: [{_id: null, o: doubleClosestToLongMax * 2}]
- },
- {
- testDesc: "A long and a double do not trigger a long overflow",
- inputs: [{n: numberLongMax}, {n: 1.0}],
- expectedRes: [{_id: null, o: doubleClosestToLongMax}]
- },
- {
- testDesc: "Two doubles overflow to infinity",
- inputs: [{n: Number.MAX_VALUE}, {n: Number.MAX_VALUE}],
- expectedRes: [{_id: null, o: Infinity}]
- },
- {
- testDesc: "Two large integers do not overflow if a double is added later",
- inputs: [{n: numberLongMax}, {n: numberLongMax}, {n: 1.0}],
- expectedRes: [{_id: null, o: doubleClosestToLongMax * 2}]
- },
- {
- testDesc: "An int and a NaN double",
- inputs: [{n: NumberInt(4)}, {n: NaN}],
- expectedRes: [{_id: null, o: NaN}]
- },
- {
- testDesc: "Null values are ignored",
- inputs: [{n: NumberInt(5)}, {n: null}],
- expectedRes: [{_id: null, o: NumberInt(5)}]
- },
- {
- testDesc: "Missing values are ignored",
- inputs: [{n: NumberInt(9)}, {}],
- expectedRes: [{_id: null, o: NumberInt(9)}]
- }].forEach(({testDesc, inputs, expectedRes}) => {
- hashShardedColl.drop();
- assert.commandWorked(hashShardedColl.insert(inputs));
-
- verifyShardedSumResultsOnBothEngine(
- testDesc, hashShardedColl, [{$group: {_id: null, o: {$sum: "$n"}}}], expectedRes);
- });
-
- st.stop();
-}());
-}());
diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h
index 463001b665a..114ccc260ce 100644
--- a/src/mongo/db/pipeline/accumulator.h
+++ b/src/mongo/db/pipeline/accumulator.h
@@ -382,7 +382,8 @@ private:
*/
Decimal128 _getDecimalTotal() const;
- bool _isDecimal;
+ BSONType _totalType = NumberInt;
+ BSONType _nonDecimalTotalType = NumberInt;
DoubleDoubleSummation _nonDecimalTotal;
Decimal128 _decimalTotal;
long long _count;
diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp
index daddd56f502..a12b0da4361 100644
--- a/src/mongo/db/pipeline/accumulator_avg.cpp
+++ b/src/mongo/db/pipeline/accumulator_avg.cpp
@@ -29,6 +29,7 @@
#include "mongo/platform/basic.h"
+#include "mongo/db/exec/sbe/accumulator_sum_value_enum.h"
#include "mongo/db/pipeline/accumulator.h"
#include "mongo/db/exec/document_value/document.h"
@@ -49,33 +50,73 @@ REGISTER_STABLE_EXPRESSION(avg, ExpressionFromAccumulator<AccumulatorAvg>::parse
REGISTER_REMOVABLE_WINDOW_FUNCTION(avg, AccumulatorAvg, WindowFunctionAvg);
namespace {
+// TODO SERVER-64227 Remove 'subTotal' and 'subTotalError' fields when we branch for 6.1 because all
+// nodes in a sharded cluster would use the new data format.
const char subTotalName[] = "subTotal";
const char subTotalErrorName[] = "subTotalError"; // Used for extra precision
+const char partialSumName[] = "ps"; // Used for the full state of partial sum
const char countName[] = "count";
} // namespace
+void applyPartialSum(const std::vector<Value>& arr,
+ BSONType& nonDecimalTotalType,
+ BSONType& totalType,
+ DoubleDoubleSummation& nonDecimalTotal,
+ Decimal128& decimalTotal);
+
+Value serializePartialSum(BSONType nonDecimalTotalType,
+ BSONType totalType,
+ const DoubleDoubleSummation& nonDecimalTotal,
+ const Decimal128& decimalTotal);
+
void AccumulatorAvg::processInternal(const Value& input, bool merging) {
if (merging) {
// We expect an object that contains both a subtotal and a count. Additionally there may
// be an error value, that allows for additional precision.
// 'input' is what getValue(true) produced below.
verify(input.getType() == Object);
- // We're recursively adding the subtotal to get the proper type treatment, but this only
- // increments the count by one, so adjust the count afterwards. Similarly for 'error'.
- processInternal(input[subTotalName], false);
- _count += input[countName].getLong() - 1;
- Value error = input[subTotalErrorName];
- if (!error.missing()) {
- processInternal(error, false);
- _count--; // The error correction only adjusts the total, not the number of items.
+
+ // TODO SERVER-64227 Remove 'if' block when we branch for 6.1 because all nodes in a sharded
+ // cluster would use the new data format.
+ if (auto partialSumVal = input[partialSumName]; partialSumVal.missing()) {
+ // We're recursively adding the subtotal to get the proper type treatment, but this only
+ // increments the count by one, so adjust the count afterwards. Similarly for 'error'.
+ processInternal(input[subTotalName], false);
+ _count += input[countName].getLong() - 1;
+ Value error = input[subTotalErrorName];
+ if (!error.missing()) {
+ processInternal(error, false);
+ _count--; // The error correction only adjusts the total, not the number of items.
+ }
+ } else {
+ // The merge-side must be ready to process the full state of a partial sum from a
+ // shard-side if a shard chooses to do so. See Accumulator::getValue() for details.
+ applyPartialSum(partialSumVal.getArray(),
+ _nonDecimalTotalType,
+ _totalType,
+ _nonDecimalTotal,
+ _decimalTotal);
+ _count += input[countName].getLong();
}
+
return;
}
+ if (!input.numeric()) {
+ return;
+ }
+
+ _totalType = Value::getWidestNumeric(_totalType, input.getType());
+
+ // Keep the nonDecimalTotal's type so that the type information can be serialized too for
+ // 'toBeMerged' scenarios.
+ if (input.getType() != NumberDecimal) {
+ _nonDecimalTotalType = Value::getWidestNumeric(_nonDecimalTotalType, input.getType());
+ }
+
switch (input.getType()) {
case NumberDecimal:
_decimalTotal = _decimalTotal.add(input.getDecimal());
- _isDecimal = true;
break;
case NumberLong:
// Avoid summation using double as that loses precision.
@@ -88,8 +129,7 @@ void AccumulatorAvg::processInternal(const Value& input, bool merging) {
_nonDecimalTotal.addDouble(input.getDouble());
break;
default:
- dassert(!input.numeric());
- return;
+ MONGO_UNREACHABLE;
}
_count++;
}
@@ -104,32 +144,39 @@ Decimal128 AccumulatorAvg::_getDecimalTotal() const {
Value AccumulatorAvg::getValue(bool toBeMerged) {
if (toBeMerged) {
- if (_isDecimal)
- return Value(Document{{subTotalName, _getDecimalTotal()}, {countName, _count}});
+ auto partialSumVal =
+ serializePartialSum(_nonDecimalTotalType, _totalType, _nonDecimalTotal, _decimalTotal);
+ if (_totalType == NumberDecimal) {
+ return Value(Document{{subTotalName, _getDecimalTotal()},
+ {countName, _count},
+ {partialSumName, partialSumVal}});
+ }
- double total, error;
- std::tie(total, error) = _nonDecimalTotal.getDoubleDouble();
- return Value(
- Document{{subTotalName, total}, {countName, _count}, {subTotalErrorName, error}});
+ auto [total, error] = _nonDecimalTotal.getDoubleDouble();
+ return Value(Document{{subTotalName, total},
+ {countName, _count},
+ {subTotalErrorName, error},
+ {partialSumName, partialSumVal}});
}
if (_count == 0)
return Value(BSONNULL);
- if (_isDecimal)
+ if (_totalType == NumberDecimal)
return Value(_getDecimalTotal().divide(Decimal128(static_cast<int64_t>(_count))));
return Value(_nonDecimalTotal.getDouble() / static_cast<double>(_count));
}
AccumulatorAvg::AccumulatorAvg(ExpressionContext* const expCtx)
- : AccumulatorState(expCtx), _isDecimal(false), _count(0) {
+ : AccumulatorState(expCtx), _count(0) {
// This is a fixed size AccumulatorState so we never need to update this
_memUsageBytes = sizeof(*this);
}
void AccumulatorAvg::reset() {
- _isDecimal = false;
+ _totalType = NumberInt;
+ _nonDecimalTotalType = NumberInt;
_nonDecimalTotal = {};
_decimalTotal = {};
_count = 0;
diff --git a/src/mongo/db/pipeline/accumulator_sum.cpp b/src/mongo/db/pipeline/accumulator_sum.cpp
index 158cf727453..303f791d9ed 100644
--- a/src/mongo/db/pipeline/accumulator_sum.cpp
+++ b/src/mongo/db/pipeline/accumulator_sum.cpp
@@ -61,6 +61,46 @@ const char subTotalName[] = "subTotal";
const char subTotalErrorName[] = "subTotalError"; // Used for extra precision.
} // namespace
+void applyPartialSum(const std::vector<Value>& arr,
+ BSONType& nonDecimalTotalType,
+ BSONType& totalType,
+ DoubleDoubleSummation& nonDecimalTotal,
+ Decimal128& decimalTotal) {
+ tassert(6294002,
+ "The partial sum's first element must be an int",
+ arr[AggSumValueElems::kNonDecimalTotalTag].getType() == NumberInt);
+ nonDecimalTotalType = Value::getWidestNumeric(
+ nonDecimalTotalType,
+ static_cast<BSONType>(arr[AggSumValueElems::kNonDecimalTotalTag].getInt()));
+ totalType = Value::getWidestNumeric(totalType, nonDecimalTotalType);
+
+ tassert(6294003,
+ "The partial sum's second element must be a double",
+ arr[AggSumValueElems::kNonDecimalTotalSum].getType() == NumberDouble);
+ tassert(6294004,
+ "The partial sum's third element must be a double",
+ arr[AggSumValueElems::kNonDecimalTotalAddend].getType() == NumberDouble);
+
+ auto sum = arr[AggSumValueElems::kNonDecimalTotalSum].getDouble();
+ auto addend = arr[AggSumValueElems::kNonDecimalTotalAddend].getDouble();
+ nonDecimalTotal.addDouble(sum);
+
+ // If sum is +=INF and addend is +=NAN, 'nonDecimalTotal' becomes NAN after adding
+ // INF and NAN, which is different from the unsharded behavior. So, does not add
+ // 'addend' when sum == INF and addend == NAN. Does not add this logic to
+ // 'DoubleDoubleSummation' because this behavior is specific to sharded $sum.
+ if (std::isfinite(sum) || !std::isnan(addend)) {
+ nonDecimalTotal.addDouble(addend);
+ }
+
+ if (arr.size() == AggSumValueElems::kMaxSizeOfArray) {
+ totalType = NumberDecimal;
+ tassert(6294005,
+ "The partial sum's last element must be a decimal",
+ arr[AggSumValueElems::kDecimalTotal].getType() == NumberDecimal);
+ decimalTotal = decimalTotal.add(arr[AggSumValueElems::kDecimalTotal].getDecimal());
+ }
+}
void AccumulatorSum::processInternal(const Value& input, bool merging) {
if (!input.numeric()) {
@@ -81,44 +121,13 @@ void AccumulatorSum::processInternal(const Value& input, bool merging) {
break;
// The merge-side must be ready to process the full state of a partial sum from a
// shard-side if a shard chooses to do so. See Accumulator::getValue() for details.
- case Array: {
- auto&& arr = input.getArray();
- tassert(6294002,
- "The partial sum's first element must be an int",
- arr[AggSumValueElems::kNonDecimalTotalTag].getType() == NumberInt);
- nonDecimalTotalType = Value::getWidestNumeric(
- nonDecimalTotalType,
- static_cast<BSONType>(arr[AggSumValueElems::kNonDecimalTotalTag].getInt()));
- totalType = Value::getWidestNumeric(totalType, nonDecimalTotalType);
-
- tassert(6294003,
- "The partial sum's second element must be a double",
- arr[AggSumValueElems::kNonDecimalTotalSum].getType() == NumberDouble);
- tassert(6294004,
- "The partial sum's third element must be a double",
- arr[AggSumValueElems::kNonDecimalTotalAddend].getType() == NumberDouble);
-
- auto sum = arr[AggSumValueElems::kNonDecimalTotalSum].getDouble();
- auto addend = arr[AggSumValueElems::kNonDecimalTotalAddend].getDouble();
- nonDecimalTotal.addDouble(sum);
- // If sum is +=INF and addend is +=NAN, 'nonDecimalTotal' becomes NAN after adding
- // INF and NAN, which is different from the unsharded behavior. So, does not add
- // 'addend' when sum == INF and addend == NAN. Does not add this logic to
- // 'DoubleDoubleSummation' because this behavior is specific to sharded $sum.
- if (std::isfinite(sum) || !std::isnan(addend)) {
- nonDecimalTotal.addDouble(addend);
- }
-
- if (arr.size() == AggSumValueElems::kMaxSizeOfArray) {
- totalType = NumberDecimal;
- tassert(6294005,
- "The partial sum's last element must be a decimal",
- arr[AggSumValueElems::kDecimalTotal].getType() == NumberDecimal);
- decimalTotal =
- decimalTotal.add(arr[AggSumValueElems::kDecimalTotal].getDecimal());
- }
+ case Array:
+ applyPartialSum(input.getArray(),
+ nonDecimalTotalType,
+ totalType,
+ nonDecimalTotal,
+ decimalTotal);
break;
- }
default:
MONGO_UNREACHABLE;
}
@@ -130,8 +139,8 @@ void AccumulatorSum::processInternal(const Value& input, bool merging) {
// Keep the nonDecimalTotal's type so that the type information can be serialized too for
// 'toBeMerged' scenarios.
- if (totalType != NumberDecimal) {
- nonDecimalTotalType = totalType;
+ if (input.getType() != NumberDecimal) {
+ nonDecimalTotalType = Value::getWidestNumeric(nonDecimalTotalType, input.getType());
}
switch (input.getType()) {
case NumberLong:
@@ -155,6 +164,27 @@ intrusive_ptr<AccumulatorState> AccumulatorSum::create(ExpressionContext* const
return new AccumulatorSum(expCtx);
}
+Value serializePartialSum(BSONType nonDecimalTotalType,
+ BSONType totalType,
+ const DoubleDoubleSummation& nonDecimalTotal,
+ const Decimal128& decimalTotal) {
+ auto [sum, addend] = nonDecimalTotal.getDoubleDouble();
+
+ // The partial sum is serialized in the following form.
+ //
+ // [nonDecimalTotalType, sum, addend, decimalTotal]
+ //
+ // Presence of the 'decimalTotal' element indicates that the total type of the partial sum
+ // is 'NumberDecimal'.
+ auto valueArrayStream = ValueArrayStream();
+ valueArrayStream << static_cast<int>(nonDecimalTotalType) << sum << addend;
+ if (totalType == NumberDecimal) {
+ valueArrayStream << decimalTotal;
+ }
+
+ return valueArrayStream.done();
+}
+
Value AccumulatorSum::getValue(bool toBeMerged) {
// Serialize the full state of the partial sum result to avoid incorrect results for certain
// data set which are composed of 'NumberDecimal' values which cancel each other when being
@@ -177,20 +207,7 @@ Value AccumulatorSum::getValue(bool toBeMerged) {
auto canUseNewPartialResultFormat = fcv.isVersionInitialized() &&
fcv.isGreaterThanOrEqualTo(multiversion::FeatureCompatibilityVersion::kVersion_6_0);
if (canUseNewPartialResultFormat && toBeMerged) {
- auto [sum, addend] = nonDecimalTotal.getDoubleDouble();
-
- // The partial sum is serialized in the following form.
- //
- // [nonDecimalTotalType, sum, addend, decimalTotal]
- //
- // Presence of the 'decimalTotal' element indicates that the total type of the partial sum
- // is 'NumberDecimal'.
- auto valueArrayStream = ValueArrayStream();
- valueArrayStream << static_cast<int>(nonDecimalTotalType) << sum << addend;
- if (totalType == NumberDecimal) {
- valueArrayStream << decimalTotal;
- }
- return valueArrayStream.done();
+ return serializePartialSum(nonDecimalTotalType, totalType, nonDecimalTotal, decimalTotal);
}
switch (totalType) {
diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator.cpp b/src/mongo/db/query/sbe_stage_builder_accumulator.cpp
index 1622ec873ac..e9dbc9762a1 100644
--- a/src/mongo/db/query/sbe_stage_builder_accumulator.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_accumulator.cpp
@@ -185,10 +185,12 @@ std::pair<std::unique_ptr<sbe::EExpression>, EvalStage> buildFinalizeAvg(
// To support the sharding behavior, the mongos splits $group into two separate $group
// stages one at the mongos-side and the other at the shard-side. This stage builder builds
// the shard-side plan. The shard-side $avg accumulator is responsible to return the partial
- // avg in the form of {subTotal: val1, count: val2} when the type of sum is decimal or
- // {subTotal: val1, count: val2, subTotalError: val3} when the type of sum is non-decimal.
+ // avg in the form of {subTotal: val1, count: val2, ps: array_val} when the type of sum is
+ // decimal or {subTotal: val1, count: val2, subTotalError: val3, ps: array_val} when the
+ // type of sum is non-decimal.
auto sumResult = makeVariable(aggSlots[0]);
auto countResult = makeVariable(aggSlots[1]);
+ auto partialSumExpr = makeFunction("doubleDoublePartialSumFinalize", sumResult->clone());
// Existence of 'kDecimalTotal' element in the sum result means the type of sum result is
// decimal.
@@ -198,14 +200,18 @@ std::pair<std::unique_ptr<sbe::EExpression>, EvalStage> buildFinalizeAvg(
sumResult->clone(),
makeConstant(sbe::value::TypeTags::NumberInt32,
static_cast<int>(AggSumValueElems::kDecimalTotal))));
- // Returns {subTotal: val1, count: val2} if the type of the sum result is decimal.
+ // Returns {subTotal: val1, count: val2, ps: array_val} if the type of the sum result is
+ // decimal.
+ // TODO SERVER-64227 Remove 'subTotal' and 'subTotalError' fields when we branch for 6.1
+ // because all nodes in a sharded cluster would use the new data format.
auto thenExpr = makeNewObjFunction(
FieldPair{"subTotal"_sd,
// 'doubleDoubleSumFinalize' returns the sum, adding decimal
// sum and non-decimal sum.
makeFunction("doubleDoubleSumFinalize", sumResult->clone())},
- FieldPair{"count"_sd, countResult->clone()});
- // Returns {subTotal: val1, count: val2: subTotalError: val3} otherwise.
+ FieldPair{"count"_sd, countResult->clone()},
+ FieldPair{"ps"_sd, partialSumExpr->clone()});
+ // Returns {subTotal: val1, count: val2: subTotalError: val3, ps: array_val} otherwise.
auto elseExpr = makeNewObjFunction(
FieldPair{"subTotal"_sd,
makeFunction(
@@ -219,7 +225,8 @@ std::pair<std::unique_ptr<sbe::EExpression>, EvalStage> buildFinalizeAvg(
sumResult->clone(),
makeConstant(sbe::value::TypeTags::NumberInt32,
static_cast<int>(
- AggSumValueElems::kNonDecimalTotalAddend)))});
+ AggSumValueElems::kNonDecimalTotalAddend)))},
+ FieldPair{"ps"_sd, partialSumExpr->clone()});
auto partialAvgFinalize =
sbe::makeE<sbe::EIf>(std::move(ifCondExpr), std::move(thenExpr), std::move(elseExpr));