summaryrefslogtreecommitdiff
path: root/src/mongo/shell/encrypted_dbclient_base.cpp
diff options
context:
space:
mode:
authorMark Benvenuto <mark.benvenuto@mongodb.com>2019-06-18 00:19:33 -0400
committerMark Benvenuto <mark.benvenuto@mongodb.com>2019-06-18 00:19:33 -0400
commitc436b8090417baf847143e97f5d221285b1898e1 (patch)
treebb5369a793c8ec6646b20e92b1a0f85b7979682f /src/mongo/shell/encrypted_dbclient_base.cpp
parent5eda33f9fa40a1a17f9f63f904a8c147700d648c (diff)
downloadmongo-c436b8090417baf847143e97f5d221285b1898e1.tar.gz
SERVER-41644 Expose explicit encryption helpers in community shell
Diffstat (limited to 'src/mongo/shell/encrypted_dbclient_base.cpp')
-rw-r--r--src/mongo/shell/encrypted_dbclient_base.cpp685
1 files changed, 685 insertions, 0 deletions
diff --git a/src/mongo/shell/encrypted_dbclient_base.cpp b/src/mongo/shell/encrypted_dbclient_base.cpp
new file mode 100644
index 00000000000..34ae4a11b4a
--- /dev/null
+++ b/src/mongo/shell/encrypted_dbclient_base.cpp
@@ -0,0 +1,685 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/shell/encrypted_dbclient_base.h"
+
+#include "mongo/base/data_cursor.h"
+#include "mongo/base/data_type_validated.h"
+#include "mongo/bson/bson_depth.h"
+#include "mongo/client/dbclient_base.h"
+#include "mongo/crypto/aead_encryption.h"
+#include "mongo/crypto/symmetric_crypto.h"
+#include "mongo/db/client.h"
+#include "mongo/db/commands.h"
+#include "mongo/db/matcher/schema/encrypt_schema_gen.h"
+#include "mongo/db/namespace_string.h"
+#include "mongo/rpc/object_check.h"
+#include "mongo/rpc/op_msg_rpc_impls.h"
+#include "mongo/scripting/mozjs/bindata.h"
+#include "mongo/scripting/mozjs/implscope.h"
+#include "mongo/scripting/mozjs/maxkey.h"
+#include "mongo/scripting/mozjs/minkey.h"
+#include "mongo/scripting/mozjs/mongo.h"
+#include "mongo/scripting/mozjs/objectwrapper.h"
+#include "mongo/scripting/mozjs/valuereader.h"
+#include "mongo/scripting/mozjs/valuewriter.h"
+#include "mongo/shell/encrypted_shell_options.h"
+#include "mongo/shell/kms.h"
+#include "mongo/shell/kms_gen.h"
+#include "mongo/shell/shell_options.h"
+#include "mongo/util/lru_cache.h"
+
+namespace mongo {
+
+EncryptedShellGlobalParams encryptedShellGlobalParams;
+
+namespace {
+constexpr Duration kCacheInvalidationTime = Minutes(1);
+
+
+ImplicitEncryptedDBClientCallback* implicitEncryptedDBClientCallback{nullptr};
+
+
+} // namespace
+
+void setImplicitEncryptedDBClientCallback(ImplicitEncryptedDBClientCallback* callback) {
+ implicitEncryptedDBClientCallback = callback;
+}
+
+static void validateCollection(JSContext* cx, JS::HandleValue value) {
+ uassert(ErrorCodes::BadValue,
+ "Collection object must be provided to ClientSideFLEOptions",
+ !(value.isNull() || value.isUndefined()));
+
+ JS::RootedValue coll(cx, value);
+
+ uassert(31043,
+ "The collection object in ClientSideFLEOptions is invalid",
+ mozjs::getScope(cx)->getProto<mozjs::DBCollectionInfo>().instanceOf(coll));
+}
+
+EncryptedDBClientBase::EncryptedDBClientBase(std::unique_ptr<DBClientBase> conn,
+ ClientSideFLEOptions encryptionOptions,
+ JS::HandleValue collection,
+ JSContext* cx)
+ : _conn(std::move(conn)), _encryptionOptions(std::move(encryptionOptions)), _cx(cx) {
+ validateCollection(cx, collection);
+ _collection = JS::Heap<JS::Value>(collection);
+ uassert(31078,
+ "Cannot use WriteMode Legacy with Field Level Encryption",
+ shellGlobalParams.writeMode != "legacy");
+};
+
+std::string EncryptedDBClientBase::getServerAddress() const {
+ return _conn->getServerAddress();
+}
+
+bool EncryptedDBClientBase::call(Message& toSend,
+ Message& response,
+ bool assertOk,
+ std::string* actualServer) {
+ return _conn->call(toSend, response, assertOk, actualServer);
+}
+
+void EncryptedDBClientBase::say(Message& toSend, bool isRetry, std::string* actualServer) {
+ MONGO_UNREACHABLE;
+}
+
+bool EncryptedDBClientBase::lazySupported() const {
+ return _conn->lazySupported();
+}
+
+std::pair<rpc::UniqueReply, DBClientBase*> EncryptedDBClientBase::runCommandWithTarget(
+ OpMsgRequest request) {
+ return _conn->runCommandWithTarget(std::move(request));
+}
+
+
+/**
+ *
+ * This function reads the data from the CDR and returns a copy
+ * constructed and owned BSONObject.
+ *
+ */
+BSONObj EncryptedDBClientBase::validateBSONElement(ConstDataRange out, uint8_t bsonType) {
+ if (bsonType == BSONType::Object) {
+ ConstDataRangeCursor cdc = ConstDataRangeCursor(out);
+ BSONObj valueObj;
+
+ valueObj = cdc.readAndAdvance<Validated<BSONObj>>();
+ return valueObj.getOwned();
+ } else {
+ auto valueString = "value"_sd;
+
+ // The size here is to construct a new BSON document and validate the
+ // total size of the object. The first four bytes is for the size of an
+ // int32_t, then a space for the type of the first element, then the space
+ // for the value string and the the 0x00 terminated field name, then the
+ // size of the actual data, then the last byte for the end document character,
+ // also 0x00.
+ size_t docLength = sizeof(int32_t) + 1 + valueString.size() + 1 + out.length() + 1;
+ BufBuilder builder;
+ builder.reserveBytes(docLength);
+
+ uassert(ErrorCodes::BadValue,
+ "invalid decryption value",
+ docLength < std::numeric_limits<int32_t>::max());
+
+ builder.appendNum(static_cast<uint32_t>(docLength));
+ builder.appendChar(static_cast<uint8_t>(bsonType));
+ builder.appendStr(valueString, true);
+ builder.appendBuf(out.data(), out.length());
+ builder.appendChar('\0');
+
+ ConstDataRangeCursor cdc =
+ ConstDataRangeCursor(ConstDataRange(builder.buf(), builder.len()));
+ BSONObj elemWrapped = cdc.readAndAdvance<Validated<BSONObj>>();
+ return elemWrapped.getOwned();
+ }
+}
+
+std::string EncryptedDBClientBase::toString() const {
+ return _conn->toString();
+}
+
+int EncryptedDBClientBase::getMinWireVersion() {
+ return _conn->getMinWireVersion();
+}
+
+int EncryptedDBClientBase::getMaxWireVersion() {
+ return _conn->getMaxWireVersion();
+}
+
+void EncryptedDBClientBase::generateDataKey(JSContext* cx, JS::CallArgs args) {
+ if (args.length() != 2) {
+ uasserted(ErrorCodes::BadValue, "generateDataKey requires 2 arg");
+ }
+
+ if (!args.get(0).isString()) {
+ uasserted(ErrorCodes::BadValue, "1st param to generateDataKey has to be a string");
+ }
+
+ if (!args.get(1).isString()) {
+ uasserted(ErrorCodes::BadValue, "2nd param to generateDataKey has to be a string");
+ }
+
+ std::string kmsProvider = mozjs::ValueWriter(cx, args.get(0)).toString();
+ std::string clientMasterKey = mozjs::ValueWriter(cx, args.get(1)).toString();
+
+ std::unique_ptr<KMSService> kmsService = KMSServiceController::createFromClient(
+ kmsProvider, _encryptionOptions.getKmsProviders().toBSON());
+
+ SecureVector<uint8_t> dataKey(crypto::kFieldLevelEncryptionKeySize);
+ auto res = crypto::engineRandBytes(dataKey->data(), dataKey->size());
+ uassert(31042, "Error generating data key: " + res.codeString(), res.isOK());
+
+ BSONObj obj = kmsService->encryptDataKey(ConstDataRange(dataKey->data(), dataKey->size()),
+ clientMasterKey);
+
+ mozjs::ValueReader(cx, args.rval()).fromBSON(obj, nullptr, false);
+}
+
+void EncryptedDBClientBase::getDataKeyCollection(JSContext* cx, JS::CallArgs args) {
+ if (args.length() != 0) {
+ uasserted(ErrorCodes::BadValue, "getDataKeyCollection does not take any params");
+ }
+ args.rval().set(_collection.get());
+}
+
+void EncryptedDBClientBase::encrypt(mozjs::MozJSImplScope* scope,
+ JSContext* cx,
+ JS::CallArgs args) {
+ // Input Validation
+ uassert(ErrorCodes::BadValue, "encrypt requires 3 args", args.length() == 3);
+
+ if (!(args.get(1).isObject() || args.get(1).isString() || args.get(1).isNumber() ||
+ args.get(1).isBoolean())) {
+ uasserted(ErrorCodes::BadValue,
+ "Second parameter must be an object, string, number, or bool");
+ }
+
+ uassert(ErrorCodes::BadValue, "Third parameter must be a string", args.get(2).isString());
+ auto algorithmStr = mozjs::ValueWriter(cx, args.get(2)).toString();
+ FleAlgorithmInt algorithm;
+
+ if (StringData(algorithmStr) == FleAlgorithm_serializer(FleAlgorithmEnum::kRandom)) {
+ algorithm = FleAlgorithmInt::kRandom;
+ } else if (StringData(algorithmStr) ==
+ FleAlgorithm_serializer(FleAlgorithmEnum::kDeterministic)) {
+ algorithm = FleAlgorithmInt::kDeterministic;
+ } else {
+ uasserted(ErrorCodes::BadValue, "Third parameter must be the FLE Algorithm type");
+ }
+
+ // Extract the UUID from the callArgs
+ auto binData = getBinDataArg(scope, cx, args, 0, BinDataType::newUUID);
+ UUID uuid = UUID::fromCDR(ConstDataRange(binData.data(), binData.size()));
+ BSONType bsonType = BSONType::EOO;
+
+ BufBuilder plaintext;
+ if (args.get(1).isObject()) {
+ JS::RootedObject rootedObj(cx, &args.get(1).toObject());
+ auto jsclass = JS_GetClass(rootedObj);
+
+ if (strcmp(jsclass->name, "Object") == 0 || strcmp(jsclass->name, "Array") == 0) {
+ uassert(ErrorCodes::BadValue,
+ "Cannot deterministically encrypt object or array types.",
+ algorithm != FleAlgorithmInt::kDeterministic);
+
+ // If it is a JS Object, then we can extract all the information by simply calling
+ // ValueWriter.toBSON and setting the type bit, which is what is happening below.
+ BSONObj valueObj = mozjs::ValueWriter(cx, args.get(1)).toBSON();
+ plaintext.appendBuf(valueObj.objdata(), valueObj.objsize());
+ if (strcmp(jsclass->name, "Array") == 0) {
+ bsonType = BSONType::Array;
+ } else {
+ bsonType = BSONType::Object;
+ }
+
+ } else if (scope->getProto<mozjs::MinKeyInfo>().getJSClass() == jsclass ||
+ scope->getProto<mozjs::MaxKeyInfo>().getJSClass() == jsclass ||
+ scope->getProto<mozjs::DBRefInfo>().getJSClass() == jsclass) {
+ uasserted(ErrorCodes::BadValue, "Second parameter cannot be MinKey, MaxKey, or DBRef");
+ } else {
+ if (scope->getProto<mozjs::NumberDecimalInfo>().getJSClass() == jsclass) {
+ uassert(ErrorCodes::BadValue,
+ "Cannot deterministically encrypt NumberDecimal type objects.",
+ algorithm != FleAlgorithmInt::kDeterministic);
+ }
+
+ if (scope->getProto<mozjs::CodeInfo>().getJSClass() == jsclass) {
+ uassert(ErrorCodes::BadValue,
+ "Cannot deterministically encrypt Code type objects.",
+ algorithm != FleAlgorithmInt::kDeterministic);
+ }
+
+ // If it is one of our Mongo defined types, then we have to use the ValueWriter
+ // writeThis function, which takes in a set of WriteFieldRecursionFrames (setting
+ // a limit on how many times we can recursively dig into an object's nested
+ // structure)
+ // and writes the value out to a BSONObjBuilder. We can then extract that
+ // information
+ // from the object by building it and pulling out the first element, which is the
+ // object we are trying to get.
+ mozjs::ObjectWrapper::WriteFieldRecursionFrames frames;
+ frames.emplace(cx, rootedObj.get(), nullptr, StringData{});
+ BSONObjBuilder builder;
+ mozjs::ValueWriter(cx, args.get(1)).writeThis(&builder, "value"_sd, &frames);
+
+ BSONObj object = builder.obj();
+ auto elem = object.getField("value"_sd);
+
+ plaintext.appendBuf(elem.value(), elem.valuesize());
+ bsonType = elem.type();
+ }
+
+ } else if (args.get(1).isString()) {
+ std::string valueStr = mozjs::ValueWriter(cx, args.get(1)).toString();
+ if (valueStr.size() + 1 > std::numeric_limits<uint32_t>::max()) {
+ uasserted(ErrorCodes::BadValue, "Plaintext string to encrypt too long.");
+ }
+
+ plaintext.appendNum(static_cast<uint32_t>(valueStr.size() + 1));
+ plaintext.appendStr(valueStr, true);
+ bsonType = BSONType::String;
+
+ } else if (args.get(1).isNumber()) {
+ uassert(ErrorCodes::BadValue,
+ "Cannot deterministically encrypt Floating Point numbers.",
+ algorithm != FleAlgorithmInt::kDeterministic);
+
+ double valueNum = mozjs::ValueWriter(cx, args.get(1)).toNumber();
+ plaintext.appendNum(valueNum);
+ bsonType = BSONType::NumberDouble;
+ } else if (args.get(1).isBoolean()) {
+ uassert(ErrorCodes::BadValue,
+ "Cannot deterministically encrypt booleans.",
+ algorithm != FleAlgorithmInt::kDeterministic);
+
+ bool boolean = mozjs::ValueWriter(cx, args.get(1)).toBoolean();
+ if (boolean) {
+ plaintext.appendChar(0x01);
+ } else {
+ plaintext.appendChar(0x00);
+ }
+ bsonType = BSONType::Bool;
+ } else {
+ uasserted(ErrorCodes::BadValue, "Cannot encrypt valuetype provided.");
+ }
+ ConstDataRange plaintextRange(plaintext.buf(), plaintext.len());
+
+ auto key = getDataKey(uuid);
+ std::vector<uint8_t> fleBlob =
+ encryptWithKey(uuid, key, plaintextRange, bsonType, FleAlgorithmInt_serializer(algorithm));
+
+ // Prepare the return value
+ std::string blobStr = base64::encode(reinterpret_cast<char*>(fleBlob.data()), fleBlob.size());
+ JS::AutoValueArray<2> arr(cx);
+
+ arr[0].setInt32(BinDataType::Encrypt);
+ mozjs::ValueReader(cx, arr[1]).fromStringData(blobStr);
+ scope->getProto<mozjs::BinDataInfo>().newInstance(arr, args.rval());
+}
+
+void EncryptedDBClientBase::decrypt(mozjs::MozJSImplScope* scope,
+ JSContext* cx,
+ JS::CallArgs args) {
+ uassert(ErrorCodes::BadValue, "decrypt requires one argument", args.length() == 1);
+ uassert(ErrorCodes::BadValue,
+ "decrypt argument must be a BinData subtype Encrypt object",
+ args.get(0).isObject());
+
+ if (!scope->getProto<mozjs::BinDataInfo>().instanceOf(args.get(0))) {
+ uasserted(ErrorCodes::BadValue,
+ "decrypt argument must be a BinData subtype Encrypt object");
+ }
+
+ JS::RootedObject obj(cx, &args.get(0).get().toObject());
+ std::vector<uint8_t> binData = getBinDataArg(scope, cx, args, 0, BinDataType::Encrypt);
+
+ uassert(
+ ErrorCodes::BadValue, "Ciphertext blob too small", binData.size() > kAssociatedDataLength);
+ uassert(ErrorCodes::BadValue,
+ "Ciphertext blob algorithm unknown",
+ (FleAlgorithmInt(binData[0]) == FleAlgorithmInt::kDeterministic ||
+ FleAlgorithmInt(binData[0]) == FleAlgorithmInt::kRandom));
+
+ ConstDataRange uuidCdr = ConstDataRange(&binData[1], UUID::kNumBytes);
+ UUID uuid = UUID::fromCDR(uuidCdr);
+
+ auto key = getDataKey(uuid);
+ std::vector<uint8_t> out(binData.size() - kAssociatedDataLength);
+ size_t outLen = out.size();
+
+ auto decryptStatus = crypto::aeadDecrypt(*key,
+ &binData[kAssociatedDataLength],
+ binData.size() - kAssociatedDataLength,
+ &binData[0],
+ kAssociatedDataLength,
+ out.data(),
+ &outLen);
+ if (!decryptStatus.isOK()) {
+ uasserted(decryptStatus.code(), decryptStatus.reason());
+ }
+
+ uint8_t bsonType = binData[17];
+ BSONObj parent;
+ BSONObj decryptedObj = validateBSONElement(ConstDataRange(out.data(), outLen), bsonType);
+ if (bsonType == BSONType::Object) {
+ mozjs::ValueReader(cx, args.rval()).fromBSON(decryptedObj, &parent, true);
+ } else {
+ mozjs::ValueReader(cx, args.rval())
+ .fromBSONElement(decryptedObj.firstElement(), parent, true);
+ }
+}
+
+void EncryptedDBClientBase::trace(JSTracer* trc) {
+ JS::TraceEdge(trc, &_collection, "collection object");
+}
+
+JS::Value EncryptedDBClientBase::getCollection() const {
+ return _collection.get();
+}
+
+
+std::unique_ptr<DBClientCursor> EncryptedDBClientBase::query(const NamespaceStringOrUUID& nsOrUuid,
+ Query query,
+ int nToReturn,
+ int nToSkip,
+ const BSONObj* fieldsToReturn,
+ int queryOptions,
+ int batchSize) {
+ return _conn->query(
+ nsOrUuid, query, nToReturn, nToSkip, fieldsToReturn, queryOptions, batchSize);
+}
+
+bool EncryptedDBClientBase::isFailed() const {
+ return _conn->isFailed();
+}
+
+bool EncryptedDBClientBase::isStillConnected() {
+ return _conn->isStillConnected();
+}
+
+ConnectionString::ConnectionType EncryptedDBClientBase::type() const {
+ return _conn->type();
+}
+
+double EncryptedDBClientBase::getSoTimeout() const {
+ return _conn->getSoTimeout();
+}
+
+bool EncryptedDBClientBase::isReplicaSetMember() const {
+ return _conn->isReplicaSetMember();
+}
+
+bool EncryptedDBClientBase::isMongos() const {
+ return _conn->isMongos();
+}
+
+NamespaceString EncryptedDBClientBase::getCollectionNS() {
+ JS::RootedValue fullNameRooted(_cx);
+ JS::RootedObject collectionRooted(_cx, &_collection.get().toObject());
+ JS_GetProperty(_cx, collectionRooted, "_fullName", &fullNameRooted);
+ if (!fullNameRooted.isString()) {
+ uasserted(ErrorCodes::BadValue, "Collection object is incomplete.");
+ }
+ std::string fullName = mozjs::ValueWriter(_cx, fullNameRooted).toString();
+ NamespaceString fullNameNS = NamespaceString(fullName);
+ uassert(ErrorCodes::BadValue,
+ str::stream() << "Invalid namespace: " << fullName,
+ fullNameNS.isValid());
+ return fullNameNS;
+}
+
+std::vector<uint8_t> EncryptedDBClientBase::getBinDataArg(
+ mozjs::MozJSImplScope* scope, JSContext* cx, JS::CallArgs args, int index, BinDataType type) {
+ if (!args.get(index).isObject() ||
+ !scope->getProto<mozjs::BinDataInfo>().instanceOf(args.get(index))) {
+ uasserted(ErrorCodes::BadValue, "First parameter must be a BinData object");
+ }
+
+ mozjs::ObjectWrapper o(cx, args.get(index));
+
+ auto binType = BinDataType(static_cast<int>(o.getNumber(mozjs::InternedString::type)));
+ uassert(ErrorCodes::BadValue,
+ str::stream() << "Incorrect bindata type, expected" << typeName(type) << " but got "
+ << typeName(binType),
+ binType == type);
+ auto str = static_cast<std::string*>(JS_GetPrivate(args.get(index).toObjectOrNull()));
+ uassert(ErrorCodes::BadValue, "Cannot call getter on BinData prototype", str);
+ std::string string = base64::decode(*str);
+ return std::vector<uint8_t>(string.data(), string.data() + string.length());
+}
+
+std::shared_ptr<SymmetricKey> EncryptedDBClientBase::getDataKey(const UUID& uuid) {
+ auto ts_new = Date_t::now();
+
+ if (_datakeyCache.hasKey(uuid)) {
+ auto[key, ts] = _datakeyCache.find(uuid)->second;
+ if (ts_new - ts < kCacheInvalidationTime) {
+ return key;
+ } else {
+ _datakeyCache.erase(uuid);
+ }
+ }
+ auto key = getDataKeyFromDisk(uuid);
+ _datakeyCache.add(uuid, std::make_pair(key, ts_new));
+ return key;
+}
+
+std::shared_ptr<SymmetricKey> EncryptedDBClientBase::getDataKeyFromDisk(const UUID& uuid) {
+ NamespaceString fullNameNS = getCollectionNS();
+ BSONObj dataKeyObj = _conn->findOne(fullNameNS.ns(), QUERY("_id" << uuid));
+ if (dataKeyObj.isEmpty()) {
+ uasserted(ErrorCodes::BadValue, "Invalid keyID.");
+ }
+
+ auto keyStoreRecord = KeyStoreRecord::parse(IDLParserErrorContext("root"), dataKeyObj);
+ if (dataKeyObj.hasField("version"_sd)) {
+ uassert(ErrorCodes::BadValue,
+ "Invalid version, must be either 0 or undefined",
+ dataKeyObj.getIntField("version"_sd) == 0);
+ }
+
+ BSONElement elem = dataKeyObj.getField("keyMaterial"_sd);
+ uassert(ErrorCodes::BadValue, "Invalid key.", elem.isBinData(BinDataType::BinDataGeneral));
+ uassert(ErrorCodes::BadValue,
+ "Invalid version, must be either 0 or undefined",
+ keyStoreRecord.getVersion() == 0);
+
+ auto dataKey = keyStoreRecord.getKeyMaterial();
+ uassert(ErrorCodes::BadValue, "Invalid data key.", dataKey.length() != 0);
+
+ std::unique_ptr<KMSService> kmsService = KMSServiceController::createFromDisk(
+ _encryptionOptions.getKmsProviders().toBSON(), keyStoreRecord.getMasterKey());
+ SecureVector<uint8_t> decryptedKey =
+ kmsService->decrypt(dataKey, keyStoreRecord.getMasterKey());
+ return std::make_shared<SymmetricKey>(
+ std::move(decryptedKey), crypto::aesAlgorithm, "kms_encryption");
+}
+
+std::vector<uint8_t> EncryptedDBClientBase::encryptWithKey(UUID uuid,
+ const std::shared_ptr<SymmetricKey>& key,
+ ConstDataRange plaintext,
+ BSONType bsonType,
+ int32_t algorithm) {
+ // As per the description of the encryption algorithm for FLE, the
+ // associated data is constructed of the following -
+ // associatedData[0] = the FleAlgorithmEnum
+ // - either a 1 or a 2 depending on whether the iv is provided.
+ // associatedData[1-16] = the uuid in bytes
+ // associatedData[17] = the bson type
+
+ ConstDataRange uuidCdr = uuid.toCDR();
+ uint64_t outputLength = crypto::aeadCipherOutputLength(plaintext.length());
+ std::vector<uint8_t> outputBuffer(kAssociatedDataLength + outputLength);
+ outputBuffer[0] = static_cast<uint8_t>(algorithm);
+ std::memcpy(&outputBuffer[1], uuidCdr.data(), uuidCdr.length());
+ outputBuffer[17] = static_cast<uint8_t>(bsonType);
+ uassertStatusOK(crypto::aeadEncrypt(*key,
+ reinterpret_cast<const uint8_t*>(plaintext.data()),
+ plaintext.length(),
+ outputBuffer.data(),
+ 18,
+ // The ciphertext starts 18 bytes into the output
+ // buffer, as described above.
+ outputBuffer.data() + 18,
+ outputLength));
+ return outputBuffer;
+}
+
+namespace {
+
+/**
+ * Constructs a collection object from a namespace, passed in to the nsString parameter.
+ * The client is the connection to a database in which you want to create the collection.
+ * The collection parameter gets set to a javascript collection object.
+ */
+void createCollectionObject(JSContext* cx,
+ JS::HandleValue client,
+ StringData nsString,
+ JS::MutableHandleValue collection) {
+ invariant(!client.isNull() && !client.isUndefined());
+
+ auto ns = NamespaceString(nsString);
+ uassert(ErrorCodes::BadValue,
+ "Invalid keystore namespace.",
+ ns.isValid() && NamespaceString::validCollectionName(ns.coll()));
+
+ auto scope = mozjs::getScope(cx);
+
+ // The collection object requires a database object to be constructed as well.
+ JS::RootedValue databaseRV(cx);
+ JS::AutoValueArray<2> databaseArgs(cx);
+
+ databaseArgs[0].setObject(client.toObject());
+ mozjs::ValueReader(cx, databaseArgs[1]).fromStringData(ns.db());
+ scope->getProto<mozjs::DBInfo>().newInstance(databaseArgs, &databaseRV);
+
+ invariant(databaseRV.isObject());
+ auto databaseObj = databaseRV.toObjectOrNull();
+
+ JS::AutoValueArray<4> collectionArgs(cx);
+ collectionArgs[0].setObject(client.toObject());
+ collectionArgs[1].setObject(*databaseObj);
+ mozjs::ValueReader(cx, collectionArgs[2]).fromStringData(ns.coll());
+ mozjs::ValueReader(cx, collectionArgs[3]).fromStringData(ns.ns());
+
+ scope->getProto<mozjs::DBCollectionInfo>().newInstance(collectionArgs, collection);
+}
+
+// The parameters required to start FLE on the shell. The current connection is passed in as a
+// parameter to create the keyvault collection object if one is not provided.
+std::unique_ptr<DBClientBase> createEncryptedDBClientBase(std::unique_ptr<DBClientBase> conn,
+ JS::HandleValue arg,
+ JS::HandleObject mongoConnection,
+ JSContext* cx) {
+
+ uassert(
+ 31038, "Invalid Client Side Encryption parameters.", arg.isObject() || arg.isUndefined());
+
+ static constexpr auto keyVaultClientFieldId = "keyVaultClient";
+
+ if (!arg.isObject() && encryptedShellGlobalParams.awsAccessKeyId.empty()) {
+ return conn;
+ }
+
+ ClientSideFLEOptions encryptionOptions;
+ JS::RootedValue client(cx);
+ JS::RootedValue collection(cx);
+
+ if (!arg.isObject()) {
+ // If arg is not an object, but one of the required encryptedShellGlobalParams
+ // is defined, the user is trying to start an encrypted client with command line
+ // parameters.
+
+ AwsKMS awsKms = AwsKMS(encryptedShellGlobalParams.awsAccessKeyId,
+ encryptedShellGlobalParams.awsSecretAccessKey);
+
+ awsKms.setUrl(StringData(encryptedShellGlobalParams.awsKmsURL));
+
+ awsKms.setSessionToken(StringData(encryptedShellGlobalParams.awsSessionToken));
+
+ KmsProviders kmsProviders;
+ kmsProviders.setAws(awsKms);
+
+ // The mongoConnection object will never be null.
+ // If the encrypted shell is started through command line parameters, then the user must
+ // default to the implicit connection for the keyvault collection.
+ client.setObjectOrNull(mongoConnection.get());
+
+ // Because we cannot add a schemaMap object through the command line, we set the
+ // schemaMap object in ClientSideFLEOptions to be null so we know to always use
+ // remote schemas.
+ encryptionOptions = ClientSideFLEOptions(
+ encryptedShellGlobalParams.keyVaultNamespace, std::move(kmsProviders), BSONObj());
+ } else {
+ uassert(ErrorCodes::BadValue,
+ "Collection object must be passed to Field Level Encryption Options",
+ arg.isObject());
+
+ const BSONObj obj = mozjs::ValueWriter(cx, arg).toBSON();
+ encryptionOptions = encryptionOptions.parse(IDLParserErrorContext("root"), obj);
+
+ // IDL does not perform a deep copy of BSONObjs when parsing, so we must get an
+ // owned copy of the schemaMap.
+ encryptionOptions.setSchemaMap(encryptionOptions.getSchemaMap().getOwned());
+
+ // This logic tries to extract the client from the args. If the connection object is defined
+ // in the ClientSideFLEOptions struct, then the client will extract it and set itself to be
+ // that. Else, the client will default to the implicit connection.
+ JS::RootedObject handleObject(cx, &arg.toObject());
+ JS_GetProperty(cx, handleObject, keyVaultClientFieldId, &client);
+ if (client.isNull() || client.isUndefined()) {
+ client.setObjectOrNull(mongoConnection.get());
+ }
+ }
+
+ createCollectionObject(cx, client, encryptionOptions.getKeyVaultNamespace(), &collection);
+
+ if (implicitEncryptedDBClientCallback != nullptr) {
+ return implicitEncryptedDBClientCallback(
+ std::move(conn), encryptionOptions, collection, cx);
+ }
+
+ std::unique_ptr<EncryptedDBClientBase> base =
+ std::make_unique<EncryptedDBClientBase>(std::move(conn), encryptionOptions, collection, cx);
+ return std::move(base);
+}
+
+MONGO_INITIALIZER(setCallbacksForEncryptedDBClientBase)(InitializerContext*) {
+ mongo::mozjs::setEncryptedDBClientCallback(createEncryptedDBClientBase);
+ return Status::OK();
+}
+
+} // namespace
+} // namespace mongo