summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xjstests/client_encrypt/lib/kms_http_server.py47
-rw-r--r--src/mongo/shell/kms_aws.cpp13
2 files changed, 54 insertions, 6 deletions
diff --git a/jstests/client_encrypt/lib/kms_http_server.py b/jstests/client_encrypt/lib/kms_http_server.py
index e2414bd8574..57b855a818a 100755
--- a/jstests/client_encrypt/lib/kms_http_server.py
+++ b/jstests/client_encrypt/lib/kms_http_server.py
@@ -2,8 +2,8 @@
"""Mock AWS KMS Endpoint."""
import argparse
-import collections
import base64
+import collections
import http.server
import json
import logging
@@ -12,6 +12,10 @@ import sys
import urllib.parse
import ssl
+from botocore.auth import SigV4Auth, S3SigV4Auth
+from botocore.awsrequest import AWSRequest
+from botocore.credentials import Credentials
+
import kms_http_common
SECRET_PREFIX = "00SECRET"
@@ -56,6 +60,13 @@ SUPPORTED_FAULT_TYPES = [
FAULT_DECRYPT_WRONG_KEY,
]
+def get_dict_subset(headers, subset):
+ ret = {}
+ for header in headers.keys():
+ if header.lower() in subset.lower():
+ ret[header] = headers[header]
+ return ret
+
class AwsKmsHandler(http.server.BaseHTTPRequestHandler):
"""
Handle requests from AWS KMS Monitoring and test commands
@@ -108,6 +119,15 @@ class AwsKmsHandler(http.server.BaseHTTPRequestHandler):
print("RAW INPUT: " + str(raw_input))
+ if not self.headers["Host"] == "localhost":
+ data = "Unexpected host"
+ self._send_reply(data.encode("utf-8"))
+
+ if not self._validate_signature(self.headers, raw_input):
+ data = "Bad Signature"
+ self._send_reply(data.encode("utf-8"))
+
+
# X-Amz-Target: TrentService.Encrypt
aws_operation = self.headers['X-Amz-Target']
@@ -121,6 +141,31 @@ class AwsKmsHandler(http.server.BaseHTTPRequestHandler):
data = "Unknown AWS Operation"
self._send_reply(data.encode("utf-8"))
+
+ def _validate_signature(self, headers, raw_input):
+ auth_header = headers["Authorization"]
+ signed_headers_start = auth_header.find("SignedHeaders")
+ signed_headers = auth_header[signed_headers_start:auth_header.find(",", signed_headers_start)]
+ signed_headers_dict = get_dict_subset(headers, signed_headers)
+
+ request = AWSRequest(method="POST", url="/", data=raw_input, headers=signed_headers_dict)
+ # SigV4Auth assumes this header exists even though it is not required by the algorithm
+ request.context['timestamp'] = headers['X-Amz-Date']
+
+ region_start = auth_header.find("Credential=access/") + len("Credential=access/YYYYMMDD/")
+ region = auth_header[region_start:auth_header.find("/", region_start)]
+
+ credentials = Credentials("access", "secret")
+ auth = SigV4Auth(credentials, "kms", region)
+ string_to_sign = auth.string_to_sign(request, auth.canonical_request(request))
+ expected_signature = auth.signature(string_to_sign, request)
+
+ signature_headers_start = auth_header.find("Signature=") + len("Signature=")
+ actual_signature = auth_header[signature_headers_start:]
+
+ return expected_signature == actual_signature
+
+
def _do_encrypt(self, raw_input):
request = json.loads(raw_input)
diff --git a/src/mongo/shell/kms_aws.cpp b/src/mongo/shell/kms_aws.cpp
index b923a59355c..1417b847781 100644
--- a/src/mongo/shell/kms_aws.cpp
+++ b/src/mongo/shell/kms_aws.cpp
@@ -159,7 +159,7 @@ public:
BSONObj encryptDataKey(ConstDataRange cdr, StringData keyId) final;
private:
- void initRequest(kms_request_t* request, StringData region);
+ void initRequest(kms_request_t* request, StringData host, StringData region);
private:
// SSL Manager
@@ -181,7 +181,7 @@ void uassertKmsRequestInternal(kms_request_t* request, bool ok) {
#define uassertKmsRequest(X) uassertKmsRequestInternal(request, (X));
-void AWSKMSService::initRequest(kms_request_t* request, StringData region) {
+void AWSKMSService::initRequest(kms_request_t* request, StringData host, StringData region) {
// use current time
uassertKmsRequest(kms_request_set_date(request, nullptr));
@@ -194,6 +194,9 @@ void AWSKMSService::initRequest(kms_request_t* request, StringData region) {
uassertKmsRequest(kms_request_set_access_key_id(request, _config.accessKeyId.c_str()));
uassertKmsRequest(kms_request_set_secret_key(request, _config.secretAccessKey->c_str()));
+ // Set host to be the host we are targeting instead of defaulting to kms.<region>.amazonaws.com
+ uassertKmsRequest(kms_request_add_header_field(request, "Host", host.toString().c_str()));
+
if (!_config.sessionToken.value_or("").empty()) {
// TODO: move this into kms-message
uassertKmsRequest(kms_request_add_header_field(
@@ -250,7 +253,7 @@ std::vector<uint8_t> AWSKMSService::encrypt(ConstDataRange cdr, StringData kmsKe
_server = getDefaultHost(region);
}
- initRequest(request.get(), region);
+ initRequest(request.get(), _server.host(), region);
auto buffer = UniqueKmsCharBuffer(kms_request_get_signed(request.get()));
auto buffer_len = strlen(buffer.get());
@@ -300,12 +303,12 @@ SecureVector<uint8_t> AWSKMSService::decrypt(ConstDataRange cdr, BSONObj masterK
auto request = UniqueKmsRequest(kms_decrypt_request_new(
reinterpret_cast<const uint8_t*>(cdr.data()), cdr.length(), nullptr));
- initRequest(request.get(), awsMasterKey.getRegion());
-
if (_server.empty()) {
_server = getDefaultHost(awsMasterKey.getRegion());
}
+ initRequest(request.get(), _server.host(), awsMasterKey.getRegion());
+
auto buffer = UniqueKmsCharBuffer(kms_request_get_signed(request.get()));
auto buffer_len = strlen(buffer.get());
AWSConnection connection(_sslManager.get());