summaryrefslogtreecommitdiff
path: root/src/mongo/shell/kms_aws.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/shell/kms_aws.cpp')
-rw-r--r--src/mongo/shell/kms_aws.cpp13
1 files changed, 8 insertions, 5 deletions
diff --git a/src/mongo/shell/kms_aws.cpp b/src/mongo/shell/kms_aws.cpp
index b923a59355c..1417b847781 100644
--- a/src/mongo/shell/kms_aws.cpp
+++ b/src/mongo/shell/kms_aws.cpp
@@ -159,7 +159,7 @@ public:
BSONObj encryptDataKey(ConstDataRange cdr, StringData keyId) final;
private:
- void initRequest(kms_request_t* request, StringData region);
+ void initRequest(kms_request_t* request, StringData host, StringData region);
private:
// SSL Manager
@@ -181,7 +181,7 @@ void uassertKmsRequestInternal(kms_request_t* request, bool ok) {
#define uassertKmsRequest(X) uassertKmsRequestInternal(request, (X));
-void AWSKMSService::initRequest(kms_request_t* request, StringData region) {
+void AWSKMSService::initRequest(kms_request_t* request, StringData host, StringData region) {
// use current time
uassertKmsRequest(kms_request_set_date(request, nullptr));
@@ -194,6 +194,9 @@ void AWSKMSService::initRequest(kms_request_t* request, StringData region) {
uassertKmsRequest(kms_request_set_access_key_id(request, _config.accessKeyId.c_str()));
uassertKmsRequest(kms_request_set_secret_key(request, _config.secretAccessKey->c_str()));
+ // Set host to be the host we are targeting instead of defaulting to kms.<region>.amazonaws.com
+ uassertKmsRequest(kms_request_add_header_field(request, "Host", host.toString().c_str()));
+
if (!_config.sessionToken.value_or("").empty()) {
// TODO: move this into kms-message
uassertKmsRequest(kms_request_add_header_field(
@@ -250,7 +253,7 @@ std::vector<uint8_t> AWSKMSService::encrypt(ConstDataRange cdr, StringData kmsKe
_server = getDefaultHost(region);
}
- initRequest(request.get(), region);
+ initRequest(request.get(), _server.host(), region);
auto buffer = UniqueKmsCharBuffer(kms_request_get_signed(request.get()));
auto buffer_len = strlen(buffer.get());
@@ -300,12 +303,12 @@ SecureVector<uint8_t> AWSKMSService::decrypt(ConstDataRange cdr, BSONObj masterK
auto request = UniqueKmsRequest(kms_decrypt_request_new(
reinterpret_cast<const uint8_t*>(cdr.data()), cdr.length(), nullptr));
- initRequest(request.get(), awsMasterKey.getRegion());
-
if (_server.empty()) {
_server = getDefaultHost(awsMasterKey.getRegion());
}
+ initRequest(request.get(), _server.host(), awsMasterKey.getRegion());
+
auto buffer = UniqueKmsCharBuffer(kms_request_get_signed(request.get()));
auto buffer_len = strlen(buffer.get());
AWSConnection connection(_sslManager.get());