diff options
Diffstat (limited to 'src/mongo')
-rw-r--r-- | src/mongo/shell/kms_aws.cpp | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/src/mongo/shell/kms_aws.cpp b/src/mongo/shell/kms_aws.cpp index b923a59355c..1417b847781 100644 --- a/src/mongo/shell/kms_aws.cpp +++ b/src/mongo/shell/kms_aws.cpp @@ -159,7 +159,7 @@ public: BSONObj encryptDataKey(ConstDataRange cdr, StringData keyId) final; private: - void initRequest(kms_request_t* request, StringData region); + void initRequest(kms_request_t* request, StringData host, StringData region); private: // SSL Manager @@ -181,7 +181,7 @@ void uassertKmsRequestInternal(kms_request_t* request, bool ok) { #define uassertKmsRequest(X) uassertKmsRequestInternal(request, (X)); -void AWSKMSService::initRequest(kms_request_t* request, StringData region) { +void AWSKMSService::initRequest(kms_request_t* request, StringData host, StringData region) { // use current time uassertKmsRequest(kms_request_set_date(request, nullptr)); @@ -194,6 +194,9 @@ void AWSKMSService::initRequest(kms_request_t* request, StringData region) { uassertKmsRequest(kms_request_set_access_key_id(request, _config.accessKeyId.c_str())); uassertKmsRequest(kms_request_set_secret_key(request, _config.secretAccessKey->c_str())); + // Set host to be the host we are targeting instead of defaulting to kms.<region>.amazonaws.com + uassertKmsRequest(kms_request_add_header_field(request, "Host", host.toString().c_str())); + if (!_config.sessionToken.value_or("").empty()) { // TODO: move this into kms-message uassertKmsRequest(kms_request_add_header_field( @@ -250,7 +253,7 @@ std::vector<uint8_t> AWSKMSService::encrypt(ConstDataRange cdr, StringData kmsKe _server = getDefaultHost(region); } - initRequest(request.get(), region); + initRequest(request.get(), _server.host(), region); auto buffer = UniqueKmsCharBuffer(kms_request_get_signed(request.get())); auto buffer_len = strlen(buffer.get()); @@ -300,12 +303,12 @@ SecureVector<uint8_t> AWSKMSService::decrypt(ConstDataRange cdr, BSONObj masterK auto request = UniqueKmsRequest(kms_decrypt_request_new( reinterpret_cast<const uint8_t*>(cdr.data()), cdr.length(), nullptr)); - initRequest(request.get(), awsMasterKey.getRegion()); - if (_server.empty()) { _server = getDefaultHost(awsMasterKey.getRegion()); } + initRequest(request.get(), _server.host(), awsMasterKey.getRegion()); + auto buffer = UniqueKmsCharBuffer(kms_request_get_signed(request.get())); auto buffer_len = strlen(buffer.get()); AWSConnection connection(_sslManager.get()); |