summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Cohan <james.cohan@10gen.com>2015-07-01 10:17:35 -0400
committerJames Cohan <james.cohan@10gen.com>2015-07-01 11:45:39 -0400
commit4f4e36b69bfd6ea8615e2c0786a4d1dbca502a74 (patch)
tree45eb0b659079dde1a47743d671ca78832d37fe1a
parent9a743c6fb3c143cc45ea89842243e480fdf229a3 (diff)
downloadmongo-4f4e36b69bfd6ea8615e2c0786a4d1dbca502a74.tar.gz
SERVER-8568 Add $sqrt aggregation expression
-rw-r--r--jstests/aggregation/bugs/server8568.js43
-rw-r--r--src/mongo/db/pipeline/expression.cpp26
-rw-r--r--src/mongo/db/pipeline/expression.h4
3 files changed, 71 insertions, 2 deletions
diff --git a/jstests/aggregation/bugs/server8568.js b/jstests/aggregation/bugs/server8568.js
new file mode 100644
index 00000000000..a03472ad4fa
--- /dev/null
+++ b/jstests/aggregation/bugs/server8568.js
@@ -0,0 +1,43 @@
+// SERVER-8568: Adding $sqrt expression
+
+// For assertErrorCode.
+load('jstests/aggregation/extras/utils.js');
+
+(function() {
+ 'use strict';
+ var coll = db.sqrt;
+ coll.drop();
+ assert.writeOK(coll.insert({_id: 0}));
+
+ // Helper for testing that op returns expResult.
+ function testOp(op, expResult) {
+ var pipeline = [{$project: {_id: 0, result: op}}];
+ assert.eq(coll.aggregate(pipeline).toArray(), [{result: expResult}]);
+ }
+
+ // Helper for testing that op results in error with code errorCode.
+ function testError(op, errorCode) {
+ var pipeline = [{$project: {_id: 0, result: op}}];
+ assertErrorCode(coll, pipeline, errorCode);
+ }
+
+ // Valid input: Numeric arg >= 0, null, or NaN.
+
+ testOp({$sqrt: [100]}, 10);
+ testOp({$sqrt: [0]}, 0);
+ // All types converted to doubles.
+ testOp({$sqrt: [NumberLong("100")]}, 10);
+ // LLONG_MAX is converted to a double.
+ testOp({$sqrt: [NumberLong("9223372036854775807")]}, 3037000499.97605);
+ // Null inputs result in null.
+ testOp({$sqrt: [null]}, null);
+ // NaN inputs result in NaN.
+ testOp({$sqrt: [NaN]}, NaN);
+
+ // Invalid input: non-numeric/non-null, arg is negative.
+
+ // Arg must be numeric or null.
+ testError({$sqrt: ["string"]}, 28715);
+ // Args cannot be negative.
+ testError({$sqrt: [-1]}, 28714);
+}()); \ No newline at end of file
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index 8db17f3152a..2ce8b92cbbb 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -1321,7 +1321,6 @@ void ExpressionObject::addToDocument(MutableDocument& out,
if (dynamic_cast<ExpressionObject*>(it->second.get()) && pValue.getDocument().empty())
continue;
-
out.addField(fieldName, pValue);
}
}
@@ -2407,7 +2406,6 @@ intrusive_ptr<Expression> ExpressionSetIsSubset::optimize() {
return new Optimized(arrayToSet(rhs), vpOperand);
}
-
return optimized;
}
@@ -2470,6 +2468,30 @@ const char* ExpressionSize::getOpName() const {
return "$size";
}
+/* ----------------------- ExpressionSqrt ---------------------------- */
+
+Value ExpressionSqrt::evaluateInternal(Variables* vars) const {
+ Value argVal = vpOperand[0]->evaluateInternal(vars);
+ if (argVal.nullish())
+ return Value(BSONNULL);
+
+ uassert(28715,
+ str::stream() << "$sqrt only supports numeric types, not "
+ << typeName(argVal.getType()),
+ argVal.numeric());
+
+ double argDouble = argVal.coerceToDouble();
+ uassert(28714,
+ "$sqrt's argument must be greater than or equal to 0",
+ argDouble >= 0 || std::isnan(argDouble));
+ return Value(sqrt(argDouble));
+}
+
+REGISTER_EXPRESSION("$sqrt", ExpressionSqrt::parse);
+const char* ExpressionSqrt::getOpName() const {
+ return "$sqrt";
+}
+
/* ----------------------- ExpressionStrcasecmp ---------------------------- */
Value ExpressionStrcasecmp::evaluateInternal(Variables* vars) const {
diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h
index 3650efc6655..330487b4342 100644
--- a/src/mongo/db/pipeline/expression.h
+++ b/src/mongo/db/pipeline/expression.h
@@ -989,6 +989,10 @@ public:
const char* getOpName() const final;
};
+class ExpressionSqrt final : public ExpressionFixedArity<ExpressionSqrt, 1> {
+ Value evaluateInternal(Variables* vars) const final;
+ const char* getOpName() const final;
+};
class ExpressionStrcasecmp final : public ExpressionFixedArity<ExpressionStrcasecmp, 2> {
public: