summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Benvenuto <mark.benvenuto@mongodb.com>2021-01-21 20:32:40 -0500
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-03-15 16:10:09 +0000
commita71d3efdc76842dda106e73c71d28b003013abcf (patch)
tree6c4724a2f400f0167d3787fa0316a2eea592320c
parentb8e3b2a500c3467c4ded2211ffc13bc6fd004e75 (diff)
downloadmongo-a71d3efdc76842dda106e73c71d28b003013abcf.tar.gz
SERVER-52651 Add FLE Support for Azure
(cherry picked from commit 5724443a4cf42d369714a86ee76de0a41f02bfd8)
-rw-r--r--jstests/client_encrypt/fle_azure_faults.js175
-rwxr-xr-xjstests/client_encrypt/lib/kms_http_server_azure.py239
-rw-r--r--jstests/client_encrypt/lib/mock_kms.js18
-rw-r--r--src/mongo/shell/SConscript1
-rw-r--r--src/mongo/shell/kms.cpp4
-rw-r--r--src/mongo/shell/kms.idl62
-rw-r--r--src/mongo/shell/kms_azure.cpp301
-rw-r--r--src/mongo/shell/kms_network.cpp6
-rw-r--r--src/mongo/shell/kms_network.h4
9 files changed, 799 insertions, 11 deletions
diff --git a/jstests/client_encrypt/fle_azure_faults.js b/jstests/client_encrypt/fle_azure_faults.js
new file mode 100644
index 00000000000..dfcf121844a
--- /dev/null
+++ b/jstests/client_encrypt/fle_azure_faults.js
@@ -0,0 +1,175 @@
+/**
+ * Verify the Azure KMS implementation can handle a buggy KMS.
+ */
+
+load("jstests/client_encrypt/lib/mock_kms.js");
+load('jstests/ssl/libs/ssl_helpers.js');
+
+(function() {
+"use strict";
+
+const x509_options = {
+ sslMode: "requireSSL",
+ sslPEMKeyFile: SERVER_CERT,
+ sslCAFile: CA_CERT
+};
+
+const mockKey = {
+ keyName: "my_key",
+ keyVaultEndpoint: "https://localhost:80",
+};
+
+const randomAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Random";
+
+const conn = MongoRunner.runMongod(x509_options);
+const test = conn.getDB("test");
+const collection = test.coll;
+
+function runKMS(mock_kms, func) {
+ mock_kms.start();
+
+ const azureKMS = {
+ tenantId: "my_tentant",
+ clientId: "access@mongodb.com",
+ clientSecret: "secret",
+ identityPlatformEndpoint: mock_kms.getURL(),
+ };
+
+ const clientSideFLEOptions = {
+ kmsProviders: {
+ azure: azureKMS,
+ },
+ keyVaultNamespace: "test.coll",
+ schemaMap: {},
+ };
+
+ const shell = Mongo(conn.host, clientSideFLEOptions);
+ const cleanCacheShell = Mongo(conn.host, clientSideFLEOptions);
+
+ collection.drop();
+
+ func(shell, cleanCacheShell);
+
+ mock_kms.stop();
+}
+
+// OAuth faults must be tested first so a cached token cannot be used
+function testBadOAuthRequestResult() {
+ const mock_kms = new MockKMSServerAzure(FAULT_OAUTH, false);
+
+ runKMS(mock_kms, (shell) => {
+ const keyVault = shell.getKeyVault();
+
+ const error = assert.throws(() => keyVault.createKey("azure", mockKey, ["mongoKey"]));
+ assert.eq(
+ error,
+ "Error: code 9: FailedToParse: Expecting '{': offset:0 of:Internal Error of some sort.");
+ });
+}
+
+testBadOAuthRequestResult();
+
+function testBadOAuthRequestError() {
+ const mock_kms = new MockKMSServerAzure(FAULT_OAUTH_CORRECT_FORMAT, false);
+
+ runKMS(mock_kms, (shell) => {
+ const keyVault = shell.getKeyVault();
+
+ const error = assert.throws(() => keyVault.createKey("azure", mockKey, ["mongoKey"]));
+ assert.commandFailedWithCode(error, [ErrorCodes.OperationFailed]);
+ assert.eq(
+ error,
+ "Error: Failed to make oauth request: Azure OAuth Error : FAULT_OAUTH_CORRECT_FORMAT");
+ });
+}
+
+testBadOAuthRequestError();
+
+function testBadEncryptResult() {
+ const mock_kms = new MockKMSServerAzure(FAULT_ENCRYPT, false);
+
+ runKMS(mock_kms, (shell) => {
+ const keyVault = shell.getKeyVault();
+ mockKey.keyVaultEndpoint = mock_kms.getEndpoint();
+
+ assert.throws(() => keyVault.createKey("azure", mockKey, ["mongoKey"]));
+ assert.eq(keyVault.getKeys("mongoKey").toArray().length, 0);
+ });
+}
+
+testBadEncryptResult();
+
+function testBadEncryptError() {
+ const mock_kms = new MockKMSServerAzure(FAULT_ENCRYPT_CORRECT_FORMAT, false);
+
+ runKMS(mock_kms, (shell) => {
+ const keyVault = shell.getKeyVault();
+ mockKey.keyVaultEndpoint = mock_kms.getEndpoint();
+
+ let error = assert.throws(() => keyVault.createKey("azure", mockKey, ["mongoKey"]));
+ assert.commandFailedWithCode(error, [5265103]);
+ });
+}
+
+testBadEncryptError();
+
+function testBadDecryptResult() {
+ const mock_kms = new MockKMSServerAzure(FAULT_DECRYPT, false);
+
+ runKMS(mock_kms, (shell) => {
+ const keyVault = shell.getKeyVault();
+ mockKey.keyVaultEndpoint = mock_kms.getEndpoint();
+
+ const keyId = keyVault.createKey("azure", mockKey, ["mongoKey"]);
+ const str = "mongo";
+ assert.throws(() => {
+ const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm);
+ });
+ });
+}
+
+testBadDecryptResult();
+
+function testBadDecryptKeyResult() {
+ const mock_kms = new MockKMSServerAzure(FAULT_DECRYPT_WRONG_KEY, true);
+
+ runKMS(mock_kms, (shell, cleanCacheShell) => {
+ const keyVault = shell.getKeyVault();
+ mockKey.keyVaultEndpoint = mock_kms.getEndpoint();
+
+ keyVault.createKey("azure", mockKey, ["mongoKey"]);
+ const keyId = keyVault.getKeys("mongoKey").toArray()[0]._id;
+ const str = "mongo";
+ const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm);
+
+ mock_kms.enableFaults();
+
+ assert.throws(() => {
+ let str = cleanCacheShell.decrypt(encStr);
+ });
+ });
+}
+
+testBadDecryptKeyResult();
+
+function testBadDecryptError() {
+ const mock_kms = new MockKMSServerAzure(FAULT_DECRYPT_CORRECT_FORMAT, false);
+
+ runKMS(mock_kms, (shell) => {
+ const keyVault = shell.getKeyVault();
+ mockKey.keyVaultEndpoint = mock_kms.getEndpoint();
+
+ keyVault.createKey("azure", mockKey, ["mongoKey"]);
+ const keyId = keyVault.getKeys("mongoKey").toArray()[0]._id;
+ const str = "mongo";
+ let error = assert.throws(() => {
+ const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm);
+ });
+ assert.commandFailedWithCode(error, [5265103]);
+ });
+}
+
+testBadDecryptError();
+
+MongoRunner.stopMongod(conn);
+})();
diff --git a/jstests/client_encrypt/lib/kms_http_server_azure.py b/jstests/client_encrypt/lib/kms_http_server_azure.py
new file mode 100755
index 00000000000..af2b7b4a016
--- /dev/null
+++ b/jstests/client_encrypt/lib/kms_http_server_azure.py
@@ -0,0 +1,239 @@
+#! /usr/bin/env python3
+"""Mock Azure KMS Endpoint."""
+import argparse
+import base64
+import http
+import json
+import logging
+import urllib.parse
+import sys
+
+import kms_http_common
+
+SUPPORTED_FAULT_TYPES = [
+ kms_http_common.FAULT_ENCRYPT,
+ kms_http_common.FAULT_ENCRYPT_CORRECT_FORMAT,
+ kms_http_common.FAULT_DECRYPT,
+ kms_http_common.FAULT_DECRYPT_CORRECT_FORMAT,
+ kms_http_common.FAULT_DECRYPT_WRONG_KEY,
+ kms_http_common.FAULT_OAUTH,
+ kms_http_common.FAULT_OAUTH_CORRECT_FORMAT,
+]
+
+SECRET_PREFIX = "00SECRET"
+FAKE_OAUTH_TOKEN = "omg_im_an_oauth_token"
+
+URL_PATH_OAUTH_AUDIENCE = "/token"
+URL_PATH_OAUTH_SCOPE = "/auth/cloudkms"
+URL_PATH_MOCK_KEY = "/keys/my_key/"
+
+
+class AzureKmsHandler(kms_http_common.KmsHandlerBase):
+ """
+ Handle requests from Azure KMS Monitoring and test commands
+ """
+
+ def do_POST(self):
+ """Serve a POST request."""
+ print("Received POST: " + self.path)
+ parts = urllib.parse.urlsplit(self.path)
+ path = parts[2]
+
+ if path == "/my_tentant/oauth2/v2.0/token":
+ self._do_oauth_request()
+ elif path.startswith(URL_PATH_MOCK_KEY):
+ self._do_operation()
+ else:
+ self.send_response(http.HTTPStatus.NOT_FOUND)
+ self.end_headers()
+ self.wfile.write("Unknown URL".encode())
+
+ def _do_operation(self):
+ clen = int(self.headers.get("content-length"))
+
+ raw_input = self.rfile.read(clen)
+
+ print(f"RAW INPUT: {str(raw_input)}")
+
+ if not self.headers["Host"] == "localhost":
+ data = "Unexpected host"
+ self._send_reply(data.encode("utf-8"))
+
+ if not self.headers["Authorization"] == f"Bearer {FAKE_OAUTH_TOKEN}":
+ data = "Unexpected bearer token"
+ self._send_reply(data.encode("utf-8"))
+
+ parts = urllib.parse.urlsplit(self.path)
+ operation = parts.path.split('/')[-1]
+
+ if operation == "wrapkey":
+ self._do_encrypt(raw_input)
+ elif operation == "unwrapkey":
+ self._do_decrypt(raw_input)
+ else:
+ self._send_reply(f"Unknown operation: {operation}".encode("utf-8"))
+
+ def _do_encrypt(self, raw_input):
+ request = json.loads(raw_input)
+
+ print(request)
+
+ plaintext = request["value"]
+
+ ciphertext = SECRET_PREFIX.encode() + plaintext.encode()
+ ciphertext = base64.urlsafe_b64encode(ciphertext).decode()
+
+ if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_ENCRYPT) \
+ and not kms_http_common.disable_faults:
+ return self._do_encrypt_faults(ciphertext)
+
+ response = {
+ "value": ciphertext,
+ "kid": "my_key",
+ }
+
+ self._send_reply(json.dumps(response).encode('utf-8'))
+
+ def _do_encrypt_faults(self, raw_ciphertext):
+ kms_http_common.stats.fault_calls += 1
+
+ if kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT:
+ self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR)
+ return
+ elif kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT_CORRECT_FORMAT:
+ response = {
+ "error": {
+ "code": "bad",
+ "message": "Error encrypting message",
+ }
+ }
+ self._send_reply(json.dumps(response).encode('utf-8'))
+ return
+
+ raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type)
+
+ def _do_decrypt(self, raw_input):
+ request = json.loads(raw_input)
+ blob = base64.urlsafe_b64decode(request["value"]).decode()
+
+ print("FOUND SECRET: " + blob)
+
+ # our "encrypted" values start with the word SECRET_PREFIX otherwise they did not come from us
+ if not blob.startswith(SECRET_PREFIX):
+ raise ValueError()
+
+ blob = blob[len(SECRET_PREFIX):]
+
+ if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_DECRYPT) \
+ and not kms_http_common.disable_faults:
+ return self._do_decrypt_faults(blob)
+
+ response = {
+ "kid": "my_key",
+ "value": blob,
+ }
+
+ self._send_reply(json.dumps(response).encode('utf-8'))
+
+ def _do_decrypt_faults(self, blob):
+ kms_http_common.stats.fault_calls += 1
+
+ if kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT:
+ self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR)
+ return
+ elif kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_WRONG_KEY:
+ response = {
+ "kid": "my_key",
+ "value": "ta7DXE7J0OiCRw03dYMJSeb8nVF5qxTmZ9zWmjuX4zW/SOorSCaY8VMTWG+cRInMx/rr/+QeVw2WjU2IpOSvMg==",
+ }
+ self._send_reply(json.dumps(response).encode('utf-8'))
+ return
+ elif kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_CORRECT_FORMAT:
+ response = {
+ "error": {
+ "code": "bad",
+ "message": "Error decrypting message",
+ }
+ }
+ self._send_reply(json.dumps(response).encode('utf-8'))
+ return
+
+ raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type)
+
+ def _do_oauth_request(self):
+ clen = int(self.headers.get('content-length'))
+
+ raw_input = self.rfile.read(clen)
+
+ print(f"RAW INPUT: {str(raw_input)}")
+
+ if not self.headers["Host"] == "localhost":
+ data = "Unexpected host"
+ self._send_reply(data.encode("utf-8"))
+
+ if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_OAUTH) \
+ and not kms_http_common.disable_faults:
+ return self._do_oauth_faults()
+
+ response = {
+ "access_token": FAKE_OAUTH_TOKEN,
+ "scope": self.headers["Host"] + URL_PATH_OAUTH_SCOPE,
+ "token_type": "Bearer",
+ "expires_in": 3600,
+ }
+
+ self._send_reply(json.dumps(response).encode("utf-8"))
+
+ def _do_oauth_faults(self):
+ kms_http_common.stats.fault_calls += 1
+
+ if kms_http_common.fault_type == kms_http_common.FAULT_OAUTH:
+ self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR)
+ return
+ elif kms_http_common.fault_type == kms_http_common.FAULT_OAUTH_CORRECT_FORMAT:
+ response = {
+ "error": "Azure OAuth Error",
+ "error_description": "FAULT_OAUTH_CORRECT_FORMAT",
+ "error_uri": "https://mongodb.com/whoopsies.pdf",
+ }
+ self._send_reply(json.dumps(response).encode('utf-8'))
+ return
+
+ raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type)
+
+
+def main():
+ """Main Method."""
+ parser = argparse.ArgumentParser(description='MongoDB Mock Azure KMS Endpoint.')
+
+ parser.add_argument('-p', '--port', type=int, default=8000, help="Port to listen on")
+
+ parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing")
+
+ parser.add_argument('--fault', type=str, help="Type of fault to inject")
+
+ parser.add_argument('--disable-faults', action='store_true', help="Disable faults on startup")
+
+ parser.add_argument('--ca_file', type=str, required=True, help="TLS CA PEM file")
+
+ parser.add_argument('--cert_file', type=str, required=True, help="TLS Server PEM file")
+
+ args = parser.parse_args()
+ if args.verbose:
+ logging.basicConfig(level=logging.DEBUG)
+
+ if args.fault:
+ if args.fault not in SUPPORTED_FAULT_TYPES:
+ print("Unsupported fault type %s, supports types are %s" % (args.fault, SUPPORTED_FAULT_TYPES))
+ sys.exit(1)
+
+ kms_http_common.fault_type = args.fault
+
+ if args.disable_faults:
+ kms_http_common.disable_faults = True
+
+ kms_http_common.run(args.port, args.cert_file, args.ca_file, AzureKmsHandler)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/jstests/client_encrypt/lib/mock_kms.js b/jstests/client_encrypt/lib/mock_kms.js
index b5d61986dd2..2dabdd96916 100644
--- a/jstests/client_encrypt/lib/mock_kms.js
+++ b/jstests/client_encrypt/lib/mock_kms.js
@@ -146,7 +146,7 @@ class MockKMSServerAWS {
}
/**
- * Get the URL.
+ * Get the URL. Prefixed with https://
*
* @return {string} url of http server
*/
@@ -155,6 +155,15 @@ class MockKMSServerAWS {
}
/**
+ * Get the endpoint. A "<host>:<port>".
+ *
+ * @return {string} url of http server
+ */
+ getEndpoint() {
+ return "localhost:" + this.port;
+ }
+
+ /**
* Stop the web server
*/
stop() {
@@ -168,3 +177,10 @@ class MockKMSServerGCP extends MockKMSServerAWS {
this.web_server_py = "jstests/client_encrypt/lib/kms_http_server_gcp.py";
}
}
+
+class MockKMSServerAzure extends MockKMSServerAWS {
+ constructor(fault_type, disableFaultsOnStartup) {
+ super(fault_type, disableFaultsOnStartup);
+ this.web_server_py = "jstests/client_encrypt/lib/kms_http_server_azure.py";
+ }
+}
diff --git a/src/mongo/shell/SConscript b/src/mongo/shell/SConscript
index 6bbfeccbf3f..675c09e7b07 100644
--- a/src/mongo/shell/SConscript
+++ b/src/mongo/shell/SConscript
@@ -141,6 +141,7 @@ kmsEnv.Library(
source=[
"kms.cpp",
"kms_aws.cpp",
+ "kms_azure.cpp",
"kms_gcp.cpp",
"kms_local.cpp",
"kms_network.cpp",
diff --git a/src/mongo/shell/kms.cpp b/src/mongo/shell/kms.cpp
index 2883e9bd0ee..4bea437a267 100644
--- a/src/mongo/shell/kms.cpp
+++ b/src/mongo/shell/kms.cpp
@@ -42,7 +42,9 @@ HostAndPort parseUrl(StringData url) {
// URL: https://(host):(port)
//
constexpr StringData urlPrefix = "https://"_sd;
- uassert(51140, "KMS URL must start with https://", url.startsWith(urlPrefix));
+ uassert(51140,
+ str::stream() << "KMS URL must start with https://, URL: " << url,
+ url.startsWith(urlPrefix));
StringData hostAndPort = url.substr(urlPrefix.size());
diff --git a/src/mongo/shell/kms.idl b/src/mongo/shell/kms.idl
index 47129a51baa..53ac55f2c7b 100644
--- a/src/mongo/shell/kms.idl
+++ b/src/mongo/shell/kms.idl
@@ -38,6 +38,7 @@ enums:
type: string
values:
aws: "aws"
+ azure: "azure"
gcp: "gcp"
local: "local"
@@ -64,6 +65,24 @@ structs:
type: string
optional: true
+ azureKMSError:
+ description: "Azure KMS Error"
+ strict: false
+ fields:
+ code: string
+ message: string
+
+ # Options passed to Mongo() javascript constructor
+ azureKMS:
+ description: "Azure KMS config"
+ fields:
+ tenantId: string
+ clientId: string
+ clientSecret: string
+ identityPlatformEndpoint:
+ type: string
+ optional: true
+
# Documented here: https://cloud.google.com/apis/design/errors#http_mapping
gcpKMSError:
description: "GCP KMS Error"
@@ -96,6 +115,9 @@ structs:
aws:
type: awsKMS
optional: true
+ azure:
+ type: azureKMS
+ optional: true
gcp:
type: gcpKMS
optional: true
@@ -160,6 +182,38 @@ structs:
masterKey:
type: awsMasterKey
+ azureEncryptResponse:
+ description: "Response from Azure KMS wrapKey request"
+ strict: false
+ fields:
+ kid: string
+ value: string
+
+ azureDecryptResponse:
+ description: "Response from Azure KMS unwrapKey request"
+ strict: false
+ fields:
+ kid: string
+ value: string
+
+ azureMasterKey:
+ description: "Azure KMS Key Store Description"
+ fields:
+ provider:
+ type: string
+ default: '"azure"'
+ keyName: string
+ keyVersion:
+ type: string
+ optional: true
+ keyVaultEndpoint: string
+
+ azureMasterKeyAndMaterial:
+ description: "Azure KMS Key Material Description"
+ fields:
+ keyMaterial: bindata_generic
+ masterKey: azureMasterKey
+
gcpEncryptResponse:
description: "Response from GCP KMS Encrypt request"
strict: false
@@ -241,10 +295,10 @@ structs:
access_token: string
token_type: string
# Expires_in is in seconds
- expires_in:
+ expires_in:
type: int
optional: true
- scope:
+ scope:
type: string
optional: true
@@ -254,10 +308,10 @@ structs:
strict: false
fields:
error: string
- error_description:
+ error_description:
type: string
optional: true
- error_uri:
+ error_uri:
type: string
optional: true
diff --git a/src/mongo/shell/kms_azure.cpp b/src/mongo/shell/kms_azure.cpp
new file mode 100644
index 00000000000..ae8c933f931
--- /dev/null
+++ b/src/mongo/shell/kms_azure.cpp
@@ -0,0 +1,301 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "kms_message/kms_request.h"
+#include "mongo/shell/kms_gen.h"
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::logv2::LogComponent::kControl
+
+#include "mongo/platform/basic.h"
+
+#include <fmt/format.h>
+#include <kms_message/kms_azure_request.h>
+#include <kms_message/kms_b64.h>
+#include <kms_message/kms_message.h>
+
+#include "mongo/bson/json.h"
+#include "mongo/shell/kms.h"
+#include "mongo/shell/kms_network.h"
+#include "mongo/util/net/hostandport.h"
+
+namespace mongo {
+namespace {
+
+using namespace fmt::literals;
+
+constexpr auto kAzureKms = "azure"_sd;
+
+// Default endpoints for Azure
+constexpr auto kDefaultIdentityPlatformEndpoint = "login.microsoftonline.com"_sd;
+// Since scope is passed as URL parameter, it needs to be escaped and kms_message does not escape
+// it.
+constexpr auto kDefaultOAuthScope = "https%3A%2F%2Fvault.azure.net%2F.default"_sd;
+
+struct AzureConfig {
+ // ID for the user in Azure
+ std::string tenantId;
+
+ // ID for the application in Azure
+ std::string clientId;
+
+ // Secret key for the application in Azure
+ std::string clientSecret;
+
+ // Options to pass to kms-message
+ UniqueKmsRequestOpts opts;
+};
+
+/**
+ * Manages OAuth token requests and caching
+ */
+class AzureKMSOAuthService final : public KMSOAuthService {
+public:
+ AzureKMSOAuthService(const AzureConfig& config,
+ HostAndPort endpoint,
+ std::shared_ptr<SSLManagerInterface> sslManager)
+ : KMSOAuthService(endpoint, sslManager), _config(config) {}
+
+protected:
+ UniqueKmsRequest getOAuthRequest() final {
+ auto request =
+ UniqueKmsRequest(kms_azure_request_oauth_new(_oAuthEndpoint.host().c_str(),
+ kDefaultOAuthScope.toString().c_str(),
+ _config.tenantId.c_str(),
+ _config.clientId.c_str(),
+ _config.clientSecret.c_str(),
+ _config.opts.get()));
+
+ const char* msg = kms_request_get_error(request.get());
+ uassert(5265101, "Internal Azure KMS Error: {}"_format(msg), msg == nullptr);
+
+ return request;
+ }
+
+private:
+ const AzureConfig& _config;
+};
+
+/**
+ * Manages SSL information and config for how to talk to Azure KMS.
+ */
+class AzureKMSService final : public KMSService {
+public:
+ AzureKMSService() = default;
+
+ static std::unique_ptr<KMSService> create(const AzureKMS& config);
+
+ SecureVector<uint8_t> decrypt(ConstDataRange cdr, BSONObj masterKey) final;
+
+ BSONObj encryptDataKeyByBSONObj(ConstDataRange cdr, BSONObj keyId) final;
+
+ StringData name() const final {
+ return kAzureKms;
+ }
+
+private:
+ template <typename AzureResponseT>
+ std::unique_ptr<uint8_t, decltype(std::free)*> makeRequest(kms_request_t* request,
+ const HostAndPort& keyVaultEndpoint,
+ size_t* raw_len);
+
+private:
+ // SSL Manager
+ std::shared_ptr<SSLManagerInterface> _sslManager;
+
+ // Azure configuration settings
+ AzureConfig _config;
+
+ // Service fr managing OAuth requests and token cache
+ std::unique_ptr<AzureKMSOAuthService> _oauthService;
+};
+
+std::unique_ptr<KMSService> AzureKMSService::create(const AzureKMS& config) {
+ auto azureKMS = std::make_unique<AzureKMSService>();
+
+ SSLParams params;
+ getSSLParamsForNetworkKMS(&params);
+
+ HostAndPort identityPlatformHostAndPort(kDefaultIdentityPlatformEndpoint.toString(), 443);
+ if (config.getIdentityPlatformEndpoint().has_value()) {
+ // Leave the CA file empty so we default to system CA but for local testing allow it to
+ // inherit the CA file.
+ params.sslCAFile = sslGlobalParams.sslCAFile;
+ identityPlatformHostAndPort = parseUrl(config.getIdentityPlatformEndpoint().get());
+ }
+
+ azureKMS->_sslManager = SSLManagerInterface::create(params, false);
+
+ azureKMS->_config.opts = UniqueKmsRequestOpts(kms_request_opt_new());
+ kms_request_opt_set_provider(azureKMS->_config.opts.get(), KMS_REQUEST_PROVIDER_AZURE);
+
+ azureKMS->_config.clientSecret = config.getClientSecret().toString();
+
+ azureKMS->_config.clientId = config.getClientId().toString();
+
+ azureKMS->_config.tenantId = config.getTenantId().toString();
+
+ azureKMS->_oauthService = std::make_unique<AzureKMSOAuthService>(
+ azureKMS->_config, identityPlatformHostAndPort, azureKMS->_sslManager);
+
+ return azureKMS;
+}
+
+HostAndPort parseEndpoint(StringData endpoint) {
+ HostAndPort host(endpoint);
+
+ if (host.hasPort()) {
+ return host;
+ }
+
+ return {host.host(), 443};
+}
+
+template <typename AzureResponseT>
+std::unique_ptr<uint8_t, decltype(std::free)*> AzureKMSService::makeRequest(
+ kms_request_t* request, const HostAndPort& keyVaultEndpoint, size_t* raw_len) {
+ auto buffer = UniqueKmsCharBuffer(kms_request_to_string(request));
+ auto buffer_len = strlen(buffer.get());
+ KMSNetworkConnection connection(_sslManager.get());
+ auto response =
+ connection.makeOneRequest(keyVaultEndpoint, ConstDataRange(buffer.get(), buffer_len));
+
+ auto body = kms_response_get_body(response.get(), nullptr);
+
+ BSONObj obj = fromjson(body);
+
+ if (obj.hasField("error")) {
+ AzureKMSError azureResponse;
+ try {
+ azureResponse =
+ AzureKMSError::parse(IDLParserErrorContext("azureError"), obj["error"].Obj());
+ } catch (DBException& dbe) {
+ uasserted(5265102,
+ "Azure KMS failed to parse error message: {}, Response : {}"_format(
+ dbe.toString(), obj.toString()));
+ }
+
+ uasserted(5265103,
+ "Azure KMS failed, response: {} : {}"_format(azureResponse.getCode(),
+ azureResponse.getMessage()));
+ }
+
+ auto azureResponse = AzureResponseT::parse(IDLParserErrorContext("azureResponse"), obj);
+
+ auto b64Url = azureResponse.getValue().toString();
+ std::unique_ptr<uint8_t, decltype(std::free)*> raw_str(
+ kms_message_b64url_to_raw(b64Url.c_str(), raw_len), std::free);
+ uassert(5265104, "Azure KMS failed to convert key blob from base64 URL.", raw_str != nullptr);
+
+ return raw_str;
+}
+
+
+SecureVector<uint8_t> AzureKMSService::decrypt(ConstDataRange cdr, BSONObj masterKey) {
+ auto azureMasterKey = AzureMasterKey::parse(IDLParserErrorContext("azureMasterKey"), masterKey);
+ StringData bearerToken = _oauthService->getBearerToken();
+
+ HostAndPort keyVaultEndpoint = parseEndpoint(azureMasterKey.getKeyVaultEndpoint());
+
+ auto request = UniqueKmsRequest(kms_azure_request_unwrapkey_new(
+ keyVaultEndpoint.host().c_str(),
+ bearerToken.toString().c_str(),
+ azureMasterKey.getKeyName().toString().c_str(),
+ azureMasterKey.getKeyVersion().value_or(""_sd).toString().c_str(),
+ reinterpret_cast<const uint8_t*>(cdr.data()),
+ cdr.length(),
+ _config.opts.get()));
+
+ size_t raw_len;
+ auto raw_str = makeRequest<AzureDecryptResponse>(request.get(), keyVaultEndpoint, &raw_len);
+
+ return kmsResponseToSecureVector(
+ StringData(reinterpret_cast<const char*>(raw_str.get()), raw_len));
+}
+
+BSONObj AzureKMSService::encryptDataKeyByBSONObj(ConstDataRange cdr, BSONObj keyId) {
+ StringData bearerToken = _oauthService->getBearerToken();
+ AzureMasterKey masterKey =
+ AzureMasterKey::parse(IDLParserErrorContext("azureMasterKey"), keyId);
+
+ HostAndPort keyVaultEndpoint = parseEndpoint(masterKey.getKeyVaultEndpoint());
+
+ auto request = UniqueKmsRequest(
+ kms_azure_request_wrapkey_new(keyVaultEndpoint.host().c_str(),
+ bearerToken.toString().c_str(),
+ masterKey.getKeyName().toString().c_str(),
+ masterKey.getKeyVersion().value_or(""_sd).toString().c_str(),
+ reinterpret_cast<const uint8_t*>(cdr.data()),
+ cdr.length(),
+ _config.opts.get()));
+
+ size_t raw_len;
+ auto raw_str = makeRequest<AzureDecryptResponse>(request.get(), keyVaultEndpoint, &raw_len);
+
+ auto dataKey =
+ kmsResponseToVector(StringData(reinterpret_cast<const char*>(raw_str.get()), raw_len));
+
+ AzureMasterKeyAndMaterial keyAndMaterial;
+ keyAndMaterial.setKeyMaterial(std::move(dataKey));
+ keyAndMaterial.setMasterKey(std::move(masterKey));
+
+ return keyAndMaterial.toBSON();
+}
+
+/**
+ * Factory for AzureKMSService if user specifies azure config to mongo() JS constructor.
+ */
+class AzureKMSServiceFactory final : public KMSServiceFactory {
+public:
+ AzureKMSServiceFactory() = default;
+ ~AzureKMSServiceFactory() = default;
+
+ std::unique_ptr<KMSService> create(const BSONObj& config) final {
+ auto field = config[KmsProviders::kAzureFieldName];
+ if (field.eoo()) {
+ return nullptr;
+ }
+
+ uassert(5265106,
+ "Misconfigured Azure KMS Config: {}"_format(field.toString()),
+ field.type() == BSONType::Object);
+
+ auto obj = field.Obj();
+ return AzureKMSService::create(AzureKMS::parse(IDLParserErrorContext("root"), obj));
+ }
+};
+
+} // namespace
+
+MONGO_INITIALIZER(KMSRegisterAzure)(::mongo::InitializerContext*) {
+ kms_message_init();
+ KMSServiceController::registerFactory(KMSProviderEnum::azure,
+ std::make_unique<AzureKMSServiceFactory>());
+ return Status::OK();
+}
+
+} // namespace mongo
diff --git a/src/mongo/shell/kms_network.cpp b/src/mongo/shell/kms_network.cpp
index 7f0b05d814c..9efa4cb2929 100644
--- a/src/mongo/shell/kms_network.cpp
+++ b/src/mongo/shell/kms_network.cpp
@@ -117,7 +117,7 @@ void getSSLParamsForNetworkKMS(SSLParams* params) {
std::vector({SSLParams::Protocols::TLS1_0, SSLParams::Protocols::TLS1_1});
}
-std::vector<uint8_t> kmsResponseToVector(const std::string& str) {
+std::vector<uint8_t> kmsResponseToVector(StringData str) {
std::vector<uint8_t> blob;
std::transform(std::begin(str), std::end(str), std::back_inserter(blob), [](auto c) {
@@ -127,8 +127,8 @@ std::vector<uint8_t> kmsResponseToVector(const std::string& str) {
return blob;
}
-SecureVector<uint8_t> kmsResponseToSecureVector(const std::string& str) {
- SecureVector<uint8_t> blob(str.length());
+SecureVector<uint8_t> kmsResponseToSecureVector(StringData str) {
+ SecureVector<uint8_t> blob(str.size());
std::transform(std::begin(str), std::end(str), blob->data(), [](auto c) {
return static_cast<uint8_t>(c);
diff --git a/src/mongo/shell/kms_network.h b/src/mongo/shell/kms_network.h
index dc512e51e65..724412f869b 100644
--- a/src/mongo/shell/kms_network.h
+++ b/src/mongo/shell/kms_network.h
@@ -73,12 +73,12 @@ void getSSLParamsForNetworkKMS(SSLParams*);
/**
* Converts a base64 encoded KMS response to a vector of bytes.
*/
-std::vector<uint8_t> kmsResponseToVector(const std::string& str);
+std::vector<uint8_t> kmsResponseToVector(StringData str);
/**
* Converts a base64 encoded KMS response to a securely allocated vector of bytes.
*/
-SecureVector<uint8_t> kmsResponseToSecureVector(const std::string& str);
+SecureVector<uint8_t> kmsResponseToSecureVector(StringData str);
/**
* Base class for KMS services that use OAuth for authorization.