summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--buildscripts/resmokeconfig/suites/client_encrypt.yml19
-rw-r--r--etc/evergreen.yml11
-rw-r--r--jstests/client_encrypt/fle_aws_faults.js142
-rw-r--r--jstests/client_encrypt/fle_command_line_encryption.js41
-rw-r--r--jstests/client_encrypt/fle_encrypt_decrypt_shell.js112
-rw-r--r--jstests/client_encrypt/fle_key_faults.js94
-rw-r--r--jstests/client_encrypt/fle_keys.js75
-rw-r--r--jstests/client_encrypt/fle_valid_fle_options.js68
-rw-r--r--jstests/client_encrypt/lib/fle_command_line_explicit_encryption.js84
-rw-r--r--jstests/client_encrypt/lib/kms_http_common.py21
-rw-r--r--jstests/client_encrypt/lib/kms_http_control.py52
-rwxr-xr-xjstests/client_encrypt/lib/kms_http_server.py298
-rw-r--r--jstests/client_encrypt/lib/mock_kms.js161
-rw-r--r--src/mongo/SConscript2
-rw-r--r--src/mongo/crypto/SConscript33
-rw-r--r--src/mongo/crypto/aead_encryption.cpp391
-rw-r--r--src/mongo/crypto/aead_encryption.h94
-rw-r--r--src/mongo/crypto/aead_encryption_test.cpp152
-rw-r--r--src/mongo/crypto/symmetric_crypto.cpp107
-rw-r--r--src/mongo/crypto/symmetric_crypto.h196
-rw-r--r--src/mongo/crypto/symmetric_crypto_apple.cpp183
-rw-r--r--src/mongo/crypto/symmetric_crypto_openssl.cpp255
-rw-r--r--src/mongo/crypto/symmetric_crypto_windows.cpp335
-rw-r--r--src/mongo/crypto/symmetric_key.cpp100
-rw-r--r--src/mongo/crypto/symmetric_key.h146
-rw-r--r--src/mongo/db/storage/storage_engine_lock_file_posix.cpp6
-rw-r--r--src/mongo/db/storage/storage_engine_lock_file_windows.cpp6
-rw-r--r--src/mongo/shell/SConscript74
-rw-r--r--src/mongo/shell/encrypted_dbclient_base.cpp685
-rw-r--r--src/mongo/shell/encrypted_dbclient_base.h171
-rw-r--r--src/mongo/shell/encrypted_shell_options.h45
-rw-r--r--src/mongo/shell/fle_shell_options.idl37
-rw-r--r--src/mongo/shell/keyvault.js106
-rw-r--r--src/mongo/shell/kms.cpp80
-rw-r--r--src/mongo/shell/kms.h135
-rw-r--r--src/mongo/shell/kms.idl164
-rw-r--r--src/mongo/shell/kms_aws.cpp461
-rw-r--r--src/mongo/shell/kms_local.cpp153
-rw-r--r--src/mongo/shell/kms_shell.cpp52
-rw-r--r--src/mongo/shell/kms_test.cpp86
40 files changed, 5422 insertions, 11 deletions
diff --git a/buildscripts/resmokeconfig/suites/client_encrypt.yml b/buildscripts/resmokeconfig/suites/client_encrypt.yml
new file mode 100644
index 00000000000..4a06a017a34
--- /dev/null
+++ b/buildscripts/resmokeconfig/suites/client_encrypt.yml
@@ -0,0 +1,19 @@
+test_kind: js_test
+
+selector:
+ roots:
+ - jstests/client_encrypt/*.js
+
+executor:
+ config:
+ shell_options:
+ nodb: ''
+ readMode: commands
+ ssl: ''
+ tlsAllowInvalidHostnames: ''
+ tlsAllowInvalidCertificates: ''
+ tlsCAFile: jstests/libs/ca.pem
+ tlsCertificateKeyFile: jstests/libs/client.pem
+ hooks:
+ - class: CleanEveryN
+ n: 20
diff --git a/etc/evergreen.yml b/etc/evergreen.yml
index 35798c9f162..91ea7bb198f 100644
--- a/etc/evergreen.yml
+++ b/etc/evergreen.yml
@@ -7239,6 +7239,16 @@ tasks:
resmoke_jobs_max: 1
- <<: *task_template
+ name: client_encrypt
+ tags: ["ssl"]
+ commands:
+ - func: "do setup"
+ - func: "run tests"
+ vars:
+ resmoke_args: --suites=client_encrypt --storageEngine=wiredTiger
+ resmoke_jobs_max: 1
+
+- <<: *task_template
name: fle
tags: ["encrypt"]
commands:
@@ -10700,7 +10710,6 @@ buildvariants:
distros:
- rhel62-large
- name: disk_wiredtiger
- - name: fle
- name: free_monitoring
- name: .jscore .common
- name: .jstestfuzz .common
diff --git a/jstests/client_encrypt/fle_aws_faults.js b/jstests/client_encrypt/fle_aws_faults.js
new file mode 100644
index 00000000000..bee9586ca43
--- /dev/null
+++ b/jstests/client_encrypt/fle_aws_faults.js
@@ -0,0 +1,142 @@
+/**
+ * Verify the AWS KMS implementation can handle a buggy KMS.
+ */
+
+load("jstests/client_encrypt/lib/mock_kms.js");
+load('jstests/ssl/libs/ssl_helpers.js');
+
+(function() {
+ "use strict";
+
+ const x509_options = {sslMode: "requireSSL", sslPEMKeyFile: SERVER_CERT, sslCAFile: CA_CERT};
+
+ const randomAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Random";
+
+ const conn = MongoRunner.runMongod(x509_options);
+ const test = conn.getDB("test");
+ const collection = test.coll;
+
+ function runKMS(mock_kms, func) {
+ mock_kms.start();
+
+ const awsKMS = {
+ accessKeyId: "access",
+ secretAccessKey: "secret",
+ url: mock_kms.getURL(),
+ };
+
+ const clientSideFLEOptions = {
+ kmsProviders: {
+ aws: awsKMS,
+ },
+ keyVaultNamespace: "test.coll",
+ schemaMap: {}
+ };
+
+ const shell = Mongo(conn.host, clientSideFLEOptions);
+ const cleanCacheShell = Mongo(conn.host, clientSideFLEOptions);
+
+ collection.drop();
+
+ func(shell, cleanCacheShell);
+
+ mock_kms.stop();
+ }
+
+ function testBadEncryptResult(fault) {
+ const mock_kms = new MockKMSServer(fault, false);
+
+ runKMS(mock_kms, (shell) => {
+ const keyVault = shell.getKeyVault();
+
+ assert.throws(() => keyVault.createKey(
+ "aws", "arn:aws:kms:us-east-1:fake:fake:fake", ["mongoKey"]));
+ assert.eq(keyVault.getKeys("mongoKey").toArray().length, 0);
+ });
+ }
+
+ testBadEncryptResult(FAULT_ENCRYPT);
+ testBadEncryptResult(FAULT_ENCRYPT_WRONG_FIELDS);
+ testBadEncryptResult(FAULT_ENCRYPT_BAD_BASE64);
+
+ function testBadEncryptError() {
+ const mock_kms = new MockKMSServer(FAULT_ENCRYPT_CORRECT_FORMAT, false);
+
+ runKMS(mock_kms, (shell) => {
+ const keyVault = shell.getKeyVault();
+
+ let error =
+ assert.throws(() => keyVault.createKey(
+ "aws", "arn:aws:kms:us-east-1:fake:fake:fake", ["mongoKey"]));
+ assert.commandFailedWithCode(error, [51224]);
+ assert.eq(
+ error,
+ "Error: AWS KMS failed to encrypt: NotFoundException : Error encrypting message");
+ });
+ }
+
+ testBadEncryptError();
+
+ function testBadDecryptResult(fault) {
+ const mock_kms = new MockKMSServer(fault, false);
+
+ runKMS(mock_kms, (shell) => {
+ const keyVault = shell.getKeyVault();
+ assert.writeOK(
+ keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake", ["mongoKey"]));
+ const keyId = keyVault.getKeys("mongoKey").toArray()[0]._id;
+ const str = "mongo";
+ assert.throws(() => {
+ const encStr = shell.encrypt(keyId, str, randomAlgorithm);
+ });
+ });
+ }
+
+ testBadDecryptResult(FAULT_DECRYPT);
+
+ function testBadDecryptKeyResult(fault) {
+ const mock_kms = new MockKMSServer(fault, true);
+
+ runKMS(mock_kms, (shell, cleanCacheShell) => {
+ const keyVault = shell.getKeyVault();
+
+ assert.writeOK(
+ keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake", ["mongoKey"]));
+ const keyId = keyVault.getKeys("mongoKey").toArray()[0]._id;
+ const str = "mongo";
+ const encStr = shell.encrypt(keyId, str, randomAlgorithm);
+
+ mock_kms.enableFaults();
+
+ assert.throws(() => {
+ var str = cleanCacheShell.decrypt(encStr);
+ });
+
+ });
+ }
+
+ testBadDecryptKeyResult(FAULT_DECRYPT_WRONG_KEY);
+
+ function testBadDecryptError() {
+ const mock_kms = new MockKMSServer(FAULT_DECRYPT_CORRECT_FORMAT, false);
+
+ runKMS(mock_kms, (shell) => {
+ const keyVault = shell.getKeyVault();
+ assert.writeOK(
+ keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake", ["mongoKey"]));
+ const keyId = keyVault.getKeys("mongoKey").toArray()[0]._id;
+ const str = "mongo";
+ let error = assert.throws(() => {
+ const encStr = shell.encrypt(keyId, str, randomAlgorithm);
+ });
+ assert.commandFailedWithCode(error, [51225]);
+ assert.eq(
+ error,
+ "Error: AWS KMS failed to decrypt: NotFoundException : Error decrypting message");
+ });
+ }
+
+ testBadDecryptError();
+
+ MongoRunner.stopMongod(conn);
+}()); \ No newline at end of file
diff --git a/jstests/client_encrypt/fle_command_line_encryption.js b/jstests/client_encrypt/fle_command_line_encryption.js
new file mode 100644
index 00000000000..9113f9f2d74
--- /dev/null
+++ b/jstests/client_encrypt/fle_command_line_encryption.js
@@ -0,0 +1,41 @@
+/*
+ * This file tests an encrypted shell started using command line parameters.
+ *
+ */
+load('jstests/ssl/libs/ssl_helpers.js');
+
+(function() {
+
+ const x509_options = {sslMode: "requireSSL", sslPEMKeyFile: SERVER_CERT, sslCAFile: CA_CERT};
+ const conn = MongoRunner.runMongod(x509_options);
+
+ const shellOpts = [
+ "mongo",
+ "--host",
+ conn.host,
+ "--port",
+ conn.port,
+ "--tls",
+ "--sslPEMKeyFile",
+ CLIENT_CERT,
+ "--sslCAFile",
+ CA_CERT,
+ "--tlsAllowInvalidHostnames",
+ "--awsAccessKeyId",
+ "access",
+ "--awsSecretAccessKey",
+ "secret",
+ "--keyVaultNamespace",
+ "test.coll",
+ "--kmsURL",
+ "https://localhost:8000",
+ ];
+
+ const testFiles = [
+ "jstests/client_encrypt/lib/fle_command_line_explicit_encryption.js",
+ ];
+
+ for (const file of testFiles) {
+ runMongoProgram(...shellOpts, file);
+ }
+}()); \ No newline at end of file
diff --git a/jstests/client_encrypt/fle_encrypt_decrypt_shell.js b/jstests/client_encrypt/fle_encrypt_decrypt_shell.js
new file mode 100644
index 00000000000..79ffb87cc19
--- /dev/null
+++ b/jstests/client_encrypt/fle_encrypt_decrypt_shell.js
@@ -0,0 +1,112 @@
+/**
+ * Check the functionality of encrypt and decrypt functions in KeyStore.js
+ */
+load("jstests/client_encrypt/lib/mock_kms.js");
+load('jstests/ssl/libs/ssl_helpers.js');
+
+(function() {
+ "use strict";
+
+ const mock_kms = new MockKMSServer();
+ mock_kms.start();
+
+ const x509_options = {sslMode: "requireSSL", sslPEMKeyFile: SERVER_CERT, sslCAFile: CA_CERT};
+
+ const conn = MongoRunner.runMongod(x509_options);
+ const test = conn.getDB("test");
+ const collection = test.coll;
+
+ const awsKMS = {
+ accessKeyId: "access",
+ secretAccessKey: "secret",
+ url: mock_kms.getURL(),
+ };
+
+ let localKMS = {
+ key: BinData(
+ 0,
+ "/i8ytmWQuCe1zt3bIuVa4taPGKhqasVp0/0yI4Iy0ixQPNmeDF1J5qPUbBYoueVUJHMqj350eRTwztAWXuBdSQ=="),
+ };
+
+ const clientSideFLEOptions = {
+ kmsProviders: {
+ aws: awsKMS,
+ local: localKMS,
+ },
+ keyVaultNamespace: "test.coll",
+ schemaMap: {}
+ };
+
+ const kmsTypes = ["aws", "local"];
+
+ const randomAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Random";
+ const deterministicAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic";
+ const encryptionAlgorithms = [randomAlgorithm, deterministicAlgorithm];
+
+ const passTestCases = [
+ "mongo",
+ NumberLong(13),
+ NumberInt(23),
+ UUID(),
+ ISODate(),
+ new Date('December 17, 1995 03:24:00'),
+ BinData(2, '1234'),
+ new Timestamp(1, 2),
+ new ObjectId(),
+ new DBPointer("mongo", new ObjectId()),
+ /test/
+ ];
+
+ const failDeterministic = [
+ true,
+ false,
+ 12,
+ NumberDecimal(0.1234),
+ ["this is an array"],
+ {"value": "mongo"},
+ Code("function() { return true; }")
+ ];
+
+ const failTestCases = [null, undefined, MinKey(), MaxKey(), DBRef("test", "test", "test")];
+
+ const shell = Mongo(conn.host, clientSideFLEOptions);
+ const keyVault = shell.getKeyVault();
+
+ // Testing for every combination of (kmsType, algorithm, javascriptVariable)
+ for (const kmsType of kmsTypes) {
+ for (const encryptionAlgorithm of encryptionAlgorithms) {
+ collection.drop();
+
+ assert.writeOK(
+ keyVault.createKey(kmsType, "arn:aws:kms:us-east-1:fake:fake:fake", ['mongoKey']));
+ const keyId = keyVault.getKeyByAltName("mongoKey").toArray()[0]._id;
+
+ let pass;
+ let fail;
+ if (encryptionAlgorithm === randomAlgorithm) {
+ pass = [...passTestCases, ...failDeterministic];
+ fail = failTestCases;
+ } else if (encryptionAlgorithm === deterministicAlgorithm) {
+ pass = passTestCases;
+ fail = [...failTestCases, ...failDeterministic];
+ }
+
+ for (const passTestCase of pass) {
+ const encPassTestCase = shell.encrypt(keyId, passTestCase, encryptionAlgorithm);
+ assert.eq(passTestCase, shell.decrypt(encPassTestCase));
+
+ if (encryptionAlgorithm === deterministicAlgorithm) {
+ assert.eq(encPassTestCase,
+ shell.encrypt(keyId, passTestCase, encryptionAlgorithm));
+ }
+ }
+
+ for (const failTestCase of fail) {
+ assert.throws(shell.encrypt, [keyId, failTestCase, encryptionAlgorithm]);
+ }
+ }
+ }
+
+ MongoRunner.stopMongod(conn);
+ mock_kms.stop();
+}()); \ No newline at end of file
diff --git a/jstests/client_encrypt/fle_key_faults.js b/jstests/client_encrypt/fle_key_faults.js
new file mode 100644
index 00000000000..5f2fdcab08a
--- /dev/null
+++ b/jstests/client_encrypt/fle_key_faults.js
@@ -0,0 +1,94 @@
+/**
+ * Verify the KMS support handles a buggy Key Store
+ */
+
+load("jstests/client_encrypt/lib/mock_kms.js");
+load('jstests/ssl/libs/ssl_helpers.js');
+
+(function() {
+ "use strict";
+
+ const mock_kms = new MockKMSServer();
+ mock_kms.start();
+
+ const x509_options = {sslMode: "requireSSL", sslPEMKeyFile: SERVER_CERT, sslCAFile: CA_CERT};
+
+ const conn = MongoRunner.runMongod(x509_options);
+ const test = conn.getDB("test");
+ const collection = test.coll;
+
+ const awsKMS = {
+ accessKeyId: "access",
+ secretAccessKey: "secret",
+ url: mock_kms.getURL(),
+ };
+
+ var localKMS = {
+ key: BinData(
+ 0,
+ "/i8ytmWQuCe1zt3bIuVa4taPGKhqasVp0/0yI4Iy0ixQPNmeDF1J5qPUbBYoueVUJHMqj350eRTwztAWXuBdSQ=="),
+ };
+
+ const clientSideFLEOptions = {
+ kmsProviders: {
+ aws: awsKMS,
+ local: localKMS,
+ },
+ keyVaultNamespace: "test.coll",
+ schemaMap: {}
+ };
+
+ function testFault(kmsType, func) {
+ collection.drop();
+
+ const shell = Mongo(conn.host, clientSideFLEOptions);
+ const keyVault = shell.getKeyVault();
+
+ assert.writeOK(
+ keyVault.createKey(kmsType, "arn:aws:kms:us-east-1:fake:fake:fake", ['mongoKey']));
+ const keyId = keyVault.getKeyByAltName("mongoKey").toArray()[0]._id;
+
+ func(keyId, shell);
+ }
+
+ function testFaults(func) {
+ const kmsTypes = ["aws", "local"];
+
+ for (const kmsType of kmsTypes) {
+ testFault(kmsType, func);
+ }
+ }
+
+ // Negative - drop the key vault collection
+ testFaults((keyId, shell) => {
+ collection.drop();
+
+ const str = "mongo";
+ assert.throws(() => {
+ const encStr = shell.encrypt(keyId, str);
+ });
+ });
+
+ // Negative - delete the keys
+ testFaults((keyId, shell) => {
+ collection.deleteMany({});
+
+ const str = "mongo";
+ assert.throws(() => {
+ const encStr = shell.encrypt(keyId, str);
+ });
+ });
+
+ // Negative - corrupt the master key with an unkown provider
+ testFaults((keyId, shell) => {
+ collection.updateMany({}, {$set: {"masterKey.provider": "fake"}});
+
+ const str = "mongo";
+ assert.throws(() => {
+ const encStr = shell.encrypt(keyId, str);
+ });
+ });
+
+ MongoRunner.stopMongod(conn);
+ mock_kms.stop();
+}()); \ No newline at end of file
diff --git a/jstests/client_encrypt/fle_keys.js b/jstests/client_encrypt/fle_keys.js
new file mode 100644
index 00000000000..875615ac9a8
--- /dev/null
+++ b/jstests/client_encrypt/fle_keys.js
@@ -0,0 +1,75 @@
+/**
+ * Check functionality of KeyVault.js
+ */
+
+load("jstests/client_encrypt/lib/mock_kms.js");
+load('jstests/ssl/libs/ssl_helpers.js');
+
+(function() {
+ "use strict";
+
+ const mock_kms = new MockKMSServer();
+ mock_kms.start();
+
+ const x509_options = {sslMode: "requireSSL", sslPEMKeyFile: SERVER_CERT, sslCAFile: CA_CERT};
+
+ const conn = MongoRunner.runMongod(x509_options);
+ const test = conn.getDB("test");
+ const collection = test.coll;
+
+ const awsKMS = {
+ accessKeyId: "access",
+ secretAccessKey: "secret",
+ url: mock_kms.getURL(),
+ };
+
+ const clientSideFLEOptions = {
+ kmsProviders: {
+ aws: awsKMS,
+ },
+ keyVaultNamespace: "test.coll",
+ schemaMap: {}
+ };
+
+ const conn_str = "mongodb://" + conn.host + "/?ssl=true";
+ const shell = Mongo(conn_str, clientSideFLEOptions);
+ const keyVault = shell.getKeyVault();
+
+ var key = keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake", ['mongoKey']);
+ assert.eq(1, keyVault.getKeys().itcount());
+
+ var result = keyVault.createKey("aws", "arn:aws:kms:us-east-4:fake:fake:fake", {});
+ assert.eq("TypeError: key alternate names must be of Array type.", result);
+
+ result = keyVault.createKey("aws", "arn:aws:kms:us-east-5:fake:fake:fake", [1]);
+ assert.eq("TypeError: items in key alternate names must be of String type.", result);
+
+ assert.eq(1, keyVault.getKeyByAltName("mongoKey").itcount());
+
+ var keyId = keyVault.getKeyByAltName("mongoKey").toArray()[0]._id;
+
+ keyVault.addKeyAlternateName(keyId, "mongoKey2");
+
+ assert.eq(1, keyVault.getKeyByAltName("mongoKey2").itcount());
+ assert.eq(2, keyVault.getKey(keyId).toArray()[0].keyAltNames.length);
+ assert.eq(1, keyVault.getKeys().itcount());
+
+ result = keyVault.addKeyAlternateName(keyId, [2]);
+ assert.eq("TypeError: key alternate name cannot be object or array type.", result);
+
+ keyVault.removeKeyAlternateName(keyId, "mongoKey2");
+ assert.eq(1, keyVault.getKey(keyId).toArray()[0].keyAltNames.length);
+
+ result = keyVault.deleteKey(keyId);
+ assert.eq(0, keyVault.getKey(keyId).itcount());
+ assert.eq(0, keyVault.getKeys().itcount());
+
+ assert.writeOK(keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake1"));
+ assert.writeOK(keyVault.createKey("aws", "arn:aws:kms:us-east-2:fake:fake:fake2"));
+ assert.writeOK(keyVault.createKey("aws", "arn:aws:kms:us-east-3:fake:fake:fake3"));
+
+ assert.eq(3, keyVault.getKeys().itcount());
+
+ MongoRunner.stopMongod(conn);
+ mock_kms.stop();
+}()); \ No newline at end of file
diff --git a/jstests/client_encrypt/fle_valid_fle_options.js b/jstests/client_encrypt/fle_valid_fle_options.js
new file mode 100644
index 00000000000..2189501ad00
--- /dev/null
+++ b/jstests/client_encrypt/fle_valid_fle_options.js
@@ -0,0 +1,68 @@
+
+load("jstests/client_encrypt/lib/mock_kms.js");
+load('jstests/ssl/libs/ssl_helpers.js');
+
+(function() {
+ "use strict";
+
+ const mock_kms = new MockKMSServer();
+ mock_kms.start();
+
+ const randomAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Random";
+ const deterministicAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic";
+
+ const x509_options =
+ {sslMode: "requireSSL", sslPEMKeyFile: SERVER_CERT, sslCAFile: CA_CERT, vvvvv: ""};
+
+ const conn = MongoRunner.runMongod(x509_options);
+ const unencryptedDatabase = conn.getDB("test");
+ const collection = unencryptedDatabase.keystore;
+
+ const awsKMS = {
+ accessKeyId: "access",
+ secretAccessKey: "secret",
+ url: mock_kms.getURL(),
+ };
+
+ const clientSideFLEOptionsFail = [
+ {
+ kmsProviders: {
+ aws: awsKMS,
+ },
+ schemaMap: {},
+ },
+ {
+ kmsProviders: {
+ aws: awsKMS,
+ },
+ keyVaultNamespace: "test.keystore",
+ },
+ {
+ keyVaultNamespace: "test.keystore",
+ schemaMap: {},
+ },
+ ];
+
+ clientSideFLEOptionsFail.forEach(element => {
+ assert.throws(Mongo, [conn.host, element]);
+ });
+
+ const clientSideFLEOptionsPass = [
+ {
+ kmsProviders: {
+ aws: awsKMS,
+ },
+ keyVaultNamespace: "test.keystore",
+ schemaMap: {},
+ },
+ ];
+
+ clientSideFLEOptionsPass.forEach(element => {
+ assert.doesNotThrow(() => {
+ Mongo(conn.host, element);
+ });
+ });
+
+ MongoRunner.stopMongod(conn);
+ mock_kms.stop();
+}());
diff --git a/jstests/client_encrypt/lib/fle_command_line_explicit_encryption.js b/jstests/client_encrypt/lib/fle_command_line_explicit_encryption.js
new file mode 100644
index 00000000000..0ca10b2057c
--- /dev/null
+++ b/jstests/client_encrypt/lib/fle_command_line_explicit_encryption.js
@@ -0,0 +1,84 @@
+/**
+* Check the functionality of encrypt and decrypt functions in KeyVault.js. This test is run by
+* jstests/fle/fle_command_line_encryption.js.
+*/
+
+load("jstests/client_encrypt/lib/mock_kms.js");
+
+(function() {
+ "use strict";
+
+ const mock_kms = new MockKMSServer();
+ mock_kms.start();
+
+ const shell = Mongo();
+ const keyVault = shell.getKeyVault();
+
+ const test = shell.getDB("test");
+ const collection = test.coll;
+
+ const randomAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Random";
+ const deterministicAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic";
+ const encryptionAlgorithms = [randomAlgorithm, deterministicAlgorithm];
+
+ const passTestCases = [
+ "mongo",
+ NumberLong(13),
+ NumberInt(23),
+ UUID(),
+ ISODate(),
+ new Date('December 17, 1995 03:24:00'),
+ BinData(2, '1234'),
+ new Timestamp(1, 2),
+ new ObjectId(),
+ new DBPointer("mongo", new ObjectId()),
+ /test/
+ ];
+
+ const failDeterministic = [
+ true,
+ false,
+ 12,
+ NumberDecimal(0.1234),
+ ["this is an array"],
+ {"value": "mongo"},
+ Code("function() { return true; }")
+ ];
+
+ const failTestCases = [null, undefined, MinKey(), MaxKey(), DBRef("test", "test", "test")];
+
+ // Testing for every combination of (algorithm, javascriptVariable)
+ for (const encryptionAlgorithm of encryptionAlgorithms) {
+ collection.drop();
+
+ assert.writeOK(
+ keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake", ['mongoKey']));
+ const keyId = keyVault.getKeyByAltName("mongoKey").toArray()[0]._id;
+
+ let pass;
+ let fail;
+ if (encryptionAlgorithm === randomAlgorithm) {
+ pass = [...passTestCases, ...failDeterministic];
+ fail = failTestCases;
+ } else if (encryptionAlgorithm === deterministicAlgorithm) {
+ pass = passTestCases;
+ fail = [...failTestCases, ...failDeterministic];
+ }
+
+ for (const passTestCase of pass) {
+ const encPassTestCase = shell.encrypt(keyId, passTestCase, encryptionAlgorithm);
+ assert.eq(passTestCase, shell.decrypt(encPassTestCase));
+
+ if (encryptionAlgorithm == deterministicAlgorithm) {
+ assert.eq(encPassTestCase, shell.encrypt(keyId, passTestCase, encryptionAlgorithm));
+ }
+ }
+
+ for (const failTestCase of fail) {
+ assert.throws(shell.encrypt, [keyId, failTestCase, encryptionAlgorithm]);
+ }
+ }
+
+ mock_kms.stop();
+ print("Test completed with no errors.");
+}()); \ No newline at end of file
diff --git a/jstests/client_encrypt/lib/kms_http_common.py b/jstests/client_encrypt/lib/kms_http_common.py
new file mode 100644
index 00000000000..aaef6a8ad69
--- /dev/null
+++ b/jstests/client_encrypt/lib/kms_http_common.py
@@ -0,0 +1,21 @@
+"""Common code for mock kms http endpoint."""
+import json
+
+URL_PATH_STATS = "/stats"
+URL_DISABLE_FAULTS = "/disable_faults"
+URL_ENABLE_FAULTS = "/enable_faults"
+
+class Stats:
+ """Stats class shared between client and server."""
+
+ def __init__(self):
+ self.encrypt_calls = 0
+ self.decrypt_calls = 0
+ self.fault_calls = 0
+
+ def __repr__(self):
+ return json.dumps({
+ 'decrypts': self.decrypt_calls,
+ 'encrypts': self.encrypt_calls,
+ 'faults': self.fault_calls,
+ })
diff --git a/jstests/client_encrypt/lib/kms_http_control.py b/jstests/client_encrypt/lib/kms_http_control.py
new file mode 100644
index 00000000000..2f62780fb77
--- /dev/null
+++ b/jstests/client_encrypt/lib/kms_http_control.py
@@ -0,0 +1,52 @@
+#! /usr/bin/env python3
+"""
+Python script to interact with mock AWS KMS HTTP server.
+"""
+
+import argparse
+import json
+import logging
+import sys
+import urllib.request
+import ssl
+
+import kms_http_common
+
+def main():
+ """Main entry point."""
+ parser = argparse.ArgumentParser(description='MongoDB Mock AWS KMS Endpoint.')
+
+ parser.add_argument('-p', '--port', type=int, default=8000, help="Port to listen on")
+
+ parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing")
+
+ parser.add_argument('--ca_file', type=str, required=True, help="TLS CA PEM file")
+
+ parser.add_argument('--query', type=str, help="Query endpoint <name>")
+
+ args = parser.parse_args()
+ if args.verbose:
+ logging.basicConfig(level=logging.DEBUG)
+
+ url_str = "https://localhost:" + str(args.port)
+ if args.query == "stats":
+ url_str += kms_http_common.URL_PATH_STATS
+ elif args.query == "disable_faults":
+ url_str += kms_http_common.URL_DISABLE_FAULTS
+ elif args.query == "enable_faults":
+ url_str += kms_http_common.URL_ENABLE_FAULTS
+ else:
+ print("Unknown query type")
+ sys.exit(1)
+
+ context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=args.ca_file)
+
+ with urllib.request.urlopen(url_str, context=context) as f:
+ print(f.read().decode('utf-8'))
+
+ sys.exit(0)
+
+
+if __name__ == '__main__':
+
+ main()
diff --git a/jstests/client_encrypt/lib/kms_http_server.py b/jstests/client_encrypt/lib/kms_http_server.py
new file mode 100755
index 00000000000..e2414bd8574
--- /dev/null
+++ b/jstests/client_encrypt/lib/kms_http_server.py
@@ -0,0 +1,298 @@
+#! /usr/bin/env python3
+"""Mock AWS KMS Endpoint."""
+
+import argparse
+import collections
+import base64
+import http.server
+import json
+import logging
+import socketserver
+import sys
+import urllib.parse
+import ssl
+
+import kms_http_common
+
+SECRET_PREFIX = "00SECRET"
+
+# Pass this data out of band instead of storing it in AwsKmsHandler since the
+# BaseHTTPRequestHandler does not call the methods as object methods but as class methods. This
+# means there is not self.
+stats = kms_http_common.Stats()
+disable_faults = False
+fault_type = None
+
+"""Fault which causes encrypt to return 500."""
+FAULT_ENCRYPT = "fault_encrypt"
+
+"""Fault which causes encrypt to return an error that contains a type and message"""
+FAULT_ENCRYPT_CORRECT_FORMAT = "fault_encrypt_correct_format"
+
+"""Fault which causes encrypt to return wrong fields in JSON."""
+FAULT_ENCRYPT_WRONG_FIELDS = "fault_encrypt_wrong_fields"
+
+"""Fault which causes encrypt to return bad BASE64."""
+FAULT_ENCRYPT_BAD_BASE64 = "fault_encrypt_bad_base64"
+
+"""Fault which causes decrypt to return 500."""
+FAULT_DECRYPT = "fault_decrypt"
+
+"""Fault which causes decrypt to return an error that contains a type and message"""
+FAULT_DECRYPT_CORRECT_FORMAT = "fault_decrypt_correct_format"
+
+"""Fault which causes decrypt to return wrong key."""
+FAULT_DECRYPT_WRONG_KEY = "fault_decrypt_wrong_key"
+
+
+# List of supported fault types
+SUPPORTED_FAULT_TYPES = [
+ FAULT_ENCRYPT,
+ FAULT_ENCRYPT_CORRECT_FORMAT,
+ FAULT_ENCRYPT_WRONG_FIELDS,
+ FAULT_ENCRYPT_BAD_BASE64,
+ FAULT_DECRYPT,
+ FAULT_DECRYPT_CORRECT_FORMAT,
+ FAULT_DECRYPT_WRONG_KEY,
+]
+
+class AwsKmsHandler(http.server.BaseHTTPRequestHandler):
+ """
+ Handle requests from AWS KMS Monitoring and test commands
+ """
+ protocol_version = "HTTP/1.1"
+
+ def do_GET(self):
+ """Serve a Test GET request."""
+ parts = urllib.parse.urlsplit(self.path)
+ path = parts[2]
+
+ if path == kms_http_common.URL_PATH_STATS:
+ self._do_stats()
+ elif path == kms_http_common.URL_DISABLE_FAULTS:
+ self._do_disable_faults()
+ elif path == kms_http_common.URL_ENABLE_FAULTS:
+ self._do_enable_faults()
+ else:
+ self.send_response(http.HTTPStatus.NOT_FOUND)
+ self.end_headers()
+ self.wfile.write("Unknown URL".encode())
+
+ def do_POST(self):
+ """Serve a POST request."""
+ parts = urllib.parse.urlsplit(self.path)
+ path = parts[2]
+
+ if path == "/":
+ self._do_post()
+ else:
+ self.send_response(http.HTTPStatus.NOT_FOUND)
+ self.end_headers()
+ self.wfile.write("Unknown URL".encode())
+
+ def _send_reply(self, data, status=http.HTTPStatus.OK):
+ print("Sending Response: " + data.decode())
+
+ self.send_response(status)
+ self.send_header("content-type", "application/octet-stream")
+ self.send_header("Content-Length", str(len(data)))
+ self.end_headers()
+
+ self.wfile.write(data)
+
+ def _do_post(self):
+ global stats
+ clen = int(self.headers.get('content-length'))
+
+ raw_input = self.rfile.read(clen)
+
+ print("RAW INPUT: " + str(raw_input))
+
+ # X-Amz-Target: TrentService.Encrypt
+ aws_operation = self.headers['X-Amz-Target']
+
+ if aws_operation == "TrentService.Encrypt":
+ stats.encrypt_calls += 1
+ self._do_encrypt(raw_input)
+ elif aws_operation == "TrentService.Decrypt":
+ stats.decrypt_calls += 1
+ self._do_decrypt(raw_input)
+ else:
+ data = "Unknown AWS Operation"
+ self._send_reply(data.encode("utf-8"))
+
+ def _do_encrypt(self, raw_input):
+ request = json.loads(raw_input)
+
+ print(request)
+
+ plaintext = request["Plaintext"]
+ keyid = request["KeyId"]
+
+ ciphertext = SECRET_PREFIX.encode() + plaintext.encode()
+ ciphertext = base64.b64encode(ciphertext).decode()
+
+ if fault_type and fault_type.startswith(FAULT_ENCRYPT) and not disable_faults:
+ return self._do_encrypt_faults(ciphertext)
+
+ response = {
+ "CiphertextBlob" : ciphertext,
+ "KeyId" : keyid,
+ }
+
+ self._send_reply(json.dumps(response).encode('utf-8'))
+
+ def _do_encrypt_faults(self, raw_ciphertext):
+ stats.fault_calls += 1
+
+ if fault_type == FAULT_ENCRYPT:
+ self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR)
+ return
+ elif fault_type == FAULT_ENCRYPT_WRONG_FIELDS:
+ response = {
+ "SomeBlob" : raw_ciphertext,
+ "KeyId" : "foo",
+ }
+
+ self._send_reply(json.dumps(response).encode('utf-8'))
+ return
+ elif fault_type == FAULT_ENCRYPT_BAD_BASE64:
+ response = {
+ "CiphertextBlob" : "foo",
+ "KeyId" : "foo",
+ }
+
+ self._send_reply(json.dumps(response).encode('utf-8'))
+ return
+ elif fault_type == FAULT_ENCRYPT_CORRECT_FORMAT:
+ response = {
+ "__type" : "NotFoundException",
+ "message" : "Error encrypting message",
+ }
+
+ self._send_reply(json.dumps(response).encode('utf-8'))
+ return
+
+ raise ValueError("Unknown Fault Type: " + fault_type)
+
+ def _do_decrypt(self, raw_input):
+ request = json.loads(raw_input)
+ blob = base64.b64decode(request["CiphertextBlob"]).decode()
+
+ print("FOUND SECRET: " + blob)
+
+ # our "encrypted" values start with the word SECRET_PREFIX otherwise they did not come from us
+ if not blob.startswith(SECRET_PREFIX):
+ raise ValueError()
+
+ blob = blob[len(SECRET_PREFIX):]
+
+ if fault_type and fault_type.startswith(FAULT_DECRYPT) and not disable_faults:
+ return self._do_decrypt_faults(blob)
+
+ response = {
+ "Plaintext" : blob,
+ "KeyId" : "Not a clue",
+ }
+
+ self._send_reply(json.dumps(response).encode('utf-8'))
+
+ def _do_decrypt_faults(self, blob):
+ stats.fault_calls += 1
+
+ if fault_type == FAULT_DECRYPT:
+ self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR)
+ return
+ elif fault_type == FAULT_DECRYPT_WRONG_KEY:
+ response = {
+ "Plaintext" : "ta7DXE7J0OiCRw03dYMJSeb8nVF5qxTmZ9zWmjuX4zW/SOorSCaY8VMTWG+cRInMx/rr/+QeVw2WjU2IpOSvMg==",
+ "KeyId" : "Not a clue",
+ }
+
+ self._send_reply(json.dumps(response).encode('utf-8'))
+ return
+ elif fault_type == FAULT_DECRYPT_CORRECT_FORMAT:
+ response = {
+ "__type" : "NotFoundException",
+ "message" : "Error decrypting message",
+ }
+
+ self._send_reply(json.dumps(response).encode('utf-8'))
+ return
+
+ raise ValueError("Unknown Fault Type: " + fault_type)
+
+ def _send_header(self):
+ self.send_response(http.HTTPStatus.OK)
+ self.send_header("content-type", "application/octet-stream")
+ self.end_headers()
+
+ def _do_stats(self):
+ self._send_header()
+
+ self.wfile.write(str(stats).encode('utf-8'))
+
+ def _do_disable_faults(self):
+ global disable_faults
+ disable_faults = True
+ self._send_header()
+
+ def _do_enable_faults(self):
+ global disable_faults
+ disable_faults = False
+ self._send_header()
+
+def run(port, cert_file, ca_file, server_class=http.server.HTTPServer, handler_class=AwsKmsHandler):
+ """Run web server."""
+ server_address = ('', port)
+
+ httpd = server_class(server_address, handler_class)
+
+ httpd.socket = ssl.wrap_socket (httpd.socket,
+ certfile=cert_file,
+ ca_certs=ca_file, server_side=True)
+
+ print("Mock KMS Web Server Listening on %s" % (str(server_address)))
+
+ httpd.serve_forever()
+
+
+def main():
+ """Main Method."""
+ global fault_type
+ global disable_faults
+
+ parser = argparse.ArgumentParser(description='MongoDB Mock AWS KMS Endpoint.')
+
+ parser.add_argument('-p', '--port', type=int, default=8000, help="Port to listen on")
+
+ parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing")
+
+ parser.add_argument('--fault', type=str, help="Type of fault to inject")
+
+ parser.add_argument('--disable-faults', action='store_true', help="Disable faults on startup")
+
+ parser.add_argument('--ca_file', type=str, required=True, help="TLS CA PEM file")
+
+ parser.add_argument('--cert_file', type=str, required=True, help="TLS Server PEM file")
+
+ args = parser.parse_args()
+ if args.verbose:
+ logging.basicConfig(level=logging.DEBUG)
+
+ if args.fault:
+ if args.fault not in SUPPORTED_FAULT_TYPES:
+ print("Unsupported fault type %s, supports types are %s" % (args.fault, SUPPORTED_FAULT_TYPES))
+ sys.exit(1)
+
+ fault_type = args.fault
+
+ if args.disable_faults:
+ disable_faults = True
+
+ run(args.port, args.cert_file, args.ca_file)
+
+
+if __name__ == '__main__':
+
+ main()
diff --git a/jstests/client_encrypt/lib/mock_kms.js b/jstests/client_encrypt/lib/mock_kms.js
new file mode 100644
index 00000000000..a7f34c37312
--- /dev/null
+++ b/jstests/client_encrypt/lib/mock_kms.js
@@ -0,0 +1,161 @@
+/**
+ * Starts a mock KMS Server to test
+ * FLE encryption and decryption.
+ */
+
+// These faults must match the list of faults in kms_http_server.py, see the
+// SUPPORTED_FAULT_TYPES list in kms_http_server.py
+const FAULT_ENCRYPT = "fault_encrypt";
+const FAULT_ENCRYPT_CORRECT_FORMAT = "fault_encrypt_correct_format";
+const FAULT_ENCRYPT_WRONG_FIELDS = "fault_encrypt_wrong_fields";
+const FAULT_ENCRYPT_BAD_BASE64 = "fault_encrypt_bad_base64";
+const FAULT_DECRYPT = "fault_decrypt";
+const FAULT_DECRYPT_CORRECT_FORMAT = "fault_decrypt_correct_format";
+const FAULT_DECRYPT_WRONG_KEY = "fault_decrypt_wrong_key";
+
+const DISABLE_FAULTS = "disable_faults";
+const ENABLE_FAULTS = "enable_faults";
+
+class MockKMSServer {
+ /**
+ * Create a new webserver.
+ *
+ * @param {string} fault_type
+ * @param {bool} disableFaultsOnStartup optionally disable fault on startup
+ */
+ constructor(fault_type, disableFaultsOnStartup) {
+ this.python = "python3";
+ this.disableFaultsOnStartup = disableFaultsOnStartup || false;
+ this.fault_type = fault_type;
+
+ if (_isWindows()) {
+ this.python = "python.exe";
+ }
+
+ print("Using python interpreter: " + this.python);
+
+ this.ca_file = "jstests/libs/ca.pem";
+ this.server_cert_file = "jstests/libs/server.pem";
+ this.web_server_py = "jstests/client_encrypt/lib/kms_http_server.py";
+ this.control_py = "jstests/client_encrypt/lib/kms_http_control.py";
+ this.port = -1;
+ }
+
+ /**
+ * Start a web server
+ */
+ start() {
+ this.port = allocatePort();
+ print("Mock Web server is listening on port: " + this.port);
+
+ let args = [
+ this.python,
+ "-u",
+ this.web_server_py,
+ "--port=" + this.port,
+ "--ca_file=" + this.ca_file,
+ "--cert_file=" + this.server_cert_file
+ ];
+ if (this.fault_type) {
+ args.push("--fault=" + this.fault_type);
+ if (this.disableFaultsOnStartup) {
+ args.push("--disable-faults");
+ }
+ }
+
+ this.pid = _startMongoProgram({args: args});
+ assert(checkProgram(this.pid));
+
+ assert.soon(function() {
+ return rawMongoProgramOutput().search("Mock KMS Web Server Listening") !== -1;
+ });
+ sleep(1000);
+ print("Mock KMS Server successfully started");
+ }
+
+ _runCommand(cmd) {
+ let ret = 0;
+ if (_isWindows()) {
+ ret = runProgram('cmd.exe', '/c', cmd);
+ } else {
+ ret = runProgram('/bin/sh', '-c', cmd);
+ }
+
+ assert.eq(ret, 0);
+ }
+
+ /**
+ * Query the HTTP server.
+ *
+ * @param {string} query type
+ *
+ * @return {object} Object representation of JSON from the server.
+ */
+ query(query) {
+ const out_file = "out_" + this.port + ".txt";
+ const python_command = this.python + " -u " + this.control_py + " --port=" + this.port +
+ " --ca_file=" + this.ca_file + " --query=" + query + " > " + out_file;
+
+ this._runCommand(python_command);
+
+ const result = cat(out_file);
+
+ try {
+ return JSON.parse(result);
+ } catch (e) {
+ jsTestLog("Failed to parse: " + result + "\n" + result);
+ throw e;
+ }
+ }
+
+ /**
+ * Control the HTTP server.
+ *
+ * @param {string} query type
+ */
+ control(query) {
+ const python_command = this.python + " -u " + this.control_py + " --port=" + this.port +
+ " --ca_file=" + this.ca_file + " --query=" + query;
+
+ this._runCommand(python_command);
+ }
+
+ /**
+ * Disable Faults
+ */
+ disableFaults() {
+ this.control(DISABLE_FAULTS);
+ }
+
+ /**
+ * Enable Faults
+ */
+ enableFaults() {
+ this.control(ENABLE_FAULTS);
+ }
+
+ /**
+ * Query the stats page for the HTTP server.
+ *
+ * @return {object} Object representation of JSON from the server.
+ */
+ queryStats() {
+ return this.query("stats");
+ }
+
+ /**
+ * Get the URL.
+ *
+ * @return {string} url of http server
+ */
+ getURL() {
+ return "https://localhost:" + this.port;
+ }
+
+ /**
+ * Stop the web server
+ */
+ stop() {
+ stopMongoProgramByPid(this.pid);
+ }
+}
diff --git a/src/mongo/SConscript b/src/mongo/SConscript
index c56ffe4bfe1..6dfd113c4ad 100644
--- a/src/mongo/SConscript
+++ b/src/mongo/SConscript
@@ -622,6 +622,8 @@ if not has_option('noshell') and usemozjs:
"shell_core",
"db/server_options_core",
"client/clientdriver_network",
+ "shell/kms_shell" if get_option('ssl') == 'on' else '',
+ "shell/encrypted_dbclient" if get_option('ssl') == 'on' else '',
"$BUILD_DIR/mongo/util/password",
'$BUILD_DIR/mongo/db/storage/duplicate_key_error_info',
"$BUILD_DIR/mongo/db/views/resolved_view",
diff --git a/src/mongo/crypto/SConscript b/src/mongo/crypto/SConscript
index 97f73e3356a..5cc05d65e2d 100644
--- a/src/mongo/crypto/SConscript
+++ b/src/mongo/crypto/SConscript
@@ -76,3 +76,36 @@ env.CppUnitTest('mechanism_scram_test',
'$BUILD_DIR/mongo/base/secure_allocator',
'sha_block_${MONGO_CRYPTO}',
])
+
+
+env.Library(target='symmetric_crypto',
+ source=[
+ 'symmetric_crypto.cpp',
+ 'symmetric_crypto_${MONGO_CRYPTO}.cpp',
+ 'symmetric_key.cpp',
+ ],
+ LIBDEPS=[
+ '$BUILD_DIR/mongo/base/secure_allocator',
+ '$BUILD_DIR/mongo/util/net/ssl_manager',
+ '$BUILD_DIR/mongo/util/secure_zero_memory',
+ ],
+)
+
+env.Library(
+ target="aead_encryption",
+ source=[
+ "aead_encryption.cpp",
+ ],
+ LIBDEPS=[
+ 'symmetric_crypto',
+ '$BUILD_DIR/mongo/db/matcher/expressions',
+ ],
+)
+
+env.CppUnitTest(
+ target='aead_encryption_test',
+ source='aead_encryption_test.cpp',
+ LIBDEPS=[
+ 'aead_encryption',
+ ]
+)
diff --git a/src/mongo/crypto/aead_encryption.cpp b/src/mongo/crypto/aead_encryption.cpp
new file mode 100644
index 00000000000..77ec7ed41c4
--- /dev/null
+++ b/src/mongo/crypto/aead_encryption.cpp
@@ -0,0 +1,391 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/crypto/aead_encryption.h"
+
+#include "mongo/base/data_view.h"
+#include "mongo/crypto/sha512_block.h"
+#include "mongo/crypto/symmetric_crypto.h"
+#include "mongo/db/matcher/schema/encrypt_schema_gen.h"
+#include "mongo/util/secure_compare_memory.h"
+
+namespace mongo {
+namespace crypto {
+
+namespace {
+constexpr size_t kHmacOutSize = 32;
+constexpr size_t kIVSize = 16;
+
+// AssociatedData can be 2^24 bytes but since there needs to be room for the ciphertext in the
+// object, a value of 1<<16 was decided to cap the maximum size of AssociatedData.
+constexpr int kMaxAssociatedDataLength = 1 << 16;
+
+size_t aesCBCCipherOutputLength(size_t plainTextLen) {
+ return aesBlockSize * (1 + plainTextLen / aesBlockSize);
+}
+
+std::pair<size_t, size_t> aesCBCExpectedPlaintextLen(size_t cipherTextLength) {
+ return {cipherTextLength - aesCBCIVSize - aesBlockSize, cipherTextLength - aesCBCIVSize};
+}
+
+void aeadGenerateIV(const SymmetricKey* key, uint8_t* buffer, size_t bufferLen) {
+ if (bufferLen < aesCBCIVSize) {
+ fassert(51235, "IV buffer is too small for selected mode");
+ }
+
+ auto status = engineRandBytes(buffer, aesCBCIVSize);
+ if (!status.isOK()) {
+ fassert(51236, status);
+ }
+}
+
+Status _aesEncrypt(const SymmetricKey& key,
+ const std::uint8_t* in,
+ std::size_t inLen,
+ std::uint8_t* out,
+ std::size_t outLen,
+ std::size_t* resultLen,
+ bool ivProvided) try {
+
+ if (!ivProvided) {
+ aeadGenerateIV(&key, out, aesCBCIVSize);
+ }
+
+ auto encryptor =
+ uassertStatusOK(SymmetricEncryptor::create(key, aesMode::cbc, out, aesCBCIVSize));
+
+ const size_t dataSize = outLen - aesCBCIVSize;
+ uint8_t* data = out + aesCBCIVSize;
+
+ const auto updateLen = uassertStatusOK(encryptor->update(in, inLen, data, dataSize));
+ const auto finalLen =
+ uassertStatusOK(encryptor->finalize(data + updateLen, dataSize - updateLen));
+ const auto len = updateLen + finalLen;
+
+ // Some cipher modes, such as GCM, will know in advance exactly how large their ciphertexts will
+ // be. Others, like CBC, will have an upper bound. When this is true, we must allocate enough
+ // memory to store the worst case. We must then set the actual size of the ciphertext so that
+ // the buffer it has been written to may be serialized.
+ invariant(len <= dataSize);
+ *resultLen = aesCBCIVSize + len;
+
+ // Check the returned length, including block size padding
+ if (len != aesCBCCipherOutputLength(inLen)) {
+ return {ErrorCodes::BadValue,
+ str::stream() << "Encrypt error, expected cipher text of length "
+ << aesCBCCipherOutputLength(inLen)
+ << " but found "
+ << len};
+ }
+
+ return Status::OK();
+} catch (const AssertionException& ex) {
+ return ex.toStatus();
+}
+
+Status _aesDecrypt(const SymmetricKey& key,
+ ConstDataRange in,
+ std::uint8_t* out,
+ std::size_t outLen,
+ std::size_t* resultLen) try {
+ // Check the plaintext buffer can fit the product of decryption
+ auto[lowerBound, upperBound] = aesCBCExpectedPlaintextLen(in.length());
+ if (upperBound > outLen) {
+ return {ErrorCodes::BadValue,
+ str::stream() << "Cleartext buffer of size " << outLen
+ << " too small for output which can be as large as "
+ << upperBound
+ << "]"};
+ }
+
+ const uint8_t* dataPtr = reinterpret_cast<const std::uint8_t*>(in.data());
+
+ auto decryptor =
+ uassertStatusOK(SymmetricDecryptor::create(key, aesMode::cbc, dataPtr, aesCBCIVSize));
+
+ const size_t dataSize = in.length() - aesCBCIVSize;
+ const uint8_t* data = dataPtr + aesCBCIVSize;
+
+ const auto updateLen = uassertStatusOK(decryptor->update(data, dataSize, out, outLen));
+
+ const auto finalLen = uassertStatusOK(decryptor->finalize(out + updateLen, outLen - updateLen));
+
+ *resultLen = updateLen + finalLen;
+ invariant(*resultLen <= outLen);
+
+ // Check the returned length, excluding headers block padding
+ if (*resultLen < lowerBound || *resultLen > upperBound) {
+ return {ErrorCodes::BadValue,
+ str::stream() << "Decrypt error, expected clear text length in interval"
+ << "["
+ << lowerBound
+ << ","
+ << upperBound
+ << "]"
+ << "but found "
+ << *resultLen};
+ }
+
+ return Status::OK();
+} catch (const AssertionException& ex) {
+ return ex.toStatus();
+}
+
+} // namespace
+
+size_t aeadCipherOutputLength(size_t plainTextLen) {
+ // To calculate the size of the byte, we divide by the byte size and add 2 for padding
+ // (1 for the attached IV, and 1 for the extra padding). The algorithm will add padding even
+ // if the len is a multiple of the byte size, so if the len divides cleanly it will be
+ // 32 bytes longer than the original, which is 16 bytes as padding and 16 bytes for the
+ // IV. For things that don't divide cleanly, the cast takes care of floor dividing so it will
+ // be 0 < x < 16 bytes added for padding and 16 bytes added for the IV.
+ size_t aesOutLen = aesBlockSize * (plainTextLen / aesBlockSize + 2);
+ return aesOutLen + kHmacOutSize;
+}
+
+Status aeadEncrypt(const SymmetricKey& key,
+ const uint8_t* in,
+ const size_t inLen,
+ const uint8_t* associatedData,
+ const uint64_t associatedDataLen,
+ uint8_t* out,
+ size_t outLen) {
+
+ if (associatedDataLen >= kMaxAssociatedDataLength) {
+ return Status(ErrorCodes::BadValue,
+ str::stream()
+ << "AssociatedData for encryption is too large. Cannot be larger than "
+ << kMaxAssociatedDataLength
+ << " bytes.");
+ }
+
+ // According to the rfc on AES encryption, the associatedDataLength is defined as the
+ // number of bits in associatedData in BigEndian format. This is what the code segment
+ // below describes.
+ // RFC: (https://tools.ietf.org/html/draft-mcgrew-aead-aes-cbc-hmac-sha2-01#section-2.1)
+ std::array<uint8_t, sizeof(uint64_t)> dataLenBitsEncodedStorage;
+ DataRange dataLenBitsEncoded(dataLenBitsEncodedStorage);
+ dataLenBitsEncoded.write<BigEndian<uint64_t>>(associatedDataLen * 8);
+
+ auto keySize = key.getKeySize();
+ if (keySize < kAeadAesHmacKeySize) {
+ return Status(ErrorCodes::BadValue,
+ "AEAD encryption key too short. "
+ "Must be either 64 or 96 bytes.");
+ }
+
+ ConstDataRange aeadKey(key.getKey(), kAeadAesHmacKeySize);
+
+ if (key.getKeySize() == kAeadAesHmacKeySize) {
+ // local key store key encryption
+ return aeadEncryptWithIV(aeadKey,
+ in,
+ inLen,
+ nullptr,
+ 0,
+ associatedData,
+ associatedDataLen,
+ dataLenBitsEncoded,
+ out,
+ outLen);
+ }
+
+ if (key.getKeySize() != kFieldLevelEncryptionKeySize) {
+ return Status(ErrorCodes::BadValue, "Invalid key size.");
+ }
+
+ if (in == nullptr || !in) {
+ return Status(ErrorCodes::BadValue, "Invalid AEAD plaintext input.");
+ }
+
+ if (key.getAlgorithm() != aesAlgorithm) {
+ return Status(ErrorCodes::BadValue, "Invalid algorithm for key.");
+ }
+
+ ConstDataRange hmacCDR(nullptr, 0);
+ SHA512Block hmacOutput;
+ if (static_cast<int>(associatedData[0]) ==
+ FleAlgorithmInt_serializer(FleAlgorithmInt::kDeterministic)) {
+ const uint8_t* ivKey = key.getKey() + kAeadAesHmacKeySize;
+ hmacOutput = SHA512Block::computeHmac(ivKey,
+ sym256KeySize,
+ {ConstDataRange(associatedData, associatedDataLen),
+ dataLenBitsEncoded,
+ ConstDataRange(in, inLen)});
+
+ static_assert(SHA512Block::kHashLength >= kIVSize,
+ "Invalid AEAD parameters. Generated IV too short.");
+
+ hmacCDR = ConstDataRange(hmacOutput.data(), kIVSize);
+ }
+ return aeadEncryptWithIV(aeadKey,
+ in,
+ inLen,
+ reinterpret_cast<const uint8_t*>(hmacCDR.data()),
+ hmacCDR.length(),
+ associatedData,
+ associatedDataLen,
+ dataLenBitsEncoded,
+ out,
+ outLen);
+}
+
+Status aeadEncryptWithIV(ConstDataRange key,
+ const uint8_t* in,
+ const size_t inLen,
+ const uint8_t* iv,
+ const size_t ivLen,
+ const uint8_t* associatedData,
+ const uint64_t associatedDataLen,
+ ConstDataRange dataLenBitsEncoded,
+ uint8_t* out,
+ size_t outLen) {
+ if (key.length() != kAeadAesHmacKeySize) {
+ return Status(ErrorCodes::BadValue, "Invalid key size.");
+ }
+
+ if (!(in && out)) {
+ return Status(ErrorCodes::BadValue, "Invalid AEAD parameters.");
+ }
+
+ if (outLen != aeadCipherOutputLength(inLen)) {
+ return Status(ErrorCodes::BadValue, "Invalid output buffer size.");
+ }
+
+ if (associatedDataLen >= kMaxAssociatedDataLength) {
+ return Status(ErrorCodes::BadValue,
+ str::stream()
+ << "AssociatedData for encryption is too large. Cannot be larger than "
+ << kMaxAssociatedDataLength
+ << " bytes.");
+ }
+
+ const uint8_t* macKey = reinterpret_cast<const uint8_t*>(key.data());
+ const uint8_t* encKey = reinterpret_cast<const uint8_t*>(key.data() + sym256KeySize);
+
+ size_t aesOutLen = outLen - kHmacOutSize;
+
+ size_t cipherTextLen = 0;
+
+ SymmetricKey symEncKey(encKey, sym256KeySize, aesAlgorithm, "aesKey", 1);
+
+ bool ivProvided = false;
+ if (ivLen != 0) {
+ invariant(ivLen == 16);
+ std::copy(iv, iv + ivLen, out);
+ ivProvided = true;
+ }
+
+ auto sEncrypt = _aesEncrypt(symEncKey, in, inLen, out, aesOutLen, &cipherTextLen, ivProvided);
+
+ if (!sEncrypt.isOK()) {
+ return sEncrypt;
+ }
+
+ SHA512Block hmacOutput =
+ SHA512Block::computeHmac(macKey,
+ sym256KeySize,
+ {ConstDataRange(associatedData, associatedDataLen),
+ ConstDataRange(out, cipherTextLen),
+ dataLenBitsEncoded});
+
+ std::copy(hmacOutput.data(), hmacOutput.data() + kHmacOutSize, out + cipherTextLen);
+ return Status::OK();
+}
+
+Status aeadDecrypt(const SymmetricKey& key,
+ const uint8_t* cipherText,
+ const size_t cipherLen,
+ const uint8_t* associatedData,
+ const uint64_t associatedDataLen,
+ uint8_t* out,
+ size_t* outLen) {
+ if (key.getKeySize() < kAeadAesHmacKeySize) {
+ return Status(ErrorCodes::BadValue, "Invalid key size.");
+ }
+
+ if (!(cipherText && out)) {
+ return Status(ErrorCodes::BadValue, "Invalid AEAD parameters.");
+ }
+
+ if ((*outLen) != cipherLen) {
+ return Status(ErrorCodes::BadValue, "Output buffer must be as long as the cipherText.");
+ }
+
+ if (associatedDataLen >= kMaxAssociatedDataLength) {
+ return Status(ErrorCodes::BadValue,
+ str::stream()
+ << "AssociatedData for encryption is too large. Cannot be larger than "
+ << kMaxAssociatedDataLength
+ << " bytes.");
+ }
+
+ const uint8_t* macKey = key.getKey();
+ const uint8_t* encKey = key.getKey() + sym256KeySize;
+
+ if (cipherLen < kHmacOutSize) {
+ return Status(ErrorCodes::BadValue, "Ciphertext is not long enough.");
+ }
+ size_t aesLen = cipherLen - kHmacOutSize;
+
+ // According to the rfc on AES encryption, the associatedDataLength is defined as the
+ // number of bits in associatedData in BigEndian format. This is what the code segment
+ // below describes.
+ std::array<uint8_t, sizeof(uint64_t)> dataLenBitsEncodedStorage;
+ DataRange dataLenBitsEncoded(dataLenBitsEncodedStorage);
+ dataLenBitsEncoded.write<BigEndian<uint64_t>>(associatedDataLen * 8);
+
+ SHA512Block hmacOutput =
+ SHA512Block::computeHmac(macKey,
+ sym256KeySize,
+ {ConstDataRange(associatedData, associatedDataLen),
+ ConstDataRange(cipherText, aesLen),
+ dataLenBitsEncoded});
+
+ if (consttimeMemEqual(reinterpret_cast<const unsigned char*>(hmacOutput.data()),
+ reinterpret_cast<const unsigned char*>(cipherText + aesLen),
+ kHmacOutSize) == false) {
+ return Status(ErrorCodes::BadValue, "HMAC data authentication failed.");
+ }
+
+ SymmetricKey symEncKey(encKey, sym256KeySize, aesAlgorithm, key.getKeyId(), 1);
+
+ auto sDecrypt = _aesDecrypt(symEncKey, ConstDataRange(cipherText, aesLen), out, aesLen, outLen);
+ if (!sDecrypt.isOK()) {
+ return sDecrypt;
+ }
+
+ return Status::OK();
+}
+
+} // namespace crypto
+} // namespace mongo
diff --git a/src/mongo/crypto/aead_encryption.h b/src/mongo/crypto/aead_encryption.h
new file mode 100644
index 00000000000..c5fb79479e6
--- /dev/null
+++ b/src/mongo/crypto/aead_encryption.h
@@ -0,0 +1,94 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+
+#include "mongo/base/data_view.h"
+#include "mongo/base/status.h"
+#include "mongo/crypto/symmetric_key.h"
+
+namespace mongo {
+namespace crypto {
+
+/**
+ * Constants used in the AEAD function
+ */
+
+constexpr size_t kFieldLevelEncryptionKeySize = 96;
+constexpr size_t kAeadAesHmacKeySize = 64;
+
+/**
+ * Returns the length of the ciphertext output given the plaintext length. Only for AEAD.
+ */
+size_t aeadCipherOutputLength(size_t plainTextLen);
+
+
+/**
+ * Encrypts the plaintext using following the AEAD_AES_256_CBC_HMAC_SHA_512 encryption
+ * algorithm. Writes output to out.
+ */
+Status aeadEncrypt(const SymmetricKey& key,
+ const uint8_t* in,
+ const size_t inLen,
+ const uint8_t* associatedData,
+ const uint64_t associatedDataLen,
+ uint8_t* out,
+ size_t outLen);
+
+/**
+ * Internal calls for the aeadEncryption algorithm. Only used for testing.
+ */
+Status aeadEncryptWithIV(ConstDataRange key,
+ const uint8_t* in,
+ const size_t inLen,
+ const uint8_t* iv,
+ const size_t ivLen,
+ const uint8_t* associatedData,
+ const uint64_t associatedDataLen,
+ ConstDataRange dataLenBitsEncodedStorage,
+ uint8_t* out,
+ size_t outLen);
+
+/**
+ * Decrypts the cipherText using AEAD_AES_256_CBC_HMAC_SHA_512 decryption. Writes output
+ * to out.
+ */
+Status aeadDecrypt(const SymmetricKey& key,
+ const uint8_t* cipherText,
+ const size_t cipherLen,
+ const uint8_t* associatedData,
+ const uint64_t associatedDataLen,
+ uint8_t* out,
+ size_t* outLen);
+
+} // namespace crypto
+} // namespace mongo
diff --git a/src/mongo/crypto/aead_encryption_test.cpp b/src/mongo/crypto/aead_encryption_test.cpp
new file mode 100644
index 00000000000..28177f05b82
--- /dev/null
+++ b/src/mongo/crypto/aead_encryption_test.cpp
@@ -0,0 +1,152 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <algorithm>
+
+#include "mongo/unittest/death_test.h"
+#include "mongo/unittest/unittest.h"
+
+#include "aead_encryption.h"
+
+namespace mongo {
+namespace {
+
+// The first test is to ensure that the length of the cipher is correct when
+// calling AEAD encrypt.
+TEST(AEAD, aeadCipherOutputLength) {
+ size_t plainTextLen = 16;
+ auto cipherLen = crypto::aeadCipherOutputLength(plainTextLen);
+ ASSERT_EQ(cipherLen, size_t(80));
+
+ plainTextLen = 10;
+ cipherLen = crypto::aeadCipherOutputLength(plainTextLen);
+ ASSERT_EQ(cipherLen, size_t(64));
+}
+
+TEST(AEAD, EncryptAndDecrypt) {
+ // Test case from RFC:
+ // https://tools.ietf.org/html/draft-mcgrew-aead-aes-cbc-hmac-sha2-05#section-5.4
+
+ const uint8_t aesAlgorithm = 0x1;
+
+ std::array<uint8_t, 64> symKey = {
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c,
+ 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19,
+ 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26,
+ 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33,
+ 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f};
+
+ SecureVector<uint8_t> aesVector = SecureVector<uint8_t>(symKey.begin(), symKey.end());
+ SymmetricKey key = SymmetricKey(aesVector, aesAlgorithm, "aeadEncryptDecryptTest");
+
+ const std::array<uint8_t, 128> plainTextTest = {
+ 0x41, 0x20, 0x63, 0x69, 0x70, 0x68, 0x65, 0x72, 0x20, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d,
+ 0x20, 0x6d, 0x75, 0x73, 0x74, 0x20, 0x6e, 0x6f, 0x74, 0x20, 0x62, 0x65, 0x20, 0x72, 0x65,
+ 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x20, 0x74, 0x6f, 0x20, 0x62, 0x65, 0x20, 0x73, 0x65,
+ 0x63, 0x72, 0x65, 0x74, 0x2c, 0x20, 0x61, 0x6e, 0x64, 0x20, 0x69, 0x74, 0x20, 0x6d, 0x75,
+ 0x73, 0x74, 0x20, 0x62, 0x65, 0x20, 0x61, 0x62, 0x6c, 0x65, 0x20, 0x74, 0x6f, 0x20, 0x66,
+ 0x61, 0x6c, 0x6c, 0x20, 0x69, 0x6e, 0x74, 0x6f, 0x20, 0x74, 0x68, 0x65, 0x20, 0x68, 0x61,
+ 0x6e, 0x64, 0x73, 0x20, 0x6f, 0x66, 0x20, 0x74, 0x68, 0x65, 0x20, 0x65, 0x6e, 0x65, 0x6d,
+ 0x79, 0x20, 0x77, 0x69, 0x74, 0x68, 0x6f, 0x75, 0x74, 0x20, 0x69, 0x6e, 0x63, 0x6f, 0x6e,
+ 0x76, 0x65, 0x6e, 0x69, 0x65, 0x6e, 0x63, 0x65};
+
+ std::array<uint8_t, 192> cryptoBuffer = {};
+
+ std::array<uint8_t, 16> iv = {0x1a,
+ 0xf3,
+ 0x8c,
+ 0x2d,
+ 0xc2,
+ 0xb9,
+ 0x6f,
+ 0xfd,
+ 0xd8,
+ 0x66,
+ 0x94,
+ 0x09,
+ 0x23,
+ 0x41,
+ 0xbc,
+ 0x04};
+
+ std::array<uint8_t, 42> associatedData = {
+ 0x54, 0x68, 0x65, 0x20, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x20, 0x70, 0x72, 0x69,
+ 0x6e, 0x63, 0x69, 0x70, 0x6c, 0x65, 0x20, 0x6f, 0x66, 0x20, 0x41, 0x75, 0x67, 0x75,
+ 0x73, 0x74, 0x65, 0x20, 0x4b, 0x65, 0x72, 0x63, 0x6b, 0x68, 0x6f, 0x66, 0x66, 0x73};
+
+ const size_t dataLen = 42;
+
+ std::array<uint8_t, sizeof(uint64_t)> dataLenBitsEncodedStorage;
+ DataRange dataLenBitsEncoded(dataLenBitsEncodedStorage);
+ dataLenBitsEncoded.write<BigEndian<uint64_t>>(dataLen * 8);
+
+ const size_t outLen = crypto::aeadCipherOutputLength(128);
+
+ ASSERT_OK(crypto::aeadEncryptWithIV(symKey,
+ plainTextTest.data(),
+ plainTextTest.size(),
+ iv.data(),
+ iv.size(),
+ associatedData.data(),
+ dataLen,
+ dataLenBitsEncoded,
+ cryptoBuffer.data(),
+ outLen));
+
+ std::array<uint8_t, 192> cryptoBufferTest = {
+ 0x1a, 0xf3, 0x8c, 0x2d, 0xc2, 0xb9, 0x6f, 0xfd, 0xd8, 0x66, 0x94, 0x09, 0x23, 0x41, 0xbc,
+ 0x04, 0x4a, 0xff, 0xaa, 0xad, 0xb7, 0x8c, 0x31, 0xc5, 0xda, 0x4b, 0x1b, 0x59, 0x0d, 0x10,
+ 0xff, 0xbd, 0x3d, 0xd8, 0xd5, 0xd3, 0x02, 0x42, 0x35, 0x26, 0x91, 0x2d, 0xa0, 0x37, 0xec,
+ 0xbc, 0xc7, 0xbd, 0x82, 0x2c, 0x30, 0x1d, 0xd6, 0x7c, 0x37, 0x3b, 0xcc, 0xb5, 0x84, 0xad,
+ 0x3e, 0x92, 0x79, 0xc2, 0xe6, 0xd1, 0x2a, 0x13, 0x74, 0xb7, 0x7f, 0x07, 0x75, 0x53, 0xdf,
+ 0x82, 0x94, 0x10, 0x44, 0x6b, 0x36, 0xeb, 0xd9, 0x70, 0x66, 0x29, 0x6a, 0xe6, 0x42, 0x7e,
+ 0xa7, 0x5c, 0x2e, 0x08, 0x46, 0xa1, 0x1a, 0x09, 0xcc, 0xf5, 0x37, 0x0d, 0xc8, 0x0b, 0xfe,
+ 0xcb, 0xad, 0x28, 0xc7, 0x3f, 0x09, 0xb3, 0xa3, 0xb7, 0x5e, 0x66, 0x2a, 0x25, 0x94, 0x41,
+ 0x0a, 0xe4, 0x96, 0xb2, 0xe2, 0xe6, 0x60, 0x9e, 0x31, 0xe6, 0xe0, 0x2c, 0xc8, 0x37, 0xf0,
+ 0x53, 0xd2, 0x1f, 0x37, 0xff, 0x4f, 0x51, 0x95, 0x0b, 0xbe, 0x26, 0x38, 0xd0, 0x9d, 0xd7,
+ 0xa4, 0x93, 0x09, 0x30, 0x80, 0x6d, 0x07, 0x03, 0xb1, 0xf6, 0x4d, 0xd3, 0xb4, 0xc0, 0x88,
+ 0xa7, 0xf4, 0x5c, 0x21, 0x68, 0x39, 0x64, 0x5b, 0x20, 0x12, 0xbf, 0x2e, 0x62, 0x69, 0xa8,
+ 0xc5, 0x6a, 0x81, 0x6d, 0xbc, 0x1b, 0x26, 0x77, 0x61, 0x95, 0x5b, 0xc5};
+
+ ASSERT_EQ(0, std::memcmp(cryptoBuffer.data(), cryptoBufferTest.data(), 192));
+
+ std::array<uint8_t, 192> plainText = {};
+ size_t plainTextDecryptLen = 192;
+ ASSERT_OK(crypto::aeadDecrypt(key,
+ cryptoBuffer.data(),
+ cryptoBuffer.size(),
+ associatedData.data(),
+ dataLen,
+ plainText.data(),
+ &plainTextDecryptLen));
+
+ ASSERT_EQ(0, std::memcmp(plainText.data(), plainTextTest.data(), 128));
+}
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/crypto/symmetric_crypto.cpp b/src/mongo/crypto/symmetric_crypto.cpp
new file mode 100644
index 00000000000..32d888cfbbb
--- /dev/null
+++ b/src/mongo/crypto/symmetric_crypto.cpp
@@ -0,0 +1,107 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kDefault
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/crypto/symmetric_crypto.h"
+
+#include <memory>
+
+#include "mongo/base/data_cursor.h"
+#include "mongo/base/init.h"
+#include "mongo/base/status.h"
+#include "mongo/crypto/symmetric_key.h"
+#include "mongo/platform/random.h"
+#include "mongo/util/assert_util.h"
+#include "mongo/util/log.h"
+#include "mongo/util/net/ssl_manager.h"
+#include "mongo/util/str.h"
+
+namespace mongo {
+namespace crypto {
+
+namespace {
+std::unique_ptr<SecureRandom> random;
+} // namespace
+
+MONGO_INITIALIZER(CreateKeyEntropySource)(InitializerContext* context) {
+ random = std::unique_ptr<SecureRandom>(SecureRandom::create());
+ return Status::OK();
+}
+
+size_t aesGetIVSize(crypto::aesMode mode) {
+ switch (mode) {
+ case crypto::aesMode::cbc:
+ return crypto::aesCBCIVSize;
+ case crypto::aesMode::gcm:
+ return crypto::aesGCMIVSize;
+ default:
+ fassertFailed(4053);
+ }
+}
+
+aesMode getCipherModeFromString(const std::string& mode) {
+ if (mode == aes256CBCName) {
+ return aesMode::cbc;
+ } else if (mode == aes256GCMName) {
+ return aesMode::gcm;
+ } else {
+ MONGO_UNREACHABLE;
+ }
+}
+
+std::string getStringFromCipherMode(aesMode mode) {
+ if (mode == aesMode::cbc) {
+ return aes256CBCName;
+ } else if (mode == aesMode::gcm) {
+ return aes256GCMName;
+ } else {
+ MONGO_UNREACHABLE;
+ }
+}
+
+SymmetricKey aesGenerate(size_t keySize, SymmetricKeyId keyId) {
+ invariant(keySize == sym256KeySize);
+
+ SecureVector<uint8_t> key(keySize);
+
+ size_t offset = 0;
+ while (offset < keySize) {
+ std::uint64_t randomValue = random->nextInt64();
+ memcpy(key->data() + offset, &randomValue, sizeof(randomValue));
+ offset += sizeof(randomValue);
+ }
+
+ return SymmetricKey(std::move(key), aesAlgorithm, std::move(keyId));
+}
+
+} // namespace crypto
+} // namespace mongo
diff --git a/src/mongo/crypto/symmetric_crypto.h b/src/mongo/crypto/symmetric_crypto.h
new file mode 100644
index 00000000000..350675a1763
--- /dev/null
+++ b/src/mongo/crypto/symmetric_crypto.h
@@ -0,0 +1,196 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+#include <set>
+#include <string>
+
+#include "mongo/base/status.h"
+#include "mongo/base/status_with.h"
+#include "mongo/crypto/symmetric_key.h"
+
+namespace mongo {
+namespace crypto {
+
+/**
+ * Encryption algorithm identifiers and block sizes
+ */
+constexpr uint8_t aesAlgorithm = 0x1;
+
+/**
+ * Block and key sizes
+ */
+constexpr size_t aesBlockSize = 16;
+constexpr size_t sym256KeySize = 32;
+
+/**
+ * Min and max symmetric key lengths
+ */
+constexpr size_t minKeySize = 16;
+constexpr size_t maxKeySize = 32;
+
+/**
+ * CBC fixed constants
+ */
+constexpr size_t aesCBCIVSize = aesBlockSize;
+
+/**
+ * GCM tunable parameters
+ */
+constexpr size_t aesGCMTagSize = 12;
+constexpr size_t aesGCMIVSize = 12;
+
+/**
+ * Encryption mode identifiers
+ */
+enum class aesMode : uint8_t { cbc, gcm };
+
+/**
+ * Algorithm names which this module recognizes
+ */
+const std::string aes256CBCName = "AES256-CBC";
+const std::string aes256GCMName = "AES256-GCM";
+
+aesMode getCipherModeFromString(const std::string& mode);
+std::string getStringFromCipherMode(aesMode);
+
+/**
+ * Generates a new, random, symmetric key for use with AES.
+ */
+SymmetricKey aesGenerate(size_t keySize, SymmetricKeyId keyId);
+
+/* Platform specific engines should implement these. */
+
+/**
+ * Interface to a symmetric cryptography engine.
+ * For use with encrypting payloads.
+ */
+class SymmetricEncryptor {
+public:
+ virtual ~SymmetricEncryptor() = default;
+
+ /**
+ * Process a chunk of data from <in> and store the ciphertext in <out>.
+ * Returns the number of bytes written to <out> which will not exceed <outLen>.
+ * Because <inLen> for this and/or previous calls may not lie on a block boundary,
+ * the number of bytes written to <out> may be more or less than <inLen>.
+ */
+ virtual StatusWith<size_t> update(const uint8_t* in,
+ size_t inLen,
+ uint8_t* out,
+ size_t outLen) = 0;
+
+ /**
+ * Append Additional AuthenticatedData (AAD) to a GCM encryption stream.
+ */
+ virtual Status addAuthenticatedData(const uint8_t* in, size_t inLen) = 0;
+
+ /**
+ * Finish an encryption by flushing any buffered bytes for a partial cipherblock to <out>.
+ * Returns the number of bytes written, not to exceed <outLen>.
+ */
+ virtual StatusWith<size_t> finalize(uint8_t* out, size_t outLen) = 0;
+
+ /**
+ * For aesMode::gcm, writes the GCM tag to <out>.
+ * Returns the number of bytes used, not to exceed <outLen>.
+ */
+ virtual StatusWith<size_t> finalizeTag(uint8_t* out, size_t outLen) = 0;
+
+ /**
+ * Create an instance of a SymmetricEncryptor object from the currently available
+ * cipher engine (e.g. OpenSSL).
+ */
+ static StatusWith<std::unique_ptr<SymmetricEncryptor>> create(const SymmetricKey& key,
+ aesMode mode,
+ const uint8_t* iv,
+ size_t inLen);
+};
+
+/**
+ * Interface to a symmetric cryptography engine.
+ * For use with encrypting payloads.
+ */
+class SymmetricDecryptor {
+public:
+ virtual ~SymmetricDecryptor() = default;
+
+ /**
+ * Process a chunk of data from <in> and store the decrypted text in <out>.
+ * Returns the number of bytes written to <out> which will not exceed <outLen>.
+ * Because <inLen> for this and/or previous calls may not lie on a block boundary,
+ * the number of bytes written to <out> may be more or less than <inLen>.
+ */
+ virtual StatusWith<size_t> update(const uint8_t* in,
+ size_t inLen,
+ uint8_t* out,
+ size_t outLen) = 0;
+
+ /**
+ * For aesMode::gcm, inform the cipher engine of additional authenticated data (AAD).
+ */
+ virtual Status addAuthenticatedData(const uint8_t* in, size_t inLen) = 0;
+
+ /**
+ * For aesMode::gcm, informs the cipher engine of the GCM tag associated with this data stream.
+ */
+ virtual Status updateTag(const uint8_t* tag, size_t tagLen) = 0;
+
+ /**
+ * Finish an decryption by flushing any buffered bytes for a partial cipherblock to <out>.
+ * Returns the number of bytes written, not to exceed <outLen>.
+ */
+ virtual StatusWith<size_t> finalize(uint8_t* out, size_t outLen) = 0;
+
+ /**
+ * Create an instance of a SymmetricDecryptor object from the currently available
+ * cipher engine (e.g. OpenSSL).
+ */
+ static StatusWith<std::unique_ptr<SymmetricDecryptor>> create(const SymmetricKey& key,
+ aesMode mode,
+ const uint8_t* iv,
+ size_t ivLen);
+};
+
+/**
+ * Returns a list of cipher modes supported by the cipher engine.
+ * e.g. {"AES256-CBC", "AES256-GCM"}
+ */
+std::set<std::string> getSupportedSymmetricAlgorithms();
+
+/**
+ * Generate a quantity of random bytes from the cipher engine.
+ */
+Status engineRandBytes(uint8_t* buffer, size_t len);
+
+} // namespace crypto
+} // namespace mongo
diff --git a/src/mongo/crypto/symmetric_crypto_apple.cpp b/src/mongo/crypto/symmetric_crypto_apple.cpp
new file mode 100644
index 00000000000..9ca5c9c0b1e
--- /dev/null
+++ b/src/mongo/crypto/symmetric_crypto_apple.cpp
@@ -0,0 +1,183 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include <CommonCrypto/CommonCryptor.h>
+#include <Security/Security.h>
+#include <memory>
+#include <set>
+
+#include "mongo/base/init.h"
+#include "mongo/base/status.h"
+#include "mongo/crypto/symmetric_crypto.h"
+#include "mongo/crypto/symmetric_key.h"
+#include "mongo/platform/random.h"
+#include "mongo/util/assert_util.h"
+#include "mongo/util/str.h"
+
+namespace mongo {
+namespace crypto {
+
+namespace {
+
+template <typename Parent>
+class SymmetricImplApple : public Parent {
+public:
+ SymmetricImplApple(const SymmetricKey& key, aesMode mode, const uint8_t* iv, size_t ivLen)
+ : _ctx(nullptr, CCCryptorRelease) {
+ static_assert(
+ std::is_same<Parent, SymmetricEncryptor>::value ||
+ std::is_same<Parent, SymmetricDecryptor>::value,
+ "SymmetricImplApple must inherit from SymmetricEncryptor or SymmetricDecryptor");
+
+ uassert(ErrorCodes::UnsupportedFormat,
+ "Native crypto on this platform only supports AES256-CBC",
+ mode == aesMode::cbc);
+
+ // Note: AES256 uses a 256byte keysize,
+ // but is still functionally a 128bit block algorithm.
+ // Therefore we expect a 128 bit block length.
+ uassert(ErrorCodes::BadValue,
+ str::stream() << "Invalid ivlen for selected algorithm, expected "
+ << kCCBlockSizeAES128
+ << ", got "
+ << ivLen,
+ ivLen == kCCBlockSizeAES128);
+
+ CCCryptorRef context = nullptr;
+ constexpr auto op =
+ std::is_same<Parent, SymmetricEncryptor>::value ? kCCEncrypt : kCCDecrypt;
+ const auto status = CCCryptorCreate(op,
+ kCCAlgorithmAES,
+ kCCOptionPKCS7Padding,
+ key.getKey(),
+ key.getKeySize(),
+ iv,
+ &context);
+ uassert(ErrorCodes::UnknownError,
+ str::stream() << "CCCryptorCreate failure: " << status,
+ status == kCCSuccess);
+
+ _ctx.reset(context);
+ }
+
+ StatusWith<size_t> update(const uint8_t* in, size_t inLen, uint8_t* out, size_t outLen) final {
+ size_t outUsed = 0;
+ const auto status = CCCryptorUpdate(_ctx.get(), in, inLen, out, outLen, &outUsed);
+ if (status != kCCSuccess) {
+ return Status(ErrorCodes::UnknownError,
+ str::stream() << "Unable to perform CCCryptorUpdate: " << status);
+ }
+ return outUsed;
+ }
+
+ Status addAuthenticatedData(const uint8_t* in, size_t inLen) final {
+ fassert(51128, inLen == 0);
+ return Status::OK();
+ }
+
+ StatusWith<size_t> finalize(uint8_t* out, size_t outLen) final {
+ size_t outUsed = 0;
+ const auto status = CCCryptorFinal(_ctx.get(), out, outLen, &outUsed);
+ if (status != kCCSuccess) {
+ return Status(ErrorCodes::UnknownError,
+ str::stream() << "Unable to perform CCCryptorFinal: " << status);
+ }
+ return outUsed;
+ }
+
+private:
+ std::unique_ptr<_CCCryptor, decltype(&CCCryptorRelease)> _ctx;
+};
+
+class SymmetricEncryptorApple : public SymmetricImplApple<SymmetricEncryptor> {
+public:
+ using SymmetricImplApple::SymmetricImplApple;
+
+ StatusWith<size_t> finalizeTag(uint8_t* out, size_t outLen) final {
+ // CBC only, no tag to create.
+ return 0;
+ }
+};
+
+
+class SymmetricDecryptorApple : public SymmetricImplApple<SymmetricDecryptor> {
+public:
+ using SymmetricImplApple::SymmetricImplApple;
+
+ Status updateTag(const uint8_t* tag, size_t tagLen) final {
+ // CBC only, no tag to verify.
+ if (tagLen > 0) {
+ return {ErrorCodes::BadValue, "Unexpected tag for non-gcm cipher"};
+ }
+ return Status::OK();
+ }
+};
+
+} // namespace
+
+std::set<std::string> getSupportedSymmetricAlgorithms() {
+ return {aes256CBCName};
+}
+
+Status engineRandBytes(uint8_t* buffer, size_t len) {
+ auto result = SecRandomCopyBytes(kSecRandomDefault, len, buffer);
+ if (result != errSecSuccess) {
+ return {ErrorCodes::UnknownError,
+ str::stream() << "Failed generating random bytes: " << result};
+ } else {
+ return Status::OK();
+ }
+}
+
+StatusWith<std::unique_ptr<SymmetricEncryptor>> SymmetricEncryptor::create(const SymmetricKey& key,
+ aesMode mode,
+ const uint8_t* iv,
+ size_t ivLen) try {
+ std::unique_ptr<SymmetricEncryptor> encryptor =
+ std::make_unique<SymmetricEncryptorApple>(key, mode, iv, ivLen);
+ return std::move(encryptor);
+} catch (const DBException& e) {
+ return e.toStatus();
+}
+
+StatusWith<std::unique_ptr<SymmetricDecryptor>> SymmetricDecryptor::create(const SymmetricKey& key,
+ aesMode mode,
+ const uint8_t* iv,
+ size_t ivLen) try {
+ std::unique_ptr<SymmetricDecryptor> decryptor =
+ std::make_unique<SymmetricDecryptorApple>(key, mode, iv, ivLen);
+ return std::move(decryptor);
+} catch (const DBException& e) {
+ return e.toStatus();
+}
+
+} // namespace crypto
+} // namespace mongo
diff --git a/src/mongo/crypto/symmetric_crypto_openssl.cpp b/src/mongo/crypto/symmetric_crypto_openssl.cpp
new file mode 100644
index 00000000000..6329331a511
--- /dev/null
+++ b/src/mongo/crypto/symmetric_crypto_openssl.cpp
@@ -0,0 +1,255 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kStorage
+
+#include "mongo/platform/basic.h"
+
+#include <memory>
+#include <openssl/rand.h>
+#include <set>
+
+#include "mongo/base/data_cursor.h"
+#include "mongo/base/init.h"
+#include "mongo/base/status.h"
+#include "mongo/crypto/symmetric_crypto.h"
+#include "mongo/crypto/symmetric_key.h"
+#include "mongo/platform/random.h"
+#include "mongo/util/assert_util.h"
+#include "mongo/util/log.h"
+#include "mongo/util/net/ssl_manager.h"
+#include "mongo/util/str.h"
+
+namespace mongo {
+namespace crypto {
+
+namespace {
+template <typename Init>
+void initCipherContext(
+ EVP_CIPHER_CTX* ctx, const SymmetricKey& key, aesMode mode, const uint8_t* iv, Init init) {
+ const auto keySize = key.getKeySize();
+ const EVP_CIPHER* cipher = nullptr;
+ if (keySize == sym256KeySize) {
+ if (mode == crypto::aesMode::cbc) {
+ cipher = EVP_get_cipherbyname("aes-256-cbc");
+ } else if (mode == crypto::aesMode::gcm) {
+ cipher = EVP_get_cipherbyname("aes-256-gcm");
+ }
+ }
+ uassert(ErrorCodes::BadValue,
+ str::stream() << "Unrecognized AES key size/cipher mode. Size: " << keySize << " Mode: "
+ << getStringFromCipherMode(mode),
+ cipher);
+
+ const bool initOk = (1 == init(ctx, cipher, nullptr, key.getKey(), iv));
+ uassert(ErrorCodes::UnknownError,
+ str::stream() << SSLManagerInterface::getSSLErrorMessage(ERR_get_error()),
+ initOk);
+}
+
+class SymmetricEncryptorOpenSSL : public SymmetricEncryptor {
+public:
+ SymmetricEncryptorOpenSSL(const SymmetricKey& key, aesMode mode, const uint8_t* iv)
+ : _ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free), _mode(mode) {
+ initCipherContext(_ctx.get(), key, mode, iv, EVP_EncryptInit_ex);
+ }
+
+ StatusWith<size_t> update(const uint8_t* in, size_t inLen, uint8_t* out, size_t outLen) final {
+ int len = 0;
+ if (1 != EVP_EncryptUpdate(_ctx.get(), out, &len, in, inLen)) {
+ return Status(ErrorCodes::UnknownError,
+ str::stream()
+ << SSLManagerInterface::getSSLErrorMessage(ERR_get_error()));
+ }
+ return static_cast<size_t>(len);
+ }
+
+ Status addAuthenticatedData(const uint8_t* in, size_t inLen) final {
+ fassert(51126, _mode == crypto::aesMode::gcm);
+
+ auto swUpdate = update(in, inLen, nullptr, 0);
+ if (!swUpdate.isOK()) {
+ return swUpdate.getStatus();
+ }
+
+ const auto len = swUpdate.getValue();
+ if (len != inLen) {
+ return {ErrorCodes::InternalError,
+ str::stream() << "Unexpected write length while appending AAD: " << len};
+ }
+
+ return Status::OK();
+ }
+
+ StatusWith<size_t> finalize(uint8_t* out, size_t outLen) final {
+ int len = 0;
+ if (1 != EVP_EncryptFinal_ex(_ctx.get(), out, &len)) {
+ return Status(ErrorCodes::UnknownError,
+ str::stream()
+ << SSLManagerInterface::getSSLErrorMessage(ERR_get_error()));
+ }
+ return static_cast<size_t>(len);
+ }
+
+ StatusWith<size_t> finalizeTag(uint8_t* out, size_t outLen) final {
+ if (_mode == aesMode::gcm) {
+#ifdef EVP_CTRL_GCM_GET_TAG
+ if (1 != EVP_CIPHER_CTX_ctrl(_ctx.get(), EVP_CTRL_GCM_GET_TAG, outLen, out)) {
+ return Status(ErrorCodes::UnknownError,
+ str::stream()
+ << SSLManagerInterface::getSSLErrorMessage(ERR_get_error()));
+ }
+ return crypto::aesGCMTagSize;
+#else
+ return Status(ErrorCodes::UnsupportedFormat, "GCM support is not available");
+#endif
+ }
+
+ // Otherwise, not a tagged cipher mode, write nothing.
+ return 0;
+ }
+
+private:
+ std::unique_ptr<EVP_CIPHER_CTX, decltype(&EVP_CIPHER_CTX_free)> _ctx;
+ const aesMode _mode;
+};
+
+class SymmetricDecryptorOpenSSL : public SymmetricDecryptor {
+public:
+ SymmetricDecryptorOpenSSL(const SymmetricKey& key, aesMode mode, const uint8_t* iv)
+ : _ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free), _mode(mode) {
+ initCipherContext(_ctx.get(), key, mode, iv, EVP_DecryptInit_ex);
+ }
+
+ StatusWith<size_t> update(const uint8_t* in, size_t inLen, uint8_t* out, size_t outLen) final {
+ int len = 0;
+ if (1 != EVP_DecryptUpdate(_ctx.get(), out, &len, in, inLen)) {
+ return Status(ErrorCodes::UnknownError,
+ str::stream()
+ << SSLManagerInterface::getSSLErrorMessage(ERR_get_error()));
+ }
+ return static_cast<size_t>(len);
+ }
+
+ Status addAuthenticatedData(const uint8_t* in, size_t inLen) final {
+ fassert(51125, _mode == crypto::aesMode::gcm);
+
+ auto swUpdate = update(in, inLen, nullptr, 0);
+ if (!swUpdate.isOK()) {
+ return swUpdate.getStatus();
+ }
+
+ const auto len = swUpdate.getValue();
+ if (len != inLen) {
+ return {ErrorCodes::InternalError,
+ str::stream() << "Unexpected write length while appending AAD: " << len};
+ }
+
+ return Status::OK();
+ }
+
+ StatusWith<size_t> finalize(uint8_t* out, size_t outLen) final {
+ int len = 0;
+ if (1 != EVP_DecryptFinal_ex(_ctx.get(), out, &len)) {
+ return Status(ErrorCodes::UnknownError,
+ str::stream()
+ << SSLManagerInterface::getSSLErrorMessage(ERR_get_error()));
+ }
+ return static_cast<size_t>(len);
+ }
+
+ Status updateTag(const uint8_t* tag, size_t tagLen) final {
+ // validateEncryptionOption asserts that platforms without GCM will never start in GCM mode
+ if (_mode == aesMode::gcm) {
+#ifdef EVP_CTRL_GCM_GET_TAG
+ if (1 != EVP_CIPHER_CTX_ctrl(
+ _ctx.get(), EVP_CTRL_GCM_SET_TAG, tagLen, const_cast<uint8_t*>(tag))) {
+ return Status(ErrorCodes::UnknownError,
+ str::stream()
+ << "Unable to set GCM tag: "
+ << SSLManagerInterface::getSSLErrorMessage(ERR_get_error()));
+ }
+#else
+ return {ErrorCodes::UnsupportedFormat, "GCM support is not available"};
+#endif
+ } else if (tagLen != 0) {
+ return {ErrorCodes::BadValue, "Unexpected tag for non-gcm cipher"};
+ }
+
+ return Status::OK();
+ }
+
+private:
+ std::unique_ptr<EVP_CIPHER_CTX, decltype(&EVP_CIPHER_CTX_free)> _ctx;
+ const aesMode _mode;
+};
+
+} // namespace
+
+std::set<std::string> getSupportedSymmetricAlgorithms() {
+#if defined(EVP_CTRL_GCM_GET_TAG) && !defined(__APPLE__)
+ return {aes256CBCName, aes256GCMName};
+#else
+ return {aes256CBCName};
+#endif
+}
+
+Status engineRandBytes(uint8_t* buffer, size_t len) {
+ if (RAND_bytes(reinterpret_cast<unsigned char*>(buffer), len) == 1) {
+ return Status::OK();
+ }
+ return {ErrorCodes::UnknownError,
+ str::stream() << "Unable to acquire random bytes from OpenSSL: "
+ << SSLManagerInterface::getSSLErrorMessage(ERR_get_error())};
+}
+
+StatusWith<std::unique_ptr<SymmetricEncryptor>> SymmetricEncryptor::create(const SymmetricKey& key,
+ aesMode mode,
+ const uint8_t* iv,
+ size_t ivLen) try {
+ std::unique_ptr<SymmetricEncryptor> encryptor =
+ std::make_unique<SymmetricEncryptorOpenSSL>(key, mode, iv);
+ return std::move(encryptor);
+} catch (const DBException& e) {
+ return e.toStatus();
+}
+
+StatusWith<std::unique_ptr<SymmetricDecryptor>> SymmetricDecryptor::create(const SymmetricKey& key,
+ aesMode mode,
+ const uint8_t* iv,
+ size_t ivLen) try {
+ std::unique_ptr<SymmetricDecryptor> decryptor =
+ std::make_unique<SymmetricDecryptorOpenSSL>(key, mode, iv);
+ return std::move(decryptor);
+} catch (const DBException& e) {
+ return e.toStatus();
+}
+
+} // namespace crypto
+} // namespace mongo
diff --git a/src/mongo/crypto/symmetric_crypto_windows.cpp b/src/mongo/crypto/symmetric_crypto_windows.cpp
new file mode 100644
index 00000000000..25dd5f304b8
--- /dev/null
+++ b/src/mongo/crypto/symmetric_crypto_windows.cpp
@@ -0,0 +1,335 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kStorage
+
+#include "mongo/platform/basic.h"
+
+#include <memory>
+#include <vector>
+
+#include "mongo/base/secure_allocator.h"
+#include "mongo/base/status.h"
+#include "mongo/crypto/symmetric_crypto.h"
+#include "mongo/crypto/symmetric_key.h"
+#include "mongo/platform/shared_library.h"
+#include "mongo/util/assert_util.h"
+#include "mongo/util/log.h"
+#include "mongo/util/str.h"
+
+namespace mongo {
+namespace crypto {
+
+namespace {
+
+// RtlNtStatusToDosError function, only available via GetProcAddress
+using pRtlNtStatusToDosError = ULONG(WINAPI*)(NTSTATUS Status);
+
+std::string statusWithDescription(NTSTATUS status) {
+ auto swLib = SharedLibrary::create("ntdll.dll");
+ if (swLib.getStatus().isOK()) {
+
+ auto swFunc =
+ swLib.getValue()->getFunctionAs<pRtlNtStatusToDosError>("RtlNtStatusToDosError");
+ if (swFunc.isOK()) {
+
+ pRtlNtStatusToDosError RtlNtStatusToDosErrorFunc = swFunc.getValue();
+ ULONG errorCode = RtlNtStatusToDosErrorFunc(status);
+
+ if (errorCode != ERROR_MR_MID_NOT_FOUND) {
+ return errnoWithDescription(errorCode);
+ }
+ }
+ }
+
+ return str::stream() << "Failed to get error message for NTSTATUS: " << status;
+}
+
+struct AlgoInfo {
+ BCRYPT_ALG_HANDLE algo;
+ DWORD keyBlobSize;
+};
+
+/**
+ * Initialize crypto algorithms from default system CNG provider.
+ */
+class BCryptCryptoLoader {
+public:
+ BCryptCryptoLoader() {
+ loadAlgo(_algoAESCBC, BCRYPT_AES_ALGORITHM, BCRYPT_CHAIN_MODE_CBC);
+
+ auto status =
+ ::BCryptOpenAlgorithmProvider(&_random, BCRYPT_RNG_ALGORITHM, MS_PRIMITIVE_PROVIDER, 0);
+ invariant(status == STATUS_SUCCESS);
+ }
+
+ ~BCryptCryptoLoader() {
+ invariant(BCryptCloseAlgorithmProvider(_algoAESCBC.algo, 0) == STATUS_SUCCESS);
+ invariant(BCryptCloseAlgorithmProvider(_random, 0) == STATUS_SUCCESS);
+ }
+
+ AlgoInfo& getAlgo(aesMode mode) {
+ switch (mode) {
+ case aesMode::cbc:
+ return _algoAESCBC;
+ default:
+ MONGO_UNREACHABLE;
+ }
+ }
+
+ BCRYPT_ALG_HANDLE getRandom() {
+ return _random;
+ }
+
+private:
+ void loadAlgo(AlgoInfo& algo, const wchar_t* name, const wchar_t* chainingMode) {
+ NTSTATUS status = BCryptOpenAlgorithmProvider(&algo.algo, name, MS_PRIMITIVE_PROVIDER, 0);
+ invariant(status == STATUS_SUCCESS);
+
+ status = BCryptSetProperty(algo.algo,
+ BCRYPT_CHAINING_MODE,
+ reinterpret_cast<PUCHAR>(const_cast<wchar_t*>(chainingMode)),
+ sizeof(wchar_t) * wcslen(chainingMode),
+ 0);
+ invariant(status == STATUS_SUCCESS);
+
+ DWORD cbOutput = sizeof(algo.keyBlobSize);
+ status = BCryptGetProperty(algo.algo,
+ BCRYPT_OBJECT_LENGTH,
+ reinterpret_cast<PUCHAR>(&algo.keyBlobSize),
+ cbOutput,
+ &cbOutput,
+ 0);
+ invariant(status == STATUS_SUCCESS);
+ }
+
+private:
+ AlgoInfo _algoAESCBC;
+ BCRYPT_ALG_HANDLE _random;
+};
+
+static BCryptCryptoLoader& getBCryptCryptoLoader() {
+ static BCryptCryptoLoader loader;
+ return loader;
+}
+
+/**
+ * Base class to support initialize symmetric key buffers and state.
+ */
+template <typename Parent>
+class SymmetricImplWindows : public Parent {
+public:
+ SymmetricImplWindows(const SymmetricKey& key, aesMode mode, const uint8_t* iv, size_t ivLen)
+ : _keyHandle(INVALID_HANDLE_VALUE), _mode(mode) {
+ AlgoInfo& algo = getBCryptCryptoLoader().getAlgo(mode);
+
+
+ // Initialize key storage buffers
+ _keyObjectBuf->resize(algo.keyBlobSize);
+
+ SecureVector<unsigned char> keyBlob;
+ keyBlob->reserve(sizeof(BCRYPT_KEY_DATA_BLOB_HEADER) + key.getKeySize());
+
+ BCRYPT_KEY_DATA_BLOB_HEADER blobHeader;
+ blobHeader.dwMagic = BCRYPT_KEY_DATA_BLOB_MAGIC;
+ blobHeader.dwVersion = BCRYPT_KEY_DATA_BLOB_VERSION1;
+ blobHeader.cbKeyData = key.getKeySize();
+
+ std::copy(reinterpret_cast<uint8_t*>(&blobHeader),
+ reinterpret_cast<uint8_t*>(&blobHeader) + sizeof(BCRYPT_KEY_DATA_BLOB_HEADER),
+ std::back_inserter(*keyBlob));
+
+ std::copy(key.getKey(), key.getKey() + key.getKeySize(), std::back_inserter(*keyBlob));
+
+ NTSTATUS status = BCryptImportKey(algo.algo,
+ NULL,
+ BCRYPT_KEY_DATA_BLOB,
+ &_keyHandle,
+ _keyObjectBuf->data(),
+ _keyObjectBuf->size(),
+ keyBlob->data(),
+ keyBlob->size(),
+ 0);
+ uassert(ErrorCodes::OperationFailed,
+ str::stream() << "ImportKey failed: " << statusWithDescription(status),
+ status == STATUS_SUCCESS);
+
+ std::copy(iv, iv + ivLen, std::back_inserter(_iv));
+ }
+
+ ~SymmetricImplWindows() {
+ if (_keyHandle != INVALID_HANDLE_VALUE) {
+ BCryptDestroyKey(_keyHandle);
+ }
+ }
+
+ Status addAuthenticatedData(const uint8_t* in, size_t inLen) final {
+ fassert(51127, inLen == 0);
+ return Status::OK();
+ }
+
+protected:
+ const aesMode _mode;
+
+ // Buffers for key data
+ BCRYPT_KEY_HANDLE _keyHandle;
+
+ SecureVector<unsigned char> _keyObjectBuf;
+
+ // Buffer for CBC data
+ std::vector<unsigned char> _iv;
+};
+
+class SymmetricEncryptorWindows : public SymmetricImplWindows<SymmetricEncryptor> {
+public:
+ using SymmetricImplWindows::SymmetricImplWindows;
+
+ StatusWith<size_t> update(const uint8_t* in, size_t inLen, uint8_t* out, size_t outLen) final {
+ ULONG len = 0;
+
+ NTSTATUS status = BCryptEncrypt(_keyHandle,
+ const_cast<PUCHAR>(in),
+ inLen,
+ NULL,
+ _iv.data(),
+ _iv.size(),
+ out,
+ outLen,
+ &len,
+ BCRYPT_BLOCK_PADDING);
+
+ if (status != STATUS_SUCCESS) {
+ return Status{ErrorCodes::OperationFailed,
+ str::stream() << "Encrypt failed: " << statusWithDescription(status)};
+ }
+
+ return static_cast<size_t>(len);
+ }
+
+ StatusWith<size_t> finalize(uint8_t* out, size_t outLen) final {
+ // No finalize needed
+ return 0;
+ }
+
+ StatusWith<size_t> finalizeTag(uint8_t* out, size_t outLen) final {
+ // Not a tagged cipher mode, write nothing.
+ return 0;
+ }
+};
+
+class SymmetricDecryptorWindows : public SymmetricImplWindows<SymmetricDecryptor> {
+public:
+ using SymmetricImplWindows::SymmetricImplWindows;
+
+ StatusWith<size_t> update(const uint8_t* in, size_t inLen, uint8_t* out, size_t outLen) final {
+ ULONG len = 0;
+
+ NTSTATUS status = BCryptDecrypt(_keyHandle,
+ const_cast<PUCHAR>(in),
+ inLen,
+ NULL,
+ _iv.data(),
+ _iv.size(),
+ out,
+ outLen,
+ &len,
+ BCRYPT_BLOCK_PADDING);
+
+ if (status != STATUS_SUCCESS) {
+ return Status{ErrorCodes::OperationFailed,
+ str::stream() << "Decrypt failed: " << statusWithDescription(status)};
+ }
+
+ return static_cast<size_t>(len);
+ }
+
+ StatusWith<size_t> finalize(uint8_t* out, size_t outLen) final {
+ return 0;
+ }
+
+ Status updateTag(const uint8_t* tag, size_t tagLen) final {
+ return Status::OK();
+ }
+};
+
+} // namespace
+
+std::set<std::string> getSupportedSymmetricAlgorithms() {
+ return {aes256CBCName};
+}
+
+Status engineRandBytes(uint8_t* buffer, size_t len) {
+ NTSTATUS status = BCryptGenRandom(getBCryptCryptoLoader().getRandom(), buffer, len, 0);
+ if (status == STATUS_SUCCESS) {
+ return Status::OK();
+ }
+
+ return {ErrorCodes::UnknownError,
+ str::stream() << "Unable to acquire random bytes from BCrypt: "
+ << statusWithDescription(status)};
+}
+
+StatusWith<std::unique_ptr<SymmetricEncryptor>> SymmetricEncryptor::create(const SymmetricKey& key,
+ aesMode mode,
+ const uint8_t* iv,
+ size_t ivLen) {
+ if (mode != aesMode::cbc) {
+ return Status(ErrorCodes::UnsupportedFormat,
+ "Native crypto on this platform only supports AES256-CBC");
+ }
+
+ try {
+ std::unique_ptr<SymmetricEncryptor> encryptor =
+ std::make_unique<SymmetricEncryptorWindows>(key, mode, iv, ivLen);
+ return std::move(encryptor);
+ } catch (const DBException& e) {
+ return e.toStatus();
+ }
+}
+
+StatusWith<std::unique_ptr<SymmetricDecryptor>> SymmetricDecryptor::create(const SymmetricKey& key,
+ aesMode mode,
+ const uint8_t* iv,
+ size_t ivLen) {
+ if (mode != aesMode::cbc) {
+ return Status(ErrorCodes::UnsupportedFormat,
+ "Native crypto on this platform only supports AES256-CBC");
+ }
+
+ try {
+ std::unique_ptr<SymmetricDecryptor> decryptor =
+ std::make_unique<SymmetricDecryptorWindows>(key, mode, iv, ivLen);
+ return std::move(decryptor);
+ } catch (const DBException& e) {
+ return e.toStatus();
+ }
+}
+
+} // namespace crypto
+} // namespace mongo
diff --git a/src/mongo/crypto/symmetric_key.cpp b/src/mongo/crypto/symmetric_key.cpp
new file mode 100644
index 00000000000..a2bf2526bea
--- /dev/null
+++ b/src/mongo/crypto/symmetric_key.cpp
@@ -0,0 +1,100 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kStorage
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/crypto/symmetric_key.h"
+
+#include <cstring>
+
+#include "mongo/crypto/symmetric_crypto.h"
+#include "mongo/util/log.h"
+#include "mongo/util/secure_zero_memory.h"
+#include "mongo/util/str.h"
+
+namespace mongo {
+
+std::string SymmetricKeyId::_initStrRep() const {
+ return str::stream() << _name << " (" << _id << ")";
+}
+
+const std::string& SymmetricKeyId::toString() const {
+ if (!_strRep.empty()) {
+ return _strRep;
+ } else {
+ return _name;
+ }
+}
+
+SymmetricKey::SymmetricKey(const uint8_t* key,
+ size_t keySize,
+ uint32_t algorithm,
+ SymmetricKeyId keyId,
+ uint32_t initializationCount)
+ : _algorithm(algorithm),
+ _keySize(keySize),
+ _key(key, key + keySize),
+ _keyId(std::move(keyId)),
+ _initializationCount(initializationCount),
+ _invocationCount(0) {
+ if (_keySize < crypto::minKeySize || _keySize > crypto::maxKeySize) {
+ error() << "Attempt to construct symmetric key of invalid size: " << _keySize;
+ return;
+ }
+}
+
+SymmetricKey::SymmetricKey(SecureVector<uint8_t> key, uint32_t algorithm, SymmetricKeyId keyId)
+ : _algorithm(algorithm),
+ _keySize(key->size()),
+ _key(std::move(key)),
+ _keyId(std::move(keyId)),
+ _initializationCount(1),
+ _invocationCount(0) {}
+
+SymmetricKey::SymmetricKey(SymmetricKey&& sk)
+ : _algorithm(sk._algorithm),
+ _keySize(sk._keySize),
+ _key(std::move(sk._key)),
+ _keyId(std::move(sk._keyId)),
+ _initializationCount(sk._initializationCount),
+ _invocationCount(sk._invocationCount.load()) {}
+
+SymmetricKey& SymmetricKey::operator=(SymmetricKey&& sk) {
+ _algorithm = sk._algorithm;
+ _keySize = sk._keySize;
+ _key = std::move(sk._key);
+ _keyId = std::move(sk._keyId);
+ _initializationCount = sk._initializationCount;
+ _invocationCount.store(sk._invocationCount.load());
+
+ return *this;
+}
+} // namespace mongo
diff --git a/src/mongo/crypto/symmetric_key.h b/src/mongo/crypto/symmetric_key.h
new file mode 100644
index 00000000000..b09a35778b2
--- /dev/null
+++ b/src/mongo/crypto/symmetric_key.h
@@ -0,0 +1,146 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "mongo/base/secure_allocator.h"
+#include "mongo/platform/atomic_word.h"
+
+namespace mongo {
+class Status;
+
+class SymmetricKeyId {
+public:
+ using id_type = std::uint64_t;
+
+ template <typename StringLike>
+ SymmetricKeyId(const StringLike& name, id_type id)
+ : _id(id), _name(name), _strRep(_initStrRep()) {}
+
+ template <typename StringLike>
+ SymmetricKeyId(const StringLike& name) : _name(name) {}
+
+ const std::string& toString() const;
+
+ bool operator==(const SymmetricKeyId& other) const {
+ return _id == other._id && _name == other._name;
+ }
+
+ bool operator!=(const SymmetricKeyId& other) const {
+ return !(*this == other);
+ }
+
+ const boost::optional<id_type>& id() const {
+ return _id;
+ }
+
+ const std::string& name() const {
+ return _name;
+ }
+
+private:
+ std::string _initStrRep() const;
+
+ boost::optional<id_type> _id;
+ std::string _name;
+ std::string _strRep;
+};
+
+/**
+ * Class representing a symmetric key
+ */
+class SymmetricKey {
+ SymmetricKey(const SymmetricKey&) = delete;
+ SymmetricKey& operator=(const SymmetricKey&) = delete;
+
+public:
+ SymmetricKey(const uint8_t* key,
+ size_t keySize,
+ uint32_t algorithm,
+ SymmetricKeyId keyId,
+ uint32_t initializationCount);
+ SymmetricKey(SecureVector<uint8_t> key, uint32_t algorithm, SymmetricKeyId keyId);
+
+ SymmetricKey(SymmetricKey&&);
+ SymmetricKey& operator=(SymmetricKey&&);
+
+ ~SymmetricKey() = default;
+
+ int getAlgorithm() const {
+ return _algorithm;
+ }
+
+ size_t getKeySize() const {
+ return _keySize;
+ }
+
+ // Return the number of times the key has been retrieved from the key store
+ uint32_t getInitializationCount() const {
+ return _initializationCount;
+ }
+
+ uint32_t incrementAndGetInitializationCount() {
+ _initializationCount++;
+ return _initializationCount;
+ }
+
+ uint64_t getAndIncrementInvocationCount() const {
+ return _invocationCount.fetchAndAdd(1);
+ }
+
+ const uint8_t* getKey() const {
+ return _key->data();
+ }
+
+ const SymmetricKeyId& getKeyId() const {
+ return _keyId;
+ }
+
+ void setKeyId(SymmetricKeyId keyId) {
+ _keyId = std::move(keyId);
+ }
+
+private:
+ int _algorithm;
+
+ size_t _keySize;
+
+ SecureVector<uint8_t> _key;
+
+ SymmetricKeyId _keyId;
+
+ uint32_t _initializationCount;
+ mutable AtomicWord<unsigned long long> _invocationCount;
+};
+
+using UniqueSymmetricKey = std::unique_ptr<SymmetricKey>;
+} // namespace mongo
diff --git a/src/mongo/db/storage/storage_engine_lock_file_posix.cpp b/src/mongo/db/storage/storage_engine_lock_file_posix.cpp
index 5399a0b2f7e..b39b0503547 100644
--- a/src/mongo/db/storage/storage_engine_lock_file_posix.cpp
+++ b/src/mongo/db/storage/storage_engine_lock_file_posix.cpp
@@ -206,15 +206,13 @@ Status StorageEngineLockFile::writeString(StringData str) {
if (bytesWritten < 0) {
int errorcode = errno;
return Status(ErrorCodes::FileStreamFailed,
- str::stream() << "Unable to write string " << str << " to file: "
- << _filespec
+ str::stream() << "Unable to write string " << str << " to file: " << _filespec
<< ' '
<< errnoWithDescription(errorcode));
} else if (bytesWritten == 0) {
return Status(ErrorCodes::FileStreamFailed,
- str::stream() << "Unable to write string " << str << " to file: "
- << _filespec
+ str::stream() << "Unable to write string " << str << " to file: " << _filespec
<< " no data written.");
}
diff --git a/src/mongo/db/storage/storage_engine_lock_file_windows.cpp b/src/mongo/db/storage/storage_engine_lock_file_windows.cpp
index 72abd6a68bf..2be6f11bb03 100644
--- a/src/mongo/db/storage/storage_engine_lock_file_windows.cpp
+++ b/src/mongo/db/storage/storage_engine_lock_file_windows.cpp
@@ -170,14 +170,12 @@ Status StorageEngineLockFile::writeString(StringData str) {
NULL) == FALSE) {
int errorcode = GetLastError();
return Status(ErrorCodes::FileStreamFailed,
- str::stream() << "Unable to write string " << str << " to file: "
- << _filespec
+ str::stream() << "Unable to write string " << str << " to file: " << _filespec
<< ' '
<< errnoWithDescription(errorcode));
} else if (bytesWritten == 0) {
return Status(ErrorCodes::FileStreamFailed,
- str::stream() << "Unable to write string " << str << " to file: "
- << _filespec
+ str::stream() << "Unable to write string " << str << " to file: " << _filespec
<< " no data written.");
}
diff --git a/src/mongo/shell/SConscript b/src/mongo/shell/SConscript
index b8e6855a860..e9bf5d5abdf 100644
--- a/src/mongo/shell/SConscript
+++ b/src/mongo/shell/SConscript
@@ -1,6 +1,9 @@
# -*- mode: python; -*-
-Import("env")
+Import([
+ 'env',
+ 'get_option'
+])
env = env.Clone()
@@ -61,7 +64,8 @@ env.JSHeader(
"shardingtest.js",
"servers_misc.js",
"replsettest.js",
- "bridge.js"
+ "bridge.js",
+ "keyvault.js",
],
)
@@ -147,3 +151,69 @@ env.CppUnitTest(
'$BUILD_DIR/mongo/util/signal_handlers',
]
)
+
+kmsEnv = env.Clone()
+
+kmsEnv.InjectThirdParty(libraries=['kms-message'])
+
+kmsEnv.Library(
+ target="kms",
+ source=[
+ "kms.cpp",
+ "kms_aws.cpp",
+ "kms_local.cpp",
+ kmsEnv.Idlc("kms.idl")[0],
+ ],
+ LIBDEPS=[
+ '$BUILD_DIR/mongo/base/secure_allocator',
+ ],
+ LIBDEPS_PRIVATE=[
+ '$BUILD_DIR/mongo/base',
+ '$BUILD_DIR/mongo/crypto/aead_encryption',
+ '$BUILD_DIR/mongo/db/commands/test_commands_enabled',
+ '$BUILD_DIR/mongo/util/net/network',
+ '$BUILD_DIR/mongo/util/net/socket',
+ '$BUILD_DIR/mongo/util/net/ssl_manager',
+ '$BUILD_DIR/mongo/util/net/ssl_options',
+ '$BUILD_DIR/third_party/shim_kms_message',
+ ],
+)
+
+env.CppUnitTest(
+ target='kms_test',
+ source='kms_test.cpp',
+ LIBDEPS=[
+ 'kms',
+ ]
+)
+
+env.Library(
+ target="kms_shell",
+ source=[
+ "kms_shell.cpp",
+ ],
+ LIBDEPS_PRIVATE=[
+ '$BUILD_DIR/mongo/shell_core',
+ 'kms',
+ ],
+)
+
+scriptingEnv = env.Clone()
+scriptingEnv.InjectMozJS()
+
+scriptingEnv.Library(
+ target="encrypted_dbclient",
+ source=[
+ "encrypted_dbclient_base.cpp",
+ scriptingEnv.Idlc("fle_shell_options.idl")[0],
+ ],
+ LIBDEPS_PRIVATE=[
+ '$BUILD_DIR/mongo/crypto/aead_encryption',
+ '$BUILD_DIR/mongo/crypto/symmetric_crypto',
+ '$BUILD_DIR/mongo/client/clientdriver_minimal',
+ '$BUILD_DIR/mongo/scripting/scripting',
+ '$BUILD_DIR/mongo/shell/shell_options_register',
+ '$BUILD_DIR/third_party/shim_mozjs',
+ 'kms',
+ ],
+)
diff --git a/src/mongo/shell/encrypted_dbclient_base.cpp b/src/mongo/shell/encrypted_dbclient_base.cpp
new file mode 100644
index 00000000000..34ae4a11b4a
--- /dev/null
+++ b/src/mongo/shell/encrypted_dbclient_base.cpp
@@ -0,0 +1,685 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/shell/encrypted_dbclient_base.h"
+
+#include "mongo/base/data_cursor.h"
+#include "mongo/base/data_type_validated.h"
+#include "mongo/bson/bson_depth.h"
+#include "mongo/client/dbclient_base.h"
+#include "mongo/crypto/aead_encryption.h"
+#include "mongo/crypto/symmetric_crypto.h"
+#include "mongo/db/client.h"
+#include "mongo/db/commands.h"
+#include "mongo/db/matcher/schema/encrypt_schema_gen.h"
+#include "mongo/db/namespace_string.h"
+#include "mongo/rpc/object_check.h"
+#include "mongo/rpc/op_msg_rpc_impls.h"
+#include "mongo/scripting/mozjs/bindata.h"
+#include "mongo/scripting/mozjs/implscope.h"
+#include "mongo/scripting/mozjs/maxkey.h"
+#include "mongo/scripting/mozjs/minkey.h"
+#include "mongo/scripting/mozjs/mongo.h"
+#include "mongo/scripting/mozjs/objectwrapper.h"
+#include "mongo/scripting/mozjs/valuereader.h"
+#include "mongo/scripting/mozjs/valuewriter.h"
+#include "mongo/shell/encrypted_shell_options.h"
+#include "mongo/shell/kms.h"
+#include "mongo/shell/kms_gen.h"
+#include "mongo/shell/shell_options.h"
+#include "mongo/util/lru_cache.h"
+
+namespace mongo {
+
+EncryptedShellGlobalParams encryptedShellGlobalParams;
+
+namespace {
+constexpr Duration kCacheInvalidationTime = Minutes(1);
+
+
+ImplicitEncryptedDBClientCallback* implicitEncryptedDBClientCallback{nullptr};
+
+
+} // namespace
+
+void setImplicitEncryptedDBClientCallback(ImplicitEncryptedDBClientCallback* callback) {
+ implicitEncryptedDBClientCallback = callback;
+}
+
+static void validateCollection(JSContext* cx, JS::HandleValue value) {
+ uassert(ErrorCodes::BadValue,
+ "Collection object must be provided to ClientSideFLEOptions",
+ !(value.isNull() || value.isUndefined()));
+
+ JS::RootedValue coll(cx, value);
+
+ uassert(31043,
+ "The collection object in ClientSideFLEOptions is invalid",
+ mozjs::getScope(cx)->getProto<mozjs::DBCollectionInfo>().instanceOf(coll));
+}
+
+EncryptedDBClientBase::EncryptedDBClientBase(std::unique_ptr<DBClientBase> conn,
+ ClientSideFLEOptions encryptionOptions,
+ JS::HandleValue collection,
+ JSContext* cx)
+ : _conn(std::move(conn)), _encryptionOptions(std::move(encryptionOptions)), _cx(cx) {
+ validateCollection(cx, collection);
+ _collection = JS::Heap<JS::Value>(collection);
+ uassert(31078,
+ "Cannot use WriteMode Legacy with Field Level Encryption",
+ shellGlobalParams.writeMode != "legacy");
+};
+
+std::string EncryptedDBClientBase::getServerAddress() const {
+ return _conn->getServerAddress();
+}
+
+bool EncryptedDBClientBase::call(Message& toSend,
+ Message& response,
+ bool assertOk,
+ std::string* actualServer) {
+ return _conn->call(toSend, response, assertOk, actualServer);
+}
+
+void EncryptedDBClientBase::say(Message& toSend, bool isRetry, std::string* actualServer) {
+ MONGO_UNREACHABLE;
+}
+
+bool EncryptedDBClientBase::lazySupported() const {
+ return _conn->lazySupported();
+}
+
+std::pair<rpc::UniqueReply, DBClientBase*> EncryptedDBClientBase::runCommandWithTarget(
+ OpMsgRequest request) {
+ return _conn->runCommandWithTarget(std::move(request));
+}
+
+
+/**
+ *
+ * This function reads the data from the CDR and returns a copy
+ * constructed and owned BSONObject.
+ *
+ */
+BSONObj EncryptedDBClientBase::validateBSONElement(ConstDataRange out, uint8_t bsonType) {
+ if (bsonType == BSONType::Object) {
+ ConstDataRangeCursor cdc = ConstDataRangeCursor(out);
+ BSONObj valueObj;
+
+ valueObj = cdc.readAndAdvance<Validated<BSONObj>>();
+ return valueObj.getOwned();
+ } else {
+ auto valueString = "value"_sd;
+
+ // The size here is to construct a new BSON document and validate the
+ // total size of the object. The first four bytes is for the size of an
+ // int32_t, then a space for the type of the first element, then the space
+ // for the value string and the the 0x00 terminated field name, then the
+ // size of the actual data, then the last byte for the end document character,
+ // also 0x00.
+ size_t docLength = sizeof(int32_t) + 1 + valueString.size() + 1 + out.length() + 1;
+ BufBuilder builder;
+ builder.reserveBytes(docLength);
+
+ uassert(ErrorCodes::BadValue,
+ "invalid decryption value",
+ docLength < std::numeric_limits<int32_t>::max());
+
+ builder.appendNum(static_cast<uint32_t>(docLength));
+ builder.appendChar(static_cast<uint8_t>(bsonType));
+ builder.appendStr(valueString, true);
+ builder.appendBuf(out.data(), out.length());
+ builder.appendChar('\0');
+
+ ConstDataRangeCursor cdc =
+ ConstDataRangeCursor(ConstDataRange(builder.buf(), builder.len()));
+ BSONObj elemWrapped = cdc.readAndAdvance<Validated<BSONObj>>();
+ return elemWrapped.getOwned();
+ }
+}
+
+std::string EncryptedDBClientBase::toString() const {
+ return _conn->toString();
+}
+
+int EncryptedDBClientBase::getMinWireVersion() {
+ return _conn->getMinWireVersion();
+}
+
+int EncryptedDBClientBase::getMaxWireVersion() {
+ return _conn->getMaxWireVersion();
+}
+
+void EncryptedDBClientBase::generateDataKey(JSContext* cx, JS::CallArgs args) {
+ if (args.length() != 2) {
+ uasserted(ErrorCodes::BadValue, "generateDataKey requires 2 arg");
+ }
+
+ if (!args.get(0).isString()) {
+ uasserted(ErrorCodes::BadValue, "1st param to generateDataKey has to be a string");
+ }
+
+ if (!args.get(1).isString()) {
+ uasserted(ErrorCodes::BadValue, "2nd param to generateDataKey has to be a string");
+ }
+
+ std::string kmsProvider = mozjs::ValueWriter(cx, args.get(0)).toString();
+ std::string clientMasterKey = mozjs::ValueWriter(cx, args.get(1)).toString();
+
+ std::unique_ptr<KMSService> kmsService = KMSServiceController::createFromClient(
+ kmsProvider, _encryptionOptions.getKmsProviders().toBSON());
+
+ SecureVector<uint8_t> dataKey(crypto::kFieldLevelEncryptionKeySize);
+ auto res = crypto::engineRandBytes(dataKey->data(), dataKey->size());
+ uassert(31042, "Error generating data key: " + res.codeString(), res.isOK());
+
+ BSONObj obj = kmsService->encryptDataKey(ConstDataRange(dataKey->data(), dataKey->size()),
+ clientMasterKey);
+
+ mozjs::ValueReader(cx, args.rval()).fromBSON(obj, nullptr, false);
+}
+
+void EncryptedDBClientBase::getDataKeyCollection(JSContext* cx, JS::CallArgs args) {
+ if (args.length() != 0) {
+ uasserted(ErrorCodes::BadValue, "getDataKeyCollection does not take any params");
+ }
+ args.rval().set(_collection.get());
+}
+
+void EncryptedDBClientBase::encrypt(mozjs::MozJSImplScope* scope,
+ JSContext* cx,
+ JS::CallArgs args) {
+ // Input Validation
+ uassert(ErrorCodes::BadValue, "encrypt requires 3 args", args.length() == 3);
+
+ if (!(args.get(1).isObject() || args.get(1).isString() || args.get(1).isNumber() ||
+ args.get(1).isBoolean())) {
+ uasserted(ErrorCodes::BadValue,
+ "Second parameter must be an object, string, number, or bool");
+ }
+
+ uassert(ErrorCodes::BadValue, "Third parameter must be a string", args.get(2).isString());
+ auto algorithmStr = mozjs::ValueWriter(cx, args.get(2)).toString();
+ FleAlgorithmInt algorithm;
+
+ if (StringData(algorithmStr) == FleAlgorithm_serializer(FleAlgorithmEnum::kRandom)) {
+ algorithm = FleAlgorithmInt::kRandom;
+ } else if (StringData(algorithmStr) ==
+ FleAlgorithm_serializer(FleAlgorithmEnum::kDeterministic)) {
+ algorithm = FleAlgorithmInt::kDeterministic;
+ } else {
+ uasserted(ErrorCodes::BadValue, "Third parameter must be the FLE Algorithm type");
+ }
+
+ // Extract the UUID from the callArgs
+ auto binData = getBinDataArg(scope, cx, args, 0, BinDataType::newUUID);
+ UUID uuid = UUID::fromCDR(ConstDataRange(binData.data(), binData.size()));
+ BSONType bsonType = BSONType::EOO;
+
+ BufBuilder plaintext;
+ if (args.get(1).isObject()) {
+ JS::RootedObject rootedObj(cx, &args.get(1).toObject());
+ auto jsclass = JS_GetClass(rootedObj);
+
+ if (strcmp(jsclass->name, "Object") == 0 || strcmp(jsclass->name, "Array") == 0) {
+ uassert(ErrorCodes::BadValue,
+ "Cannot deterministically encrypt object or array types.",
+ algorithm != FleAlgorithmInt::kDeterministic);
+
+ // If it is a JS Object, then we can extract all the information by simply calling
+ // ValueWriter.toBSON and setting the type bit, which is what is happening below.
+ BSONObj valueObj = mozjs::ValueWriter(cx, args.get(1)).toBSON();
+ plaintext.appendBuf(valueObj.objdata(), valueObj.objsize());
+ if (strcmp(jsclass->name, "Array") == 0) {
+ bsonType = BSONType::Array;
+ } else {
+ bsonType = BSONType::Object;
+ }
+
+ } else if (scope->getProto<mozjs::MinKeyInfo>().getJSClass() == jsclass ||
+ scope->getProto<mozjs::MaxKeyInfo>().getJSClass() == jsclass ||
+ scope->getProto<mozjs::DBRefInfo>().getJSClass() == jsclass) {
+ uasserted(ErrorCodes::BadValue, "Second parameter cannot be MinKey, MaxKey, or DBRef");
+ } else {
+ if (scope->getProto<mozjs::NumberDecimalInfo>().getJSClass() == jsclass) {
+ uassert(ErrorCodes::BadValue,
+ "Cannot deterministically encrypt NumberDecimal type objects.",
+ algorithm != FleAlgorithmInt::kDeterministic);
+ }
+
+ if (scope->getProto<mozjs::CodeInfo>().getJSClass() == jsclass) {
+ uassert(ErrorCodes::BadValue,
+ "Cannot deterministically encrypt Code type objects.",
+ algorithm != FleAlgorithmInt::kDeterministic);
+ }
+
+ // If it is one of our Mongo defined types, then we have to use the ValueWriter
+ // writeThis function, which takes in a set of WriteFieldRecursionFrames (setting
+ // a limit on how many times we can recursively dig into an object's nested
+ // structure)
+ // and writes the value out to a BSONObjBuilder. We can then extract that
+ // information
+ // from the object by building it and pulling out the first element, which is the
+ // object we are trying to get.
+ mozjs::ObjectWrapper::WriteFieldRecursionFrames frames;
+ frames.emplace(cx, rootedObj.get(), nullptr, StringData{});
+ BSONObjBuilder builder;
+ mozjs::ValueWriter(cx, args.get(1)).writeThis(&builder, "value"_sd, &frames);
+
+ BSONObj object = builder.obj();
+ auto elem = object.getField("value"_sd);
+
+ plaintext.appendBuf(elem.value(), elem.valuesize());
+ bsonType = elem.type();
+ }
+
+ } else if (args.get(1).isString()) {
+ std::string valueStr = mozjs::ValueWriter(cx, args.get(1)).toString();
+ if (valueStr.size() + 1 > std::numeric_limits<uint32_t>::max()) {
+ uasserted(ErrorCodes::BadValue, "Plaintext string to encrypt too long.");
+ }
+
+ plaintext.appendNum(static_cast<uint32_t>(valueStr.size() + 1));
+ plaintext.appendStr(valueStr, true);
+ bsonType = BSONType::String;
+
+ } else if (args.get(1).isNumber()) {
+ uassert(ErrorCodes::BadValue,
+ "Cannot deterministically encrypt Floating Point numbers.",
+ algorithm != FleAlgorithmInt::kDeterministic);
+
+ double valueNum = mozjs::ValueWriter(cx, args.get(1)).toNumber();
+ plaintext.appendNum(valueNum);
+ bsonType = BSONType::NumberDouble;
+ } else if (args.get(1).isBoolean()) {
+ uassert(ErrorCodes::BadValue,
+ "Cannot deterministically encrypt booleans.",
+ algorithm != FleAlgorithmInt::kDeterministic);
+
+ bool boolean = mozjs::ValueWriter(cx, args.get(1)).toBoolean();
+ if (boolean) {
+ plaintext.appendChar(0x01);
+ } else {
+ plaintext.appendChar(0x00);
+ }
+ bsonType = BSONType::Bool;
+ } else {
+ uasserted(ErrorCodes::BadValue, "Cannot encrypt valuetype provided.");
+ }
+ ConstDataRange plaintextRange(plaintext.buf(), plaintext.len());
+
+ auto key = getDataKey(uuid);
+ std::vector<uint8_t> fleBlob =
+ encryptWithKey(uuid, key, plaintextRange, bsonType, FleAlgorithmInt_serializer(algorithm));
+
+ // Prepare the return value
+ std::string blobStr = base64::encode(reinterpret_cast<char*>(fleBlob.data()), fleBlob.size());
+ JS::AutoValueArray<2> arr(cx);
+
+ arr[0].setInt32(BinDataType::Encrypt);
+ mozjs::ValueReader(cx, arr[1]).fromStringData(blobStr);
+ scope->getProto<mozjs::BinDataInfo>().newInstance(arr, args.rval());
+}
+
+void EncryptedDBClientBase::decrypt(mozjs::MozJSImplScope* scope,
+ JSContext* cx,
+ JS::CallArgs args) {
+ uassert(ErrorCodes::BadValue, "decrypt requires one argument", args.length() == 1);
+ uassert(ErrorCodes::BadValue,
+ "decrypt argument must be a BinData subtype Encrypt object",
+ args.get(0).isObject());
+
+ if (!scope->getProto<mozjs::BinDataInfo>().instanceOf(args.get(0))) {
+ uasserted(ErrorCodes::BadValue,
+ "decrypt argument must be a BinData subtype Encrypt object");
+ }
+
+ JS::RootedObject obj(cx, &args.get(0).get().toObject());
+ std::vector<uint8_t> binData = getBinDataArg(scope, cx, args, 0, BinDataType::Encrypt);
+
+ uassert(
+ ErrorCodes::BadValue, "Ciphertext blob too small", binData.size() > kAssociatedDataLength);
+ uassert(ErrorCodes::BadValue,
+ "Ciphertext blob algorithm unknown",
+ (FleAlgorithmInt(binData[0]) == FleAlgorithmInt::kDeterministic ||
+ FleAlgorithmInt(binData[0]) == FleAlgorithmInt::kRandom));
+
+ ConstDataRange uuidCdr = ConstDataRange(&binData[1], UUID::kNumBytes);
+ UUID uuid = UUID::fromCDR(uuidCdr);
+
+ auto key = getDataKey(uuid);
+ std::vector<uint8_t> out(binData.size() - kAssociatedDataLength);
+ size_t outLen = out.size();
+
+ auto decryptStatus = crypto::aeadDecrypt(*key,
+ &binData[kAssociatedDataLength],
+ binData.size() - kAssociatedDataLength,
+ &binData[0],
+ kAssociatedDataLength,
+ out.data(),
+ &outLen);
+ if (!decryptStatus.isOK()) {
+ uasserted(decryptStatus.code(), decryptStatus.reason());
+ }
+
+ uint8_t bsonType = binData[17];
+ BSONObj parent;
+ BSONObj decryptedObj = validateBSONElement(ConstDataRange(out.data(), outLen), bsonType);
+ if (bsonType == BSONType::Object) {
+ mozjs::ValueReader(cx, args.rval()).fromBSON(decryptedObj, &parent, true);
+ } else {
+ mozjs::ValueReader(cx, args.rval())
+ .fromBSONElement(decryptedObj.firstElement(), parent, true);
+ }
+}
+
+void EncryptedDBClientBase::trace(JSTracer* trc) {
+ JS::TraceEdge(trc, &_collection, "collection object");
+}
+
+JS::Value EncryptedDBClientBase::getCollection() const {
+ return _collection.get();
+}
+
+
+std::unique_ptr<DBClientCursor> EncryptedDBClientBase::query(const NamespaceStringOrUUID& nsOrUuid,
+ Query query,
+ int nToReturn,
+ int nToSkip,
+ const BSONObj* fieldsToReturn,
+ int queryOptions,
+ int batchSize) {
+ return _conn->query(
+ nsOrUuid, query, nToReturn, nToSkip, fieldsToReturn, queryOptions, batchSize);
+}
+
+bool EncryptedDBClientBase::isFailed() const {
+ return _conn->isFailed();
+}
+
+bool EncryptedDBClientBase::isStillConnected() {
+ return _conn->isStillConnected();
+}
+
+ConnectionString::ConnectionType EncryptedDBClientBase::type() const {
+ return _conn->type();
+}
+
+double EncryptedDBClientBase::getSoTimeout() const {
+ return _conn->getSoTimeout();
+}
+
+bool EncryptedDBClientBase::isReplicaSetMember() const {
+ return _conn->isReplicaSetMember();
+}
+
+bool EncryptedDBClientBase::isMongos() const {
+ return _conn->isMongos();
+}
+
+NamespaceString EncryptedDBClientBase::getCollectionNS() {
+ JS::RootedValue fullNameRooted(_cx);
+ JS::RootedObject collectionRooted(_cx, &_collection.get().toObject());
+ JS_GetProperty(_cx, collectionRooted, "_fullName", &fullNameRooted);
+ if (!fullNameRooted.isString()) {
+ uasserted(ErrorCodes::BadValue, "Collection object is incomplete.");
+ }
+ std::string fullName = mozjs::ValueWriter(_cx, fullNameRooted).toString();
+ NamespaceString fullNameNS = NamespaceString(fullName);
+ uassert(ErrorCodes::BadValue,
+ str::stream() << "Invalid namespace: " << fullName,
+ fullNameNS.isValid());
+ return fullNameNS;
+}
+
+std::vector<uint8_t> EncryptedDBClientBase::getBinDataArg(
+ mozjs::MozJSImplScope* scope, JSContext* cx, JS::CallArgs args, int index, BinDataType type) {
+ if (!args.get(index).isObject() ||
+ !scope->getProto<mozjs::BinDataInfo>().instanceOf(args.get(index))) {
+ uasserted(ErrorCodes::BadValue, "First parameter must be a BinData object");
+ }
+
+ mozjs::ObjectWrapper o(cx, args.get(index));
+
+ auto binType = BinDataType(static_cast<int>(o.getNumber(mozjs::InternedString::type)));
+ uassert(ErrorCodes::BadValue,
+ str::stream() << "Incorrect bindata type, expected" << typeName(type) << " but got "
+ << typeName(binType),
+ binType == type);
+ auto str = static_cast<std::string*>(JS_GetPrivate(args.get(index).toObjectOrNull()));
+ uassert(ErrorCodes::BadValue, "Cannot call getter on BinData prototype", str);
+ std::string string = base64::decode(*str);
+ return std::vector<uint8_t>(string.data(), string.data() + string.length());
+}
+
+std::shared_ptr<SymmetricKey> EncryptedDBClientBase::getDataKey(const UUID& uuid) {
+ auto ts_new = Date_t::now();
+
+ if (_datakeyCache.hasKey(uuid)) {
+ auto[key, ts] = _datakeyCache.find(uuid)->second;
+ if (ts_new - ts < kCacheInvalidationTime) {
+ return key;
+ } else {
+ _datakeyCache.erase(uuid);
+ }
+ }
+ auto key = getDataKeyFromDisk(uuid);
+ _datakeyCache.add(uuid, std::make_pair(key, ts_new));
+ return key;
+}
+
+std::shared_ptr<SymmetricKey> EncryptedDBClientBase::getDataKeyFromDisk(const UUID& uuid) {
+ NamespaceString fullNameNS = getCollectionNS();
+ BSONObj dataKeyObj = _conn->findOne(fullNameNS.ns(), QUERY("_id" << uuid));
+ if (dataKeyObj.isEmpty()) {
+ uasserted(ErrorCodes::BadValue, "Invalid keyID.");
+ }
+
+ auto keyStoreRecord = KeyStoreRecord::parse(IDLParserErrorContext("root"), dataKeyObj);
+ if (dataKeyObj.hasField("version"_sd)) {
+ uassert(ErrorCodes::BadValue,
+ "Invalid version, must be either 0 or undefined",
+ dataKeyObj.getIntField("version"_sd) == 0);
+ }
+
+ BSONElement elem = dataKeyObj.getField("keyMaterial"_sd);
+ uassert(ErrorCodes::BadValue, "Invalid key.", elem.isBinData(BinDataType::BinDataGeneral));
+ uassert(ErrorCodes::BadValue,
+ "Invalid version, must be either 0 or undefined",
+ keyStoreRecord.getVersion() == 0);
+
+ auto dataKey = keyStoreRecord.getKeyMaterial();
+ uassert(ErrorCodes::BadValue, "Invalid data key.", dataKey.length() != 0);
+
+ std::unique_ptr<KMSService> kmsService = KMSServiceController::createFromDisk(
+ _encryptionOptions.getKmsProviders().toBSON(), keyStoreRecord.getMasterKey());
+ SecureVector<uint8_t> decryptedKey =
+ kmsService->decrypt(dataKey, keyStoreRecord.getMasterKey());
+ return std::make_shared<SymmetricKey>(
+ std::move(decryptedKey), crypto::aesAlgorithm, "kms_encryption");
+}
+
+std::vector<uint8_t> EncryptedDBClientBase::encryptWithKey(UUID uuid,
+ const std::shared_ptr<SymmetricKey>& key,
+ ConstDataRange plaintext,
+ BSONType bsonType,
+ int32_t algorithm) {
+ // As per the description of the encryption algorithm for FLE, the
+ // associated data is constructed of the following -
+ // associatedData[0] = the FleAlgorithmEnum
+ // - either a 1 or a 2 depending on whether the iv is provided.
+ // associatedData[1-16] = the uuid in bytes
+ // associatedData[17] = the bson type
+
+ ConstDataRange uuidCdr = uuid.toCDR();
+ uint64_t outputLength = crypto::aeadCipherOutputLength(plaintext.length());
+ std::vector<uint8_t> outputBuffer(kAssociatedDataLength + outputLength);
+ outputBuffer[0] = static_cast<uint8_t>(algorithm);
+ std::memcpy(&outputBuffer[1], uuidCdr.data(), uuidCdr.length());
+ outputBuffer[17] = static_cast<uint8_t>(bsonType);
+ uassertStatusOK(crypto::aeadEncrypt(*key,
+ reinterpret_cast<const uint8_t*>(plaintext.data()),
+ plaintext.length(),
+ outputBuffer.data(),
+ 18,
+ // The ciphertext starts 18 bytes into the output
+ // buffer, as described above.
+ outputBuffer.data() + 18,
+ outputLength));
+ return outputBuffer;
+}
+
+namespace {
+
+/**
+ * Constructs a collection object from a namespace, passed in to the nsString parameter.
+ * The client is the connection to a database in which you want to create the collection.
+ * The collection parameter gets set to a javascript collection object.
+ */
+void createCollectionObject(JSContext* cx,
+ JS::HandleValue client,
+ StringData nsString,
+ JS::MutableHandleValue collection) {
+ invariant(!client.isNull() && !client.isUndefined());
+
+ auto ns = NamespaceString(nsString);
+ uassert(ErrorCodes::BadValue,
+ "Invalid keystore namespace.",
+ ns.isValid() && NamespaceString::validCollectionName(ns.coll()));
+
+ auto scope = mozjs::getScope(cx);
+
+ // The collection object requires a database object to be constructed as well.
+ JS::RootedValue databaseRV(cx);
+ JS::AutoValueArray<2> databaseArgs(cx);
+
+ databaseArgs[0].setObject(client.toObject());
+ mozjs::ValueReader(cx, databaseArgs[1]).fromStringData(ns.db());
+ scope->getProto<mozjs::DBInfo>().newInstance(databaseArgs, &databaseRV);
+
+ invariant(databaseRV.isObject());
+ auto databaseObj = databaseRV.toObjectOrNull();
+
+ JS::AutoValueArray<4> collectionArgs(cx);
+ collectionArgs[0].setObject(client.toObject());
+ collectionArgs[1].setObject(*databaseObj);
+ mozjs::ValueReader(cx, collectionArgs[2]).fromStringData(ns.coll());
+ mozjs::ValueReader(cx, collectionArgs[3]).fromStringData(ns.ns());
+
+ scope->getProto<mozjs::DBCollectionInfo>().newInstance(collectionArgs, collection);
+}
+
+// The parameters required to start FLE on the shell. The current connection is passed in as a
+// parameter to create the keyvault collection object if one is not provided.
+std::unique_ptr<DBClientBase> createEncryptedDBClientBase(std::unique_ptr<DBClientBase> conn,
+ JS::HandleValue arg,
+ JS::HandleObject mongoConnection,
+ JSContext* cx) {
+
+ uassert(
+ 31038, "Invalid Client Side Encryption parameters.", arg.isObject() || arg.isUndefined());
+
+ static constexpr auto keyVaultClientFieldId = "keyVaultClient";
+
+ if (!arg.isObject() && encryptedShellGlobalParams.awsAccessKeyId.empty()) {
+ return conn;
+ }
+
+ ClientSideFLEOptions encryptionOptions;
+ JS::RootedValue client(cx);
+ JS::RootedValue collection(cx);
+
+ if (!arg.isObject()) {
+ // If arg is not an object, but one of the required encryptedShellGlobalParams
+ // is defined, the user is trying to start an encrypted client with command line
+ // parameters.
+
+ AwsKMS awsKms = AwsKMS(encryptedShellGlobalParams.awsAccessKeyId,
+ encryptedShellGlobalParams.awsSecretAccessKey);
+
+ awsKms.setUrl(StringData(encryptedShellGlobalParams.awsKmsURL));
+
+ awsKms.setSessionToken(StringData(encryptedShellGlobalParams.awsSessionToken));
+
+ KmsProviders kmsProviders;
+ kmsProviders.setAws(awsKms);
+
+ // The mongoConnection object will never be null.
+ // If the encrypted shell is started through command line parameters, then the user must
+ // default to the implicit connection for the keyvault collection.
+ client.setObjectOrNull(mongoConnection.get());
+
+ // Because we cannot add a schemaMap object through the command line, we set the
+ // schemaMap object in ClientSideFLEOptions to be null so we know to always use
+ // remote schemas.
+ encryptionOptions = ClientSideFLEOptions(
+ encryptedShellGlobalParams.keyVaultNamespace, std::move(kmsProviders), BSONObj());
+ } else {
+ uassert(ErrorCodes::BadValue,
+ "Collection object must be passed to Field Level Encryption Options",
+ arg.isObject());
+
+ const BSONObj obj = mozjs::ValueWriter(cx, arg).toBSON();
+ encryptionOptions = encryptionOptions.parse(IDLParserErrorContext("root"), obj);
+
+ // IDL does not perform a deep copy of BSONObjs when parsing, so we must get an
+ // owned copy of the schemaMap.
+ encryptionOptions.setSchemaMap(encryptionOptions.getSchemaMap().getOwned());
+
+ // This logic tries to extract the client from the args. If the connection object is defined
+ // in the ClientSideFLEOptions struct, then the client will extract it and set itself to be
+ // that. Else, the client will default to the implicit connection.
+ JS::RootedObject handleObject(cx, &arg.toObject());
+ JS_GetProperty(cx, handleObject, keyVaultClientFieldId, &client);
+ if (client.isNull() || client.isUndefined()) {
+ client.setObjectOrNull(mongoConnection.get());
+ }
+ }
+
+ createCollectionObject(cx, client, encryptionOptions.getKeyVaultNamespace(), &collection);
+
+ if (implicitEncryptedDBClientCallback != nullptr) {
+ return implicitEncryptedDBClientCallback(
+ std::move(conn), encryptionOptions, collection, cx);
+ }
+
+ std::unique_ptr<EncryptedDBClientBase> base =
+ std::make_unique<EncryptedDBClientBase>(std::move(conn), encryptionOptions, collection, cx);
+ return std::move(base);
+}
+
+MONGO_INITIALIZER(setCallbacksForEncryptedDBClientBase)(InitializerContext*) {
+ mongo::mozjs::setEncryptedDBClientCallback(createEncryptedDBClientBase);
+ return Status::OK();
+}
+
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/shell/encrypted_dbclient_base.h b/src/mongo/shell/encrypted_dbclient_base.h
new file mode 100644
index 00000000000..f72b8f6cddc
--- /dev/null
+++ b/src/mongo/shell/encrypted_dbclient_base.h
@@ -0,0 +1,171 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/base/data_cursor.h"
+#include "mongo/base/data_type_validated.h"
+#include "mongo/bson/bson_depth.h"
+#include "mongo/client/dbclient_base.h"
+#include "mongo/crypto/aead_encryption.h"
+#include "mongo/crypto/symmetric_crypto.h"
+#include "mongo/db/client.h"
+#include "mongo/db/commands.h"
+#include "mongo/db/matcher/schema/encrypt_schema_gen.h"
+#include "mongo/db/namespace_string.h"
+#include "mongo/rpc/object_check.h"
+#include "mongo/rpc/op_msg_rpc_impls.h"
+#include "mongo/scripting/mozjs/bindata.h"
+#include "mongo/scripting/mozjs/implscope.h"
+#include "mongo/scripting/mozjs/maxkey.h"
+#include "mongo/scripting/mozjs/minkey.h"
+#include "mongo/scripting/mozjs/mongo.h"
+#include "mongo/scripting/mozjs/objectwrapper.h"
+#include "mongo/scripting/mozjs/valuereader.h"
+#include "mongo/scripting/mozjs/valuewriter.h"
+#include "mongo/shell/encrypted_shell_options.h"
+#include "mongo/shell/kms.h"
+#include "mongo/shell/kms_gen.h"
+#include "mongo/shell/shell_options.h"
+#include "mongo/util/lru_cache.h"
+
+namespace mongo {
+
+constexpr std::size_t kEncryptedDBCacheSize = 50;
+
+constexpr int kAssociatedDataLength = 18;
+constexpr uint8_t kIntentToEncryptBit = 0x00;
+constexpr uint8_t kDeterministicEncryptionBit = 0x01;
+constexpr uint8_t kRandomEncryptionBit = 0x02;
+
+class EncryptedDBClientBase : public DBClientBase, public mozjs::EncryptionCallbacks {
+public:
+ EncryptedDBClientBase(std::unique_ptr<DBClientBase> conn,
+ ClientSideFLEOptions encryptionOptions,
+ JS::HandleValue collection,
+ JSContext* cx);
+
+
+ std::string getServerAddress() const final;
+
+ bool call(Message& toSend, Message& response, bool assertOk, std::string* actualServer) final;
+
+ void say(Message& toSend, bool isRetry, std::string* actualServer) final;
+
+ bool lazySupported() const final;
+
+ using DBClientBase::runCommandWithTarget;
+ virtual std::pair<rpc::UniqueReply, DBClientBase*> runCommandWithTarget(
+ OpMsgRequest request) override;
+ std::string toString() const final;
+
+ int getMinWireVersion() final;
+
+ int getMaxWireVersion() final;
+
+ using EncryptionCallbacks::generateDataKey;
+ void generateDataKey(JSContext* cx, JS::CallArgs args) final;
+
+ using EncryptionCallbacks::getDataKeyCollection;
+ void getDataKeyCollection(JSContext* cx, JS::CallArgs args) final;
+
+ using EncryptionCallbacks::encrypt;
+ void encrypt(mozjs::MozJSImplScope* scope, JSContext* cx, JS::CallArgs args) final;
+
+ using EncryptionCallbacks::decrypt;
+ void decrypt(mozjs::MozJSImplScope* scope, JSContext* cx, JS::CallArgs args) final;
+
+ using EncryptionCallbacks::trace;
+ void trace(JSTracer* trc) final;
+
+ using DBClientBase::query;
+ std::unique_ptr<DBClientCursor> query(const NamespaceStringOrUUID& nsOrUuid,
+ Query query,
+ int nToReturn,
+ int nToSkip,
+ const BSONObj* fieldsToReturn,
+ int queryOptions,
+ int batchSize) final;
+
+ bool isFailed() const final;
+
+ bool isStillConnected() final;
+
+ ConnectionString::ConnectionType type() const final;
+
+ double getSoTimeout() const final;
+
+ bool isReplicaSetMember() const final;
+
+ bool isMongos() const final;
+
+protected:
+ JS::Value getCollection() const;
+
+ BSONObj validateBSONElement(ConstDataRange out, uint8_t bsonType);
+
+ NamespaceString getCollectionNS();
+
+ std::shared_ptr<SymmetricKey> getDataKey(const UUID& uuid);
+
+ std::vector<uint8_t> encryptWithKey(UUID uuid,
+ const std::shared_ptr<SymmetricKey>& key,
+ ConstDataRange plaintext,
+ BSONType bsonType,
+ int32_t algorithm);
+
+private:
+ std::vector<uint8_t> getBinDataArg(mozjs::MozJSImplScope* scope,
+ JSContext* cx,
+ JS::CallArgs args,
+ int index,
+ BinDataType type);
+
+ std::shared_ptr<SymmetricKey> getDataKeyFromDisk(const UUID& uuid);
+
+protected:
+ std::unique_ptr<DBClientBase> _conn;
+ ClientSideFLEOptions _encryptionOptions;
+
+private:
+ LRUCache<UUID, std::pair<std::shared_ptr<SymmetricKey>, Date_t>, UUID::Hash> _datakeyCache{
+ kEncryptedDBCacheSize};
+ JS::Heap<JS::Value> _collection;
+ JSContext* _cx;
+};
+
+using ImplicitEncryptedDBClientCallback =
+ std::unique_ptr<DBClientBase>(std::unique_ptr<DBClientBase> conn,
+ ClientSideFLEOptions encryptionOptions,
+ JS::HandleValue collection,
+ JSContext* cx);
+void setImplicitEncryptedDBClientCallback(ImplicitEncryptedDBClientCallback* callback);
+
+
+} // namespace mongo
diff --git a/src/mongo/shell/encrypted_shell_options.h b/src/mongo/shell/encrypted_shell_options.h
new file mode 100644
index 00000000000..f839c637d9a
--- /dev/null
+++ b/src/mongo/shell/encrypted_shell_options.h
@@ -0,0 +1,45 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <string>
+
+namespace mongo {
+
+struct EncryptedShellGlobalParams {
+ std::string awsAccessKeyId;
+ std::string awsSecretAccessKey;
+ std::string awsSessionToken;
+ std::string keyVaultNamespace;
+ std::string awsKmsURL;
+};
+
+extern EncryptedShellGlobalParams encryptedShellGlobalParams;
+}
diff --git a/src/mongo/shell/fle_shell_options.idl b/src/mongo/shell/fle_shell_options.idl
new file mode 100644
index 00000000000..64578aaa69b
--- /dev/null
+++ b/src/mongo/shell/fle_shell_options.idl
@@ -0,0 +1,37 @@
+# Copyright (C) 2019-present MongoDB, Inc.
+
+global:
+ cpp_namespace: "mongo"
+ configs:
+ section: 'FLE AWS Options'
+ source: [ cli ]
+ cpp_includes:
+ - mongo/shell/encrypted_shell_options.h
+
+configs:
+ "awsAccessKeyId":
+ description: "AWS Access Key for FLE Amazon KMS"
+ arg_vartype: String
+ cpp_varname: encryptedShellGlobalParams.awsAccessKeyId
+ requires: [ "awsSecretAccessKey", "keyVaultNamespace" ]
+ "awsSecretAccessKey":
+ description: "AWS Secret Key for FLE Amazon KMS"
+ arg_vartype: String
+ cpp_varname: encryptedShellGlobalParams.awsSecretAccessKey
+ redact: true
+ requires: [ "awsAccessKeyId", "keyVaultNamespace" ]
+ "awsSessionToken":
+ description: "Optional AWS Session Token ID"
+ arg_vartype: String
+ cpp_varname: encryptedShellGlobalParams.awsSessionToken
+ requires: [ "awsAccessKeyId", "awsSecretAccessKey", "keyVaultNamespace" ]
+ "keyVaultNamespace":
+ description: "database.collection to store encrypted FLE parameters"
+ arg_vartype: String
+ cpp_varname: encryptedShellGlobalParams.keyVaultNamespace
+ requires: [ "awsAccessKeyId", "awsSecretAccessKey" ]
+ "kmsURL":
+ description: "Test parameter to override the URL for KMS"
+ arg_vartype: String
+ cpp_varname: encryptedShellGlobalParams.awsKmsURL
+ requires: [ "awsAccessKeyId", "awsSecretAccessKey", "keyVaultNamespace" ]
diff --git a/src/mongo/shell/keyvault.js b/src/mongo/shell/keyvault.js
new file mode 100644
index 00000000000..cf91a39cd2d
--- /dev/null
+++ b/src/mongo/shell/keyvault.js
@@ -0,0 +1,106 @@
+// Class that allows the mongo shell to talk to the mongodb KeyVault.
+// Loaded only into the enterprise module.
+
+Mongo.prototype.getKeyVault = function() {
+ return new KeyVault(this);
+};
+
+class KeyVault {
+ constructor(mongo) {
+ this.mongo = mongo;
+ var collection = mongo.getDataKeyCollection();
+ this.keyColl = collection;
+ this.keyColl.createIndex(
+ {keyAltNames: 1},
+ {unique: true, partialFilterExpression: {keyAltNames: {$exists: true}}});
+ }
+
+ createKey(kmsProvider, customerMasterKey, keyAltNames = undefined) {
+ if (typeof kmsProvider !== "string") {
+ return "TypeError: kmsProvider must be of String type.";
+ }
+
+ if (typeof customerMasterKey !== "string") {
+ return "TypeError: customer master key must be of String type.";
+ }
+
+ var masterKeyAndMaterial = this.mongo.generateDataKey(kmsProvider, customerMasterKey);
+ var masterKey = masterKeyAndMaterial.masterKey;
+
+ var current = ISODate();
+
+ var doc = {
+ "_id": UUID(),
+ "keyMaterial": masterKeyAndMaterial.keyMaterial,
+ "creationDate": current,
+ "updateDate": current,
+ "status": NumberInt(0),
+ "version": NumberLong(0),
+ "masterKey": masterKey,
+ };
+
+ if (keyAltNames) {
+ if (!Array.isArray(keyAltNames)) {
+ return "TypeError: key alternate names must be of Array type.";
+ }
+
+ let i = 0;
+ for (i = 0; i < keyAltNames.length; i++) {
+ if (typeof keyAltNames[i] !== "string") {
+ return "TypeError: items in key alternate names must be of String type.";
+ }
+ }
+
+ doc.keyAltNames = keyAltNames;
+ }
+
+ return this.keyColl.insert(doc);
+ }
+
+ getKey(keyId) {
+ return this.keyColl.find({"_id": keyId});
+ }
+
+ getKeyByAltName(keyAltName) {
+ return this.keyColl.find({"keyAltNames": keyAltName});
+ }
+
+ deleteKey(keyId) {
+ return this.keyColl.deleteOne({"_id": keyId});
+ }
+
+ getKeys() {
+ return this.keyColl.find();
+ }
+
+ addKeyAlternateName(keyId, keyAltName) {
+ // keyAltName is not allowed to be an array or an object. In javascript,
+ // typeof array is object.
+ if (typeof keyAltName === "object") {
+ return "TypeError: key alternate name cannot be object or array type.";
+ }
+ return this.keyColl.findAndModify({
+ query: {"_id": keyId},
+ update: {$push: {"keyAltNames": keyAltName}, $currentDate: {"updateDate": true}},
+ });
+ }
+
+ removeKeyAlternateName(keyId, keyAltName) {
+ if (typeof keyAltName === "object") {
+ return "TypeError: key alternate name cannot be object or array type.";
+ }
+ const ret = this.keyColl.findAndModify({
+ query: {"_id": keyId},
+ update: {$pull: {"keyAltNames": keyAltName}, $currentDate: {"updateDate": true}}
+ });
+
+ if (ret != null && ret.keyAltNames.length === 1 && ret.keyAltNames[0] === keyAltName) {
+ // Remove the empty array to prevent duplicate key violations
+ return this.keyColl.findAndModify({
+ query: {"_id": keyId, "keyAltNames": undefined},
+ update: {$unset: {"keyAltNames": ""}, $currentDate: {"updateDate": true}}
+ });
+ }
+ return ret;
+ }
+} \ No newline at end of file
diff --git a/src/mongo/shell/kms.cpp b/src/mongo/shell/kms.cpp
new file mode 100644
index 00000000000..ed7bb0e934f
--- /dev/null
+++ b/src/mongo/shell/kms.cpp
@@ -0,0 +1,80 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "kms.h"
+
+#include "mongo/platform/random.h"
+#include "mongo/shell/kms_gen.h"
+#include "mongo/util/net/hostandport.h"
+#include "mongo/util/text.h"
+
+
+namespace mongo {
+
+HostAndPort parseUrl(StringData url) {
+ // Treat the URL as a host and port
+ // URL: https://(host):(port)
+ //
+ constexpr StringData urlPrefix = "https://"_sd;
+ uassert(51140, "AWS KMS URL must start with https://", url.startsWith(urlPrefix));
+
+ StringData hostAndPort = url.substr(urlPrefix.size());
+
+ return HostAndPort(hostAndPort);
+}
+
+stdx::unordered_map<KMSProviderEnum, std::unique_ptr<KMSServiceFactory>>
+ KMSServiceController::_factories;
+
+void KMSServiceController::registerFactory(KMSProviderEnum provider,
+ std::unique_ptr<KMSServiceFactory> factory) {
+ auto ret = _factories.insert({provider, std::move(factory)});
+ invariant(ret.second);
+}
+
+std::unique_ptr<KMSService> KMSServiceController::createFromClient(StringData kmsProvider,
+ const BSONObj& config) {
+ KMSProviderEnum provider =
+ KMSProvider_parse(IDLParserErrorContext("client fle options"), kmsProvider);
+
+ auto service = _factories.at(provider)->create(config);
+ uassert(51192, str::stream() << "Cannot find client kms provider " << kmsProvider, service);
+ return service;
+}
+
+std::unique_ptr<KMSService> KMSServiceController::createFromDisk(const BSONObj& config,
+ const BSONObj& masterKey) {
+ auto providerObj = masterKey.getStringField("provider"_sd);
+ auto provider = KMSProvider_parse(IDLParserErrorContext("root"), providerObj);
+ auto service = _factories.at(provider)->create(config);
+ uassert(51193, str::stream() << "Cannot find disk kms provider " << providerObj, service);
+ return service;
+}
+
+} // namespace mongo
diff --git a/src/mongo/shell/kms.h b/src/mongo/shell/kms.h
new file mode 100644
index 00000000000..1c69cac9764
--- /dev/null
+++ b/src/mongo/shell/kms.h
@@ -0,0 +1,135 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <cstddef>
+#include <memory>
+#include <vector>
+
+#include "mongo/base/data_range.h"
+#include "mongo/base/secure_allocator.h"
+#include "mongo/base/string_data.h"
+#include "mongo/bson/bsonobj.h"
+#include "mongo/shell/kms_gen.h"
+#include "mongo/stdx/unordered_map.h"
+#include "mongo/util/net/hostandport.h"
+
+namespace mongo {
+
+/**
+ * KMSService
+ *
+ * Represents a Key Management Service. May be a local file KMS or remote.
+ *
+ * Responsible for securely encrypting and decrypting data. The encrypted data is treated as a
+ * blockbox by callers.
+ */
+class KMSService {
+public:
+ virtual ~KMSService() = default;
+
+ /**
+ * Encrypt a plaintext with the specified key and return a encrypted blob.
+ */
+ virtual std::vector<uint8_t> encrypt(ConstDataRange cdr, StringData keyId) = 0;
+
+ /**
+ * Decrypt an encrypted blob and return the plaintext.
+ */
+ virtual SecureVector<uint8_t> decrypt(ConstDataRange cdr, BSONObj masterKey) = 0;
+
+ /**
+ * Encrypt a data key with the specified key and return a BSONObj that describes what needs to
+ * be store in the key vault.
+ *
+ * {
+ * keyMaterial : "<ciphertext>""
+ * masterKey : {
+ * provider : "<provider_name>"
+ * ... <provider specific fields>
+ * }
+ * }
+ */
+ virtual BSONObj encryptDataKey(ConstDataRange cdr, StringData keyId) = 0;
+};
+
+/**
+ * KMSService Factory
+ *
+ * Provides static registration of KMSService.
+ */
+class KMSServiceFactory {
+public:
+ virtual ~KMSServiceFactory() = default;
+
+ /**
+ * Create an instance of the KMS service
+ */
+ virtual std::unique_ptr<KMSService> create(const BSONObj& config) = 0;
+};
+
+/**
+ * KMSService Controller
+ *
+ * Provides static registration of KMSServiceFactory
+ */
+class KMSServiceController {
+public:
+ /**
+ * Create an instance of the KMS service
+ */
+ static void registerFactory(KMSProviderEnum provider,
+ std::unique_ptr<KMSServiceFactory> factory);
+
+
+ /**
+ * Creates a KMS Service for the specified provider with the config.
+ */
+ static std::unique_ptr<KMSService> createFromClient(StringData kmsProvider,
+ const BSONObj& config);
+
+ /**
+ * Creates a KMS Service with the given mongo constructor options and key vault record.
+ */
+ static std::unique_ptr<KMSService> createFromDisk(const BSONObj& config,
+ const BSONObj& kmsProvider);
+
+private:
+ static stdx::unordered_map<KMSProviderEnum, std::unique_ptr<KMSServiceFactory>> _factories;
+};
+
+/**
+ * Parse a basic url of "https://host:port" to a HostAndPort.
+ *
+ * Does not support URL encoding or anything else.
+ */
+HostAndPort parseUrl(StringData url);
+
+} // namespace mongo
diff --git a/src/mongo/shell/kms.idl b/src/mongo/shell/kms.idl
new file mode 100644
index 00000000000..c49bad7f423
--- /dev/null
+++ b/src/mongo/shell/kms.idl
@@ -0,0 +1,164 @@
+# Copyright (C) 2019-present MongoDB, Inc.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the Server Side Public License, version 1,
+# as published by MongoDB, Inc.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# Server Side Public License for more details.
+#
+# You should have received a copy of the Server Side Public License
+# along with this program. If not, see
+# <http://www.mongodb.com/licensing/server-side-public-license>.
+#
+# As a special exception, the copyright holders give permission to link the
+# code of portions of this program with the OpenSSL library under certain
+# conditions as described in each individual source file and distribute
+# linked combinations including the program with the OpenSSL library. You
+# must comply with the Server Side Public License in all respects for
+# all of the code used other than as permitted herein. If you modify file(s)
+# with this exception, you may extend this exception to your version of the
+# file(s), but you are not obligated to do so. If you do not wish to do so,
+# delete this exception statement from your version. If you delete this
+# exception statement from all source files in the program, then also delete
+# it in the license file.
+#
+
+global:
+ cpp_namespace: "mongo"
+
+imports:
+ - "mongo/idl/basic_types.idl"
+
+enums:
+ KMSProvider:
+ description: "Enumeration of supported KMS Providers"
+ type: string
+ values:
+ aws: "aws"
+ local: "local"
+
+structs:
+ awsKMSError:
+ description: "AWS KMS error"
+ fields:
+ __type:
+ type: string
+ cpp_name: type
+ message: string
+
+ # Options passed to Mongo() javascript constructor
+ awsKMS:
+ description: "AWS KMS config"
+ fields:
+ accessKeyId: string
+ secretAccessKey: string
+ sessionToken:
+ type: string
+ optional: true
+ url:
+ type: string
+ optional: true
+
+ # Options passed to Mongo() javascript constructor
+ localKMS:
+ description: "Local KMS config"
+ fields:
+ key: bindata_generic
+
+ kmsProviders:
+ description: "Supported KMS Providers"
+ strict: true
+ fields:
+ aws:
+ type: awsKMS
+ optional: true
+ local:
+ type: localKMS
+ optional: true
+
+ clientSideFLEOptions:
+ description: "FLE Options inputted through the Mongo constructor in the shell"
+ fields:
+ keyVaultClient: #Parsed as a JSHandleValue, not through IDL
+ type: void
+ ignore: true
+ keyVaultNamespace:
+ type: string
+ kmsProviders: kmsProviders
+ schemaMap:
+ type: object
+ bypassAutoEncryption:
+ type: bool
+ optional: true
+
+ awsEncryptResponse:
+ description: "Response from AWS KMS Encrypt request, i.e. TrentService.Encrypt"
+ fields:
+ CiphertextBlob:
+ type: string
+ KeyId:
+ type: string
+
+ awsDecryptResponse:
+ description: "Response from AWS KMS Decrypt request, i.e. TrentService.Decrypt"
+ fields:
+ Plaintext:
+ type: string
+ KeyId:
+ type: string
+
+ awsMasterKey:
+ description: "AWS KMS Key Store Description"
+ fields:
+ provider:
+ type: string
+ default: '"aws"'
+ key:
+ type: string
+ region:
+ type: string
+ endpoint:
+ type: string
+
+ awsMasterKeyAndMaterial:
+ description: "AWS KMS Key Material Description"
+ fields:
+ keyMaterial:
+ type: bindata_generic
+ masterKey:
+ type: awsMasterKey
+
+ localMasterKey:
+ description: "Local KMS Key Store Description"
+ fields:
+ provider:
+ type: string
+ default: '"local"'
+
+ localMasterKeyAndMaterial:
+ description: "Local KMS Key Material Description"
+ fields:
+ keyMaterial:
+ type: bindata_generic
+ masterKey:
+ type: localMasterKey
+
+ keyStoreRecord:
+ description: "A V0 Key Store Record"
+ fields:
+ _id: uuid
+ keyMaterial: bindata_generic
+ creationDate: date
+ updateDate: date
+ status: int
+ version:
+ type: long
+ default: 0
+ masterKey: object
+ keyAltNames:
+ type: array<string>
+ ignore: true
+
diff --git a/src/mongo/shell/kms_aws.cpp b/src/mongo/shell/kms_aws.cpp
new file mode 100644
index 00000000000..167f4ceae56
--- /dev/null
+++ b/src/mongo/shell/kms_aws.cpp
@@ -0,0 +1,461 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kControl
+
+#include <kms_message/kms_message.h>
+
+#include <stdlib.h>
+
+#include "mongo/base/init.h"
+#include "mongo/base/parse_number.h"
+#include "mongo/base/secure_allocator.h"
+#include "mongo/base/status_with.h"
+#include "mongo/bson/json.h"
+#include "mongo/db/commands/test_commands_enabled.h"
+#include "mongo/shell/kms.h"
+#include "mongo/shell/kms_gen.h"
+#include "mongo/util/base64.h"
+#include "mongo/util/log.h"
+#include "mongo/util/net/hostandport.h"
+#include "mongo/util/net/sock.h"
+#include "mongo/util/net/ssl_manager.h"
+#include "mongo/util/net/ssl_options.h"
+#include "mongo/util/text.h"
+#include "mongo/util/time_support.h"
+
+namespace mongo {
+namespace {
+
+/**
+ * Free kms_request_t
+ */
+struct kms_request_tFree {
+ void operator()(kms_request_t* p) noexcept {
+ if (p) {
+ ::kms_request_destroy(p);
+ }
+ }
+};
+
+using UniqueKmsRequest = std::unique_ptr<kms_request_t, kms_request_tFree>;
+
+/**
+ * Free kms_response_parser_t
+ */
+struct kms_response_parser_tFree {
+ void operator()(kms_response_parser_t* p) noexcept {
+ if (p) {
+ ::kms_response_parser_destroy(p);
+ }
+ }
+};
+
+using UniqueKmsResponseParser = std::unique_ptr<kms_response_parser_t, kms_response_parser_tFree>;
+
+/**
+ * Free kms_response_t
+ */
+struct kms_response_tFree {
+ void operator()(kms_response_t* p) noexcept {
+ if (p) {
+ ::kms_response_destroy(p);
+ }
+ }
+};
+
+using UniqueKmsResponse = std::unique_ptr<kms_response_t, kms_response_tFree>;
+
+/**
+ * Free kms_char_buffer
+ */
+struct kms_char_free {
+ void operator()(char* x) {
+ kms_request_free_string(x);
+ }
+};
+
+using UniqueKmsCharBuffer = std::unique_ptr<char, kms_char_free>;
+
+/**
+ * Make a request to a AWS HTTP endpoint.
+ *
+ * Does not maintain a persistent HTTP connection.
+ */
+class AWSConnection {
+public:
+ AWSConnection(SSLManagerInterface* ssl)
+ : _sslManager(ssl), _socket(std::make_unique<Socket>(10, logger::LogSeverity::Log())) {}
+
+ UniqueKmsResponse makeOneRequest(const HostAndPort& host, ConstDataRange request);
+
+private:
+ UniqueKmsResponse sendRequest(ConstDataRange request);
+
+ void connect(const HostAndPort& host);
+
+private:
+ // SSL Manager for connections
+ SSLManagerInterface* _sslManager;
+
+ // Synchronous socket
+ std::unique_ptr<Socket> _socket;
+};
+
+/**
+ * AWS configuration settings
+ */
+struct AWSConfig {
+ // AWS_ACCESS_KEY_ID
+ std::string accessKeyId;
+
+ // AWS_SECRET_ACCESS_KEY
+ SecureString secretAccessKey;
+
+ // Optional AWS_SESSION_TOKEN for AWS STS tokens
+ boost::optional<std::string> sessionToken;
+};
+
+/**
+ * Manages SSL information and config for how to talk to AWS KMS.
+ */
+class AWSKMSService : public KMSService {
+public:
+ AWSKMSService() = default;
+ ~AWSKMSService() final = default;
+
+ static std::unique_ptr<KMSService> create(const AwsKMS& config);
+
+ std::vector<uint8_t> encrypt(ConstDataRange cdr, StringData kmsKeyId) final;
+
+ SecureVector<uint8_t> decrypt(ConstDataRange cdr, BSONObj masterKey) final;
+
+ BSONObj encryptDataKey(ConstDataRange cdr, StringData keyId) final;
+
+private:
+ void initRequest(kms_request_t* request, StringData region);
+
+private:
+ // SSL Manager
+ std::unique_ptr<SSLManagerInterface> _sslManager;
+
+ // Server to connect to
+ HostAndPort _server;
+
+ // AWS configuration settings
+ AWSConfig _config;
+};
+
+void uassertKmsRequestInternal(kms_request_t* request, bool ok) {
+ if (!ok) {
+ const char* msg = kms_request_get_error(request);
+ uasserted(51135, str::stream() << "Internal AWS KMS Error: " << msg);
+ }
+}
+
+#define uassertKmsRequest(X) uassertKmsRequestInternal(request, (X));
+
+void AWSKMSService::initRequest(kms_request_t* request, StringData region) {
+
+ // use current time
+ uassertKmsRequest(kms_request_set_date(request, nullptr));
+
+ uassertKmsRequest(kms_request_set_region(request, region.toString().c_str()));
+
+ // kms is always the name of the service
+ uassertKmsRequest(kms_request_set_service(request, "kms"));
+
+ uassertKmsRequest(kms_request_set_access_key_id(request, _config.accessKeyId.c_str()));
+ uassertKmsRequest(kms_request_set_secret_key(request, _config.secretAccessKey->c_str()));
+
+ if (!_config.sessionToken.value_or("").empty()) {
+ // TODO: move this into kms-message
+ uassertKmsRequest(kms_request_add_header_field(
+ request, "X-Amz-Security-Token", _config.sessionToken.get().c_str()));
+ }
+}
+
+std::vector<uint8_t> toVector(const std::string& str) {
+ std::vector<uint8_t> blob;
+
+ std::transform(std::begin(str), std::end(str), std::back_inserter(blob), [](auto c) {
+ return static_cast<uint8_t>(c);
+ });
+
+ return blob;
+}
+
+SecureVector<uint8_t> toSecureVector(const std::string& str) {
+ SecureVector<uint8_t> blob(str.length());
+
+ std::transform(std::begin(str), std::end(str), blob->data(), [](auto c) {
+ return static_cast<uint8_t>(c);
+ });
+
+ return blob;
+}
+
+/**
+ * Takes in a CMK of the format arn:partition:service:region:account-id:resource (minimum). We
+ * care about extracting the region. This function ensures that there are at least 6 partitions,
+ * parses the provider, and returns a pair of provider and the region.
+ */
+std::string parseCMK(StringData cmk) {
+ std::vector<std::string> cmkTokenized = StringSplitter::split(cmk.toString(), ":");
+ uassert(31040, "Invalid AWS KMS Customer Master Key.", cmkTokenized.size() > 5);
+ return cmkTokenized[3];
+}
+
+HostAndPort getDefaultHost(StringData region) {
+ std::string hostname = str::stream() << "kms." << region << ".amazonaws.com";
+ return HostAndPort(hostname, 443);
+}
+
+std::vector<uint8_t> AWSKMSService::encrypt(ConstDataRange cdr, StringData kmsKeyId) {
+ auto request =
+ UniqueKmsRequest(kms_encrypt_request_new(reinterpret_cast<const uint8_t*>(cdr.data()),
+ cdr.length(),
+ kmsKeyId.toString().c_str(),
+ NULL));
+
+ auto region = parseCMK(kmsKeyId);
+
+ if (_server.empty()) {
+ _server = getDefaultHost(region);
+ }
+
+ initRequest(request.get(), region);
+
+ auto buffer = UniqueKmsCharBuffer(kms_request_get_signed(request.get()));
+ auto buffer_len = strlen(buffer.get());
+
+ AWSConnection connection(_sslManager.get());
+ auto response = connection.makeOneRequest(_server, ConstDataRange(buffer.get(), buffer_len));
+
+ auto body = kms_response_get_body(response.get());
+
+ BSONObj obj = fromjson(body);
+
+ auto field = obj["__type"];
+
+ if (!field.eoo()) {
+ auto awsResponse = AwsKMSError::parse(IDLParserErrorContext("root"), obj);
+
+ uasserted(51224,
+ str::stream() << "AWS KMS failed to encrypt: " << awsResponse.getType() << " : "
+ << awsResponse.getMessage());
+ }
+
+ auto awsResponse = AwsEncryptResponse::parse(IDLParserErrorContext("root"), obj);
+
+ auto blobStr = base64::decode(awsResponse.getCiphertextBlob().toString());
+
+ return toVector(blobStr);
+}
+
+BSONObj AWSKMSService::encryptDataKey(ConstDataRange cdr, StringData keyId) {
+ auto dataKey = encrypt(cdr, keyId);
+
+ AwsMasterKey masterKey;
+ masterKey.setKey(keyId);
+ masterKey.setRegion(parseCMK(keyId));
+ masterKey.setEndpoint(_server.toString());
+
+ AwsMasterKeyAndMaterial keyAndMaterial;
+ keyAndMaterial.setKeyMaterial(dataKey);
+ keyAndMaterial.setMasterKey(masterKey);
+
+ return keyAndMaterial.toBSON();
+}
+
+SecureVector<uint8_t> AWSKMSService::decrypt(ConstDataRange cdr, BSONObj masterKey) {
+ auto awsMasterKey = AwsMasterKey::parse(IDLParserErrorContext("root"), masterKey);
+
+ auto request = UniqueKmsRequest(kms_decrypt_request_new(
+ reinterpret_cast<const uint8_t*>(cdr.data()), cdr.length(), nullptr));
+
+ initRequest(request.get(), awsMasterKey.getRegion());
+
+ if (_server.empty()) {
+ _server = getDefaultHost(awsMasterKey.getRegion());
+ }
+
+ auto buffer = UniqueKmsCharBuffer(kms_request_get_signed(request.get()));
+ auto buffer_len = strlen(buffer.get());
+ AWSConnection connection(_sslManager.get());
+ auto response = connection.makeOneRequest(_server, ConstDataRange(buffer.get(), buffer_len));
+
+ auto body = kms_response_get_body(response.get());
+
+ BSONObj obj = fromjson(body);
+
+ auto field = obj["__type"];
+
+ if (!field.eoo()) {
+ auto awsResponse = AwsKMSError::parse(IDLParserErrorContext("root"), obj);
+
+ uasserted(51225,
+ str::stream() << "AWS KMS failed to decrypt: " << awsResponse.getType() << " : "
+ << awsResponse.getMessage());
+ }
+
+ auto awsResponse = AwsDecryptResponse::parse(IDLParserErrorContext("root"), obj);
+
+ auto blobStr = base64::decode(awsResponse.getPlaintext().toString());
+
+ return toSecureVector(blobStr);
+}
+
+void AWSConnection::connect(const HostAndPort& host) {
+ SockAddr server(host.host().c_str(), host.port(), AF_UNSPEC);
+
+ uassert(51136,
+ str::stream() << "AWS KMS server address " << host.host() << " is invalid.",
+ server.isValid());
+
+ uassert(51137,
+ str::stream() << "Could not connect to AWS KMS server " << server.toString(),
+ _socket->connect(server));
+
+ uassert(51138,
+ str::stream() << "Failed to perform SSL handshake with the AWS KMS server "
+ << host.toString(),
+ _socket->secure(_sslManager, host.host()));
+}
+
+// Sends a request message to the AWS KMS server and creates a KMS Response.
+UniqueKmsResponse AWSConnection::sendRequest(ConstDataRange request) {
+ std::array<char, 512> resp;
+
+ _socket->send(
+ reinterpret_cast<const char*>(request.data()), request.length(), "AWS KMS request");
+
+ auto parser = UniqueKmsResponseParser(kms_response_parser_new());
+ int bytes_to_read = 0;
+
+ while ((bytes_to_read = kms_response_parser_wants_bytes(parser.get(), resp.size())) > 0) {
+ bytes_to_read = std::min(bytes_to_read, static_cast<int>(resp.size()));
+ bytes_to_read = _socket->unsafe_recv(resp.data(), bytes_to_read);
+
+ uassert(51139,
+ "kms_response_parser_feed failed",
+ kms_response_parser_feed(
+ parser.get(), reinterpret_cast<uint8_t*>(resp.data()), bytes_to_read));
+ }
+
+ auto response = UniqueKmsResponse(kms_response_parser_get_response(parser.get()));
+
+ return response;
+}
+
+UniqueKmsResponse AWSConnection::makeOneRequest(const HostAndPort& host, ConstDataRange request) {
+ connect(host);
+
+ auto resp = sendRequest(request);
+
+ _socket->close();
+
+ return resp;
+}
+
+boost::optional<std::string> toString(boost::optional<StringData> str) {
+ if (str) {
+ return {str.get().toString()};
+ }
+ return boost::none;
+}
+
+std::unique_ptr<KMSService> AWSKMSService::create(const AwsKMS& config) {
+ auto awsKMS = std::make_unique<AWSKMSService>();
+
+ SSLParams params;
+ params.sslPEMKeyFile = "";
+ params.sslPEMKeyPassword = "";
+ params.sslClusterFile = "";
+ params.sslClusterPassword = "";
+ params.sslCAFile = "";
+
+ params.sslCRLFile = "";
+
+ // Copy the rest from the global SSL manager options.
+ params.sslFIPSMode = sslGlobalParams.sslFIPSMode;
+
+ // KMS servers never should have invalid certificates
+ params.sslAllowInvalidCertificates = false;
+ params.sslAllowInvalidHostnames = false;
+
+ params.sslDisabledProtocols =
+ std::vector({SSLParams::Protocols::TLS1_0, SSLParams::Protocols::TLS1_1});
+
+ // Leave the CA file empty so we default to system CA but for local testing allow it to inherit
+ // the CA file.
+ if (!config.getUrl().value_or("").empty()) {
+ params.sslCAFile = sslGlobalParams.sslCAFile;
+ awsKMS->_server = parseUrl(config.getUrl().get());
+ }
+
+ awsKMS->_sslManager = SSLManagerInterface::create(params, false);
+
+ awsKMS->_config.accessKeyId = config.getAccessKeyId().toString();
+
+ awsKMS->_config.secretAccessKey = config.getSecretAccessKey().toString();
+
+ awsKMS->_config.sessionToken = toString(config.getSessionToken());
+
+ return awsKMS;
+}
+
+/**
+ * Factory for AWSKMSService if user specifies aws config to mongo() JS constructor.
+ */
+class AWSKMSServiceFactory final : public KMSServiceFactory {
+public:
+ AWSKMSServiceFactory() = default;
+ ~AWSKMSServiceFactory() = default;
+
+ std::unique_ptr<KMSService> create(const BSONObj& config) final {
+ auto field = config[KmsProviders::kAwsFieldName];
+ if (field.eoo()) {
+ return nullptr;
+ }
+ auto obj = field.Obj();
+ return AWSKMSService::create(AwsKMS::parse(IDLParserErrorContext("root"), obj));
+ }
+};
+
+} // namspace
+
+MONGO_INITIALIZER(KMSRegister)(::mongo::InitializerContext* context) {
+ kms_message_init();
+ KMSServiceController::registerFactory(KMSProviderEnum::aws,
+ std::make_unique<AWSKMSServiceFactory>());
+ return Status::OK();
+}
+
+} // namespace mongo
diff --git a/src/mongo/shell/kms_local.cpp b/src/mongo/shell/kms_local.cpp
new file mode 100644
index 00000000000..6fec6511e1b
--- /dev/null
+++ b/src/mongo/shell/kms_local.cpp
@@ -0,0 +1,153 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <kms_message/kms_message.h>
+
+#include <stdlib.h>
+
+#include "mongo/base/init.h"
+#include "mongo/base/secure_allocator.h"
+#include "mongo/base/status_with.h"
+#include "mongo/bson/json.h"
+#include "mongo/crypto/aead_encryption.h"
+#include "mongo/crypto/symmetric_crypto.h"
+#include "mongo/crypto/symmetric_key.h"
+#include "mongo/shell/kms.h"
+#include "mongo/shell/kms_gen.h"
+#include "mongo/util/base64.h"
+
+namespace mongo {
+namespace {
+
+/**
+ * Manages Local KMS Information
+ */
+class LocalKMSService : public KMSService {
+public:
+ LocalKMSService(SymmetricKey key) : _key(std::move(key)) {}
+ ~LocalKMSService() final = default;
+
+ static std::unique_ptr<KMSService> create(const LocalKMS& config);
+
+ std::vector<uint8_t> encrypt(ConstDataRange cdr, StringData kmsKeyId) final;
+
+ SecureVector<uint8_t> decrypt(ConstDataRange cdr, BSONObj masterKey) final;
+
+ BSONObj encryptDataKey(ConstDataRange cdr, StringData keyId) final;
+
+private:
+ // Key that wraps all KMS encrypted data
+ SymmetricKey _key;
+};
+
+std::vector<uint8_t> LocalKMSService::encrypt(ConstDataRange cdr, StringData kmsKeyId) {
+ std::vector<std::uint8_t> ciphertext(crypto::aeadCipherOutputLength(cdr.length()));
+
+ uassertStatusOK(crypto::aeadEncrypt(_key,
+ reinterpret_cast<const uint8_t*>(cdr.data()),
+ cdr.length(),
+ nullptr,
+ 0,
+ ciphertext.data(),
+ ciphertext.size()));
+
+ return ciphertext;
+}
+
+BSONObj LocalKMSService::encryptDataKey(ConstDataRange cdr, StringData keyId) {
+ auto dataKey = encrypt(cdr, keyId);
+
+ LocalMasterKey masterKey;
+
+ LocalMasterKeyAndMaterial keyAndMaterial;
+ keyAndMaterial.setKeyMaterial(dataKey);
+ keyAndMaterial.setMasterKey(masterKey);
+
+ return keyAndMaterial.toBSON();
+}
+
+SecureVector<uint8_t> LocalKMSService::decrypt(ConstDataRange cdr, BSONObj masterKey) {
+ SecureVector<uint8_t> plaintext(cdr.length());
+
+ size_t outLen = plaintext->size();
+ uassertStatusOK(crypto::aeadDecrypt(_key,
+ reinterpret_cast<const uint8_t*>(cdr.data()),
+ cdr.length(),
+ nullptr,
+ 0,
+ plaintext->data(),
+ &outLen));
+ plaintext->resize(outLen);
+
+ return plaintext;
+}
+
+std::unique_ptr<KMSService> LocalKMSService::create(const LocalKMS& config) {
+ uassert(51237,
+ str::stream() << "Local KMS key must be 64 bytes, found " << config.getKey().length()
+ << " bytes instead",
+ config.getKey().length() == crypto::kAeadAesHmacKeySize);
+
+ SecureVector<uint8_t> aesVector = SecureVector<uint8_t>(
+ config.getKey().data(), config.getKey().data() + config.getKey().length());
+ SymmetricKey key = SymmetricKey(aesVector, crypto::aesAlgorithm, "local");
+
+ auto localKMS = std::make_unique<LocalKMSService>(std::move(key));
+
+ return localKMS;
+}
+
+/**
+ * Factory for LocalKMSService if user specifies local config to mongo() JS constructor.
+ */
+class LocalKMSServiceFactory final : public KMSServiceFactory {
+public:
+ LocalKMSServiceFactory() = default;
+ ~LocalKMSServiceFactory() = default;
+
+ std::unique_ptr<KMSService> create(const BSONObj& config) final {
+ auto field = config[KmsProviders::kLocalFieldName];
+ if (field.eoo()) {
+ return nullptr;
+ }
+
+ auto obj = field.Obj();
+ return LocalKMSService::create(LocalKMS::parse(IDLParserErrorContext("root"), obj));
+ }
+};
+
+} // namspace
+
+MONGO_INITIALIZER(LocalKMSRegister)(::mongo::InitializerContext* context) {
+ KMSServiceController::registerFactory(KMSProviderEnum::local,
+ std::make_unique<LocalKMSServiceFactory>());
+ return Status::OK();
+}
+
+} // namespace mongo
diff --git a/src/mongo/shell/kms_shell.cpp b/src/mongo/shell/kms_shell.cpp
new file mode 100644
index 00000000000..f05fbdf3c8e
--- /dev/null
+++ b/src/mongo/shell/kms_shell.cpp
@@ -0,0 +1,52 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/base/init.h"
+#include "mongo/scripting/engine.h"
+#include "mongo/shell/shell_utils.h"
+
+namespace mongo {
+
+namespace JSFiles {
+extern const JSFile keyvault;
+}
+
+namespace {
+
+void callback_fn(Scope& scope) {
+ scope.execSetup(JSFiles::keyvault);
+}
+
+MONGO_INITIALIZER(setKeyvaultCallback)(InitializerContext*) {
+ shell_utils::setEnterpriseShellCallback(mongo::callback_fn);
+ return Status::OK();
+}
+
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/shell/kms_test.cpp b/src/mongo/shell/kms_test.cpp
new file mode 100644
index 00000000000..fd3284acf6e
--- /dev/null
+++ b/src/mongo/shell/kms_test.cpp
@@ -0,0 +1,86 @@
+/**
+ * Copyright (C) 2019-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "kms.h"
+
+#include "mongo/base/data_range.h"
+#include "mongo/bson/bsonmisc.h"
+#include "mongo/bson/bsonobj.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo {
+namespace {
+
+bool isEquals(ConstDataRange left, ConstDataRange right) {
+ return std::equal(
+ left.data(), left.data() + left.length(), right.data(), right.data() + right.length());
+}
+
+
+// Negative: incorrect key size
+TEST(KmsTest, TestBadKey) {
+ std::array<uint8_t, 3> key{0x1, 0x2, 0x3};
+ BSONObj config =
+ BSON("local" << BSON("key" << BSONBinData(key.data(), key.size(), BinDataGeneral)));
+
+ ASSERT_THROWS(KMSServiceController::createFromClient("local", config), AssertionException);
+}
+
+// Positive: Test Encrypt works
+TEST(KmsTest, TestGoodKey) {
+ std::array<uint8_t, 64> key = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a,
+ 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
+ 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
+ 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b,
+ 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36,
+ 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f};
+
+ BSONObj config =
+ BSON("local" << BSON("key" << BSONBinData(key.data(), key.size(), BinDataGeneral)));
+
+ auto service = KMSServiceController::createFromClient("local", config);
+
+ auto myKey = "My Secret Key"_sd;
+
+ auto material = service->encryptDataKey(ConstDataRange(myKey.rawData(), myKey.size()), "");
+
+ LocalMasterKeyAndMaterial glob =
+ LocalMasterKeyAndMaterial::parse(IDLParserErrorContext("root"), material);
+
+ auto keyMaterial = glob.getKeyMaterial();
+
+ auto plaintext = service->decrypt(keyMaterial, BSONObj());
+
+ ASSERT_TRUE(isEquals(myKey.toString(), *plaintext));
+}
+
+} // namespace
+} // namespace mongo