diff options
-rwxr-xr-x | jstests/client_encrypt/lib/kms_http_server.py | 47 | ||||
-rw-r--r-- | src/mongo/shell/kms_aws.cpp | 13 |
2 files changed, 54 insertions, 6 deletions
diff --git a/jstests/client_encrypt/lib/kms_http_server.py b/jstests/client_encrypt/lib/kms_http_server.py index e2414bd8574..57b855a818a 100755 --- a/jstests/client_encrypt/lib/kms_http_server.py +++ b/jstests/client_encrypt/lib/kms_http_server.py @@ -2,8 +2,8 @@ """Mock AWS KMS Endpoint.""" import argparse -import collections import base64 +import collections import http.server import json import logging @@ -12,6 +12,10 @@ import sys import urllib.parse import ssl +from botocore.auth import SigV4Auth, S3SigV4Auth +from botocore.awsrequest import AWSRequest +from botocore.credentials import Credentials + import kms_http_common SECRET_PREFIX = "00SECRET" @@ -56,6 +60,13 @@ SUPPORTED_FAULT_TYPES = [ FAULT_DECRYPT_WRONG_KEY, ] +def get_dict_subset(headers, subset): + ret = {} + for header in headers.keys(): + if header.lower() in subset.lower(): + ret[header] = headers[header] + return ret + class AwsKmsHandler(http.server.BaseHTTPRequestHandler): """ Handle requests from AWS KMS Monitoring and test commands @@ -108,6 +119,15 @@ class AwsKmsHandler(http.server.BaseHTTPRequestHandler): print("RAW INPUT: " + str(raw_input)) + if not self.headers["Host"] == "localhost": + data = "Unexpected host" + self._send_reply(data.encode("utf-8")) + + if not self._validate_signature(self.headers, raw_input): + data = "Bad Signature" + self._send_reply(data.encode("utf-8")) + + # X-Amz-Target: TrentService.Encrypt aws_operation = self.headers['X-Amz-Target'] @@ -121,6 +141,31 @@ class AwsKmsHandler(http.server.BaseHTTPRequestHandler): data = "Unknown AWS Operation" self._send_reply(data.encode("utf-8")) + + def _validate_signature(self, headers, raw_input): + auth_header = headers["Authorization"] + signed_headers_start = auth_header.find("SignedHeaders") + signed_headers = auth_header[signed_headers_start:auth_header.find(",", signed_headers_start)] + signed_headers_dict = get_dict_subset(headers, signed_headers) + + request = AWSRequest(method="POST", url="/", data=raw_input, headers=signed_headers_dict) + # SigV4Auth assumes this header exists even though it is not required by the algorithm + request.context['timestamp'] = headers['X-Amz-Date'] + + region_start = auth_header.find("Credential=access/") + len("Credential=access/YYYYMMDD/") + region = auth_header[region_start:auth_header.find("/", region_start)] + + credentials = Credentials("access", "secret") + auth = SigV4Auth(credentials, "kms", region) + string_to_sign = auth.string_to_sign(request, auth.canonical_request(request)) + expected_signature = auth.signature(string_to_sign, request) + + signature_headers_start = auth_header.find("Signature=") + len("Signature=") + actual_signature = auth_header[signature_headers_start:] + + return expected_signature == actual_signature + + def _do_encrypt(self, raw_input): request = json.loads(raw_input) diff --git a/src/mongo/shell/kms_aws.cpp b/src/mongo/shell/kms_aws.cpp index b923a59355c..1417b847781 100644 --- a/src/mongo/shell/kms_aws.cpp +++ b/src/mongo/shell/kms_aws.cpp @@ -159,7 +159,7 @@ public: BSONObj encryptDataKey(ConstDataRange cdr, StringData keyId) final; private: - void initRequest(kms_request_t* request, StringData region); + void initRequest(kms_request_t* request, StringData host, StringData region); private: // SSL Manager @@ -181,7 +181,7 @@ void uassertKmsRequestInternal(kms_request_t* request, bool ok) { #define uassertKmsRequest(X) uassertKmsRequestInternal(request, (X)); -void AWSKMSService::initRequest(kms_request_t* request, StringData region) { +void AWSKMSService::initRequest(kms_request_t* request, StringData host, StringData region) { // use current time uassertKmsRequest(kms_request_set_date(request, nullptr)); @@ -194,6 +194,9 @@ void AWSKMSService::initRequest(kms_request_t* request, StringData region) { uassertKmsRequest(kms_request_set_access_key_id(request, _config.accessKeyId.c_str())); uassertKmsRequest(kms_request_set_secret_key(request, _config.secretAccessKey->c_str())); + // Set host to be the host we are targeting instead of defaulting to kms.<region>.amazonaws.com + uassertKmsRequest(kms_request_add_header_field(request, "Host", host.toString().c_str())); + if (!_config.sessionToken.value_or("").empty()) { // TODO: move this into kms-message uassertKmsRequest(kms_request_add_header_field( @@ -250,7 +253,7 @@ std::vector<uint8_t> AWSKMSService::encrypt(ConstDataRange cdr, StringData kmsKe _server = getDefaultHost(region); } - initRequest(request.get(), region); + initRequest(request.get(), _server.host(), region); auto buffer = UniqueKmsCharBuffer(kms_request_get_signed(request.get())); auto buffer_len = strlen(buffer.get()); @@ -300,12 +303,12 @@ SecureVector<uint8_t> AWSKMSService::decrypt(ConstDataRange cdr, BSONObj masterK auto request = UniqueKmsRequest(kms_decrypt_request_new( reinterpret_cast<const uint8_t*>(cdr.data()), cdr.length(), nullptr)); - initRequest(request.get(), awsMasterKey.getRegion()); - if (_server.empty()) { _server = getDefaultHost(awsMasterKey.getRegion()); } + initRequest(request.get(), _server.host(), awsMasterKey.getRegion()); + auto buffer = UniqueKmsCharBuffer(kms_request_get_signed(request.get())); auto buffer_len = strlen(buffer.get()); AWSConnection connection(_sslManager.get()); |