summaryrefslogtreecommitdiff
path: root/src/mongo/db/query/fle/server_rewrite_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/query/fle/server_rewrite_test.cpp')
-rw-r--r--src/mongo/db/query/fle/server_rewrite_test.cpp307
1 files changed, 300 insertions, 7 deletions
diff --git a/src/mongo/db/query/fle/server_rewrite_test.cpp b/src/mongo/db/query/fle/server_rewrite_test.cpp
index cb81656dcb6..034de8f0aa9 100644
--- a/src/mongo/db/query/fle/server_rewrite_test.cpp
+++ b/src/mongo/db/query/fle/server_rewrite_test.cpp
@@ -31,7 +31,9 @@
#include <memory>
#include "mongo/bson/bsonelement.h"
+#include "mongo/bson/bsonmisc.h"
#include "mongo/bson/bsonobjbuilder.h"
+#include "mongo/bson/bsontypes.h"
#include "mongo/db/matcher/expression_leaf.h"
#include "mongo/db/pipeline/expression_context_for_test.h"
#include "mongo/db/query/fle/server_rewrite.h"
@@ -42,9 +44,19 @@
namespace mongo {
namespace {
-class MockFLEQueryRewriter : public fle::FLEQueryRewriter {
+class BasicMockFLEQueryRewriter : public fle::FLEQueryRewriter {
public:
- MockFLEQueryRewriter() : fle::FLEQueryRewriter(new ExpressionContextForTest()), _tags() {}
+ BasicMockFLEQueryRewriter() : fle::FLEQueryRewriter(new ExpressionContextForTest()) {}
+
+ BSONObj rewriteMatchExpressionForTest(const BSONObj& obj) {
+ auto res = rewriteMatchExpression(obj);
+ return res ? res.get() : obj;
+ }
+};
+
+class MockFLEQueryRewriter : public BasicMockFLEQueryRewriter {
+public:
+ MockFLEQueryRewriter() : _tags() {}
bool isFleFindPayload(const BSONElement& fleFindPayload) const override {
return _encryptedFields.find(fleFindPayload.fieldNameStringData()) !=
@@ -56,11 +68,6 @@ public:
_tags[fieldvalue] = tags;
}
- BSONObj rewriteMatchExpressionForTest(const BSONObj& obj) {
- auto res = rewriteMatchExpression(obj);
- return res ? res.get() : obj;
- }
-
private:
BSONObj rewritePayloadAsTags(BSONElement fleFindPayload) const override {
ASSERT(fleFindPayload.isNumber()); // Only accept numbers as mock FFPs.
@@ -72,6 +79,7 @@ private:
std::map<std::pair<StringData, int>, BSONObj> _tags;
std::set<StringData> _encryptedFields;
};
+
class FLEServerRewriteTest : public unittest::Test {
public:
FLEServerRewriteTest() {}
@@ -361,5 +369,290 @@ TEST_F(FLEServerRewriteTest, ComparisonToObjectIgnored) {
}
}
+template <typename T>
+std::vector<uint8_t> toEncryptedVector(EncryptedBinDataType dt, T t) {
+ BSONObj obj = t.toBSON();
+
+ std::vector<uint8_t> buf(obj.objsize() + 1);
+ buf[0] = static_cast<uint8_t>(dt);
+
+ std::copy(obj.objdata(), obj.objdata() + obj.objsize(), buf.data() + 1);
+
+ return buf;
+}
+
+template <typename T>
+void toEncryptedBinData(StringData field, EncryptedBinDataType dt, T t, BSONObjBuilder* builder) {
+ auto buf = toEncryptedVector(dt, t);
+
+ builder->appendBinData(field, buf.size(), BinDataType::Encrypt, buf.data());
+}
+
+constexpr auto kIndexKeyId = "12345678-1234-9876-1234-123456789012"_sd;
+constexpr auto kUserKeyId = "ABCDEFAB-1234-9876-1234-123456789012"_sd;
+static UUID indexKeyId = uassertStatusOK(UUID::parse(kIndexKeyId.toString()));
+static UUID userKeyId = uassertStatusOK(UUID::parse(kUserKeyId.toString()));
+
+std::vector<char> testValue = {0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19};
+std::vector<char> testValue2 = {0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29};
+
+const FLEIndexKey& getIndexKey() {
+ static std::string indexVec = hexblob::decode(
+ "7dbfebc619aa68a659f64b8e23ccd21644ac326cb74a26840c3d2420176c40ae088294d00ad6cae9684237b21b754cf503f085c25cd320bf035c3417416e1e6fe3d9219f79586582112740b2add88e1030d91926ae8afc13ee575cfb8bb965b7"_sd);
+ static FLEIndexKey indexKey(KeyMaterial(indexVec.begin(), indexVec.end()));
+ return indexKey;
+}
+
+const FLEUserKey& getUserKey() {
+ static std::string userVec = hexblob::decode(
+ "a7ddbc4c8be00d51f68d9d8e485f351c8edc8d2206b24d8e0e1816d005fbe520e489125047d647b0d8684bfbdbf09c304085ed086aba6c2b2b1677ccc91ced8847a733bf5e5682c84b3ee7969e4a5fe0e0c21e5e3ee190595a55f83147d8de2a"_sd);
+ static FLEUserKey userKey(KeyMaterial(userVec.begin(), userVec.end()));
+ return userKey;
+}
+
+
+BSONObj generateFFP(StringData path, int value) {
+ auto indexKey = getIndexKey();
+ FLEIndexKeyAndId indexKeyAndId(indexKey.data, indexKeyId);
+ auto userKey = getUserKey();
+ FLEUserKeyAndId userKeyAndId(userKey.data, indexKeyId);
+
+ BSONObj doc = BSON("value" << value);
+ auto element = doc.firstElement();
+ auto fpp = FLEClientCrypto::serializeFindPayload(indexKeyAndId, userKeyAndId, element, 0);
+
+ BSONObjBuilder builder;
+ toEncryptedBinData(path, EncryptedBinDataType::kFLE2FindEqualityPayload, fpp, &builder);
+ return builder.obj();
+}
+
+class FLEServerHighCardRewriteTest : public unittest::Test {
+public:
+ FLEServerHighCardRewriteTest() {}
+
+ void setUp() override {}
+
+ void tearDown() override {}
+
+protected:
+ BasicMockFLEQueryRewriter _mock;
+};
+
+
+TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Equality) {
+ _mock.setForceHighCardinalityForTest();
+
+ auto match = generateFFP("ssn", 1);
+ auto expected = fromjson(R"({
+ "$expr": {
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }
+ }
+})");
+
+ auto actual = _mock.rewriteMatchExpressionForTest(match);
+ ASSERT_BSONOBJ_EQ(actual, expected);
+}
+
+
+TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_In) {
+ _mock.setForceHighCardinalityForTest();
+
+ auto ffp1 = generateFFP("ssn", 1);
+ auto ffp2 = generateFFP("ssn", 2);
+ auto ffp3 = generateFFP("ssn", 3);
+ auto expected = fromjson(R"({
+ "$or": [
+ {
+ "$expr": {
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }
+ }
+ },
+ {
+ "$expr": {
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CLpCo6rNuYMVT+6n1HCX15MNrVYDNqf6udO46ayo43Sw",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }
+ }
+ },
+ {
+ "$expr": {
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CPi44oCQHnNDeRqHsNLzbdCeHt2DK/wCly0g2dxU5fqN",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }
+ }
+ }
+ ]
+})");
+
+ auto match =
+ BSON("ssn" << BSON("$in" << BSON_ARRAY(ffp1.firstElement()
+ << ffp2.firstElement() << ffp3.firstElement())));
+
+ auto actual = _mock.rewriteMatchExpressionForTest(match);
+ ASSERT_BSONOBJ_EQ(actual, expected);
+}
+
+
+TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Expr) {
+
+ _mock.setForceHighCardinalityForTest();
+
+ auto ffp = generateFFP("$ssn", 1);
+ int len;
+ auto v = ffp.firstElement().binDataClean(len);
+ auto match = BSON("$expr" << BSON("$eq" << BSON_ARRAY(ffp.firstElement().fieldName()
+ << BSONBinData(v, len, Encrypt))));
+
+ auto expected = fromjson(R"({ "$expr": {
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }
+ }
+ })");
+
+ auto actual = _mock.rewriteMatchExpressionForTest(match);
+ ASSERT_BSONOBJ_EQ(actual, expected);
+}
+
+TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Expr_In) {
+
+ _mock.setForceHighCardinalityForTest();
+
+ auto ffp = generateFFP("$ssn", 1);
+ int len;
+ auto v = ffp.firstElement().binDataClean(len);
+
+ auto ffp2 = generateFFP("$ssn", 1);
+ int len2;
+ auto v2 = ffp2.firstElement().binDataClean(len2);
+
+ auto match = BSON(
+ "$expr" << BSON("$in" << BSON_ARRAY(ffp.firstElement().fieldName()
+ << BSON_ARRAY(BSONBinData(v, len, Encrypt)
+ << BSONBinData(v2, len2, Encrypt)))));
+
+ auto expected = fromjson(R"({ "$expr": { "$or" : [ {
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }},
+ {
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }}
+ ]}})");
+
+ auto actual = _mock.rewriteMatchExpressionForTest(match);
+ ASSERT_BSONOBJ_EQ(actual, expected);
+}
+
} // namespace
} // namespace mongo