diff options
author | Davis Haupt <davis.haupt@mongodb.com> | 2022-12-05 22:25:27 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2022-12-05 23:49:52 +0000 |
commit | 1040c746da55b24f02f3f9715e91c57aeeccc8e6 (patch) | |
tree | 293a10b847748e9d16fc8e97a8f3b6b443464ff5 | |
parent | dca926506ac48399abba8a764328c43d08b010a6 (diff) | |
download | mongo-1040c746da55b24f02f3f9715e91c57aeeccc8e6.tar.gz |
SERVER-25823 add bitwise AND, OR, XOR to the agg language
-rw-r--r-- | jstests/aggregation/expressions/bitwise.js | 49 | ||||
-rw-r--r-- | src/mongo/db/exec/document_value/value.cpp | 21 | ||||
-rw-r--r-- | src/mongo/db/exec/document_value/value.h | 4 | ||||
-rw-r--r-- | src/mongo/db/pipeline/abt/agg_expression_visitor.cpp | 11 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 5 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.h | 146 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression_dependencies.cpp | 3 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression_test.cpp | 146 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression_visitor.h | 9 | ||||
-rw-r--r-- | src/mongo/db/query/cqf_command_utils.cpp | 11 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_expression.cpp | 17 | ||||
-rw-r--r-- | src/mongo/util/safe_num.h | 2 |
12 files changed, 415 insertions, 9 deletions
diff --git a/jstests/aggregation/expressions/bitwise.js b/jstests/aggregation/expressions/bitwise.js index e006977044d..ece467c145b 100644 --- a/jstests/aggregation/expressions/bitwise.js +++ b/jstests/aggregation/expressions/bitwise.js @@ -11,12 +11,53 @@ const coll = db[collName]; coll.drop(); assert.commandWorked(coll.insert([ - {_id: 0, a: NumberInt(0), b: NumberInt(127)}, - {_id: 1, a: NumberInt(1), b: NumberInt(2)}, - {_id: 2, a: NumberInt(2), b: NumberInt(3)}, - {_id: 3, a: NumberInt(3), b: NumberInt(5)}, + {_id: 0, a: NumberInt(0), b: NumberInt(127), c: [NumberInt(0), NumberInt(127)]}, + {_id: 1, a: NumberInt(1), b: NumberInt(2), c: [NumberInt(1), NumberInt(2)]}, + {_id: 2, a: NumberInt(2), b: NumberInt(3), c: [NumberInt(2), NumberInt(3)]}, + {_id: 3, a: NumberInt(3), b: NumberInt(5), c: [NumberInt(3), NumberInt(5)]}, ])); +function runAndAssert(expression, expectedResult) { + assertArrayEq({ + actual: coll.aggregate([{$project: {r: {[expression]: ["$a", "$b"]}}}]) + .toArray() + .map(doc => doc.r), + expected: expectedResult + }); +} + +runAndAssert("$bitAnd", [0, 0, 2, 1]); +runAndAssert("$bitOr", [127, 3, 3, 7]); +runAndAssert("$bitXor", [127, 3, 1, 6]); + +for (const operator of ["$bitAnd", "$bitOr", "$bitXor"]) { + for (const operand + of [Number(12.0), NumberDecimal("12"), "$c", ["$c"], [[NumberInt(1), NumberInt(2)]]]) { + assert.commandFailedWithCode(coll.runCommand({ + aggregate: collName, + cursor: {}, + pipeline: [{ + $project: { + r: {[operator]: ["$a", operand]}, + } + }] + }), + ErrorCodes.TypeMismatch); + } + for (const argument of ["$c", ["$c"], [[NumberInt(1), NumberInt(2)]]]) { + assert.commandFailedWithCode(coll.runCommand({ + aggregate: collName, + cursor: {}, + pipeline: [{ + $project: { + r: {[operator]: argument}, + } + }] + }), + ErrorCodes.TypeMismatch); + } +} + assertArrayEq({ actual: coll.aggregate([ {$project: {r: {$bitNot: "$a"}}}, diff --git a/src/mongo/db/exec/document_value/value.cpp b/src/mongo/db/exec/document_value/value.cpp index 97573281dbb..ccb612032f1 100644 --- a/src/mongo/db/exec/document_value/value.cpp +++ b/src/mongo/db/exec/document_value/value.cpp @@ -285,6 +285,27 @@ Value::Value(const vector<Document>& vec) : _storage(Array) { _storage.putVector(std::move(storageVec)); } +Value::Value(const SafeNum& value) : _storage(value.type()) { + switch (value.type()) { + case EOO: + break; + case NumberInt: + _storage.intValue = value._value.int32Val; + break; + case NumberLong: + _storage.longValue = value._value.int64Val; + break; + case NumberDouble: + _storage.doubleValue = value._value.doubleVal; + break; + case NumberDecimal: + _storage.putDecimal(Decimal128(value._value.decimalVal)); + break; + default: + MONGO_UNREACHABLE; + } +} + Value Value::createIntOrLong(long long longValue) { int intValue = longValue; if (intValue != longValue) { diff --git a/src/mongo/db/exec/document_value/value.h b/src/mongo/db/exec/document_value/value.h index 7da4b971c34..ab2e2c02934 100644 --- a/src/mongo/db/exec/document_value/value.h +++ b/src/mongo/db/exec/document_value/value.h @@ -33,6 +33,7 @@ #include "mongo/base/string_data.h" #include "mongo/db/exec/document_value/value_internal.h" #include "mongo/util/concepts.h" +#include "mongo/util/safe_num.h" #include "mongo/util/uuid.h" namespace mongo { @@ -146,6 +147,9 @@ public: /// Deep-convert from BSONElement to Value explicit Value(const BSONElement& elem); + /// Create a value from a SafeNum. + explicit Value(const SafeNum& value); + /** Construct a long or integer-valued Value. * diff --git a/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp b/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp index b27230c2abf..35f5c6291c8 100644 --- a/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp +++ b/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp @@ -86,11 +86,18 @@ public: void visit(const ExpressionArrayElemAt* expr) override final { unsupportedExpression(expr->getOpName()); } - + void visit(const ExpressionBitAnd* expr) override final { + unsupportedExpression("bitAnd"); + } + void visit(const ExpressionBitOr* expr) override final { + unsupportedExpression("bitOr"); + } + void visit(const ExpressionBitXor* expr) override final { + unsupportedExpression("bitXor"); + } void visit(const ExpressionBitNot* expr) override final { unsupportedExpression(expr->getOpName()); } - void visit(const ExpressionFirst* expr) override final { unsupportedExpression(expr->getOpName()); } diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index c773c25073e..0579f25da80 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -8070,6 +8070,11 @@ const char* ExpressionBitNot::getOpName() const { return "$bitNot"; } +/* ------------------------- $bitAnd, $bitOr, and $bitXor ------------------------ */ + +REGISTER_STABLE_EXPRESSION(bitAnd, ExpressionBitAnd::parse); +REGISTER_STABLE_EXPRESSION(bitOr, ExpressionBitOr::parse); +REGISTER_STABLE_EXPRESSION(bitXor, ExpressionBitXor::parse); MONGO_INITIALIZER_GROUP(BeginExpressionRegistration, ("default"), ("EndExpressionRegistration")) MONGO_INITIALIZER_GROUP(EndExpressionRegistration, ("BeginExpressionRegistration"), ()) diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 697d4ac0597..4411d9d5d40 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -4350,6 +4350,151 @@ public: } }; +template <typename SubClass> +class ExpressionBitwise : public ExpressionVariadic<SubClass> { +public: + explicit ExpressionBitwise(ExpressionContext* const expCtx) + : ExpressionVariadic<SubClass>(expCtx) {} + + ExpressionBitwise(ExpressionContext* const expCtx, Expression::ExpressionVector&& children) + : ExpressionVariadic<SubClass>(expCtx, std::move(children)) {} + + ExpressionNary::Associativity getAssociativity() const final { + return ExpressionNary::Associativity::kFull; + } + + bool isCommutative() const final { + return true; + } + + Value evaluate(const Document& root, Variables* variables) const final { + auto result = this->getIdentity(); + for (auto&& child : this->_children) { + Value val = child->evaluate(root, variables); + if (val.nullish()) { + return Value(BSONNULL); + } + auto valNum = uassertStatusOK(safeNumFromValue(val)); + result = doOperation(result, valNum); + } + return Value(result); + } + +private: + StatusWith<SafeNum> safeNumFromValue(const Value& val) const { + switch (val.getType()) { + case NumberInt: + return val.getInt(); + case NumberLong: + return (int64_t)val.getLong(); + default: + return Status(ErrorCodes::TypeMismatch, + str::stream() + << this->getOpName() << " only supports int and long operands."); + } + } + + virtual SafeNum doOperation(const SafeNum& a, const SafeNum& b) const = 0; + virtual SafeNum getIdentity() const = 0; +}; + +class ExpressionBitAnd final : public ExpressionBitwise<ExpressionBitAnd> { +public: + SafeNum doOperation(const SafeNum& a, const SafeNum& b) const final { + return a.bitAnd(b); + } + + SafeNum getIdentity() const final { + return -1; // In two's complement, this is all 1's. + } + + const char* getOpName() const final { + return "$bitAnd"; + }; + + explicit ExpressionBitAnd(ExpressionContext* const expCtx) + : ExpressionBitwise<ExpressionBitAnd>(expCtx) { + expCtx->sbeCompatible = false; + } + + ExpressionBitAnd(ExpressionContext* const expCtx, ExpressionVector&& children) + : ExpressionBitwise<ExpressionBitAnd>(expCtx, std::move(children)) { + expCtx->sbeCompatible = false; + } + + void acceptVisitor(ExpressionMutableVisitor* visitor) final { + return visitor->visit(this); + } + + void acceptVisitor(ExpressionConstVisitor* visitor) const final { + return visitor->visit(this); + } +}; + +class ExpressionBitOr final : public ExpressionBitwise<ExpressionBitOr> { +public: + SafeNum doOperation(const SafeNum& a, const SafeNum& b) const final { + return a.bitOr(b); + } + + SafeNum getIdentity() const final { + return 0; + } + + const char* getOpName() const final { + return "$bitOr"; + }; + + explicit ExpressionBitOr(ExpressionContext* const expCtx) + : ExpressionBitwise<ExpressionBitOr>(expCtx) { + expCtx->sbeCompatible = false; + } + + ExpressionBitOr(ExpressionContext* const expCtx, ExpressionVector&& children) + : ExpressionBitwise<ExpressionBitOr>(expCtx, std::move(children)) { + expCtx->sbeCompatible = false; + } + void acceptVisitor(ExpressionMutableVisitor* visitor) final { + return visitor->visit(this); + } + + void acceptVisitor(ExpressionConstVisitor* visitor) const final { + return visitor->visit(this); + } +}; + +class ExpressionBitXor final : public ExpressionBitwise<ExpressionBitXor> { +public: + SafeNum doOperation(const SafeNum& a, const SafeNum& b) const final { + return a.bitXor(b); + } + + SafeNum getIdentity() const final { + return 0; + } + + const char* getOpName() const final { + return "$bitXor"; + }; + + explicit ExpressionBitXor(ExpressionContext* const expCtx) + : ExpressionBitwise<ExpressionBitXor>(expCtx) { + expCtx->sbeCompatible = false; + } + + ExpressionBitXor(ExpressionContext* const expCtx, ExpressionVector&& children) + : ExpressionBitwise<ExpressionBitXor>(expCtx, std::move(children)) { + expCtx->sbeCompatible = false; + } + + void acceptVisitor(ExpressionMutableVisitor* visitor) final { + return visitor->visit(this); + } + + void acceptVisitor(ExpressionConstVisitor* visitor) const final { + return visitor->visit(this); + } +}; class ExpressionBitNot final : public ExpressionSingleNumericArg<ExpressionBitNot> { public: explicit ExpressionBitNot(ExpressionContext* const expCtx) @@ -4372,5 +4517,4 @@ public: return visitor->visit(this); } }; - } // namespace mongo diff --git a/src/mongo/db/pipeline/expression_dependencies.cpp b/src/mongo/db/pipeline/expression_dependencies.cpp index 402aa655ba8..6bddc5cbb3f 100644 --- a/src/mongo/db/pipeline/expression_dependencies.cpp +++ b/src/mongo/db/pipeline/expression_dependencies.cpp @@ -56,6 +56,9 @@ public: void visit(const ExpressionAnyElementTrue*) {} void visit(const ExpressionArray*) {} void visit(const ExpressionArrayElemAt*) {} + void visit(const ExpressionBitAnd*) {} + void visit(const ExpressionBitOr*) {} + void visit(const ExpressionBitXor*) {} void visit(const ExpressionBitNot*) {} void visit(const ExpressionFirst*) {} void visit(const ExpressionLast*) {} diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index f053aaecd5b..245f8d7e9ac 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -4280,6 +4280,152 @@ TEST(ExpressionFLETest, ParseAndSerializeBetween) { } })"); ASSERT_BSONOBJ_EQ(value.getDocument().toBson(), roundTripExpr); } +TEST(ExpressionBitAndTest, BitAndCorrectness) { + assertExpectedResults("$bitAnd", + { + // Explicit correctness cases. + {{0b0, 0b0}, 0b0}, + {{0b0, 0b1}, 0b0}, + {{0b1, 0b0}, 0b0}, + {{0b1, 0b1}, 0b1}, + + {{0b00, 0b00}, 0b00}, + {{0b00, 0b01}, 0b00}, + {{0b01, 0b00}, 0b00}, + {{0b01, 0b01}, 0b01}, + + {{0b00, 0b00}, 0b00}, + {{0b00, 0b11}, 0b00}, + {{0b11, 0b00}, 0b00}, + {{0b11, 0b11}, 0b11}, + }); +} + +TEST(ExpressionBitAndTest, BitAndInt) { + assertExpectedResults( + "$bitAnd", + { + // Empty operand list should evaluate to the identity for the operation. + {{}, -1}, + // Singleton cases. + {{0}, 0}, + {{256}, 256}, + // Binary cases + {{5, 2}, 5 & 2}, + {{255, 0}, 255 & 0}, + // Ternary cases + {{5, 2, 10}, 5 & 2 & 10}, + }); +} + +TEST(ExpressionBitAndTest, BitAndLong) { + assertExpectedResults("$bitAnd", + { + // Singleton cases. + {{0LL}, 0LL}, + {{1LL << 40}, 1LL << 40}, + {{256LL}, 256LL}, + // Binary cases. + {{5LL, 2LL}, 5LL & 2LL}, + {{255LL, 0LL}, 255LL & 0LL}, + // Ternary cases. + {{5, 2, 10}, 5 & 2 & 10}, + }); +} + +TEST(ExpressionBitAndTest, BitAndMixedTypes) { + // Any NumberLong widens the resulting type to NumberLong. + assertExpectedResults("$bitAnd", + { + // Binary cases + {{5LL, 2}, 5LL & 2}, + {{5, 2LL}, 5 & 2LL}, + {{255LL, 0}, 255LL & 0}, + {{255, 0LL}, 255 & 0LL}, + }); +} + +TEST(ExpressionBitOrTest, BitOrInt) { + assertExpectedResults("$bitOr", + { + {{}, 0}, + // Singleton cases. + {{0}, 0}, + {{256}, 256}, + // Binary cases + {{5, 2}, 5 | 2}, + {{255, 0}, 255 | 0}, + // Ternary cases + {{5, 2, 10}, 5 | 2 | 10}, + }); +} + +TEST(ExpressionBitOrTest, BitOrLong) { + assertExpectedResults("$bitOr", + { + // Singleton cases. + {{0LL}, 0LL}, + {{256LL}, 256LL}, + // Binary cases. + {{5LL, 2LL}, 5LL | 2LL}, + {{255LL, 0LL}, 255LL | 0LL}, + // Ternary cases. + {{5, 2, 10}, 5 | 2 | 10}, + }); +} + +TEST(ExpressionBitOrTest, BitOrMixedTypes) { + // Any NumberLong widens the resulting type to NumberLong. + assertExpectedResults("$bitOr", + { + // Binary cases + {{5LL, 2}, 5LL | 2}, + {{5, 2LL}, 5 | 2LL}, + {{255LL, 0}, 255LL | 0}, + {{255, 0LL}, 255 | 0LL}, + }); +} + +TEST(ExpressionBitXorTest, BitXorInt) { + assertExpectedResults("$bitXor", + { + {{}, 0}, + // Singleton cases. + {{0}, 0}, + {{256}, 256}, + // Binary cases + {{5, 2}, 5 ^ 2}, + {{255, 0}, 255 ^ 0}, + // Ternary cases + {{5, 2, 10}, 5 ^ 2 ^ 10}, + }); +} + +TEST(ExpressionBitXorTest, BitXorLong) { + assertExpectedResults("$bitXor", + { + // Singleton cases. + {{0LL}, 0LL}, + {{256LL}, 256LL}, + // Binary cases. + {{5LL, 2LL}, 5LL ^ 2LL}, + {{255LL, 0LL}, 255LL ^ 0LL}, + // Ternary cases. + {{5, 2, 10}, 5 ^ 2 ^ 10}, + }); +} + +TEST(ExpressionBitXorTest, BitXorMixedTypes) { + // Any NumberLong widens the resulting type to NumberLong. + assertExpectedResults("$bitXor", + { + // Binary cases + {{5LL, 2}, 5LL ^ 2}, + {{5, 2LL}, 5 ^ 2LL}, + {{255LL, 0}, 255LL ^ 0}, + {{255, 0LL}, 255 ^ 0LL}, + }); +} TEST(ExpressionBitNotTest, Int) { int min = numeric_limits<int>::min(); diff --git a/src/mongo/db/pipeline/expression_visitor.h b/src/mongo/db/pipeline/expression_visitor.h index 80e1d5a7467..63ae8568cff 100644 --- a/src/mongo/db/pipeline/expression_visitor.h +++ b/src/mongo/db/pipeline/expression_visitor.h @@ -168,6 +168,9 @@ class ExpressionDateSubtract; class ExpressionDateTrunc; class ExpressionGetField; class ExpressionSetField; +class ExpressionBitAnd; +class ExpressionBitOr; +class ExpressionBitXor; class AccumulatorAvg; class AccumulatorFirstN; @@ -218,6 +221,9 @@ public: virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionTestApiVersion>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionArray>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionArrayElemAt>) = 0; + virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionBitAnd>) = 0; + virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionBitOr>) = 0; + virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionBitXor>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionBitNot>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionFirst>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionLast>) = 0; @@ -402,6 +408,9 @@ struct SelectiveConstExpressionVisitorBase : public ExpressionConstVisitor { void visit(const ExpressionAnyElementTrue*) override {} void visit(const ExpressionArray*) override {} void visit(const ExpressionArrayElemAt*) override {} + void visit(const ExpressionBitAnd*) override {} + void visit(const ExpressionBitOr*) override {} + void visit(const ExpressionBitXor*) override {} void visit(const ExpressionBitNot*) override {} void visit(const ExpressionFirst*) override {} void visit(const ExpressionLast*) override {} diff --git a/src/mongo/db/query/cqf_command_utils.cpp b/src/mongo/db/query/cqf_command_utils.cpp index 7eac72c4333..37c2d7b617d 100644 --- a/src/mongo/db/query/cqf_command_utils.cpp +++ b/src/mongo/db/query/cqf_command_utils.cpp @@ -398,6 +398,17 @@ public: unsupportedExpression(); } + void visit(const ExpressionBitAnd* expr) override final { + unsupportedExpression(); + } + + void visit(const ExpressionBitOr* expr) override final { + unsupportedExpression(); + } + + void visit(const ExpressionBitXor* expr) override final { + unsupportedExpression(); + } void visit(const ExpressionBitNot* expr) override final { unsupportedExpression(); } diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index 0fffd085892..6aa0711e4da 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -329,6 +329,9 @@ public: void visit(const ExpressionAnyElementTrue* expr) final {} void visit(const ExpressionArray* expr) final {} void visit(const ExpressionArrayElemAt* expr) final {} + void visit(const ExpressionBitAnd* expr) final {} + void visit(const ExpressionBitOr* expr) final {} + void visit(const ExpressionBitXor* expr) final {} void visit(const ExpressionBitNot* expr) final {} void visit(const ExpressionFirst* expr) final {} void visit(const ExpressionLast* expr) final {} @@ -507,6 +510,9 @@ public: void visit(const ExpressionAnyElementTrue* expr) final {} void visit(const ExpressionArray* expr) final {} void visit(const ExpressionArrayElemAt* expr) final {} + void visit(const ExpressionBitAnd* expr) final {} + void visit(const ExpressionBitOr* expr) final {} + void visit(const ExpressionBitXor* expr) final {} void visit(const ExpressionBitNot* expr) final {} void visit(const ExpressionFirst* expr) final {} void visit(const ExpressionLast* expr) final {} @@ -951,11 +957,18 @@ public: sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(arrayElemAtExpr)), std::move(stage)); } - + void visit(const ExpressionBitAnd* expr) final { + unsupportedExpression(expr->getOpName()); + } + void visit(const ExpressionBitOr* expr) final { + unsupportedExpression(expr->getOpName()); + } + void visit(const ExpressionBitXor* expr) final { + unsupportedExpression(expr->getOpName()); + } void visit(const ExpressionBitNot* expr) final { unsupportedExpression(expr->getOpName()); } - void visit(const ExpressionFirst* expr) final { buildArrayAccessByConstantIndex(_context, expr->getOpName(), 0); } diff --git a/src/mongo/util/safe_num.h b/src/mongo/util/safe_num.h index 7f16cd036f3..19d281dfc05 100644 --- a/src/mongo/util/safe_num.h +++ b/src/mongo/util/safe_num.h @@ -41,6 +41,7 @@ namespace mutablebson { class Element; class Document; } // namespace mutablebson +class Value; /** * SafeNum holds and does arithmetic on a number in a safe way, handling overflow @@ -156,6 +157,7 @@ public: friend class mutablebson::Element; friend class mutablebson::Document; + friend class Value; /** * Appends contents to given BSONObjBuilder. |