diff options
author | Mark Benvenuto <mark.benvenuto@mongodb.com> | 2021-01-21 20:32:40 -0500 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-03-15 16:10:09 +0000 |
commit | a71d3efdc76842dda106e73c71d28b003013abcf (patch) | |
tree | 6c4724a2f400f0167d3787fa0316a2eea592320c | |
parent | b8e3b2a500c3467c4ded2211ffc13bc6fd004e75 (diff) | |
download | mongo-a71d3efdc76842dda106e73c71d28b003013abcf.tar.gz |
SERVER-52651 Add FLE Support for Azure
(cherry picked from commit 5724443a4cf42d369714a86ee76de0a41f02bfd8)
-rw-r--r-- | jstests/client_encrypt/fle_azure_faults.js | 175 | ||||
-rwxr-xr-x | jstests/client_encrypt/lib/kms_http_server_azure.py | 239 | ||||
-rw-r--r-- | jstests/client_encrypt/lib/mock_kms.js | 18 | ||||
-rw-r--r-- | src/mongo/shell/SConscript | 1 | ||||
-rw-r--r-- | src/mongo/shell/kms.cpp | 4 | ||||
-rw-r--r-- | src/mongo/shell/kms.idl | 62 | ||||
-rw-r--r-- | src/mongo/shell/kms_azure.cpp | 301 | ||||
-rw-r--r-- | src/mongo/shell/kms_network.cpp | 6 | ||||
-rw-r--r-- | src/mongo/shell/kms_network.h | 4 |
9 files changed, 799 insertions, 11 deletions
diff --git a/jstests/client_encrypt/fle_azure_faults.js b/jstests/client_encrypt/fle_azure_faults.js new file mode 100644 index 00000000000..dfcf121844a --- /dev/null +++ b/jstests/client_encrypt/fle_azure_faults.js @@ -0,0 +1,175 @@ +/** + * Verify the Azure KMS implementation can handle a buggy KMS. + */ + +load("jstests/client_encrypt/lib/mock_kms.js"); +load('jstests/ssl/libs/ssl_helpers.js'); + +(function() { +"use strict"; + +const x509_options = { + sslMode: "requireSSL", + sslPEMKeyFile: SERVER_CERT, + sslCAFile: CA_CERT +}; + +const mockKey = { + keyName: "my_key", + keyVaultEndpoint: "https://localhost:80", +}; + +const randomAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Random"; + +const conn = MongoRunner.runMongod(x509_options); +const test = conn.getDB("test"); +const collection = test.coll; + +function runKMS(mock_kms, func) { + mock_kms.start(); + + const azureKMS = { + tenantId: "my_tentant", + clientId: "access@mongodb.com", + clientSecret: "secret", + identityPlatformEndpoint: mock_kms.getURL(), + }; + + const clientSideFLEOptions = { + kmsProviders: { + azure: azureKMS, + }, + keyVaultNamespace: "test.coll", + schemaMap: {}, + }; + + const shell = Mongo(conn.host, clientSideFLEOptions); + const cleanCacheShell = Mongo(conn.host, clientSideFLEOptions); + + collection.drop(); + + func(shell, cleanCacheShell); + + mock_kms.stop(); +} + +// OAuth faults must be tested first so a cached token cannot be used +function testBadOAuthRequestResult() { + const mock_kms = new MockKMSServerAzure(FAULT_OAUTH, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + + const error = assert.throws(() => keyVault.createKey("azure", mockKey, ["mongoKey"])); + assert.eq( + error, + "Error: code 9: FailedToParse: Expecting '{': offset:0 of:Internal Error of some sort."); + }); +} + +testBadOAuthRequestResult(); + +function testBadOAuthRequestError() { + const mock_kms = new MockKMSServerAzure(FAULT_OAUTH_CORRECT_FORMAT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + + const error = assert.throws(() => keyVault.createKey("azure", mockKey, ["mongoKey"])); + assert.commandFailedWithCode(error, [ErrorCodes.OperationFailed]); + assert.eq( + error, + "Error: Failed to make oauth request: Azure OAuth Error : FAULT_OAUTH_CORRECT_FORMAT"); + }); +} + +testBadOAuthRequestError(); + +function testBadEncryptResult() { + const mock_kms = new MockKMSServerAzure(FAULT_ENCRYPT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + mockKey.keyVaultEndpoint = mock_kms.getEndpoint(); + + assert.throws(() => keyVault.createKey("azure", mockKey, ["mongoKey"])); + assert.eq(keyVault.getKeys("mongoKey").toArray().length, 0); + }); +} + +testBadEncryptResult(); + +function testBadEncryptError() { + const mock_kms = new MockKMSServerAzure(FAULT_ENCRYPT_CORRECT_FORMAT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + mockKey.keyVaultEndpoint = mock_kms.getEndpoint(); + + let error = assert.throws(() => keyVault.createKey("azure", mockKey, ["mongoKey"])); + assert.commandFailedWithCode(error, [5265103]); + }); +} + +testBadEncryptError(); + +function testBadDecryptResult() { + const mock_kms = new MockKMSServerAzure(FAULT_DECRYPT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + mockKey.keyVaultEndpoint = mock_kms.getEndpoint(); + + const keyId = keyVault.createKey("azure", mockKey, ["mongoKey"]); + const str = "mongo"; + assert.throws(() => { + const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm); + }); + }); +} + +testBadDecryptResult(); + +function testBadDecryptKeyResult() { + const mock_kms = new MockKMSServerAzure(FAULT_DECRYPT_WRONG_KEY, true); + + runKMS(mock_kms, (shell, cleanCacheShell) => { + const keyVault = shell.getKeyVault(); + mockKey.keyVaultEndpoint = mock_kms.getEndpoint(); + + keyVault.createKey("azure", mockKey, ["mongoKey"]); + const keyId = keyVault.getKeys("mongoKey").toArray()[0]._id; + const str = "mongo"; + const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm); + + mock_kms.enableFaults(); + + assert.throws(() => { + let str = cleanCacheShell.decrypt(encStr); + }); + }); +} + +testBadDecryptKeyResult(); + +function testBadDecryptError() { + const mock_kms = new MockKMSServerAzure(FAULT_DECRYPT_CORRECT_FORMAT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + mockKey.keyVaultEndpoint = mock_kms.getEndpoint(); + + keyVault.createKey("azure", mockKey, ["mongoKey"]); + const keyId = keyVault.getKeys("mongoKey").toArray()[0]._id; + const str = "mongo"; + let error = assert.throws(() => { + const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm); + }); + assert.commandFailedWithCode(error, [5265103]); + }); +} + +testBadDecryptError(); + +MongoRunner.stopMongod(conn); +})(); diff --git a/jstests/client_encrypt/lib/kms_http_server_azure.py b/jstests/client_encrypt/lib/kms_http_server_azure.py new file mode 100755 index 00000000000..af2b7b4a016 --- /dev/null +++ b/jstests/client_encrypt/lib/kms_http_server_azure.py @@ -0,0 +1,239 @@ +#! /usr/bin/env python3 +"""Mock Azure KMS Endpoint.""" +import argparse +import base64 +import http +import json +import logging +import urllib.parse +import sys + +import kms_http_common + +SUPPORTED_FAULT_TYPES = [ + kms_http_common.FAULT_ENCRYPT, + kms_http_common.FAULT_ENCRYPT_CORRECT_FORMAT, + kms_http_common.FAULT_DECRYPT, + kms_http_common.FAULT_DECRYPT_CORRECT_FORMAT, + kms_http_common.FAULT_DECRYPT_WRONG_KEY, + kms_http_common.FAULT_OAUTH, + kms_http_common.FAULT_OAUTH_CORRECT_FORMAT, +] + +SECRET_PREFIX = "00SECRET" +FAKE_OAUTH_TOKEN = "omg_im_an_oauth_token" + +URL_PATH_OAUTH_AUDIENCE = "/token" +URL_PATH_OAUTH_SCOPE = "/auth/cloudkms" +URL_PATH_MOCK_KEY = "/keys/my_key/" + + +class AzureKmsHandler(kms_http_common.KmsHandlerBase): + """ + Handle requests from Azure KMS Monitoring and test commands + """ + + def do_POST(self): + """Serve a POST request.""" + print("Received POST: " + self.path) + parts = urllib.parse.urlsplit(self.path) + path = parts[2] + + if path == "/my_tentant/oauth2/v2.0/token": + self._do_oauth_request() + elif path.startswith(URL_PATH_MOCK_KEY): + self._do_operation() + else: + self.send_response(http.HTTPStatus.NOT_FOUND) + self.end_headers() + self.wfile.write("Unknown URL".encode()) + + def _do_operation(self): + clen = int(self.headers.get("content-length")) + + raw_input = self.rfile.read(clen) + + print(f"RAW INPUT: {str(raw_input)}") + + if not self.headers["Host"] == "localhost": + data = "Unexpected host" + self._send_reply(data.encode("utf-8")) + + if not self.headers["Authorization"] == f"Bearer {FAKE_OAUTH_TOKEN}": + data = "Unexpected bearer token" + self._send_reply(data.encode("utf-8")) + + parts = urllib.parse.urlsplit(self.path) + operation = parts.path.split('/')[-1] + + if operation == "wrapkey": + self._do_encrypt(raw_input) + elif operation == "unwrapkey": + self._do_decrypt(raw_input) + else: + self._send_reply(f"Unknown operation: {operation}".encode("utf-8")) + + def _do_encrypt(self, raw_input): + request = json.loads(raw_input) + + print(request) + + plaintext = request["value"] + + ciphertext = SECRET_PREFIX.encode() + plaintext.encode() + ciphertext = base64.urlsafe_b64encode(ciphertext).decode() + + if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_ENCRYPT) \ + and not kms_http_common.disable_faults: + return self._do_encrypt_faults(ciphertext) + + response = { + "value": ciphertext, + "kid": "my_key", + } + + self._send_reply(json.dumps(response).encode('utf-8')) + + def _do_encrypt_faults(self, raw_ciphertext): + kms_http_common.stats.fault_calls += 1 + + if kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT: + self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT_CORRECT_FORMAT: + response = { + "error": { + "code": "bad", + "message": "Error encrypting message", + } + } + self._send_reply(json.dumps(response).encode('utf-8')) + return + + raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type) + + def _do_decrypt(self, raw_input): + request = json.loads(raw_input) + blob = base64.urlsafe_b64decode(request["value"]).decode() + + print("FOUND SECRET: " + blob) + + # our "encrypted" values start with the word SECRET_PREFIX otherwise they did not come from us + if not blob.startswith(SECRET_PREFIX): + raise ValueError() + + blob = blob[len(SECRET_PREFIX):] + + if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_DECRYPT) \ + and not kms_http_common.disable_faults: + return self._do_decrypt_faults(blob) + + response = { + "kid": "my_key", + "value": blob, + } + + self._send_reply(json.dumps(response).encode('utf-8')) + + def _do_decrypt_faults(self, blob): + kms_http_common.stats.fault_calls += 1 + + if kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT: + self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_WRONG_KEY: + response = { + "kid": "my_key", + "value": "ta7DXE7J0OiCRw03dYMJSeb8nVF5qxTmZ9zWmjuX4zW/SOorSCaY8VMTWG+cRInMx/rr/+QeVw2WjU2IpOSvMg==", + } + self._send_reply(json.dumps(response).encode('utf-8')) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_CORRECT_FORMAT: + response = { + "error": { + "code": "bad", + "message": "Error decrypting message", + } + } + self._send_reply(json.dumps(response).encode('utf-8')) + return + + raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type) + + def _do_oauth_request(self): + clen = int(self.headers.get('content-length')) + + raw_input = self.rfile.read(clen) + + print(f"RAW INPUT: {str(raw_input)}") + + if not self.headers["Host"] == "localhost": + data = "Unexpected host" + self._send_reply(data.encode("utf-8")) + + if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_OAUTH) \ + and not kms_http_common.disable_faults: + return self._do_oauth_faults() + + response = { + "access_token": FAKE_OAUTH_TOKEN, + "scope": self.headers["Host"] + URL_PATH_OAUTH_SCOPE, + "token_type": "Bearer", + "expires_in": 3600, + } + + self._send_reply(json.dumps(response).encode("utf-8")) + + def _do_oauth_faults(self): + kms_http_common.stats.fault_calls += 1 + + if kms_http_common.fault_type == kms_http_common.FAULT_OAUTH: + self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_OAUTH_CORRECT_FORMAT: + response = { + "error": "Azure OAuth Error", + "error_description": "FAULT_OAUTH_CORRECT_FORMAT", + "error_uri": "https://mongodb.com/whoopsies.pdf", + } + self._send_reply(json.dumps(response).encode('utf-8')) + return + + raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type) + + +def main(): + """Main Method.""" + parser = argparse.ArgumentParser(description='MongoDB Mock Azure KMS Endpoint.') + + parser.add_argument('-p', '--port', type=int, default=8000, help="Port to listen on") + + parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing") + + parser.add_argument('--fault', type=str, help="Type of fault to inject") + + parser.add_argument('--disable-faults', action='store_true', help="Disable faults on startup") + + parser.add_argument('--ca_file', type=str, required=True, help="TLS CA PEM file") + + parser.add_argument('--cert_file', type=str, required=True, help="TLS Server PEM file") + + args = parser.parse_args() + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + + if args.fault: + if args.fault not in SUPPORTED_FAULT_TYPES: + print("Unsupported fault type %s, supports types are %s" % (args.fault, SUPPORTED_FAULT_TYPES)) + sys.exit(1) + + kms_http_common.fault_type = args.fault + + if args.disable_faults: + kms_http_common.disable_faults = True + + kms_http_common.run(args.port, args.cert_file, args.ca_file, AzureKmsHandler) + + +if __name__ == '__main__': + main() diff --git a/jstests/client_encrypt/lib/mock_kms.js b/jstests/client_encrypt/lib/mock_kms.js index b5d61986dd2..2dabdd96916 100644 --- a/jstests/client_encrypt/lib/mock_kms.js +++ b/jstests/client_encrypt/lib/mock_kms.js @@ -146,7 +146,7 @@ class MockKMSServerAWS { } /** - * Get the URL. + * Get the URL. Prefixed with https:// * * @return {string} url of http server */ @@ -155,6 +155,15 @@ class MockKMSServerAWS { } /** + * Get the endpoint. A "<host>:<port>". + * + * @return {string} url of http server + */ + getEndpoint() { + return "localhost:" + this.port; + } + + /** * Stop the web server */ stop() { @@ -168,3 +177,10 @@ class MockKMSServerGCP extends MockKMSServerAWS { this.web_server_py = "jstests/client_encrypt/lib/kms_http_server_gcp.py"; } } + +class MockKMSServerAzure extends MockKMSServerAWS { + constructor(fault_type, disableFaultsOnStartup) { + super(fault_type, disableFaultsOnStartup); + this.web_server_py = "jstests/client_encrypt/lib/kms_http_server_azure.py"; + } +} diff --git a/src/mongo/shell/SConscript b/src/mongo/shell/SConscript index 6bbfeccbf3f..675c09e7b07 100644 --- a/src/mongo/shell/SConscript +++ b/src/mongo/shell/SConscript @@ -141,6 +141,7 @@ kmsEnv.Library( source=[ "kms.cpp", "kms_aws.cpp", + "kms_azure.cpp", "kms_gcp.cpp", "kms_local.cpp", "kms_network.cpp", diff --git a/src/mongo/shell/kms.cpp b/src/mongo/shell/kms.cpp index 2883e9bd0ee..4bea437a267 100644 --- a/src/mongo/shell/kms.cpp +++ b/src/mongo/shell/kms.cpp @@ -42,7 +42,9 @@ HostAndPort parseUrl(StringData url) { // URL: https://(host):(port) // constexpr StringData urlPrefix = "https://"_sd; - uassert(51140, "KMS URL must start with https://", url.startsWith(urlPrefix)); + uassert(51140, + str::stream() << "KMS URL must start with https://, URL: " << url, + url.startsWith(urlPrefix)); StringData hostAndPort = url.substr(urlPrefix.size()); diff --git a/src/mongo/shell/kms.idl b/src/mongo/shell/kms.idl index 47129a51baa..53ac55f2c7b 100644 --- a/src/mongo/shell/kms.idl +++ b/src/mongo/shell/kms.idl @@ -38,6 +38,7 @@ enums: type: string values: aws: "aws" + azure: "azure" gcp: "gcp" local: "local" @@ -64,6 +65,24 @@ structs: type: string optional: true + azureKMSError: + description: "Azure KMS Error" + strict: false + fields: + code: string + message: string + + # Options passed to Mongo() javascript constructor + azureKMS: + description: "Azure KMS config" + fields: + tenantId: string + clientId: string + clientSecret: string + identityPlatformEndpoint: + type: string + optional: true + # Documented here: https://cloud.google.com/apis/design/errors#http_mapping gcpKMSError: description: "GCP KMS Error" @@ -96,6 +115,9 @@ structs: aws: type: awsKMS optional: true + azure: + type: azureKMS + optional: true gcp: type: gcpKMS optional: true @@ -160,6 +182,38 @@ structs: masterKey: type: awsMasterKey + azureEncryptResponse: + description: "Response from Azure KMS wrapKey request" + strict: false + fields: + kid: string + value: string + + azureDecryptResponse: + description: "Response from Azure KMS unwrapKey request" + strict: false + fields: + kid: string + value: string + + azureMasterKey: + description: "Azure KMS Key Store Description" + fields: + provider: + type: string + default: '"azure"' + keyName: string + keyVersion: + type: string + optional: true + keyVaultEndpoint: string + + azureMasterKeyAndMaterial: + description: "Azure KMS Key Material Description" + fields: + keyMaterial: bindata_generic + masterKey: azureMasterKey + gcpEncryptResponse: description: "Response from GCP KMS Encrypt request" strict: false @@ -241,10 +295,10 @@ structs: access_token: string token_type: string # Expires_in is in seconds - expires_in: + expires_in: type: int optional: true - scope: + scope: type: string optional: true @@ -254,10 +308,10 @@ structs: strict: false fields: error: string - error_description: + error_description: type: string optional: true - error_uri: + error_uri: type: string optional: true diff --git a/src/mongo/shell/kms_azure.cpp b/src/mongo/shell/kms_azure.cpp new file mode 100644 index 00000000000..ae8c933f931 --- /dev/null +++ b/src/mongo/shell/kms_azure.cpp @@ -0,0 +1,301 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "kms_message/kms_request.h" +#include "mongo/shell/kms_gen.h" +#define MONGO_LOGV2_DEFAULT_COMPONENT ::logv2::LogComponent::kControl + +#include "mongo/platform/basic.h" + +#include <fmt/format.h> +#include <kms_message/kms_azure_request.h> +#include <kms_message/kms_b64.h> +#include <kms_message/kms_message.h> + +#include "mongo/bson/json.h" +#include "mongo/shell/kms.h" +#include "mongo/shell/kms_network.h" +#include "mongo/util/net/hostandport.h" + +namespace mongo { +namespace { + +using namespace fmt::literals; + +constexpr auto kAzureKms = "azure"_sd; + +// Default endpoints for Azure +constexpr auto kDefaultIdentityPlatformEndpoint = "login.microsoftonline.com"_sd; +// Since scope is passed as URL parameter, it needs to be escaped and kms_message does not escape +// it. +constexpr auto kDefaultOAuthScope = "https%3A%2F%2Fvault.azure.net%2F.default"_sd; + +struct AzureConfig { + // ID for the user in Azure + std::string tenantId; + + // ID for the application in Azure + std::string clientId; + + // Secret key for the application in Azure + std::string clientSecret; + + // Options to pass to kms-message + UniqueKmsRequestOpts opts; +}; + +/** + * Manages OAuth token requests and caching + */ +class AzureKMSOAuthService final : public KMSOAuthService { +public: + AzureKMSOAuthService(const AzureConfig& config, + HostAndPort endpoint, + std::shared_ptr<SSLManagerInterface> sslManager) + : KMSOAuthService(endpoint, sslManager), _config(config) {} + +protected: + UniqueKmsRequest getOAuthRequest() final { + auto request = + UniqueKmsRequest(kms_azure_request_oauth_new(_oAuthEndpoint.host().c_str(), + kDefaultOAuthScope.toString().c_str(), + _config.tenantId.c_str(), + _config.clientId.c_str(), + _config.clientSecret.c_str(), + _config.opts.get())); + + const char* msg = kms_request_get_error(request.get()); + uassert(5265101, "Internal Azure KMS Error: {}"_format(msg), msg == nullptr); + + return request; + } + +private: + const AzureConfig& _config; +}; + +/** + * Manages SSL information and config for how to talk to Azure KMS. + */ +class AzureKMSService final : public KMSService { +public: + AzureKMSService() = default; + + static std::unique_ptr<KMSService> create(const AzureKMS& config); + + SecureVector<uint8_t> decrypt(ConstDataRange cdr, BSONObj masterKey) final; + + BSONObj encryptDataKeyByBSONObj(ConstDataRange cdr, BSONObj keyId) final; + + StringData name() const final { + return kAzureKms; + } + +private: + template <typename AzureResponseT> + std::unique_ptr<uint8_t, decltype(std::free)*> makeRequest(kms_request_t* request, + const HostAndPort& keyVaultEndpoint, + size_t* raw_len); + +private: + // SSL Manager + std::shared_ptr<SSLManagerInterface> _sslManager; + + // Azure configuration settings + AzureConfig _config; + + // Service fr managing OAuth requests and token cache + std::unique_ptr<AzureKMSOAuthService> _oauthService; +}; + +std::unique_ptr<KMSService> AzureKMSService::create(const AzureKMS& config) { + auto azureKMS = std::make_unique<AzureKMSService>(); + + SSLParams params; + getSSLParamsForNetworkKMS(¶ms); + + HostAndPort identityPlatformHostAndPort(kDefaultIdentityPlatformEndpoint.toString(), 443); + if (config.getIdentityPlatformEndpoint().has_value()) { + // Leave the CA file empty so we default to system CA but for local testing allow it to + // inherit the CA file. + params.sslCAFile = sslGlobalParams.sslCAFile; + identityPlatformHostAndPort = parseUrl(config.getIdentityPlatformEndpoint().get()); + } + + azureKMS->_sslManager = SSLManagerInterface::create(params, false); + + azureKMS->_config.opts = UniqueKmsRequestOpts(kms_request_opt_new()); + kms_request_opt_set_provider(azureKMS->_config.opts.get(), KMS_REQUEST_PROVIDER_AZURE); + + azureKMS->_config.clientSecret = config.getClientSecret().toString(); + + azureKMS->_config.clientId = config.getClientId().toString(); + + azureKMS->_config.tenantId = config.getTenantId().toString(); + + azureKMS->_oauthService = std::make_unique<AzureKMSOAuthService>( + azureKMS->_config, identityPlatformHostAndPort, azureKMS->_sslManager); + + return azureKMS; +} + +HostAndPort parseEndpoint(StringData endpoint) { + HostAndPort host(endpoint); + + if (host.hasPort()) { + return host; + } + + return {host.host(), 443}; +} + +template <typename AzureResponseT> +std::unique_ptr<uint8_t, decltype(std::free)*> AzureKMSService::makeRequest( + kms_request_t* request, const HostAndPort& keyVaultEndpoint, size_t* raw_len) { + auto buffer = UniqueKmsCharBuffer(kms_request_to_string(request)); + auto buffer_len = strlen(buffer.get()); + KMSNetworkConnection connection(_sslManager.get()); + auto response = + connection.makeOneRequest(keyVaultEndpoint, ConstDataRange(buffer.get(), buffer_len)); + + auto body = kms_response_get_body(response.get(), nullptr); + + BSONObj obj = fromjson(body); + + if (obj.hasField("error")) { + AzureKMSError azureResponse; + try { + azureResponse = + AzureKMSError::parse(IDLParserErrorContext("azureError"), obj["error"].Obj()); + } catch (DBException& dbe) { + uasserted(5265102, + "Azure KMS failed to parse error message: {}, Response : {}"_format( + dbe.toString(), obj.toString())); + } + + uasserted(5265103, + "Azure KMS failed, response: {} : {}"_format(azureResponse.getCode(), + azureResponse.getMessage())); + } + + auto azureResponse = AzureResponseT::parse(IDLParserErrorContext("azureResponse"), obj); + + auto b64Url = azureResponse.getValue().toString(); + std::unique_ptr<uint8_t, decltype(std::free)*> raw_str( + kms_message_b64url_to_raw(b64Url.c_str(), raw_len), std::free); + uassert(5265104, "Azure KMS failed to convert key blob from base64 URL.", raw_str != nullptr); + + return raw_str; +} + + +SecureVector<uint8_t> AzureKMSService::decrypt(ConstDataRange cdr, BSONObj masterKey) { + auto azureMasterKey = AzureMasterKey::parse(IDLParserErrorContext("azureMasterKey"), masterKey); + StringData bearerToken = _oauthService->getBearerToken(); + + HostAndPort keyVaultEndpoint = parseEndpoint(azureMasterKey.getKeyVaultEndpoint()); + + auto request = UniqueKmsRequest(kms_azure_request_unwrapkey_new( + keyVaultEndpoint.host().c_str(), + bearerToken.toString().c_str(), + azureMasterKey.getKeyName().toString().c_str(), + azureMasterKey.getKeyVersion().value_or(""_sd).toString().c_str(), + reinterpret_cast<const uint8_t*>(cdr.data()), + cdr.length(), + _config.opts.get())); + + size_t raw_len; + auto raw_str = makeRequest<AzureDecryptResponse>(request.get(), keyVaultEndpoint, &raw_len); + + return kmsResponseToSecureVector( + StringData(reinterpret_cast<const char*>(raw_str.get()), raw_len)); +} + +BSONObj AzureKMSService::encryptDataKeyByBSONObj(ConstDataRange cdr, BSONObj keyId) { + StringData bearerToken = _oauthService->getBearerToken(); + AzureMasterKey masterKey = + AzureMasterKey::parse(IDLParserErrorContext("azureMasterKey"), keyId); + + HostAndPort keyVaultEndpoint = parseEndpoint(masterKey.getKeyVaultEndpoint()); + + auto request = UniqueKmsRequest( + kms_azure_request_wrapkey_new(keyVaultEndpoint.host().c_str(), + bearerToken.toString().c_str(), + masterKey.getKeyName().toString().c_str(), + masterKey.getKeyVersion().value_or(""_sd).toString().c_str(), + reinterpret_cast<const uint8_t*>(cdr.data()), + cdr.length(), + _config.opts.get())); + + size_t raw_len; + auto raw_str = makeRequest<AzureDecryptResponse>(request.get(), keyVaultEndpoint, &raw_len); + + auto dataKey = + kmsResponseToVector(StringData(reinterpret_cast<const char*>(raw_str.get()), raw_len)); + + AzureMasterKeyAndMaterial keyAndMaterial; + keyAndMaterial.setKeyMaterial(std::move(dataKey)); + keyAndMaterial.setMasterKey(std::move(masterKey)); + + return keyAndMaterial.toBSON(); +} + +/** + * Factory for AzureKMSService if user specifies azure config to mongo() JS constructor. + */ +class AzureKMSServiceFactory final : public KMSServiceFactory { +public: + AzureKMSServiceFactory() = default; + ~AzureKMSServiceFactory() = default; + + std::unique_ptr<KMSService> create(const BSONObj& config) final { + auto field = config[KmsProviders::kAzureFieldName]; + if (field.eoo()) { + return nullptr; + } + + uassert(5265106, + "Misconfigured Azure KMS Config: {}"_format(field.toString()), + field.type() == BSONType::Object); + + auto obj = field.Obj(); + return AzureKMSService::create(AzureKMS::parse(IDLParserErrorContext("root"), obj)); + } +}; + +} // namespace + +MONGO_INITIALIZER(KMSRegisterAzure)(::mongo::InitializerContext*) { + kms_message_init(); + KMSServiceController::registerFactory(KMSProviderEnum::azure, + std::make_unique<AzureKMSServiceFactory>()); + return Status::OK(); +} + +} // namespace mongo diff --git a/src/mongo/shell/kms_network.cpp b/src/mongo/shell/kms_network.cpp index 7f0b05d814c..9efa4cb2929 100644 --- a/src/mongo/shell/kms_network.cpp +++ b/src/mongo/shell/kms_network.cpp @@ -117,7 +117,7 @@ void getSSLParamsForNetworkKMS(SSLParams* params) { std::vector({SSLParams::Protocols::TLS1_0, SSLParams::Protocols::TLS1_1}); } -std::vector<uint8_t> kmsResponseToVector(const std::string& str) { +std::vector<uint8_t> kmsResponseToVector(StringData str) { std::vector<uint8_t> blob; std::transform(std::begin(str), std::end(str), std::back_inserter(blob), [](auto c) { @@ -127,8 +127,8 @@ std::vector<uint8_t> kmsResponseToVector(const std::string& str) { return blob; } -SecureVector<uint8_t> kmsResponseToSecureVector(const std::string& str) { - SecureVector<uint8_t> blob(str.length()); +SecureVector<uint8_t> kmsResponseToSecureVector(StringData str) { + SecureVector<uint8_t> blob(str.size()); std::transform(std::begin(str), std::end(str), blob->data(), [](auto c) { return static_cast<uint8_t>(c); diff --git a/src/mongo/shell/kms_network.h b/src/mongo/shell/kms_network.h index dc512e51e65..724412f869b 100644 --- a/src/mongo/shell/kms_network.h +++ b/src/mongo/shell/kms_network.h @@ -73,12 +73,12 @@ void getSSLParamsForNetworkKMS(SSLParams*); /** * Converts a base64 encoded KMS response to a vector of bytes. */ -std::vector<uint8_t> kmsResponseToVector(const std::string& str); +std::vector<uint8_t> kmsResponseToVector(StringData str); /** * Converts a base64 encoded KMS response to a securely allocated vector of bytes. */ -SecureVector<uint8_t> kmsResponseToSecureVector(const std::string& str); +SecureVector<uint8_t> kmsResponseToSecureVector(StringData str); /** * Base class for KMS services that use OAuth for authorization. |