diff options
31 files changed, 3386 insertions, 41 deletions
diff --git a/etc/evergreen_yml_components/definitions.yml b/etc/evergreen_yml_components/definitions.yml index 936bb57c918..6d3c4af378c 100644 --- a/etc/evergreen_yml_components/definitions.yml +++ b/etc/evergreen_yml_components/definitions.yml @@ -1394,6 +1394,15 @@ functions: args: - "./src/evergreen/jepsen_test_fail.sh" + "load aws test credentials": + - *f_expansions_write + - command: subprocess.exec + params: + binary: bash + silent: true + args: + - "./src/evergreen/functions/aws_test_credentials_load.sh" + "setup jstestfuzz": - *f_expansions_write - command: subprocess.exec @@ -6458,6 +6467,7 @@ tasks: tags: ["encrypt", "patch_build"] commands: - func: "do setup" + - func: "load aws test credentials" - func: "run tests" vars: resmoke_jobs_max: 1 @@ -6467,6 +6477,7 @@ tasks: tags: ["encrypt", "patch_build"] commands: - func: "do setup" + - func: "load aws test credentials" - func: "run tests" - <<: *task_template diff --git a/evergreen/functions/aws_test_credentials_load.sh b/evergreen/functions/aws_test_credentials_load.sh new file mode 100644 index 00000000000..289c125f9b5 --- /dev/null +++ b/evergreen/functions/aws_test_credentials_load.sh @@ -0,0 +1,15 @@ +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" > /dev/null 2>&1 && pwd)" +. "$DIR/../prelude.sh" + +cd src + +set -o errexit +echo "const AWS_KMS_SECRET_ID = '${aws_kms_access_key_id}';" >> src/mongo/db/modules/enterprise/jstests/fle/lib/aws_secrets.js +echo "const AWS_KMS_SECRET_KEY = '${aws_kms_secret_access_key}';" >> src/mongo/db/modules/enterprise/jstests/fle/lib/aws_secrets.js + +echo "const KMS_GCP_EMAIL = '${kms_gcp_email}'; " >> src/mongo/db/modules/enterprise/jstests/fle/lib/aws_secrets.js +echo "const KMS_GCP_PRIVATEKEY = '${kms_gcp_privatekey}'; " >> src/mongo/db/modules/enterprise/jstests/fle/lib/aws_secrets.js + +echo "const KMS_AZURE_TENANT_ID = '${kms_azure_tenant_id}';" >> src/mongo/db/modules/enterprise/jstests/fle/lib/aws_secrets.js +echo "const KMS_AZURE_CLIENT_ID = '${kms_azure_client_id}';" >> src/mongo/db/modules/enterprise/jstests/fle/lib/aws_secrets.js +echo "const KMS_AZURE_CLIENT_SECRET = '${kms_azure_client_secret}';" >> src/mongo/db/modules/enterprise/jstests/fle/lib/aws_secrets.js diff --git a/jstests/client_encrypt/fle_auto_decrypt.js b/jstests/client_encrypt/fle_auto_decrypt.js index 650125e0fb7..b74a2ad421a 100644 --- a/jstests/client_encrypt/fle_auto_decrypt.js +++ b/jstests/client_encrypt/fle_auto_decrypt.js @@ -1,6 +1,7 @@ // Test to ensure that the client community shell auto decrypts an encrypted field // stored in the database if it has the correct credentials. +load("jstests/client_encrypt/lib/mock_kms.js"); load('jstests/ssl/libs/ssl_helpers.js'); (function() { diff --git a/jstests/client_encrypt/fle_aws_faults.js b/jstests/client_encrypt/fle_aws_faults.js new file mode 100644 index 00000000000..28b06a55bee --- /dev/null +++ b/jstests/client_encrypt/fle_aws_faults.js @@ -0,0 +1,164 @@ +/** + * Verify the AWS KMS implementation can handle a buggy KMS. + */ + +load("jstests/client_encrypt/lib/mock_kms.js"); +load('jstests/ssl/libs/ssl_helpers.js'); + +(function() { +"use strict"; + +const x509_options = { + sslMode: "requireSSL", + sslPEMKeyFile: SERVER_CERT, + sslCAFile: CA_CERT +}; + +const randomAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Random"; + +const conn = MongoRunner.runMongod(x509_options); +const test = conn.getDB("test"); +const collection = test.coll; + +function runKMS(mock_kms, func) { + mock_kms.start(); + + const awsKMS = { + accessKeyId: "access", + secretAccessKey: "secret", + url: mock_kms.getURL(), + }; + + const clientSideFLEOptions = { + kmsProviders: { + aws: awsKMS, + }, + keyVaultNamespace: "test.coll", + schemaMap: {} + }; + + const shell = Mongo(conn.host, clientSideFLEOptions); + const cleanCacheShell = Mongo(conn.host, clientSideFLEOptions); + + collection.drop(); + + func(shell, cleanCacheShell); + + mock_kms.stop(); +} + +function testWrongKeyType() { + const awsKMS = {accessKeyId: "access", secretAccessKey: "secret", url: "localhost:8000"}; + + const clientSideFLEOptions = { + kmsProviders: { + aws: awsKMS, + }, + keyVaultNamespace: "test.coll", + schemaMap: {} + }; + + const shell = Mongo(conn.host, clientSideFLEOptions); + + collection.drop(); + + const keyVault = shell.getKeyVault(); + + assert.throws(() => keyVault.createKey( + "aws", + {"region": "us-east-1", "key": "arn:aws:kms:us-east-1:fake:fake:fake"}, + ["mongoKey"])); +} + +testWrongKeyType(); + +function testBadEncryptResult(fault) { + const mock_kms = new MockKMSServerAWS(fault, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + + assert.throws( + () => keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake", ["mongoKey"])); + assert.eq(keyVault.getKeys("mongoKey").toArray().length, 0); + }); +} + +testBadEncryptResult(FAULT_ENCRYPT); +testBadEncryptResult(FAULT_ENCRYPT_WRONG_FIELDS); +testBadEncryptResult(FAULT_ENCRYPT_BAD_BASE64); + +function testBadEncryptError() { + const mock_kms = new MockKMSServerAWS(FAULT_ENCRYPT_CORRECT_FORMAT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + + let error = assert.throws( + () => keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake", ["mongoKey"])); + assert.commandFailedWithCode(error, [51224]); + assert.eq(error, + "Error: AWS KMS failed to encrypt: NotFoundException : Error encrypting message"); + }); +} + +testBadEncryptError(); + +function testBadDecryptResult(fault) { + const mock_kms = new MockKMSServerAWS(fault, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + const keyId = + keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake", ["mongoKey"]); + const str = "mongo"; + assert.throws(() => { + const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm); + }); + }); +} + +testBadDecryptResult(FAULT_DECRYPT); + +function testBadDecryptKeyResult(fault) { + const mock_kms = new MockKMSServerAWS(fault, true); + + runKMS(mock_kms, (shell, cleanCacheShell) => { + const keyVault = shell.getKeyVault(); + + keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake", ["mongoKey"]); + const keyId = keyVault.getKeys("mongoKey").toArray()[0]._id; + const str = "mongo"; + const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm); + + mock_kms.enableFaults(); + + assert.throws(() => { + var str = cleanCacheShell.decrypt(encStr); + }); + }); +} + +testBadDecryptKeyResult(FAULT_DECRYPT_WRONG_KEY); + +function testBadDecryptError() { + const mock_kms = new MockKMSServerAWS(FAULT_DECRYPT_CORRECT_FORMAT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake", ["mongoKey"]); + const keyId = keyVault.getKeys("mongoKey").toArray()[0]._id; + const str = "mongo"; + let error = assert.throws(() => { + const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm); + }); + assert.commandFailedWithCode(error, [51225]); + assert.eq(error, + "Error: AWS KMS failed to decrypt: NotFoundException : Error decrypting message"); + }); +} + +testBadDecryptError(); + +MongoRunner.stopMongod(conn); +}());
\ No newline at end of file diff --git a/jstests/client_encrypt/fle_azure_faults.js b/jstests/client_encrypt/fle_azure_faults.js new file mode 100644 index 00000000000..dfcf121844a --- /dev/null +++ b/jstests/client_encrypt/fle_azure_faults.js @@ -0,0 +1,175 @@ +/** + * Verify the Azure KMS implementation can handle a buggy KMS. + */ + +load("jstests/client_encrypt/lib/mock_kms.js"); +load('jstests/ssl/libs/ssl_helpers.js'); + +(function() { +"use strict"; + +const x509_options = { + sslMode: "requireSSL", + sslPEMKeyFile: SERVER_CERT, + sslCAFile: CA_CERT +}; + +const mockKey = { + keyName: "my_key", + keyVaultEndpoint: "https://localhost:80", +}; + +const randomAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Random"; + +const conn = MongoRunner.runMongod(x509_options); +const test = conn.getDB("test"); +const collection = test.coll; + +function runKMS(mock_kms, func) { + mock_kms.start(); + + const azureKMS = { + tenantId: "my_tentant", + clientId: "access@mongodb.com", + clientSecret: "secret", + identityPlatformEndpoint: mock_kms.getURL(), + }; + + const clientSideFLEOptions = { + kmsProviders: { + azure: azureKMS, + }, + keyVaultNamespace: "test.coll", + schemaMap: {}, + }; + + const shell = Mongo(conn.host, clientSideFLEOptions); + const cleanCacheShell = Mongo(conn.host, clientSideFLEOptions); + + collection.drop(); + + func(shell, cleanCacheShell); + + mock_kms.stop(); +} + +// OAuth faults must be tested first so a cached token cannot be used +function testBadOAuthRequestResult() { + const mock_kms = new MockKMSServerAzure(FAULT_OAUTH, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + + const error = assert.throws(() => keyVault.createKey("azure", mockKey, ["mongoKey"])); + assert.eq( + error, + "Error: code 9: FailedToParse: Expecting '{': offset:0 of:Internal Error of some sort."); + }); +} + +testBadOAuthRequestResult(); + +function testBadOAuthRequestError() { + const mock_kms = new MockKMSServerAzure(FAULT_OAUTH_CORRECT_FORMAT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + + const error = assert.throws(() => keyVault.createKey("azure", mockKey, ["mongoKey"])); + assert.commandFailedWithCode(error, [ErrorCodes.OperationFailed]); + assert.eq( + error, + "Error: Failed to make oauth request: Azure OAuth Error : FAULT_OAUTH_CORRECT_FORMAT"); + }); +} + +testBadOAuthRequestError(); + +function testBadEncryptResult() { + const mock_kms = new MockKMSServerAzure(FAULT_ENCRYPT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + mockKey.keyVaultEndpoint = mock_kms.getEndpoint(); + + assert.throws(() => keyVault.createKey("azure", mockKey, ["mongoKey"])); + assert.eq(keyVault.getKeys("mongoKey").toArray().length, 0); + }); +} + +testBadEncryptResult(); + +function testBadEncryptError() { + const mock_kms = new MockKMSServerAzure(FAULT_ENCRYPT_CORRECT_FORMAT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + mockKey.keyVaultEndpoint = mock_kms.getEndpoint(); + + let error = assert.throws(() => keyVault.createKey("azure", mockKey, ["mongoKey"])); + assert.commandFailedWithCode(error, [5265103]); + }); +} + +testBadEncryptError(); + +function testBadDecryptResult() { + const mock_kms = new MockKMSServerAzure(FAULT_DECRYPT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + mockKey.keyVaultEndpoint = mock_kms.getEndpoint(); + + const keyId = keyVault.createKey("azure", mockKey, ["mongoKey"]); + const str = "mongo"; + assert.throws(() => { + const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm); + }); + }); +} + +testBadDecryptResult(); + +function testBadDecryptKeyResult() { + const mock_kms = new MockKMSServerAzure(FAULT_DECRYPT_WRONG_KEY, true); + + runKMS(mock_kms, (shell, cleanCacheShell) => { + const keyVault = shell.getKeyVault(); + mockKey.keyVaultEndpoint = mock_kms.getEndpoint(); + + keyVault.createKey("azure", mockKey, ["mongoKey"]); + const keyId = keyVault.getKeys("mongoKey").toArray()[0]._id; + const str = "mongo"; + const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm); + + mock_kms.enableFaults(); + + assert.throws(() => { + let str = cleanCacheShell.decrypt(encStr); + }); + }); +} + +testBadDecryptKeyResult(); + +function testBadDecryptError() { + const mock_kms = new MockKMSServerAzure(FAULT_DECRYPT_CORRECT_FORMAT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + mockKey.keyVaultEndpoint = mock_kms.getEndpoint(); + + keyVault.createKey("azure", mockKey, ["mongoKey"]); + const keyId = keyVault.getKeys("mongoKey").toArray()[0]._id; + const str = "mongo"; + let error = assert.throws(() => { + const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm); + }); + assert.commandFailedWithCode(error, [5265103]); + }); +} + +testBadDecryptError(); + +MongoRunner.stopMongod(conn); +})(); diff --git a/jstests/client_encrypt/fle_encrypt_decrypt_shell.js b/jstests/client_encrypt/fle_encrypt_decrypt_shell.js index 27b6e3e5a51..e586a25320b 100644 --- a/jstests/client_encrypt/fle_encrypt_decrypt_shell.js +++ b/jstests/client_encrypt/fle_encrypt_decrypt_shell.js @@ -1,11 +1,15 @@ /** * Check the functionality of encrypt and decrypt functions in KeyStore.js */ +load("jstests/client_encrypt/lib/mock_kms.js"); load('jstests/ssl/libs/ssl_helpers.js'); (function() { "use strict"; +const mock_kms = new MockKMSServerAWS(); +mock_kms.start(); + const x509_options = { sslMode: "requireSSL", sslPEMKeyFile: SERVER_CERT, @@ -16,6 +20,12 @@ const conn = MongoRunner.runMongod(x509_options); const test = conn.getDB("test"); const collection = test.coll; +const awsKMS = { + accessKeyId: "access", + secretAccessKey: "secret", + url: mock_kms.getURL(), +}; + let localKMS = { key: BinData( 0, @@ -24,13 +34,14 @@ let localKMS = { const clientSideFLEOptions = { kmsProviders: { + aws: awsKMS, local: localKMS, }, keyVaultNamespace: "test.coll", schemaMap: {} }; -const kmsTypes = ["local"]; +const kmsTypes = ["aws", "local"]; const randomAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Random"; const deterministicAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"; @@ -113,4 +124,5 @@ for (const kmsType of kmsTypes) { } MongoRunner.stopMongod(conn); +mock_kms.stop(); }()); diff --git a/jstests/client_encrypt/fle_gcp_faults.js b/jstests/client_encrypt/fle_gcp_faults.js new file mode 100644 index 00000000000..9bff963dfeb --- /dev/null +++ b/jstests/client_encrypt/fle_gcp_faults.js @@ -0,0 +1,190 @@ +/** + * Verify the GCP KMS implementation can handle a buggy KMS. + */ + +load("jstests/client_encrypt/lib/mock_kms.js"); +load('jstests/ssl/libs/ssl_helpers.js'); + +(function() { +"use strict"; + +const x509_options = { + sslMode: "requireSSL", + sslPEMKeyFile: SERVER_CERT, + sslCAFile: CA_CERT +}; + +const mockKey = { + projectId: "mock", + location: "global", + keyRing: "mock-key-ring", + keyName: "mock-key" +}; + +const randomAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Random"; + +const conn = MongoRunner.runMongod(x509_options); +const test = conn.getDB("test"); +const collection = test.coll; + +function runKMS(mock_kms, func) { + mock_kms.start(); + + const gcpKMS = { + email: "access@mongodb.com", + endpoint: mock_kms.getURL(), + // Arbitrary private key generated by openssl encoded in base64. A correctly formatted key + // is required so the shell doesn't reject it, but the mock server doesn't actually use the + // key to encrypt/decrypt. + // The key is generated with `openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048` + // which fulfills Windows's stricter requirements for RSA private keys + privateKey: + "MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDDF3EoKeYp4GKgVckkSxe9Hr81TX/GHW+vSkwxZpI0LOia" + + "rffIbv1IZ8xBD+HEL16E60pPZDYyNsmvNe6HRacSukTBM1peO5DrbM0VglEudoKc7TIT6trKFb1srd6WDUoGLZ6Xgm2KCmqZesV9VnWe" + + "aLmbTa3TAZVn8nfDoRdLVl2ecE8BxrH63niMLQZl8qhcmeu61RQkOHbupTTovgH3PkufCOqgB01QHEcKgvPQBWqQvjms66U6Nfko2GMi" + + "V5lPWWsJVBfNFU96M63QRxmjPWRoNNhui4xGPcm8u9GXo0p+83Ct2NZ7o3SxpaVUQaC2a2vbjADSI9G95fxM9z/zAgMBAAECggEBALOu" + + "/pi4ZnXBZfU4rcaQpy+XhxKH65xD9l6jdqO1TgliJ2Z3vpTLrNqoR1bRUuYHnu2bbFjM+qGrFn0aljPe8i9sgfDT5HKQODytfAJIgY7i" + + "tg/k40+26oZgGZRkW3MmkDw6fiwbg9o1F9N+YTC8lh4tZG3m0KdceQhBKQ90amkGaWsunC5/ZyI7Ip8JFFmkmoZHpiKJPYHI3dMNQyJK" + + "cExSkw050UILIblPeA0AeDrtSBrz/5KfPKy0Wsh3cpuM9lE4dqupCnvp/GUZcSoG4NLg+I1FHQ/FkvY8AXtadES0JAty3E9GC60IWwvg" + + "tdajFo4e1wlfyLgCFhAdsxslBokCgYEA7Emng/zBBF6IecuZWlekM286sUhLE2fSFRg+i5CjqsIJ0yAAIKAy5Qs8PV0TXtgLEg7M0nxG" + + "9mnstkN5M2FSD4qEPe/teOJSzxgRavvqQCMtB9ABtz3nsXDQZbr5+/0bUxOsDbYW22t8dCavZ1u1SG+eC5PH/AGXLATBqR8Lfy8CgYEA" + + "0134dkbZzy07z5cHDzTmsBZnhJt8W+Hix+5IViSnw3W+OdJFjMmUW8OmmBoMqzl69EaIU44bNpcLE1UOqNa74OGdNMid3p6JLUPSsiuh" + + "9J1tH03+ZvubJQZiDiXDyXO7OeAX3ZQv0GAnK+LpYClcI50DwE655zzdbIn39iiuun0CgYEAn9XDBzl2p6n6z8i117LpVBGttjac8meM" + + "aNCZnncc/2l6k+JVs7wqMV3ERg4sCEBEXNa+HrQKnK1Sfoht+B+hDvo4Ml2WWete8M/rGF+IOhKRZ3OBdZ7el90kW2x7pcW1MiFghXXj" + + "SFIRQdDZXiVfH7zBQDubUBETXadqCSkC8ekCgYA6F22eNELQqgHyP/P0vflZFA9HZuR67E5D3L2Mz248TjQF+ECdPRnFTrSOwToSJS4h" + + "zPDS5g+cpU6p9Yqd5MamO9vVEf4xnSjeg/F4fn14mXvQSsNM0oIFXwe8E60HxQMEGQ72GzA4+PRLH4Y8o6FrOFA7nmeBojzJA/JeeTfs" + + "kQKBgQCh2Y1oIk7VcaP8F5vg1tw8VUuPadU3WqtDscenAli5Syp3ngDoIjtVII6mM8DxIE9tOrI+F653T/xcuVXjnPIymZh5LQttNABW" + + "ZE1RXp16C1uFsg3F9U1wIljep+F6D/mjySkoaM8PwE1miLwThPGvgt2YSCyuig1OMOzOXlIUnQ==", + }; + + const clientSideFLEOptions = { + kmsProviders: { + gcp: gcpKMS, + }, + keyVaultNamespace: "test.coll", + schemaMap: {}, + }; + + const shell = Mongo(conn.host, clientSideFLEOptions); + const cleanCacheShell = Mongo(conn.host, clientSideFLEOptions); + + collection.drop(); + + func(shell, cleanCacheShell); + + mock_kms.stop(); +} + +// OAuth faults must be tested first so a cached token can't be used +function testBadOAuthRequestResult() { + const mock_kms = new MockKMSServerGCP(FAULT_OAUTH, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + + const error = assert.throws(() => keyVault.createKey("gcp", mockKey, ["mongoKey"])); + assert.eq( + error, + "Error: code 9: FailedToParse: Expecting '{': offset:0 of:Internal Error of some sort."); + }); +} + +testBadOAuthRequestResult(); + +function testBadOAuthRequestError() { + const mock_kms = new MockKMSServerGCP(FAULT_OAUTH_CORRECT_FORMAT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + + const error = assert.throws(() => keyVault.createKey("gcp", mockKey, ["mongoKey"])); + assert.commandFailedWithCode(error, [ErrorCodes.OperationFailed]); + assert.eq( + error, + "Error: Failed to make oauth request: GCP OAuth Error : FAULT_OAUTH_CORRECT_FORMAT"); + }); +} + +testBadOAuthRequestError(); + +function testBadEncryptResult() { + const mock_kms = new MockKMSServerGCP(FAULT_ENCRYPT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + + assert.throws(() => keyVault.createKey("gcp", mockKey, ["mongoKey"])); + assert.eq(keyVault.getKeys("mongoKey").toArray().length, 0); + }); +} + +testBadEncryptResult(); + +function testBadEncryptError() { + const mock_kms = new MockKMSServerGCP(FAULT_ENCRYPT_CORRECT_FORMAT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + let error = assert.throws(() => keyVault.createKey("gcp", mockKey, ["mongoKey"])); + assert.commandFailedWithCode(error, [5256006]); + }); +} + +testBadEncryptError(); + +function testBadDecryptResult() { + const mock_kms = new MockKMSServerGCP(FAULT_DECRYPT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + const keyId = keyVault.createKey("gcp", mockKey, ["mongoKey"]); + const str = "mongo"; + assert.throws(() => { + const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm); + }); + }); +} + +testBadDecryptResult(); + +function testBadDecryptKeyResult() { + const mock_kms = new MockKMSServerGCP(FAULT_DECRYPT_WRONG_KEY, true); + + runKMS(mock_kms, (shell, cleanCacheShell) => { + const keyVault = shell.getKeyVault(); + + keyVault.createKey("gcp", mockKey, ["mongoKey"]); + const keyId = keyVault.getKeys("mongoKey").toArray()[0]._id; + const str = "mongo"; + const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm); + + mock_kms.enableFaults(); + + assert.throws(() => { + let str = cleanCacheShell.decrypt(encStr); + }); + }); +} + +testBadDecryptKeyResult(); + +function testBadDecryptError() { + const mock_kms = new MockKMSServerGCP(FAULT_DECRYPT_CORRECT_FORMAT, false); + + runKMS(mock_kms, (shell) => { + const keyVault = shell.getKeyVault(); + + keyVault.createKey("gcp", mockKey, ["mongoKey"]); + const keyId = keyVault.getKeys("mongoKey").toArray()[0]._id; + const str = "mongo"; + let error = assert.throws(() => { + const encStr = shell.getClientEncryption().encrypt(keyId, str, randomAlgorithm); + }); + assert.commandFailedWithCode(error, [5256008]); + }); +} + +testBadDecryptError(); + +MongoRunner.stopMongod(conn); +})(); diff --git a/jstests/client_encrypt/fle_key_faults.js b/jstests/client_encrypt/fle_key_faults.js index 204ee277ec5..d58804460e6 100644 --- a/jstests/client_encrypt/fle_key_faults.js +++ b/jstests/client_encrypt/fle_key_faults.js @@ -2,11 +2,15 @@ * Verify the KMS support handles a buggy Key Store */ +load("jstests/client_encrypt/lib/mock_kms.js"); load('jstests/ssl/libs/ssl_helpers.js'); (function() { "use strict"; +const mock_kms = new MockKMSServerAWS(); +mock_kms.start(); + const x509_options = { sslMode: "requireSSL", sslPEMKeyFile: SERVER_CERT, @@ -17,6 +21,12 @@ const conn = MongoRunner.runMongod(x509_options); const test = conn.getDB("test"); const collection = test.coll; +const awsKMS = { + accessKeyId: "access", + secretAccessKey: "secret", + url: mock_kms.getURL(), +}; + var localKMS = { key: BinData( 0, @@ -25,6 +35,7 @@ var localKMS = { const clientSideFLEOptions = { kmsProviders: { + aws: awsKMS, local: localKMS, }, keyVaultNamespace: "test.coll", @@ -44,7 +55,7 @@ function testFault(kmsType, func) { } function testFaults(func) { - const kmsTypes = ["local"]; + const kmsTypes = ["aws", "local"]; for (const kmsType of kmsTypes) { testFault(kmsType, func); @@ -82,4 +93,5 @@ testFaults((keyId, shell) => { }); MongoRunner.stopMongod(conn); +mock_kms.stop(); }());
\ No newline at end of file diff --git a/jstests/client_encrypt/fle_keys.js b/jstests/client_encrypt/fle_keys.js index 93fb0000478..cbe32b6b586 100644 --- a/jstests/client_encrypt/fle_keys.js +++ b/jstests/client_encrypt/fle_keys.js @@ -2,11 +2,15 @@ * Check functionality of KeyVault.js */ +load("jstests/client_encrypt/lib/mock_kms.js"); load('jstests/ssl/libs/ssl_helpers.js'); (function() { "use strict"; +const mock_kms = new MockKMSServerAWS(); +mock_kms.start(); + const x509_options = { sslMode: "requireSSL", sslPEMKeyFile: SERVER_CERT, @@ -17,15 +21,15 @@ const conn = MongoRunner.runMongod(x509_options); const test = conn.getDB("test"); const collection = test.coll; -const localKMS = { - key: BinData( - 0, - "tu9jUCBqZdwCelwE/EAm/4WqdxrSMi04B8e9uAV+m30rI1J2nhKZZtQjdvsSCwuI4erR6IEcEK+5eGUAODv43NDNIR9QheT2edWFewUfHKsl9cnzTc86meIzOmYl6drp"), +const awsKMS = { + accessKeyId: "access", + secretAccessKey: "secret", + url: mock_kms.getURL(), }; const clientSideFLEOptions = { kmsProviders: { - local: localKMS, + aws: awsKMS, }, keyVaultNamespace: "test.coll", schemaMap: {} @@ -35,13 +39,13 @@ const conn_str = "mongodb://" + conn.host + "/?ssl=true"; const shell = Mongo(conn_str, clientSideFLEOptions); const keyVault = shell.getKeyVault(); -keyVault.createKey("local", ['mongoKey']); +keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake", ['mongoKey']); assert.eq(1, keyVault.getKeys().itcount()); -var result = keyVault.createKey("local", "fake", {}); +var result = keyVault.createKey("aws", "arn:aws:kms:us-east-4:fake:fake:fake", {}); assert.eq("TypeError: key alternate names must be of Array type.", result); -result = keyVault.createKey("local", [1]); +result = keyVault.createKey("aws", "arn:aws:kms:us-east-5:fake:fake:fake", [1]); assert.eq("TypeError: items in key alternate names must be of String type.", result); assert.eq(1, keyVault.getKeyByAltName("mongoKey").itcount()); @@ -68,11 +72,12 @@ result = keyVault.deleteKey(keyId); assert.eq(0, keyVault.getKey(keyId).itcount()); assert.eq(0, keyVault.getKeys().itcount()); -keyVault.createKey("local", ['mongoKey1']); -keyVault.createKey("local", ['mongoKey2']); -keyVault.createKey("local", ['mongoKey3']); +keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake1"); +keyVault.createKey("aws", "arn:aws:kms:us-east-2:fake:fake:fake2"); +keyVault.createKey("aws", "arn:aws:kms:us-east-3:fake:fake:fake3"); assert.eq(3, keyVault.getKeys().itcount()); MongoRunner.stopMongod(conn); +mock_kms.stop(); }());
\ No newline at end of file diff --git a/jstests/client_encrypt/fle_valid_fle_options.js b/jstests/client_encrypt/fle_valid_fle_options.js index 74e77443ec2..e95eb89795f 100644 --- a/jstests/client_encrypt/fle_valid_fle_options.js +++ b/jstests/client_encrypt/fle_valid_fle_options.js @@ -1,9 +1,13 @@ +load("jstests/client_encrypt/lib/mock_kms.js"); load('jstests/ssl/libs/ssl_helpers.js'); (function() { "use strict"; +const mock_kms = new MockKMSServerAWS(); +mock_kms.start(); + const randomAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Random"; const deterministicAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"; @@ -18,16 +22,16 @@ const conn = MongoRunner.runMongod(x509_options); const unencryptedDatabase = conn.getDB("test"); const collection = unencryptedDatabase.keystore; -const localKMS = { - key: BinData( - 0, - "tu9jUCBqZdwCelwE/EAm/4WqdxrSMi04B8e9uAV+m30rI1J2nhKZZtQjdvsSCwuI4erR6IEcEK+5eGUAODv43NDNIR9QheT2edWFewUfHKsl9cnzTc86meIzOmYl6drp"), +const awsKMS = { + accessKeyId: "access", + secretAccessKey: "secret", + url: mock_kms.getURL(), }; const clientSideFLEOptionsFail = [ { kmsProviders: { - local: localKMS, + aws: awsKMS, }, schemaMap: {}, }, @@ -44,7 +48,7 @@ clientSideFLEOptionsFail.forEach(element => { const clientSideFLEOptionsPass = [ { kmsProviders: { - local: localKMS, + aws: awsKMS, }, keyVaultNamespace: "test.keystore", schemaMap: {}, @@ -58,4 +62,5 @@ clientSideFLEOptionsPass.forEach(element => { }); MongoRunner.stopMongod(conn); +mock_kms.stop(); }()); diff --git a/jstests/client_encrypt/lib/fle_command_line_explicit_encryption.js b/jstests/client_encrypt/lib/fle_command_line_explicit_encryption.js index 3cef4036e5f..1e89f348917 100644 --- a/jstests/client_encrypt/lib/fle_command_line_explicit_encryption.js +++ b/jstests/client_encrypt/lib/fle_command_line_explicit_encryption.js @@ -3,9 +3,14 @@ * jstests/fle/fle_command_line_encryption.js. */ +load("jstests/client_encrypt/lib/mock_kms.js"); + (function() { "use strict"; +const mock_kms = new MockKMSServerAWS(); +mock_kms.start(); + const shell = Mongo(); const keyVault = shell.getKeyVault(); @@ -46,7 +51,7 @@ const failTestCases = [null, undefined, MinKey(), MaxKey(), DBRef("test", "test" for (const encryptionAlgorithm of encryptionAlgorithms) { collection.drop(); - keyVault.createKey("local", ['mongoKey']); + keyVault.createKey("aws", "arn:aws:kms:us-east-1:fake:fake:fake", ['mongoKey']); const keyId = keyVault.getKeyByAltName("mongoKey").toArray()[0]._id; let pass; @@ -73,5 +78,6 @@ for (const encryptionAlgorithm of encryptionAlgorithms) { } } +mock_kms.stop(); print("Test completed with no errors."); }()); diff --git a/jstests/client_encrypt/lib/kms_http_common.py b/jstests/client_encrypt/lib/kms_http_common.py new file mode 100644 index 00000000000..2ee0828d9cf --- /dev/null +++ b/jstests/client_encrypt/lib/kms_http_common.py @@ -0,0 +1,148 @@ +"""Common code for mock kms http endpoint.""" +import http.server +import json +import ssl +import urllib.parse +from abc import abstractmethod + +URL_PATH_STATS = "/stats" +URL_DISABLE_FAULTS = "/disable_faults" +URL_ENABLE_FAULTS = "/enable_faults" + +"""Fault which causes encrypt to return 500.""" +FAULT_ENCRYPT = "fault_encrypt" + +"""Fault which causes encrypt to return an error that contains a type and message""" +FAULT_ENCRYPT_CORRECT_FORMAT = "fault_encrypt_correct_format" + +"""Fault which causes encrypt to return wrong fields in JSON.""" +FAULT_ENCRYPT_WRONG_FIELDS = "fault_encrypt_wrong_fields" + +"""Fault which causes encrypt to return bad BASE64.""" +FAULT_ENCRYPT_BAD_BASE64 = "fault_encrypt_bad_base64" + +"""Fault which causes decrypt to return 500.""" +FAULT_DECRYPT = "fault_decrypt" + +"""Fault which causes decrypt to return an error that contains a type and message""" +FAULT_DECRYPT_CORRECT_FORMAT = "fault_decrypt_correct_format" + +"""Fault which causes decrypt to return wrong key.""" +FAULT_DECRYPT_WRONG_KEY = "fault_decrypt_wrong_key" + +"""Fault which causes an OAuth request to return an 500.""" +FAULT_OAUTH = "fault_oauth" + +"""Fault which causes an OAuth request to return an error response""" +FAULT_OAUTH_CORRECT_FORMAT = "fault_oauth_correct_format" + + +class Stats: + """Stats class shared between client and server.""" + + def __init__(self): + self.encrypt_calls = 0 + self.decrypt_calls = 0 + self.fault_calls = 0 + + def __repr__(self): + return json.dumps({ + 'decrypts': self.decrypt_calls, + 'encrypts': self.encrypt_calls, + 'faults': self.fault_calls, + }) + + +class KmsHandlerBase(http.server.BaseHTTPRequestHandler): + protocol_version = "HTTP/1.1" + + def do_GET(self): + """Serve a Test GET request.""" + print("Received GET: " + self.path) + parts = urllib.parse.urlsplit(self.path) + path = parts[2] + + if path == URL_PATH_STATS: + self._do_stats() + elif path == URL_DISABLE_FAULTS: + self._do_disable_faults() + elif path == URL_ENABLE_FAULTS: + self._do_enable_faults() + else: + self.send_response(http.HTTPStatus.NOT_FOUND) + self.end_headers() + self.wfile.write("Unknown URL".encode()) + + @abstractmethod + def do_POST(self): + """Serve a POST request.""" + pass + + def _send_reply(self, data, status=http.HTTPStatus.OK): + print("Sending Response: " + data.decode()) + + self.send_response(status) + self.send_header("content-type", "application/octet-stream") + self.send_header("Content-Length", str(len(data))) + self.end_headers() + + self.wfile.write(data) + + @abstractmethod + def _do_encrypt(self, raw_input): + pass + + @abstractmethod + def _do_encrypt_faults(self, raw_ciphertext): + pass + + @abstractmethod + def _do_decrypt(self, raw_input): + pass + + @abstractmethod + def _do_decrypt_faults(self, blob): + pass + + def _send_header(self): + self.send_response(http.HTTPStatus.OK) + self.send_header("content-type", "application/octet-stream") + self.end_headers() + + def _do_stats(self): + self._send_header() + + self.wfile.write(str(stats).encode('utf-8')) + + def _do_disable_faults(self): + global disable_faults + disable_faults = True + self._send_header() + + def _do_enable_faults(self): + global disable_faults + disable_faults = False + self._send_header() + + +def run(port, cert_file, ca_file, handler_class, server_class=http.server.HTTPServer): + """Run web server.""" + server_address = ('', port) + + httpd = server_class(server_address, handler_class) + + httpd.socket = ssl.wrap_socket(httpd.socket, + certfile=cert_file, + ca_certs=ca_file, server_side=True) + + print(f"Mock KMS Web Server Listening on {str(server_address)}") + + httpd.serve_forever() + + +# Pass this data out of band instead of storing it in AwsKmsHandler since the +# BaseHTTPRequestHandler does not call the methods as object methods but as class methods. This +# means there is not self. +stats = Stats() +disable_faults = False +fault_type = None diff --git a/jstests/client_encrypt/lib/kms_http_control.py b/jstests/client_encrypt/lib/kms_http_control.py new file mode 100644 index 00000000000..c290ae80a25 --- /dev/null +++ b/jstests/client_encrypt/lib/kms_http_control.py @@ -0,0 +1,52 @@ +#! /usr/bin/env python3 +""" +Python script to interact with mock AWS KMS HTTP server. +""" + +import argparse +import json +import logging +import sys +import urllib.request +import ssl + +import kms_http_common + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser(description='MongoDB Mock KMS Endpoint.') + + parser.add_argument('-p', '--port', type=int, default=8000, help="Port to listen on") + + parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing") + + parser.add_argument('--ca_file', type=str, required=True, help="TLS CA PEM file") + + parser.add_argument('--query', type=str, help="Query endpoint <name>") + + args = parser.parse_args() + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + + url_str = "https://localhost:" + str(args.port) + if args.query == "stats": + url_str += kms_http_common.URL_PATH_STATS + elif args.query == "disable_faults": + url_str += kms_http_common.URL_DISABLE_FAULTS + elif args.query == "enable_faults": + url_str += kms_http_common.URL_ENABLE_FAULTS + else: + print("Unknown query type") + sys.exit(1) + + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=args.ca_file) + + with urllib.request.urlopen(url_str, context=context) as f: + print(f.read().decode('utf-8')) + + sys.exit(0) + + +if __name__ == '__main__': + + main() diff --git a/jstests/client_encrypt/lib/kms_http_server.py b/jstests/client_encrypt/lib/kms_http_server.py new file mode 100755 index 00000000000..489b6230424 --- /dev/null +++ b/jstests/client_encrypt/lib/kms_http_server.py @@ -0,0 +1,244 @@ +#! /usr/bin/env python3 +"""Mock AWS KMS Endpoint.""" + +import argparse +import base64 +import http.server +import json +import logging +import sys +import urllib.parse + +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest +from botocore.credentials import Credentials + +import kms_http_common + +SECRET_PREFIX = "00SECRET" + +# List of supported fault types +SUPPORTED_FAULT_TYPES = [ + kms_http_common.FAULT_ENCRYPT, + kms_http_common.FAULT_ENCRYPT_CORRECT_FORMAT, + kms_http_common.FAULT_ENCRYPT_WRONG_FIELDS, + kms_http_common.FAULT_ENCRYPT_BAD_BASE64, + kms_http_common.FAULT_DECRYPT, + kms_http_common.FAULT_DECRYPT_CORRECT_FORMAT, + kms_http_common.FAULT_DECRYPT_WRONG_KEY, +] + +def get_dict_subset(headers, subset): + ret = {} + for header in headers.keys(): + if header.lower() in subset.lower(): + ret[header] = headers[header] + return ret + +class AwsKmsHandler(kms_http_common.KmsHandlerBase): + """ + Handle requests from AWS KMS Monitoring and test commands + """ + + def do_POST(self): + print("Received POST: " + self.path) + parts = urllib.parse.urlsplit(self.path) + path = parts[2] + + if path == "/": + self._do_post() + else: + self.send_response(http.HTTPStatus.NOT_FOUND) + self.end_headers() + self.wfile.write("Unknown URL".encode()) + + def _do_post(self): + clen = int(self.headers.get('content-length')) + + raw_input = self.rfile.read(clen) + + print("RAW INPUT: " + str(raw_input)) + + if not self.headers["Host"] == "localhost": + data = "Unexpected host" + self._send_reply(data.encode("utf-8")) + + if not self._validate_signature(self.headers, raw_input): + data = "Bad Signature" + self._send_reply(data.encode("utf-8")) + + # X-Amz-Target: TrentService.Encrypt + aws_operation = self.headers['X-Amz-Target'] + + if aws_operation == "TrentService.Encrypt": + kms_http_common.stats.encrypt_calls += 1 + self._do_encrypt(raw_input) + elif aws_operation == "TrentService.Decrypt": + kms_http_common.stats.decrypt_calls += 1 + self._do_decrypt(raw_input) + else: + data = "Unknown AWS Operation" + self._send_reply(data.encode("utf-8")) + + def _validate_signature(self, headers, raw_input): + auth_header = headers["Authorization"] + signed_headers_start = auth_header.find("SignedHeaders") + signed_headers = auth_header[signed_headers_start:auth_header.find(",", signed_headers_start)] + signed_headers_dict = get_dict_subset(headers, signed_headers) + + request = AWSRequest(method="POST", url="/", data=raw_input, headers=signed_headers_dict) + # SigV4Auth assumes this header exists even though it is not required by the algorithm + request.context['timestamp'] = headers['X-Amz-Date'] + + region_start = auth_header.find("Credential=access/") + len("Credential=access/YYYYMMDD/") + region = auth_header[region_start:auth_header.find("/", region_start)] + + credentials = Credentials("access", "secret") + auth = SigV4Auth(credentials, "kms", region) + string_to_sign = auth.string_to_sign(request, auth.canonical_request(request)) + expected_signature = auth.signature(string_to_sign, request) + + signature_headers_start = auth_header.find("Signature=") + len("Signature=") + actual_signature = auth_header[signature_headers_start:] + + return expected_signature == actual_signature + + def _do_encrypt(self, raw_input): + request = json.loads(raw_input) + + print(request) + + plaintext = request["Plaintext"] + keyid = request["KeyId"] + + ciphertext = SECRET_PREFIX.encode() + plaintext.encode() + ciphertext = base64.b64encode(ciphertext).decode() + + if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_ENCRYPT) \ + and not kms_http_common.disable_faults: + return self._do_encrypt_faults(ciphertext) + + response = { + "CiphertextBlob" : ciphertext, + "KeyId" : keyid, + } + + self._send_reply(json.dumps(response).encode('utf-8')) + + def _do_encrypt_faults(self, raw_ciphertext): + kms_http_common.stats.fault_calls += 1 + + if kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT: + self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT_WRONG_FIELDS: + response = { + "SomeBlob" : raw_ciphertext, + "KeyId" : "foo", + } + + self._send_reply(json.dumps(response).encode('utf-8')) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT_BAD_BASE64: + response = { + "CiphertextBlob" : "foo", + "KeyId" : "foo", + } + + self._send_reply(json.dumps(response).encode('utf-8')) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT_CORRECT_FORMAT: + response = { + "__type" : "NotFoundException", + "Message" : "Error encrypting message", + } + + self._send_reply(json.dumps(response).encode('utf-8')) + return + + raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type) + + def _do_decrypt(self, raw_input): + request = json.loads(raw_input) + blob = base64.b64decode(request["CiphertextBlob"]).decode() + + print("FOUND SECRET: " + blob) + + # our "encrypted" values start with the word SECRET_PREFIX otherwise they did not come from us + if not blob.startswith(SECRET_PREFIX): + raise ValueError() + + blob = blob[len(SECRET_PREFIX):] + + if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_DECRYPT) \ + and not kms_http_common.disable_faults: + return self._do_decrypt_faults(blob) + + response = { + "Plaintext" : blob, + "KeyId" : "Not a clue", + } + + self._send_reply(json.dumps(response).encode('utf-8')) + + def _do_decrypt_faults(self, blob): + kms_http_common.stats.fault_calls += 1 + + if kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT: + self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_WRONG_KEY: + response = { + "Plaintext" : "ta7DXE7J0OiCRw03dYMJSeb8nVF5qxTmZ9zWmjuX4zW/SOorSCaY8VMTWG+cRInMx/rr/+QeVw2WjU2IpOSvMg==", + "KeyId" : "Not a clue", + } + + self._send_reply(json.dumps(response).encode('utf-8')) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_CORRECT_FORMAT: + response = { + "__type" : "NotFoundException", + "Message" : "Error decrypting message", + } + + self._send_reply(json.dumps(response).encode('utf-8')) + return + + raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type) + +def main(): + """Main Method.""" + parser = argparse.ArgumentParser(description='MongoDB Mock AWS KMS Endpoint.') + + parser.add_argument('-p', '--port', type=int, default=8000, help="Port to listen on") + + parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing") + + parser.add_argument('--fault', type=str, help="Type of fault to inject") + + parser.add_argument('--disable-faults', action='store_true', help="Disable faults on startup") + + parser.add_argument('--ca_file', type=str, required=True, help="TLS CA PEM file") + + parser.add_argument('--cert_file', type=str, required=True, help="TLS Server PEM file") + + args = parser.parse_args() + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + + if args.fault: + if args.fault not in SUPPORTED_FAULT_TYPES: + print("Unsupported fault type %s, supports types are %s" % (args.fault, SUPPORTED_FAULT_TYPES)) + sys.exit(1) + + kms_http_common.fault_type = args.fault + + if args.disable_faults: + kms_http_common.disable_faults = True + + kms_http_common.run(args.port, args.cert_file, args.ca_file, AwsKmsHandler) + + +if __name__ == '__main__': + + main() diff --git a/jstests/client_encrypt/lib/kms_http_server_azure.py b/jstests/client_encrypt/lib/kms_http_server_azure.py new file mode 100755 index 00000000000..af2b7b4a016 --- /dev/null +++ b/jstests/client_encrypt/lib/kms_http_server_azure.py @@ -0,0 +1,239 @@ +#! /usr/bin/env python3 +"""Mock Azure KMS Endpoint.""" +import argparse +import base64 +import http +import json +import logging +import urllib.parse +import sys + +import kms_http_common + +SUPPORTED_FAULT_TYPES = [ + kms_http_common.FAULT_ENCRYPT, + kms_http_common.FAULT_ENCRYPT_CORRECT_FORMAT, + kms_http_common.FAULT_DECRYPT, + kms_http_common.FAULT_DECRYPT_CORRECT_FORMAT, + kms_http_common.FAULT_DECRYPT_WRONG_KEY, + kms_http_common.FAULT_OAUTH, + kms_http_common.FAULT_OAUTH_CORRECT_FORMAT, +] + +SECRET_PREFIX = "00SECRET" +FAKE_OAUTH_TOKEN = "omg_im_an_oauth_token" + +URL_PATH_OAUTH_AUDIENCE = "/token" +URL_PATH_OAUTH_SCOPE = "/auth/cloudkms" +URL_PATH_MOCK_KEY = "/keys/my_key/" + + +class AzureKmsHandler(kms_http_common.KmsHandlerBase): + """ + Handle requests from Azure KMS Monitoring and test commands + """ + + def do_POST(self): + """Serve a POST request.""" + print("Received POST: " + self.path) + parts = urllib.parse.urlsplit(self.path) + path = parts[2] + + if path == "/my_tentant/oauth2/v2.0/token": + self._do_oauth_request() + elif path.startswith(URL_PATH_MOCK_KEY): + self._do_operation() + else: + self.send_response(http.HTTPStatus.NOT_FOUND) + self.end_headers() + self.wfile.write("Unknown URL".encode()) + + def _do_operation(self): + clen = int(self.headers.get("content-length")) + + raw_input = self.rfile.read(clen) + + print(f"RAW INPUT: {str(raw_input)}") + + if not self.headers["Host"] == "localhost": + data = "Unexpected host" + self._send_reply(data.encode("utf-8")) + + if not self.headers["Authorization"] == f"Bearer {FAKE_OAUTH_TOKEN}": + data = "Unexpected bearer token" + self._send_reply(data.encode("utf-8")) + + parts = urllib.parse.urlsplit(self.path) + operation = parts.path.split('/')[-1] + + if operation == "wrapkey": + self._do_encrypt(raw_input) + elif operation == "unwrapkey": + self._do_decrypt(raw_input) + else: + self._send_reply(f"Unknown operation: {operation}".encode("utf-8")) + + def _do_encrypt(self, raw_input): + request = json.loads(raw_input) + + print(request) + + plaintext = request["value"] + + ciphertext = SECRET_PREFIX.encode() + plaintext.encode() + ciphertext = base64.urlsafe_b64encode(ciphertext).decode() + + if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_ENCRYPT) \ + and not kms_http_common.disable_faults: + return self._do_encrypt_faults(ciphertext) + + response = { + "value": ciphertext, + "kid": "my_key", + } + + self._send_reply(json.dumps(response).encode('utf-8')) + + def _do_encrypt_faults(self, raw_ciphertext): + kms_http_common.stats.fault_calls += 1 + + if kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT: + self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT_CORRECT_FORMAT: + response = { + "error": { + "code": "bad", + "message": "Error encrypting message", + } + } + self._send_reply(json.dumps(response).encode('utf-8')) + return + + raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type) + + def _do_decrypt(self, raw_input): + request = json.loads(raw_input) + blob = base64.urlsafe_b64decode(request["value"]).decode() + + print("FOUND SECRET: " + blob) + + # our "encrypted" values start with the word SECRET_PREFIX otherwise they did not come from us + if not blob.startswith(SECRET_PREFIX): + raise ValueError() + + blob = blob[len(SECRET_PREFIX):] + + if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_DECRYPT) \ + and not kms_http_common.disable_faults: + return self._do_decrypt_faults(blob) + + response = { + "kid": "my_key", + "value": blob, + } + + self._send_reply(json.dumps(response).encode('utf-8')) + + def _do_decrypt_faults(self, blob): + kms_http_common.stats.fault_calls += 1 + + if kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT: + self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_WRONG_KEY: + response = { + "kid": "my_key", + "value": "ta7DXE7J0OiCRw03dYMJSeb8nVF5qxTmZ9zWmjuX4zW/SOorSCaY8VMTWG+cRInMx/rr/+QeVw2WjU2IpOSvMg==", + } + self._send_reply(json.dumps(response).encode('utf-8')) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_CORRECT_FORMAT: + response = { + "error": { + "code": "bad", + "message": "Error decrypting message", + } + } + self._send_reply(json.dumps(response).encode('utf-8')) + return + + raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type) + + def _do_oauth_request(self): + clen = int(self.headers.get('content-length')) + + raw_input = self.rfile.read(clen) + + print(f"RAW INPUT: {str(raw_input)}") + + if not self.headers["Host"] == "localhost": + data = "Unexpected host" + self._send_reply(data.encode("utf-8")) + + if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_OAUTH) \ + and not kms_http_common.disable_faults: + return self._do_oauth_faults() + + response = { + "access_token": FAKE_OAUTH_TOKEN, + "scope": self.headers["Host"] + URL_PATH_OAUTH_SCOPE, + "token_type": "Bearer", + "expires_in": 3600, + } + + self._send_reply(json.dumps(response).encode("utf-8")) + + def _do_oauth_faults(self): + kms_http_common.stats.fault_calls += 1 + + if kms_http_common.fault_type == kms_http_common.FAULT_OAUTH: + self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_OAUTH_CORRECT_FORMAT: + response = { + "error": "Azure OAuth Error", + "error_description": "FAULT_OAUTH_CORRECT_FORMAT", + "error_uri": "https://mongodb.com/whoopsies.pdf", + } + self._send_reply(json.dumps(response).encode('utf-8')) + return + + raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type) + + +def main(): + """Main Method.""" + parser = argparse.ArgumentParser(description='MongoDB Mock Azure KMS Endpoint.') + + parser.add_argument('-p', '--port', type=int, default=8000, help="Port to listen on") + + parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing") + + parser.add_argument('--fault', type=str, help="Type of fault to inject") + + parser.add_argument('--disable-faults', action='store_true', help="Disable faults on startup") + + parser.add_argument('--ca_file', type=str, required=True, help="TLS CA PEM file") + + parser.add_argument('--cert_file', type=str, required=True, help="TLS Server PEM file") + + args = parser.parse_args() + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + + if args.fault: + if args.fault not in SUPPORTED_FAULT_TYPES: + print("Unsupported fault type %s, supports types are %s" % (args.fault, SUPPORTED_FAULT_TYPES)) + sys.exit(1) + + kms_http_common.fault_type = args.fault + + if args.disable_faults: + kms_http_common.disable_faults = True + + kms_http_common.run(args.port, args.cert_file, args.ca_file, AzureKmsHandler) + + +if __name__ == '__main__': + main() diff --git a/jstests/client_encrypt/lib/kms_http_server_gcp.py b/jstests/client_encrypt/lib/kms_http_server_gcp.py new file mode 100755 index 00000000000..8ce42081e01 --- /dev/null +++ b/jstests/client_encrypt/lib/kms_http_server_gcp.py @@ -0,0 +1,240 @@ +#! /usr/bin/env python3 +"""Mock GCP KMS Endpoint.""" +import argparse +import base64 +import http +import json +import logging +import urllib.parse +import sys + +import kms_http_common + +SUPPORTED_FAULT_TYPES = [ + kms_http_common.FAULT_ENCRYPT, + kms_http_common.FAULT_ENCRYPT_CORRECT_FORMAT, + kms_http_common.FAULT_DECRYPT, + kms_http_common.FAULT_DECRYPT_CORRECT_FORMAT, + kms_http_common.FAULT_DECRYPT_WRONG_KEY, + kms_http_common.FAULT_OAUTH, + kms_http_common.FAULT_OAUTH_CORRECT_FORMAT, +] + +SECRET_PREFIX = "00SECRET" +FAKE_OAUTH_TOKEN = "omg_im_an_oauth_token" + +URL_PATH_OAUTH_AUDIENCE = "/token" +URL_PATH_OAUTH_SCOPE = "/auth/cloudkms" +URL_PATH_MOCK_KEY = "/v1/projects/mock/locations/global/keyRings/mock-key-ring/cryptoKeys/mock-key" + + +class GcpKmsHandler(kms_http_common.KmsHandlerBase): + """ + Handle requests from GCP KMS Monitoring and test commands + """ + + def do_POST(self): + """Serve a POST request.""" + print("Received POST: " + self.path) + parts = urllib.parse.urlsplit(self.path) + path = parts[2] + + if path == "/token": + self._do_oauth_request() + elif path.startswith(URL_PATH_MOCK_KEY): + self._do_operation() + else: + self.send_response(http.HTTPStatus.NOT_FOUND) + self.end_headers() + self.wfile.write("Unknown URL".encode()) + + def _do_operation(self): + clen = int(self.headers.get("content-length")) + + raw_input = self.rfile.read(clen) + + print(f"RAW INPUT: {str(raw_input)}") + + if not self.headers["Host"] == "localhost": + data = "Unexpected host" + self._send_reply(data.encode("utf-8")) + + if not self.headers["Authorization"] == f"Bearer {FAKE_OAUTH_TOKEN}": + data = "Unexpected bearer token" + self._send_reply(data.encode("utf-8")) + + parts = urllib.parse.urlsplit(self.path) + path = parts[2] + operation = path.split(":")[1] + + if operation == "encrypt": + self._do_encrypt(raw_input) + elif operation == "decrypt": + self._do_decrypt(raw_input) + else: + self._send_reply(f"Unknown operation: {operation}".encode("utf-8")) + + def _do_encrypt(self, raw_input): + request = json.loads(raw_input) + + print(request) + + plaintext = request["plaintext"] + + ciphertext = SECRET_PREFIX.encode() + plaintext.encode() + ciphertext = base64.b64encode(ciphertext).decode() + + if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_ENCRYPT) \ + and not kms_http_common.disable_faults: + return self._do_encrypt_faults(ciphertext) + + response = { + "ciphertext": ciphertext, + "name": "mockEncryptResponse", + } + + self._send_reply(json.dumps(response).encode('utf-8')) + + def _do_encrypt_faults(self, raw_ciphertext): + kms_http_common.stats.fault_calls += 1 + + if kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT: + self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT_CORRECT_FORMAT: + response = { + "error": { + "code": 1337, + "message": "Error encrypting message", + "status": "Dummy Status", + } + } + self._send_reply(json.dumps(response).encode('utf-8')) + return + + raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type) + + def _do_decrypt(self, raw_input): + request = json.loads(raw_input) + blob = base64.b64decode(request["ciphertext"]).decode() + + print("FOUND SECRET: " + blob) + + # our "encrypted" values start with the word SECRET_PREFIX otherwise they did not come from us + if not blob.startswith(SECRET_PREFIX): + raise ValueError() + + blob = blob[len(SECRET_PREFIX):] + + if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_DECRYPT) \ + and not kms_http_common.disable_faults: + return self._do_decrypt_faults(blob) + + response = { + "plaintext": blob, + } + + self._send_reply(json.dumps(response).encode('utf-8')) + + def _do_decrypt_faults(self, blob): + kms_http_common.stats.fault_calls += 1 + + if kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT: + self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_WRONG_KEY: + response = { + "plaintext": "ta7DXE7J0OiCRw03dYMJSeb8nVF5qxTmZ9zWmjuX4zW/SOorSCaY8VMTWG+cRInMx/rr/+QeVw2WjU2IpOSvMg==", + } + self._send_reply(json.dumps(response).encode('utf-8')) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_CORRECT_FORMAT: + response = { + "error": { + "code": 9001, + "message": "Error decrypting message", + "status": "Dummy Status", + } + } + self._send_reply(json.dumps(response).encode('utf-8')) + return + + raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type) + + def _do_oauth_request(self): + clen = int(self.headers.get('content-length')) + + raw_input = self.rfile.read(clen) + + print(f"RAW INPUT: {str(raw_input)}") + + if not self.headers["Host"] == "localhost": + data = "Unexpected host" + self._send_reply(data.encode("utf-8")) + + if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_OAUTH) \ + and not kms_http_common.disable_faults: + return self._do_oauth_faults() + + response = { + "access_token": FAKE_OAUTH_TOKEN, + "scope": self.headers["Host"] + URL_PATH_OAUTH_SCOPE, + "token_type": "Bearer", + "expires_in": 3600, + } + + self._send_reply(json.dumps(response).encode("utf-8")) + + def _do_oauth_faults(self): + kms_http_common.stats.fault_calls += 1 + + if kms_http_common.fault_type == kms_http_common.FAULT_OAUTH: + self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR) + return + elif kms_http_common.fault_type == kms_http_common.FAULT_OAUTH_CORRECT_FORMAT: + response = { + "error": "GCP OAuth Error", + "error_description": "FAULT_OAUTH_CORRECT_FORMAT", + "error_uri": "https://mongodb.com/whoopsies.pdf", + } + self._send_reply(json.dumps(response).encode('utf-8')) + return + + raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type) + + +def main(): + """Main Method.""" + parser = argparse.ArgumentParser(description='MongoDB Mock GCP KMS Endpoint.') + + parser.add_argument('-p', '--port', type=int, default=8000, help="Port to listen on") + + parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing") + + parser.add_argument('--fault', type=str, help="Type of fault to inject") + + parser.add_argument('--disable-faults', action='store_true', help="Disable faults on startup") + + parser.add_argument('--ca_file', type=str, required=True, help="TLS CA PEM file") + + parser.add_argument('--cert_file', type=str, required=True, help="TLS Server PEM file") + + args = parser.parse_args() + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + + if args.fault: + if args.fault not in SUPPORTED_FAULT_TYPES: + print("Unsupported fault type %s, supports types are %s" % (args.fault, SUPPORTED_FAULT_TYPES)) + sys.exit(1) + + kms_http_common.fault_type = args.fault + + if args.disable_faults: + kms_http_common.disable_faults = True + + kms_http_common.run(args.port, args.cert_file, args.ca_file, GcpKmsHandler) + + +if __name__ == '__main__': + main() diff --git a/jstests/client_encrypt/lib/mock_kms.js b/jstests/client_encrypt/lib/mock_kms.js new file mode 100644 index 00000000000..d1fd64abc3d --- /dev/null +++ b/jstests/client_encrypt/lib/mock_kms.js @@ -0,0 +1,188 @@ +/** + * Starts a mock KMS Server to test + * FLE encryption and decryption. + */ + +// These faults must match the list of faults in kms_http_server.py, see the +// SUPPORTED_FAULT_TYPES list in kms_http_server.py +const FAULT_ENCRYPT = "fault_encrypt"; +const FAULT_ENCRYPT_CORRECT_FORMAT = "fault_encrypt_correct_format"; +const FAULT_ENCRYPT_WRONG_FIELDS = "fault_encrypt_wrong_fields"; +const FAULT_ENCRYPT_BAD_BASE64 = "fault_encrypt_bad_base64"; +const FAULT_DECRYPT = "fault_decrypt"; +const FAULT_DECRYPT_CORRECT_FORMAT = "fault_decrypt_correct_format"; +const FAULT_DECRYPT_WRONG_KEY = "fault_decrypt_wrong_key"; +const FAULT_OAUTH = "fault_oauth"; +const FAULT_OAUTH_CORRECT_FORMAT = "fault_oauth_correct_format"; + +const DISABLE_FAULTS = "disable_faults"; +const ENABLE_FAULTS = "enable_faults"; + +class MockKMSServerAWS { + /** + * Create a new webserver. + * + * @param {string} fault_type + * @param {bool} disableFaultsOnStartup optionally disable fault on startup + */ + constructor(fault_type, disableFaultsOnStartup) { + this.python = "python3"; + this.disableFaultsOnStartup = disableFaultsOnStartup || false; + this.fault_type = fault_type; + + if (_isWindows()) { + this.python = "python.exe"; + } + + print("Using python interpreter: " + this.python); + + this.ca_file = "jstests/libs/ca.pem"; + this.server_cert_file = "jstests/libs/server.pem"; + this.web_server_py = "jstests/client_encrypt/lib/kms_http_server.py"; + this.control_py = "jstests/client_encrypt/lib/kms_http_control.py"; + this.port = -1; + } + + /** + * Start a web server + */ + start() { + this.port = allocatePort(); + print("Mock Web server is listening on port: " + this.port); + + let args = [ + this.python, + "-u", + this.web_server_py, + "--port=" + this.port, + "--ca_file=" + this.ca_file, + "--cert_file=" + this.server_cert_file + ]; + if (this.fault_type) { + args.push("--fault=" + this.fault_type); + if (this.disableFaultsOnStartup) { + args.push("--disable-faults"); + } + } + + clearRawMongoProgramOutput(); + + this.pid = _startMongoProgram({args: args}); + assert(checkProgram(this.pid)); + + assert.soon(function() { + return rawMongoProgramOutput().search("Mock KMS Web Server Listening") !== -1; + }); + sleep(1000); + print("Mock KMS Server successfully started"); + } + + _runCommand(cmd) { + let ret = 0; + if (_isWindows()) { + ret = runProgram('cmd.exe', '/c', cmd); + } else { + ret = runProgram('/bin/sh', '-c', cmd); + } + + assert.eq(ret, 0); + } + + /** + * Query the HTTP server. + * + * @param {string} query type + * + * @return {object} Object representation of JSON from the server. + */ + query(query) { + const out_file = "out_" + this.port + ".txt"; + const python_command = this.python + " -u " + this.control_py + " --port=" + this.port + + " --ca_file=" + this.ca_file + " --query=" + query + " > " + out_file; + + this._runCommand(python_command); + + const result = cat(out_file); + + try { + return JSON.parse(result); + } catch (e) { + jsTestLog("Failed to parse: " + result + "\n" + result); + throw e; + } + } + + /** + * Control the HTTP server. + * + * @param {string} query type + */ + control(query) { + const python_command = this.python + " -u " + this.control_py + " --port=" + this.port + + " --ca_file=" + this.ca_file + " --query=" + query; + + this._runCommand(python_command); + } + + /** + * Disable Faults + */ + disableFaults() { + this.control(DISABLE_FAULTS); + } + + /** + * Enable Faults + */ + enableFaults() { + this.control(ENABLE_FAULTS); + } + + /** + * Query the stats page for the HTTP server. + * + * @return {object} Object representation of JSON from the server. + */ + queryStats() { + return this.query("stats"); + } + + /** + * Get the URL. Prefixed with https:// + * + * @return {string} url of http server + */ + getURL() { + return "https://localhost:" + this.port; + } + + /** + * Get the endpoint. A "<host>:<port>". + * + * @return {string} url of http server + */ + getEndpoint() { + return "localhost:" + this.port; + } + + /** + * Stop the web server + */ + stop() { + stopMongoProgramByPid(this.pid); + } +} + +class MockKMSServerGCP extends MockKMSServerAWS { + constructor(fault_type, disableFaultsOnStartup) { + super(fault_type, disableFaultsOnStartup); + this.web_server_py = "jstests/client_encrypt/lib/kms_http_server_gcp.py"; + } +} + +class MockKMSServerAzure extends MockKMSServerAWS { + constructor(fault_type, disableFaultsOnStartup) { + super(fault_type, disableFaultsOnStartup); + this.web_server_py = "jstests/client_encrypt/lib/kms_http_server_azure.py"; + } +} diff --git a/src/mongo/shell/SConscript b/src/mongo/shell/SConscript index 1bae6d96446..1179246d35c 100644 --- a/src/mongo/shell/SConscript +++ b/src/mongo/shell/SConscript @@ -148,34 +148,33 @@ env.Library( ], ) -env.Library( - target="kms_idl", - source=[ - "kms.idl", - ], - LIBDEPS_PRIVATE=[ - '$BUILD_DIR/mongo/base', - '$BUILD_DIR/mongo/idl/idl_parser', - ], -) - if get_option('ssl') == 'on': kmsEnv = env.Clone() + kmsEnv.InjectThirdParty(libraries=['kms-message']) + kmsEnv.Library( target="kms", source=[ "kms.cpp", + "kms_aws.cpp", + "kms_azure.cpp", + "kms_gcp.cpp", "kms_local.cpp", + "kms_network.cpp", + "kms.idl", ], LIBDEPS=[ '$BUILD_DIR/mongo/base/secure_allocator', - 'kms_idl', ], LIBDEPS_PRIVATE=[ '$BUILD_DIR/mongo/base', '$BUILD_DIR/mongo/crypto/aead_encryption', '$BUILD_DIR/mongo/db/commands/test_commands_enabled', + '$BUILD_DIR/mongo/util/net/network', + '$BUILD_DIR/mongo/util/net/ssl_manager', + '$BUILD_DIR/mongo/util/net/ssl_options', + '$BUILD_DIR/third_party/shim_kms_message', ], ) @@ -206,6 +205,7 @@ if get_option('ssl') == 'on': target="encrypted_dbclient", source=[ "encrypted_dbclient_base.cpp", + "fle_shell_options.idl", ], LIBDEPS_PRIVATE=[ '$BUILD_DIR/mongo/client/clientdriver_minimal', diff --git a/src/mongo/shell/encrypted_dbclient_base.cpp b/src/mongo/shell/encrypted_dbclient_base.cpp index 844c44fea36..db14d4f8fc2 100644 --- a/src/mongo/shell/encrypted_dbclient_base.cpp +++ b/src/mongo/shell/encrypted_dbclient_base.cpp @@ -58,6 +58,7 @@ #include "mongo/scripting/mozjs/objectwrapper.h" #include "mongo/scripting/mozjs/valuereader.h" #include "mongo/scripting/mozjs/valuewriter.h" +#include "mongo/shell/encrypted_shell_options.h" #include "mongo/shell/kms.h" #include "mongo/shell/kms_gen.h" #include "mongo/shell/shell_options.h" @@ -65,6 +66,8 @@ namespace mongo { +EncryptedShellGlobalParams encryptedShellGlobalParams; + namespace { constexpr Duration kCacheInvalidationTime = Minutes(1); @@ -804,7 +807,7 @@ std::unique_ptr<DBClientBase> createEncryptedDBClientBase(std::unique_ptr<DBClie static constexpr auto keyVaultClientFieldId = "keyVaultClient"; - if (!arg.isObject()) { + if (!arg.isObject() && encryptedShellGlobalParams.awsAccessKeyId.empty()) { return conn; } @@ -812,7 +815,32 @@ std::unique_ptr<DBClientBase> createEncryptedDBClientBase(std::unique_ptr<DBClie JS::RootedValue client(cx); JS::RootedValue collection(cx); - { + if (!arg.isObject()) { + // If arg is not an object, but one of the required encryptedShellGlobalParams + // is defined, the user is trying to start an encrypted client with command line + // parameters. + + AwsKMS awsKms = AwsKMS(encryptedShellGlobalParams.awsAccessKeyId, + encryptedShellGlobalParams.awsSecretAccessKey); + + awsKms.setUrl(StringData(encryptedShellGlobalParams.awsKmsURL)); + + awsKms.setSessionToken(StringData(encryptedShellGlobalParams.awsSessionToken)); + + KmsProviders kmsProviders; + kmsProviders.setAws(awsKms); + + // The mongoConnection object will never be null. + // If the encrypted shell is started through command line parameters, then the user must + // default to the implicit connection for the keyvault collection. + client.setObjectOrNull(mongoConnection.get()); + + // Because we cannot add a schemaMap object through the command line, we set the + // schemaMap object in ClientSideFLEOptions to be null so we know to always use + // remote schemas. + encryptionOptions = ClientSideFLEOptions(encryptedShellGlobalParams.keyVaultNamespace, + std::move(kmsProviders)); + } else { uassert(ErrorCodes::BadValue, "Collection object must be passed to Field Level Encryption Options", arg.isObject()); diff --git a/src/mongo/shell/encrypted_dbclient_base.h b/src/mongo/shell/encrypted_dbclient_base.h index 4f00b5b0e4b..76a24afc10f 100644 --- a/src/mongo/shell/encrypted_dbclient_base.h +++ b/src/mongo/shell/encrypted_dbclient_base.h @@ -51,6 +51,7 @@ #include "mongo/scripting/mozjs/objectwrapper.h" #include "mongo/scripting/mozjs/valuereader.h" #include "mongo/scripting/mozjs/valuewriter.h" +#include "mongo/shell/encrypted_shell_options.h" #include "mongo/shell/kms.h" #include "mongo/shell/kms_gen.h" #include "mongo/shell/shell_options.h" diff --git a/src/mongo/shell/encrypted_shell_options.h b/src/mongo/shell/encrypted_shell_options.h new file mode 100644 index 00000000000..b4b30aba2fe --- /dev/null +++ b/src/mongo/shell/encrypted_shell_options.h @@ -0,0 +1,45 @@ +/** + * Copyright (C) 2019-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <string> + +namespace mongo { + +struct EncryptedShellGlobalParams { + std::string awsAccessKeyId; + std::string awsSecretAccessKey; + std::string awsSessionToken; + std::string keyVaultNamespace; + std::string awsKmsURL; +}; + +extern EncryptedShellGlobalParams encryptedShellGlobalParams; +} // namespace mongo diff --git a/src/mongo/shell/fle_shell_options.idl b/src/mongo/shell/fle_shell_options.idl new file mode 100644 index 00000000000..64578aaa69b --- /dev/null +++ b/src/mongo/shell/fle_shell_options.idl @@ -0,0 +1,37 @@ +# Copyright (C) 2019-present MongoDB, Inc. + +global: + cpp_namespace: "mongo" + configs: + section: 'FLE AWS Options' + source: [ cli ] + cpp_includes: + - mongo/shell/encrypted_shell_options.h + +configs: + "awsAccessKeyId": + description: "AWS Access Key for FLE Amazon KMS" + arg_vartype: String + cpp_varname: encryptedShellGlobalParams.awsAccessKeyId + requires: [ "awsSecretAccessKey", "keyVaultNamespace" ] + "awsSecretAccessKey": + description: "AWS Secret Key for FLE Amazon KMS" + arg_vartype: String + cpp_varname: encryptedShellGlobalParams.awsSecretAccessKey + redact: true + requires: [ "awsAccessKeyId", "keyVaultNamespace" ] + "awsSessionToken": + description: "Optional AWS Session Token ID" + arg_vartype: String + cpp_varname: encryptedShellGlobalParams.awsSessionToken + requires: [ "awsAccessKeyId", "awsSecretAccessKey", "keyVaultNamespace" ] + "keyVaultNamespace": + description: "database.collection to store encrypted FLE parameters" + arg_vartype: String + cpp_varname: encryptedShellGlobalParams.keyVaultNamespace + requires: [ "awsAccessKeyId", "awsSecretAccessKey" ] + "kmsURL": + description: "Test parameter to override the URL for KMS" + arg_vartype: String + cpp_varname: encryptedShellGlobalParams.awsKmsURL + requires: [ "awsAccessKeyId", "awsSecretAccessKey", "keyVaultNamespace" ] diff --git a/src/mongo/shell/kms.cpp b/src/mongo/shell/kms.cpp index f379c5f43e2..792be3ab1f3 100644 --- a/src/mongo/shell/kms.cpp +++ b/src/mongo/shell/kms.cpp @@ -37,6 +37,20 @@ namespace mongo { +HostAndPort parseUrl(StringData url) { + // Treat the URL as a host and port + // URL: https://(host):(port) + // + constexpr StringData urlPrefix = "https://"_sd; + uassert(51140, + str::stream() << "KMS URL must start with https://, URL: " << url, + url.startsWith(urlPrefix)); + + StringData hostAndPort = url.substr(urlPrefix.size()); + + return HostAndPort(hostAndPort); +} + BSONObj KMSService::encryptDataKeyByString(ConstDataRange cdr, StringData keyId) { uasserted(5380101, str::stream() << "Customer Master Keys for " << name() diff --git a/src/mongo/shell/kms.h b/src/mongo/shell/kms.h index a67b82db283..0560100feea 100644 --- a/src/mongo/shell/kms.h +++ b/src/mongo/shell/kms.h @@ -37,7 +37,6 @@ #include "mongo/base/secure_allocator.h" #include "mongo/base/string_data.h" #include "mongo/bson/bsonobj.h" -#include "mongo/crypto/symmetric_key.h" #include "mongo/shell/kms_gen.h" #include "mongo/stdx/unordered_map.h" #include "mongo/util/net/hostandport.h" @@ -85,8 +84,6 @@ public: * needs to be store in the key vault. */ virtual BSONObj encryptDataKeyByBSONObj(ConstDataRange cdr, BSONObj keyId); - - virtual SymmetricKey& getMasterKey() = 0; }; /** @@ -134,4 +131,11 @@ private: static stdx::unordered_map<KMSProviderEnum, std::unique_ptr<KMSServiceFactory>> _factories; }; +/** + * Parse a basic url of "https://host:port" to a HostAndPort. + * + * Does not support URL encoding or anything else. + */ +HostAndPort parseUrl(StringData url); + } // namespace mongo diff --git a/src/mongo/shell/kms.idl b/src/mongo/shell/kms.idl index 80875714545..5124a22c26f 100644 --- a/src/mongo/shell/kms.idl +++ b/src/mongo/shell/kms.idl @@ -37,9 +37,71 @@ enums: description: "Enumeration of supported KMS Providers" type: string values: + aws: "aws" + azure: "azure" + gcp: "gcp" local: "local" structs: + awsKMSError: + description: "AWS KMS error" + strict: false + fields: + __type: + type: string + cpp_name: type + Message: string + + # Options passed to Mongo() javascript constructor + awsKMS: + description: "AWS KMS config" + fields: + accessKeyId: string + secretAccessKey: string + sessionToken: + type: string + optional: true + url: + type: string + optional: true + + azureKMSError: + description: "Azure KMS Error" + strict: false + fields: + code: string + message: string + + # Options passed to Mongo() javascript constructor + azureKMS: + description: "Azure KMS config" + fields: + tenantId: string + clientId: string + clientSecret: string + identityPlatformEndpoint: + type: string + optional: true + + # Documented here: https://cloud.google.com/apis/design/errors#http_mapping + gcpKMSError: + description: "GCP KMS Error" + strict: false + fields: + code: int + message: string + status: string + + # Options passed to Mongo() javascript constructor + gcpKMS: + description: "GCP KMS config" + fields: + email: string + endpoint: + type: string + optional: true + privateKey: string + # Options passed to Mongo() javascript constructor localKMS: description: "Local KMS config" @@ -50,6 +112,15 @@ structs: description: "Supported KMS Providers" strict: true fields: + aws: + type: awsKMS + optional: true + azure: + type: azureKMS + optional: true + gcp: + type: gcpKMS + optional: true local: type: localKMS optional: true @@ -70,6 +141,121 @@ structs: type: bool optional: true + awsEncryptResponse: + description: "Response from AWS KMS Encrypt request, i.e. TrentService.Encrypt" + strict: false + fields: + CiphertextBlob: + type: string + KeyId: + type: string + + awsDecryptResponse: + description: "Response from AWS KMS Decrypt request, i.e. TrentService.Decrypt" + # Nov 13, 2019 they added EncryptionAlgorithm but it is not documented + strict: false + fields: + Plaintext: + type: string + KeyId: + type: string + + awsMasterKey: + description: "AWS KMS Key Store Description" + fields: + provider: + type: string + default: '"aws"' + key: + type: string + region: + type: string + endpoint: + type: string + optional: true + + awsMasterKeyAndMaterial: + description: "AWS KMS Key Material Description" + fields: + keyMaterial: + type: bindata_generic + masterKey: + type: awsMasterKey + + azureEncryptResponse: + description: "Response from Azure KMS wrapKey request" + strict: false + fields: + kid: string + value: string + + azureDecryptResponse: + description: "Response from Azure KMS unwrapKey request" + strict: false + fields: + kid: string + value: string + + azureMasterKey: + description: "Azure KMS Key Store Description" + fields: + provider: + type: string + default: '"azure"' + keyName: string + keyVersion: + type: string + optional: true + keyVaultEndpoint: string + + azureMasterKeyAndMaterial: + description: "Azure KMS Key Material Description" + fields: + keyMaterial: bindata_generic + masterKey: azureMasterKey + + gcpEncryptResponse: + description: "Response from GCP KMS Encrypt request" + strict: false + fields: + name: string + ciphertext: string + + gcpDecryptResponse: + description: "Response from GCP KMS Decrypt request" + strict: false + fields: + plaintext: string + + gcpMasterKey: + description: "GCP KMS Key Store Description" + fields: + provider: + type: string + default: '"gcp"' + keyName: + type: string + keyRing: + type: string + keyVersion: + type: string + optional: true + location: + type: string + projectId: + type: string + endpoint: + type: string + optional: true + + gcpMasterKeyAndMaterial: + description: "GCP KMS Key Material Description" + fields: + keyMaterial: + type: bindata_generic + masterKey: + type: gcpMasterKey + localMasterKey: description: "Local KMS Key Store Description" fields: @@ -101,4 +287,31 @@ structs: type: array<string> ignore: true + # Defined in 4.2.2. in RFC 6749 + OAuthResponse: + description: "An oauth response with a token" + strict: false + fields: + access_token: string + token_type: string + # Expires_in is in seconds + expires_in: + type: int + optional: true + scope: + type: string + optional: true + + # Defined in 4.2.2.1. in RFC 6749 + OAuthErrorResponse: + description: "An oauth response with a token" + strict: false + fields: + error: string + error_description: + type: string + optional: true + error_uri: + type: string + optional: true diff --git a/src/mongo/shell/kms_aws.cpp b/src/mongo/shell/kms_aws.cpp new file mode 100644 index 00000000000..5d81566348a --- /dev/null +++ b/src/mongo/shell/kms_aws.cpp @@ -0,0 +1,322 @@ +/** + * Copyright (C) 2019-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + + +#include <kms_message/kms_message.h> + +#include <stdlib.h> + +#include "mongo/base/init.h" +#include "mongo/base/parse_number.h" +#include "mongo/base/secure_allocator.h" +#include "mongo/base/status_with.h" +#include "mongo/bson/json.h" +#include "mongo/shell/kms.h" +#include "mongo/shell/kms_gen.h" +#include "mongo/shell/kms_network.h" +#include "mongo/util/base64.h" +#include "mongo/util/kms_message_support.h" +#include "mongo/util/net/hostandport.h" +#include "mongo/util/net/ssl_manager.h" +#include "mongo/util/net/ssl_options.h" +#include "mongo/util/text.h" +#include "mongo/util/time_support.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kControl + + +namespace mongo { +namespace { + +/** + * AWS configuration settings + */ +struct AWSConfig { + // AWS_ACCESS_KEY_ID + std::string accessKeyId; + + // AWS_SECRET_ACCESS_KEY + SecureString secretAccessKey; + + // Optional AWS_SESSION_TOKEN for AWS STS tokens + boost::optional<std::string> sessionToken; +}; + +constexpr auto kAwsKms = "aws"_sd; + +/** + * Manages SSL information and config for how to talk to AWS KMS. + */ +class AWSKMSService final : public KMSService { +public: + AWSKMSService() = default; + + StringData name() const { + return kAwsKms; + } + + static std::unique_ptr<KMSService> create(const AwsKMS& config); + + SecureVector<uint8_t> decrypt(ConstDataRange cdr, BSONObj masterKey) final; + + BSONObj encryptDataKeyByString(ConstDataRange cdr, StringData keyId) final; + +private: + std::vector<uint8_t> encrypt(ConstDataRange cdr, StringData kmsKeyId); + + void initRequest(kms_request_t* request, StringData host, StringData region); + +private: + // SSL Manager + std::shared_ptr<SSLManagerInterface> _sslManager; + + // Server to connect to + HostAndPort _server; + + // AWS configuration settings + AWSConfig _config; +}; + +void uassertKmsRequestInternal(kms_request_t* request, bool ok) { + if (!ok) { + const char* msg = kms_request_get_error(request); + uasserted(51135, str::stream() << "Internal AWS KMS Error: " << msg); + } +} + +#define uassertKmsRequest(X) uassertKmsRequestInternal(request, (X)); + +void AWSKMSService::initRequest(kms_request_t* request, StringData host, StringData region) { + + // use current time + uassertKmsRequest(kms_request_set_date(request, nullptr)); + + uassertKmsRequest(kms_request_set_region(request, region.toString().c_str())); + + // kms is always the name of the service + uassertKmsRequest(kms_request_set_service(request, "kms")); + + uassertKmsRequest(kms_request_set_access_key_id(request, _config.accessKeyId.c_str())); + uassertKmsRequest(kms_request_set_secret_key(request, _config.secretAccessKey->c_str())); + + // Set host to be the host we are targeting instead of defaulting to kms.<region>.amazonaws.com + uassertKmsRequest(kms_request_add_header_field(request, "Host", host.toString().c_str())); + + if (!_config.sessionToken.value_or("").empty()) { + // TODO: move this into kms-message + uassertKmsRequest(kms_request_add_header_field( + request, "X-Amz-Security-Token", _config.sessionToken.value().c_str())); + } +} + + +/** + * Takes in a CMK of the format arn:partition:service:region:account-id:resource (minimum). We + * care about extracting the region. This function ensures that there are at least 6 partitions, + * parses the provider, and returns a pair of provider and the region. + */ +std::string parseCMK(StringData cmk) { + std::vector<std::string> cmkTokenized = StringSplitter::split(cmk.toString(), ":"); + uassert(31040, "Invalid AWS KMS Customer Master Key.", cmkTokenized.size() > 5); + return cmkTokenized[3]; +} + +HostAndPort getDefaultHost(StringData region) { + std::string hostname = str::stream() << "kms." << region << ".amazonaws.com"; + return HostAndPort(hostname, 443); +} + +std::vector<uint8_t> AWSKMSService::encrypt(ConstDataRange cdr, StringData kmsKeyId) { + auto request = + UniqueKmsRequest(kms_encrypt_request_new(reinterpret_cast<const uint8_t*>(cdr.data()), + cdr.length(), + kmsKeyId.toString().c_str(), + NULL)); + + auto region = parseCMK(kmsKeyId); + + if (_server.empty()) { + _server = getDefaultHost(region); + } + + initRequest(request.get(), _server.host(), region); + + auto buffer = UniqueKmsCharBuffer(kms_request_get_signed(request.get())); + auto buffer_len = strlen(buffer.get()); + + KMSNetworkConnection connection(_sslManager.get()); + auto response = connection.makeOneRequest(_server, ConstDataRange(buffer.get(), buffer_len)); + + auto body = kms_response_get_body(response.get(), nullptr); + + BSONObj obj = fromjson(body); + + auto field = obj["__type"]; + + if (!field.eoo()) { + AwsKMSError awsResponse; + try { + awsResponse = AwsKMSError::parse(IDLParserContext("awsEncryptError"), obj); + } catch (DBException& dbe) { + uasserted(51274, + str::stream() << "AWS KMS failed to parse error message: " << dbe.toString() + << ", Response : " << obj); + } + + uasserted(51224, + str::stream() << "AWS KMS failed to encrypt: " << awsResponse.getType() << " : " + << awsResponse.getMessage()); + } + + auto awsResponse = AwsEncryptResponse::parse(IDLParserContext("awsEncryptResponse"), obj); + + auto blobStr = base64::decode(awsResponse.getCiphertextBlob().toString()); + + return kmsResponseToVector(blobStr); +} + +BSONObj AWSKMSService::encryptDataKeyByString(ConstDataRange cdr, StringData keyId) { + auto dataKey = encrypt(cdr, keyId); + + AwsMasterKey masterKey; + masterKey.setKey(keyId); + masterKey.setRegion(parseCMK(keyId)); + masterKey.setEndpoint(boost::optional<StringData>(_server.toString())); + + AwsMasterKeyAndMaterial keyAndMaterial; + keyAndMaterial.setKeyMaterial(dataKey); + keyAndMaterial.setMasterKey(masterKey); + + + return keyAndMaterial.toBSON(); +} + +SecureVector<uint8_t> AWSKMSService::decrypt(ConstDataRange cdr, BSONObj masterKey) { + auto awsMasterKey = AwsMasterKey::parse(IDLParserContext("awsMasterKey"), masterKey); + + auto request = UniqueKmsRequest(kms_decrypt_request_new( + reinterpret_cast<const uint8_t*>(cdr.data()), cdr.length(), nullptr)); + + if (_server.empty()) { + _server = getDefaultHost(awsMasterKey.getRegion()); + } + + initRequest(request.get(), _server.host(), awsMasterKey.getRegion()); + + auto buffer = UniqueKmsCharBuffer(kms_request_get_signed(request.get())); + auto buffer_len = strlen(buffer.get()); + KMSNetworkConnection connection(_sslManager.get()); + auto response = connection.makeOneRequest(_server, ConstDataRange(buffer.get(), buffer_len)); + + auto body = kms_response_get_body(response.get(), nullptr); + + BSONObj obj = fromjson(body); + + auto field = obj["__type"]; + + if (!field.eoo()) { + AwsKMSError awsResponse; + try { + awsResponse = AwsKMSError::parse(IDLParserContext("awsDecryptError"), obj); + } catch (DBException& dbe) { + uasserted(51275, + str::stream() << "AWS KMS failed to parse error message: " << dbe.toString() + << ", Response : " << obj); + } + + uasserted(51225, + str::stream() << "AWS KMS failed to decrypt: " << awsResponse.getType() << " : " + << awsResponse.getMessage()); + } + + auto awsResponse = AwsDecryptResponse::parse(IDLParserContext("awsDecryptResponse"), obj); + + auto blobStr = base64::decode(awsResponse.getPlaintext().toString()); + + return kmsResponseToSecureVector(blobStr); +} + +boost::optional<std::string> toString(boost::optional<StringData> str) { + if (str) { + return {str.value().toString()}; + } + return boost::none; +} + +std::unique_ptr<KMSService> AWSKMSService::create(const AwsKMS& config) { + auto awsKMS = std::make_unique<AWSKMSService>(); + + SSLParams params; + getSSLParamsForNetworkKMS(¶ms); + + // Leave the CA file empty so we default to system CA but for local testing allow it to inherit + // the CA file. + if (!config.getUrl().value_or("").empty()) { + params.sslCAFile = sslGlobalParams.sslCAFile; + awsKMS->_server = parseUrl(config.getUrl().value()); + } + + awsKMS->_sslManager = SSLManagerInterface::create(params, false); + + awsKMS->_config.accessKeyId = config.getAccessKeyId().toString(); + + awsKMS->_config.secretAccessKey = config.getSecretAccessKey().toString(); + + awsKMS->_config.sessionToken = toString(config.getSessionToken()); + + return awsKMS; +} + +/** + * Factory for AWSKMSService if user specifies aws config to mongo() JS constructor. + */ +class AWSKMSServiceFactory final : public KMSServiceFactory { +public: + AWSKMSServiceFactory() = default; + ~AWSKMSServiceFactory() = default; + + std::unique_ptr<KMSService> create(const BSONObj& config) final { + auto field = config[KmsProviders::kAwsFieldName]; + if (field.eoo()) { + return nullptr; + } + auto obj = field.Obj(); + return AWSKMSService::create(AwsKMS::parse(IDLParserContext("root"), obj)); + } +}; + +} // namespace + +MONGO_INITIALIZER(KMSRegisterAWS)(::mongo::InitializerContext*) { + kms_message_init(); + KMSServiceController::registerFactory(KMSProviderEnum::aws, + std::make_unique<AWSKMSServiceFactory>()); +} + +} // namespace mongo diff --git a/src/mongo/shell/kms_azure.cpp b/src/mongo/shell/kms_azure.cpp new file mode 100644 index 00000000000..459e1f812ae --- /dev/null +++ b/src/mongo/shell/kms_azure.cpp @@ -0,0 +1,301 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "kms_message/kms_request.h" +#include "mongo/shell/kms_gen.h" + +#include "mongo/platform/basic.h" + +#include <fmt/format.h> +#include <kms_message/kms_azure_request.h> +#include <kms_message/kms_b64.h> +#include <kms_message/kms_message.h> + +#include "mongo/bson/json.h" +#include "mongo/shell/kms.h" +#include "mongo/shell/kms_network.h" +#include "mongo/util/net/hostandport.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::logv2::LogComponent::kControl + + +namespace mongo { +namespace { + +using namespace fmt::literals; + +constexpr auto kAzureKms = "azure"_sd; + +// Default endpoints for Azure +constexpr auto kDefaultIdentityPlatformEndpoint = "login.microsoftonline.com"_sd; +// Since scope is passed as URL parameter, it needs to be escaped and kms_message does not escape +// it. +constexpr auto kDefaultOAuthScope = "https%3A%2F%2Fvault.azure.net%2F.default"_sd; + +struct AzureConfig { + // ID for the user in Azure + std::string tenantId; + + // ID for the application in Azure + std::string clientId; + + // Secret key for the application in Azure + std::string clientSecret; + + // Options to pass to kms-message + UniqueKmsRequestOpts opts; +}; + +/** + * Manages OAuth token requests and caching + */ +class AzureKMSOAuthService final : public KMSOAuthService { +public: + AzureKMSOAuthService(const AzureConfig& config, + HostAndPort endpoint, + std::shared_ptr<SSLManagerInterface> sslManager) + : KMSOAuthService(endpoint, sslManager), _config(config) {} + +protected: + UniqueKmsRequest getOAuthRequest() final { + auto request = + UniqueKmsRequest(kms_azure_request_oauth_new(_oAuthEndpoint.host().c_str(), + kDefaultOAuthScope.toString().c_str(), + _config.tenantId.c_str(), + _config.clientId.c_str(), + _config.clientSecret.c_str(), + _config.opts.get())); + + const char* msg = kms_request_get_error(request.get()); + uassert(5265101, "Internal Azure KMS Error: {}"_format(msg), msg == nullptr); + + return request; + } + +private: + const AzureConfig& _config; +}; + +/** + * Manages SSL information and config for how to talk to Azure KMS. + */ +class AzureKMSService final : public KMSService { +public: + AzureKMSService() = default; + + static std::unique_ptr<KMSService> create(const AzureKMS& config); + + SecureVector<uint8_t> decrypt(ConstDataRange cdr, BSONObj masterKey) final; + + BSONObj encryptDataKeyByBSONObj(ConstDataRange cdr, BSONObj keyId) final; + + StringData name() const final { + return kAzureKms; + } + +private: + template <typename AzureResponseT> + std::unique_ptr<uint8_t, decltype(std::free)*> makeRequest(kms_request_t* request, + const HostAndPort& keyVaultEndpoint, + size_t* raw_len); + +private: + // SSL Manager + std::shared_ptr<SSLManagerInterface> _sslManager; + + // Azure configuration settings + AzureConfig _config; + + // Service fr managing OAuth requests and token cache + std::unique_ptr<AzureKMSOAuthService> _oauthService; +}; + +std::unique_ptr<KMSService> AzureKMSService::create(const AzureKMS& config) { + auto azureKMS = std::make_unique<AzureKMSService>(); + + SSLParams params; + getSSLParamsForNetworkKMS(¶ms); + + HostAndPort identityPlatformHostAndPort(kDefaultIdentityPlatformEndpoint.toString(), 443); + if (config.getIdentityPlatformEndpoint().has_value()) { + // Leave the CA file empty so we default to system CA but for local testing allow it to + // inherit the CA file. + params.sslCAFile = sslGlobalParams.sslCAFile; + identityPlatformHostAndPort = parseUrl(config.getIdentityPlatformEndpoint().value()); + } + + azureKMS->_sslManager = SSLManagerInterface::create(params, false); + + azureKMS->_config.opts = UniqueKmsRequestOpts(kms_request_opt_new()); + kms_request_opt_set_provider(azureKMS->_config.opts.get(), KMS_REQUEST_PROVIDER_AZURE); + + azureKMS->_config.clientSecret = config.getClientSecret().toString(); + + azureKMS->_config.clientId = config.getClientId().toString(); + + azureKMS->_config.tenantId = config.getTenantId().toString(); + + azureKMS->_oauthService = std::make_unique<AzureKMSOAuthService>( + azureKMS->_config, identityPlatformHostAndPort, azureKMS->_sslManager); + + return azureKMS; +} + +HostAndPort parseEndpoint(StringData endpoint) { + HostAndPort host(endpoint); + + if (host.hasPort()) { + return host; + } + + return {host.host(), 443}; +} + +template <typename AzureResponseT> +std::unique_ptr<uint8_t, decltype(std::free)*> AzureKMSService::makeRequest( + kms_request_t* request, const HostAndPort& keyVaultEndpoint, size_t* raw_len) { + auto buffer = UniqueKmsCharBuffer(kms_request_to_string(request)); + auto buffer_len = strlen(buffer.get()); + KMSNetworkConnection connection(_sslManager.get()); + auto response = + connection.makeOneRequest(keyVaultEndpoint, ConstDataRange(buffer.get(), buffer_len)); + + auto body = kms_response_get_body(response.get(), nullptr); + + BSONObj obj = fromjson(body); + + if (obj.hasField("error")) { + AzureKMSError azureResponse; + try { + azureResponse = + AzureKMSError::parse(IDLParserContext("azureError"), obj["error"].Obj()); + } catch (DBException& dbe) { + uasserted(5265102, + "Azure KMS failed to parse error message: {}, Response : {}"_format( + dbe.toString(), obj.toString())); + } + + uasserted(5265103, + "Azure KMS failed, response: {} : {}"_format(azureResponse.getCode(), + azureResponse.getMessage())); + } + + auto azureResponse = AzureResponseT::parse(IDLParserContext("azureResponse"), obj); + + auto b64Url = azureResponse.getValue().toString(); + std::unique_ptr<uint8_t, decltype(std::free)*> raw_str( + kms_message_b64url_to_raw(b64Url.c_str(), raw_len), std::free); + uassert(5265104, "Azure KMS failed to convert key blob from base64 URL.", raw_str != nullptr); + + return raw_str; +} + + +SecureVector<uint8_t> AzureKMSService::decrypt(ConstDataRange cdr, BSONObj masterKey) { + auto azureMasterKey = AzureMasterKey::parse(IDLParserContext("azureMasterKey"), masterKey); + StringData bearerToken = _oauthService->getBearerToken(); + + HostAndPort keyVaultEndpoint = parseEndpoint(azureMasterKey.getKeyVaultEndpoint()); + + auto request = UniqueKmsRequest(kms_azure_request_unwrapkey_new( + keyVaultEndpoint.host().c_str(), + bearerToken.toString().c_str(), + azureMasterKey.getKeyName().toString().c_str(), + azureMasterKey.getKeyVersion().value_or(""_sd).toString().c_str(), + reinterpret_cast<const uint8_t*>(cdr.data()), + cdr.length(), + _config.opts.get())); + + size_t raw_len; + auto raw_str = makeRequest<AzureDecryptResponse>(request.get(), keyVaultEndpoint, &raw_len); + + return kmsResponseToSecureVector( + StringData(reinterpret_cast<const char*>(raw_str.get()), raw_len)); +} + +BSONObj AzureKMSService::encryptDataKeyByBSONObj(ConstDataRange cdr, BSONObj keyId) { + StringData bearerToken = _oauthService->getBearerToken(); + AzureMasterKey masterKey = AzureMasterKey::parse(IDLParserContext("azureMasterKey"), keyId); + + HostAndPort keyVaultEndpoint = parseEndpoint(masterKey.getKeyVaultEndpoint()); + + auto request = UniqueKmsRequest( + kms_azure_request_wrapkey_new(keyVaultEndpoint.host().c_str(), + bearerToken.toString().c_str(), + masterKey.getKeyName().toString().c_str(), + masterKey.getKeyVersion().value_or(""_sd).toString().c_str(), + reinterpret_cast<const uint8_t*>(cdr.data()), + cdr.length(), + _config.opts.get())); + + size_t raw_len; + auto raw_str = makeRequest<AzureDecryptResponse>(request.get(), keyVaultEndpoint, &raw_len); + + auto dataKey = + kmsResponseToVector(StringData(reinterpret_cast<const char*>(raw_str.get()), raw_len)); + + AzureMasterKeyAndMaterial keyAndMaterial; + keyAndMaterial.setKeyMaterial(std::move(dataKey)); + keyAndMaterial.setMasterKey(std::move(masterKey)); + + return keyAndMaterial.toBSON(); +} + +/** + * Factory for AzureKMSService if user specifies azure config to mongo() JS constructor. + */ +class AzureKMSServiceFactory final : public KMSServiceFactory { +public: + AzureKMSServiceFactory() = default; + ~AzureKMSServiceFactory() = default; + + std::unique_ptr<KMSService> create(const BSONObj& config) final { + auto field = config[KmsProviders::kAzureFieldName]; + if (field.eoo()) { + return nullptr; + } + + uassert(5265106, + "Misconfigured Azure KMS Config: {}"_format(field.toString()), + field.type() == BSONType::Object); + + auto obj = field.Obj(); + return AzureKMSService::create(AzureKMS::parse(IDLParserContext("root"), obj)); + } +}; + +} // namespace + +MONGO_INITIALIZER(KMSRegisterAzure)(::mongo::InitializerContext*) { + kms_message_init(); + KMSServiceController::registerFactory(KMSProviderEnum::azure, + std::make_unique<AzureKMSServiceFactory>()); +} + +} // namespace mongo diff --git a/src/mongo/shell/kms_gcp.cpp b/src/mongo/shell/kms_gcp.cpp new file mode 100644 index 00000000000..81ec2737f7e --- /dev/null +++ b/src/mongo/shell/kms_gcp.cpp @@ -0,0 +1,339 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + + +#include "mongo/platform/basic.h" + +#include <fmt/format.h> +#include <kms_message/kms_gcp_request.h> +#include <kms_message/kms_message.h> + +#include "mongo/bson/json.h" +#include "mongo/shell/kms.h" +#include "mongo/shell/kms_network.h" +#include "mongo/util/net/ssl_manager.h" +#include "mongo/util/net/ssl_options.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kControl + + +namespace mongo { +namespace { + +using namespace fmt::literals; + +constexpr StringData kGcpKms = "gcp"_sd; + +// Default endpoints for GCP +constexpr StringData kDefaultOauthEndpoint = "oauth2.googleapis.com"_sd; +constexpr StringData kDefaultOauthScope = "https://www.googleapis.com/auth/cloudkms"_sd; +constexpr StringData kGcpKMSEndpoint = "https://cloudkms.googleapis.com:443"_sd; + +// Field names for BSON objects containing key vault information +constexpr StringData kProjectIdField = "projectId"_sd; +constexpr StringData kLocationIdField = "location"_sd; +constexpr StringData kKeyRingField = "keyRing"_sd; +constexpr StringData kKeyNameField = "keyName"_sd; +constexpr StringData kKeyVerisionField = "keyVersion"_sd; + +/** + * GCP configuration settings + */ +struct GCPConfig { + // E-mail address that will be used for GCP OAuth requests. + std::string email; + + // PKCS#8 private key + std::string privateKey; + + // Options to pass to GCP KMS requests + UniqueKmsRequestOpts opts; +}; + +void uassertKmsRequestInternal(kms_request_t* request, bool ok) { + if (!ok) { + const char* msg = kms_request_get_error(request); + uasserted(5265000, str::stream() << "Internal GCP KMS Error: " << msg); + } +} + +#define uassertKmsRequest(X) uassertKmsRequestInternal(request, (X)) + +/** + * Manages OAuth token requests and caching + */ +class GCPKMSOAuthService final : public KMSOAuthService { +public: + GCPKMSOAuthService(const GCPConfig& config, + HostAndPort endpoint, + std::shared_ptr<SSLManagerInterface> sslManager) + : KMSOAuthService(endpoint, sslManager), _config(config) {} + +protected: + UniqueKmsRequest getOAuthRequest() { + std::string audience = "https://{}/token"_format(_oAuthEndpoint.host()); + std::string scope; + if (_oAuthEndpoint.host() != kDefaultOauthEndpoint.toString()) { + scope = "https://www.{}/auth/cloudkms"_format(_oAuthEndpoint.toString()); + } else { + scope = kDefaultOauthScope.toString(); + } + uassert(5365009, + str::stream() << "Internal GCP KMS Error: Private key not encoded in base64.", + base64::validate(_config.privateKey)); + std::string privateKeyDecoded = base64::decode(_config.privateKey); + + auto request = UniqueKmsRequest(kms_gcp_request_oauth_new(_oAuthEndpoint.host().c_str(), + _config.email.c_str(), + audience.c_str(), + scope.c_str(), + privateKeyDecoded.c_str(), + privateKeyDecoded.size(), + _config.opts.get())); + + const char* msg = kms_request_get_error(request.get()); + uassert(5265003, str::stream() << "Internal GCP KMS Error: " << msg, msg == nullptr); + + return request; + } + +private: + const GCPConfig& _config; +}; + +/** + * Manages SSL information and config for how to talk to GCP KMS. + */ +class GCPKMSService final : public KMSService { +public: + GCPKMSService() = default; + + static std::unique_ptr<KMSService> create(const GcpKMS&); + + + SecureVector<uint8_t> decrypt(ConstDataRange cdr, BSONObj masterKey) final; + + BSONObj encryptDataKeyByBSONObj(ConstDataRange cdr, BSONObj keyId) final; + + void configureOauthService(HostAndPort endpoint); + + StringData name() const { + return kGcpKms; + } + +private: + std::vector<uint8_t> encrypt(ConstDataRange cdr, const BSONObj& kmsKeyId); + +private: + // SSL Manager + std::shared_ptr<SSLManagerInterface> _sslManager; + + // Server to connect to + HostAndPort _server; + + // GCP configuration settings + GCPConfig _config; + + // Service for managing oauth requests and token cache + std::unique_ptr<GCPKMSOAuthService> _oauthService; +}; + +std::vector<uint8_t> GCPKMSService::encrypt(ConstDataRange cdr, const BSONObj& kmsKeyId) { + StringData bearerToken = _oauthService->getBearerToken(); + GcpMasterKey masterKey = GcpMasterKey::parse(IDLParserContext("gcpMasterKey"), kmsKeyId); + + auto request = UniqueKmsRequest(kms_gcp_request_encrypt_new( + _server.host().c_str(), + bearerToken.toString().c_str(), + masterKey.getProjectId().toString().c_str(), + masterKey.getLocation().toString().c_str(), + masterKey.getKeyRing().toString().c_str(), + masterKey.getKeyName().toString().c_str(), + masterKey.getKeyVersion().has_value() ? masterKey.getKeyVersion().value().toString().c_str() + : nullptr, + reinterpret_cast<const uint8_t*>(cdr.data()), + cdr.length(), + _config.opts.get())); + + auto buffer = UniqueKmsCharBuffer(kms_request_to_string(request.get())); + auto buffer_len = strlen(buffer.get()); + + KMSNetworkConnection connection(_sslManager.get()); + auto response = connection.makeOneRequest(_server, ConstDataRange(buffer.get(), buffer_len)); + + auto body = kms_response_get_body(response.get(), nullptr); + + BSONObj obj = fromjson(body); + + if (obj.hasField("error")) { + GcpKMSError gcpResponse; + try { + gcpResponse = + GcpKMSError::parse(IDLParserContext("gcpEncryptError"), obj["error"].Obj()); + } catch (DBException& dbe) { + uasserted(5265005, + str::stream() << "GCP KMS failed to parse error message: " << dbe.toString() + << ", Response : " << obj); + } + + uasserted(5256006, + str::stream() << "GCP KMS failed to encrypt: " << gcpResponse.getCode() << " " + << gcpResponse.getStatus() << " : " << gcpResponse.getMessage()); + } + + auto gcpResponce = GcpEncryptResponse::parse(IDLParserContext("gcpEncryptResponse"), obj); + + auto blobStr = base64::decode(gcpResponce.getCiphertext()); + + return kmsResponseToVector(blobStr); +} + +SecureVector<uint8_t> GCPKMSService::decrypt(ConstDataRange cdr, BSONObj masterKey) { + auto gcpMasterKey = GcpMasterKey::parse(IDLParserContext("gcpMasterKey"), masterKey); + StringData bearerToken = _oauthService->getBearerToken(); + + auto request = + UniqueKmsRequest(kms_gcp_request_decrypt_new(_server.host().c_str(), + bearerToken.toString().c_str(), + gcpMasterKey.getProjectId().toString().c_str(), + gcpMasterKey.getLocation().toString().c_str(), + gcpMasterKey.getKeyRing().toString().c_str(), + gcpMasterKey.getKeyName().toString().c_str(), + reinterpret_cast<const uint8_t*>(cdr.data()), + cdr.length(), + _config.opts.get())); + + auto buffer = UniqueKmsCharBuffer(kms_request_to_string(request.get())); + auto buffer_len = strlen(buffer.get()); + KMSNetworkConnection connection(_sslManager.get()); + auto response = connection.makeOneRequest(_server, ConstDataRange(buffer.get(), buffer_len)); + + auto body = kms_response_get_body(response.get(), nullptr); + + BSONObj obj = fromjson(body); + + if (obj.hasField("error")) { + GcpKMSError gcpResponse; + try { + gcpResponse = + GcpKMSError::parse(IDLParserContext("gcpDecryptError"), obj["error"].Obj()); + } catch (DBException& dbe) { + uasserted(5265007, + str::stream() << "GCP KMS failed to parse error message: " << dbe.toString() + << ", Response : " << obj); + } + + uasserted(5256008, + str::stream() << "GCP KMS failed to decrypt: " << gcpResponse.getCode() << " " + << gcpResponse.getStatus() << " : " << gcpResponse.getMessage()); + } + + auto gcpResponce = GcpDecryptResponse::parse(IDLParserContext("gcpDecryptResponse"), obj); + + auto blobStr = base64::decode(gcpResponce.getPlaintext()); + + return kmsResponseToSecureVector(blobStr); +} + +BSONObj GCPKMSService::encryptDataKeyByBSONObj(ConstDataRange cdr, BSONObj keyId) { + auto dataKey = encrypt(cdr, keyId); + + GcpMasterKey masterKey = GcpMasterKey::parse(IDLParserContext("gcpMasterKey"), keyId); + + GcpMasterKeyAndMaterial keyAndMaterial; + keyAndMaterial.setKeyMaterial(std::move(dataKey)); + keyAndMaterial.setMasterKey(std::move(masterKey)); + + return keyAndMaterial.toBSON(); +} + +std::unique_ptr<KMSService> GCPKMSService::create(const GcpKMS& config) { + auto gcpKMS = std::make_unique<GCPKMSService>(); + + SSLParams params; + getSSLParamsForNetworkKMS(¶ms); + + HostAndPort oauthHostAndPort(kDefaultOauthEndpoint.toString(), 443); + if (config.getEndpoint().has_value()) { + // Leave the CA file empty so we default to system CA but for local testing allow it to + // inherit the CA file. + params.sslCAFile = sslGlobalParams.sslCAFile; + oauthHostAndPort = parseUrl(config.getEndpoint().value()); + } + + gcpKMS->_sslManager = SSLManagerInterface::create(params, false); + + gcpKMS->configureOauthService(oauthHostAndPort); + + gcpKMS->_server = parseUrl(config.getEndpoint().value_or(kGcpKMSEndpoint)); + + gcpKMS->_config.email = config.getEmail().toString(); + + gcpKMS->_config.opts = UniqueKmsRequestOpts(kms_request_opt_new()); + kms_request_opt_set_provider(gcpKMS->_config.opts.get(), KMS_REQUEST_PROVIDER_GCP); + + gcpKMS->_config.privateKey = config.getPrivateKey().toString(); + + return gcpKMS; +} + +void GCPKMSService::configureOauthService(HostAndPort endpoint) { + _oauthService = std::make_unique<GCPKMSOAuthService>(_config, endpoint, _sslManager); +} + +/** + * Factory for GCPKMSService if user specifies gcp config to mongo() JS constructor. + */ +class GCPKMSServiceFactory final : public KMSServiceFactory { +public: + GCPKMSServiceFactory() = default; + ~GCPKMSServiceFactory() = default; + + std::unique_ptr<KMSService> create(const BSONObj& config) final { + auto field = config[KmsProviders::kGcpFieldName]; + if (field.eoo()) { + return nullptr; + } + uassert(5265009, + "Misconfigured GCP KMS Config: {}"_format(field.toString()), + field.type() == BSONType::Object); + auto obj = field.Obj(); + return GCPKMSService::create(GcpKMS::parse(IDLParserContext("root"), obj)); + } +}; + +} // namespace + +MONGO_INITIALIZER(KMSRegisterGCP)(::mongo::InitializerContext*) { + kms_message_init(); + KMSServiceController::registerFactory(KMSProviderEnum::gcp, + std::make_unique<GCPKMSServiceFactory>()); +} + +} // namespace mongo diff --git a/src/mongo/shell/kms_local.cpp b/src/mongo/shell/kms_local.cpp index 085fe4328e1..f67230ed45f 100644 --- a/src/mongo/shell/kms_local.cpp +++ b/src/mongo/shell/kms_local.cpp @@ -27,6 +27,8 @@ * it in the license file. */ +#include <kms_message/kms_message.h> + #include <stdlib.h> #include "mongo/base/init.h" @@ -62,10 +64,6 @@ public: BSONObj encryptDataKeyByString(ConstDataRange cdr, StringData keyId) final; - SymmetricKey& getMasterKey() final { - return _key; - } - private: std::vector<uint8_t> encrypt(ConstDataRange cdr, StringData kmsKeyId); diff --git a/src/mongo/shell/kms_network.cpp b/src/mongo/shell/kms_network.cpp new file mode 100644 index 00000000000..686b66ceaeb --- /dev/null +++ b/src/mongo/shell/kms_network.cpp @@ -0,0 +1,211 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + + +#include "mongo/platform/basic.h" + +#include "mongo/shell/kms_network.h" + +#include "mongo/bson/json.h" +#include "mongo/shell/kms.h" +#include "mongo/shell/kms_gen.h" +#include "mongo/util/text.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kControl + + +namespace mongo { + +void KMSNetworkConnection::connect(const HostAndPort& host) { + auto makeAddress = [](const auto& host) -> SockAddr { + try { + return SockAddr::create(host.host().c_str(), host.port(), AF_UNSPEC); + } catch (const DBException& ex) { + uasserted(51136, "Unable to resolve KMS server address" + causedBy(ex)); + } + + MONGO_UNREACHABLE; + }; + + auto addr = makeAddress(host); + + size_t attempt = 0; + constexpr size_t kMaxAttempts = 20; + while (!_socket->connect(addr)) { + ++attempt; + if (attempt > kMaxAttempts) { + uasserted(51137, + str::stream() << "Could not connect to KMS server " << addr.toString()); + } + } + + if (!_socket->secure(_sslManager, host.host())) { + uasserted(51138, + str::stream() << "Failed to perform SSL handshake with the KMS server " + << addr.toString()); + } +} + +// Sends a request message to the KMS server and creates a KMS Response. +UniqueKmsResponse KMSNetworkConnection::sendRequest(ConstDataRange request) { + std::array<char, 512> resp; + + _socket->send(reinterpret_cast<const char*>(request.data()), request.length(), "KMS request"); + + auto parser = UniqueKmsResponseParser(kms_response_parser_new()); + int bytes_to_read = 0; + + while ((bytes_to_read = kms_response_parser_wants_bytes(parser.get(), resp.size())) > 0) { + bytes_to_read = std::min(bytes_to_read, static_cast<int>(resp.size())); + bytes_to_read = _socket->unsafe_recv(resp.data(), bytes_to_read); + + uassert(51139, + "kms_response_parser_feed failed", + kms_response_parser_feed( + parser.get(), reinterpret_cast<uint8_t*>(resp.data()), bytes_to_read)); + } + + auto response = UniqueKmsResponse(kms_response_parser_get_response(parser.get())); + + return response; +} + +UniqueKmsResponse KMSNetworkConnection::makeOneRequest(const HostAndPort& host, + ConstDataRange request) { + connect(host); + + auto resp = sendRequest(request); + + _socket->close(); + + return resp; +} + +void getSSLParamsForNetworkKMS(SSLParams* params) { + params->sslPEMKeyFile = ""; + params->sslPEMKeyPassword = ""; + params->sslClusterFile = ""; + params->sslClusterPassword = ""; + params->sslCAFile = ""; + + params->sslCRLFile = ""; + + // Copy the rest from the global SSL manager options. + params->sslFIPSMode = sslGlobalParams.sslFIPSMode; + + // KMS servers never should have invalid certificates + params->sslAllowInvalidCertificates = false; + params->sslAllowInvalidHostnames = false; + + params->sslDisabledProtocols = + std::vector({SSLParams::Protocols::TLS1_0, SSLParams::Protocols::TLS1_1}); +} + +std::vector<uint8_t> kmsResponseToVector(StringData str) { + std::vector<uint8_t> blob; + + std::transform(std::begin(str), std::end(str), std::back_inserter(blob), [](auto c) { + return static_cast<uint8_t>(c); + }); + + return blob; +} + +SecureVector<uint8_t> kmsResponseToSecureVector(StringData str) { + SecureVector<uint8_t> blob(str.size()); + + std::transform(std::begin(str), std::end(str), blob->data(), [](auto c) { + return static_cast<uint8_t>(c); + }); + + return blob; +} + +StringData KMSOAuthService::getBearerToken() { + + if (!_cachedToken.empty() && _expirationDateTime > Date_t::now()) { + return _cachedToken; + } + + makeBearerTokenRequest(); + + return _cachedToken; +} + +void KMSOAuthService::makeBearerTokenRequest() { + UniqueKmsRequest request = getOAuthRequest(); + + auto buffer = UniqueKmsCharBuffer(kms_request_to_string(request.get())); + auto buffer_len = strlen(buffer.get()); + + KMSNetworkConnection connection(_sslManager.get()); + auto response = + connection.makeOneRequest(_oAuthEndpoint, ConstDataRange(buffer.get(), buffer_len)); + + auto body = kms_response_get_body(response.get(), nullptr); + + BSONObj obj = fromjson(body); + + auto field = obj[OAuthErrorResponse::kErrorFieldName]; + + if (!field.eoo()) { + OAuthErrorResponse oAuthErrorResponse; + try { + oAuthErrorResponse = OAuthErrorResponse::parse(IDLParserContext("oauthError"), obj); + } catch (DBException& dbe) { + uasserted(ErrorCodes::FailedToParse, + str::stream() << "Failed to parse error message: " << dbe.toString() + << ", Response : " << obj); + } + + std::string description; + if (oAuthErrorResponse.getError_description().has_value()) { + description = str::stream() + << " : " << oAuthErrorResponse.getError_description().value().toString(); + } + uasserted(ErrorCodes::OperationFailed, + str::stream() << "Failed to make oauth request: " << oAuthErrorResponse.getError() + << description); + } + + auto kmsResponse = OAuthResponse::parse(IDLParserContext("OAuthResponse"), obj); + + _cachedToken = kmsResponse.getAccess_token().toString(); + + // Offset the expiration time by a the socket timeout as proxy for round-trip time to the OAuth + // server. This approximation will compute the expiration time a litte earlier then needed but + // will ensure that it uses a stale bearer token. + Seconds requestBufferTime = 2 * Seconds((int)KMSNetworkConnection::so_timeout_seconds); + + // expires_in is optional but Azure and GCP always return it but to be safe, we pick a default + _expirationDateTime = + Date_t::now() + Seconds(kmsResponse.getExpires_in().value_or(60)) - requestBufferTime; +} + +} // namespace mongo diff --git a/src/mongo/shell/kms_network.h b/src/mongo/shell/kms_network.h new file mode 100644 index 00000000000..724412f869b --- /dev/null +++ b/src/mongo/shell/kms_network.h @@ -0,0 +1,125 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <string> + +#include "mongo/util/kms_message_support.h" +#include "mongo/util/net/sock.h" +#include "mongo/util/net/ssl_manager.h" +#include "mongo/util/net/ssl_options.h" +#include "mongo/util/time_support.h" + +namespace mongo { + +/** + * Make a request to an HTTP endpoint. + * + * Does not maintain a persistent HTTP connection. + */ +class KMSNetworkConnection { +public: + static constexpr double so_timeout_seconds = 10; + + KMSNetworkConnection(SSLManagerInterface* ssl) + : _sslManager(ssl), + _socket(std::make_unique<Socket>(so_timeout_seconds, logv2::LogSeverity::Info())) {} + + UniqueKmsResponse makeOneRequest(const HostAndPort& host, ConstDataRange request); + +private: + UniqueKmsResponse sendRequest(ConstDataRange request); + + void connect(const HostAndPort& host); + +private: + // SSL Manager for connections + SSLManagerInterface* _sslManager; + + // Synchronous socket + std::unique_ptr<Socket> _socket; +}; + +/** + * Creates an initial SSLParams object for KMS over the network. + */ +void getSSLParamsForNetworkKMS(SSLParams*); + +/** + * Converts a base64 encoded KMS response to a vector of bytes. + */ +std::vector<uint8_t> kmsResponseToVector(StringData str); + +/** + * Converts a base64 encoded KMS response to a securely allocated vector of bytes. + */ +SecureVector<uint8_t> kmsResponseToSecureVector(StringData str); + +/** + * Base class for KMS services that use OAuth for authorization. + * + * Each service only talks to one OAuth endpoint so caching bearer tokens is simple. + */ +class KMSOAuthService { +public: + KMSOAuthService(HostAndPort oAuthEndpoint, std::shared_ptr<SSLManagerInterface> sslManager) + : _oAuthEndpoint(oAuthEndpoint), _sslManager(sslManager) {} + + /** + * Get a bearer token to use to make requests. It may be cached. + */ + StringData getBearerToken(); + +protected: + /** + * Construct a valid kms request for retrieving a new OAuth Bearer token + */ + virtual UniqueKmsRequest getOAuthRequest() = 0; + +protected: + // OAuth Service endpoint + HostAndPort _oAuthEndpoint; + +private: + /** + * Make a TLS request to a service to fetch a bearer token. + */ + void makeBearerTokenRequest(); + +private: + // SSL Manager + std::shared_ptr<SSLManagerInterface> _sslManager; + + // Cached access token + std::string _cachedToken; + + // Expiration datetime of access token + Date_t _expirationDateTime; +}; + +} // namespace mongo |