diff --git a/.evergreen/run-kms-tls-tests.sh b/.evergreen/run-kms-tls-tests.sh index df3a38c0eec..bdc716fc86f 100755 --- a/.evergreen/run-kms-tls-tests.sh +++ b/.evergreen/run-kms-tls-tests.sh @@ -17,6 +17,12 @@ echo "Running KMS TLS tests" cp ${JAVA_HOME}/lib/security/cacerts mongo-truststore ${JAVA_HOME}/bin/keytool -importcert -trustcacerts -file ${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem -keystore mongo-truststore -storepass changeit -storetype JKS -noprompt +# Create keystore from server.pem to emulate KMS server in tests. +openssl pkcs12 -export \ + -in ${DRIVERS_TOOLS}/.evergreen/x509gen/server.pem \ + -out server.p12 \ + -password pass:test + export GRADLE_EXTRA_VARS="-Pssl.enabled=true -Pssl.trustStoreType=jks -Pssl.trustStore=`pwd`/mongo-truststore -Pssl.trustStorePassword=changeit" export KMS_TLS_ERROR_TYPE=${KMS_TLS_ERROR_TYPE} @@ -24,12 +30,14 @@ export KMS_TLS_ERROR_TYPE=${KMS_TLS_ERROR_TYPE} ./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \ -Dorg.mongodb.test.kms.tls.error.type=${KMS_TLS_ERROR_TYPE} \ + -Dorg.mongodb.test.kms.keystore.location="$(pwd)" \ driver-sync:cleanTest driver-sync:test --tests ClientSideEncryptionKmsTlsTest first=$? echo $first ./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \ -Dorg.mongodb.test.kms.tls.error.type=${KMS_TLS_ERROR_TYPE} \ + -Dorg.mongodb.test.kms.keystore.location="$(pwd)" \ driver-reactive-streams:cleanTest driver-reactive-streams:test --tests ClientSideEncryptionKmsTlsTest second=$? echo $second diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java index b82dd590618..67ebf421c9c 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java @@ -16,6 +16,7 @@ package com.mongodb.reactivestreams.client.internal.crypt; +import com.mongodb.MongoException; import com.mongodb.MongoOperationTimeoutException; import com.mongodb.MongoSocketException; import com.mongodb.MongoSocketReadTimeoutException; @@ -131,6 +132,11 @@ private void streamRead(final Stream stream, final MongoKeyDecryptor keyDecrypto @Override public void completed(final Integer integer, final Void aVoid) { + if (integer == -1) { + sink.error(new MongoException( + "Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider())); + return; + } buffer.flip(); try { keyDecryptor.feed(buffer.asNIO()); diff --git a/driver-sync/src/main/com/mongodb/client/internal/Crypt.java b/driver-sync/src/main/com/mongodb/client/internal/Crypt.java index ae7a75ae626..67fac13770c 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/Crypt.java +++ b/driver-sync/src/main/com/mongodb/client/internal/Crypt.java @@ -369,6 +369,9 @@ private void decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Ti while (bytesNeeded > 0) { byte[] bytes = new byte[bytesNeeded]; int bytesRead = inputStream.read(bytes, 0, bytes.length); + if (bytesRead == -1) { + throw new MongoException("Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider()); + } keyDecryptor.feed(ByteBuffer.wrap(bytes, 0, bytesRead)); bytesNeeded = keyDecryptor.bytesNeeded(); } @@ -376,7 +379,7 @@ private void decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Ti } private MongoException wrapInMongoException(final Throwable t) { - if (t instanceof MongoException) { + if (t instanceof MongoClientException) { return (MongoException) t; } else { return new MongoClientException("Exception in encryption library: " + t.getMessage(), t); diff --git a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsTlsTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsTlsTest.java index 6e0b5957dea..fb8b6682590 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsTlsTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsTlsTest.java @@ -20,14 +20,19 @@ import com.mongodb.MongoClientException; import com.mongodb.client.model.vault.DataKeyOptions; import com.mongodb.client.vault.ClientEncryption; +import com.mongodb.fixture.EncryptionFixture; import com.mongodb.lang.NonNull; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.junit.jupiter.api.Test; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLServerSocketFactory; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; +import java.io.IOException; +import java.net.Socket; import java.security.KeyManagementException; import java.security.NoSuchAlgorithmException; import java.security.cert.CertificateException; @@ -35,16 +40,20 @@ import java.security.cert.X509Certificate; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; import static com.mongodb.ClusterFixture.getEnv; import static com.mongodb.ClusterFixture.hasEncryptionTestsEnabled; import static com.mongodb.client.Fixture.getMongoClientSettings; import static java.util.Objects.requireNonNull; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.assumeTrue; + public abstract class AbstractClientSideEncryptionKmsTlsTest { private static final String SYSTEM_PROPERTY_KEY = "org.mongodb.test.kms.tls.error.type"; @@ -128,7 +137,7 @@ public void testInvalidKmsCertificate() { * See * 11. KMS TLS Options Tests. */ - @Test() + @Test public void testThatCustomSslContextIsUsed() { assumeTrue(hasEncryptionTestsEnabled()); @@ -165,34 +174,67 @@ public void testThatCustomSslContextIsUsed() { } } + /** + * Not a prose spec test. However, it is additional test case for better coverage. + */ + @Test + public void testUnexpectedEndOfStreamFromKmsProvider() throws Exception { + int kmsPort = 5555; + ClientEncryptionSettings clientEncryptionSettings = ClientEncryptionSettings.builder() + .keyVaultMongoClientSettings(getMongoClientSettings()) + .keyVaultNamespace("keyvault.datakeys") + .kmsProviders(new HashMap>() {{ + put("kmip", new HashMap() {{ + put("endpoint", "localhost:" + kmsPort); + }}); + }}) + .build(); + + Thread serverThread = null; + try (ClientEncryption clientEncryption = getClientEncryption(clientEncryptionSettings)) { + serverThread = startKmsServerSimulatingEof(EncryptionFixture.buildSslContextFromKeyStore( + System.getProperty("org.mongodb.test.kms.keystore.location"), + "server.p12"), kmsPort); + + MongoClientException mongoException = assertThrows(MongoClientException.class, + () -> clientEncryption.createDataKey("kmip", new DataKeyOptions())); + assertEquals("Exception in encryption library: Unexpected end of stream from KMS provider kmip", + mongoException.getMessage()); + } finally { + if (serverThread != null) { + serverThread.interrupt(); + } + } + } + private HashMap> getKmsProviders() { return new HashMap>() {{ - put("aws", new HashMap() {{ + put("aws", new HashMap() {{ put("accessKeyId", getEnv("AWS_ACCESS_KEY_ID")); put("secretAccessKey", getEnv("AWS_SECRET_ACCESS_KEY")); }}); - put("aws:named", new HashMap() {{ + put("aws:named", new HashMap() {{ put("accessKeyId", getEnv("AWS_ACCESS_KEY_ID")); put("secretAccessKey", getEnv("AWS_SECRET_ACCESS_KEY")); }}); - put("azure", new HashMap() {{ + put("azure", new HashMap() {{ put("tenantId", getEnv("AZURE_TENANT_ID")); put("clientId", getEnv("AZURE_CLIENT_ID")); put("clientSecret", getEnv("AZURE_CLIENT_SECRET")); put("identityPlatformEndpoint", "login.microsoftonline.com:443"); }}); - put("azure:named", new HashMap() {{ + put("azure:named", new HashMap() {{ put("tenantId", getEnv("AZURE_TENANT_ID")); put("clientId", getEnv("AZURE_CLIENT_ID")); put("clientSecret", getEnv("AZURE_CLIENT_SECRET")); put("identityPlatformEndpoint", "login.microsoftonline.com:443"); }}); - put("gcp", new HashMap() {{ + put("gcp", new HashMap() {{ put("email", getEnv("GCP_EMAIL")); put("privateKey", getEnv("GCP_PRIVATE_KEY")); put("endpoint", "oauth2.googleapis.com:443"); }}); - put("gcp:named", new HashMap() {{ + put("gcp:named", new HashMap() {{ put("email", getEnv("GCP_EMAIL")); put("privateKey", getEnv("GCP_PRIVATE_KEY")); put("endpoint", "oauth2.googleapis.com:443"); @@ -257,5 +299,29 @@ public void checkServerTrusted(final X509Certificate[] certs, final String authT throw new RuntimeException(e); } } + + private Thread startKmsServerSimulatingEof(final SSLContext sslContext, final int kmsPort) + throws Exception { + CompletableFuture confirmListening = new CompletableFuture<>(); + Thread serverThread = new Thread(() -> { + try { + SSLServerSocketFactory serverSocketFactory = sslContext.getServerSocketFactory(); + try (SSLServerSocket sslServerSocket = (SSLServerSocket) serverSocketFactory.createServerSocket(kmsPort)) { + sslServerSocket.setNeedClientAuth(false); + confirmListening.complete(null); + try (Socket accept = sslServerSocket.accept()) { + accept.setSoTimeout(10000); + accept.getInputStream().read(); + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + }, "KMIP-EOF-Fake-Server"); + serverThread.setDaemon(true); + serverThread.start(); + confirmListening.get(TimeUnit.SECONDS.toMillis(10), TimeUnit.MILLISECONDS); + return serverThread; + } } diff --git a/driver-sync/src/test/functional/com/mongodb/client/auth/AbstractX509AuthenticationTest.java b/driver-sync/src/test/functional/com/mongodb/client/auth/AbstractX509AuthenticationTest.java index 0d003210f3d..e325e9a23f8 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/auth/AbstractX509AuthenticationTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/auth/AbstractX509AuthenticationTest.java @@ -22,6 +22,7 @@ import com.mongodb.client.Fixture; import com.mongodb.client.MongoClient; import com.mongodb.connection.NettyTransportSettings; +import com.mongodb.fixture.EncryptionFixture; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.SslProvider; import org.junit.jupiter.api.extension.ConditionEvaluationResult; @@ -34,14 +35,6 @@ import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.security.KeyStore; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; -import java.security.UnrecoverableKeyException; -import java.security.cert.CertificateException; import java.util.stream.Stream; import static com.mongodb.AuthenticationMechanism.MONGODB_X509; @@ -52,7 +45,6 @@ @ExtendWith(AbstractX509AuthenticationTest.X509AuthenticationPropertyCondition.class) public abstract class AbstractX509AuthenticationTest { - private static final String KEYSTORE_PASSWORD = "test"; protected abstract MongoClient createMongoClient(MongoClientSettings mongoClientSettings); private static Stream shouldAuthenticateWithClientCertificate() throws Exception { @@ -128,22 +120,11 @@ private static Stream getArgumentForKeystore(final String keystoreFil } private static SSLContext buildSslContextFromKeyStore(final String keystoreFileName) throws Exception { - KeyManagerFactory keyManagerFactory = getKeyManagerFactory(keystoreFileName); - SSLContext sslContext = SSLContext.getInstance("TLS"); - sslContext.init(keyManagerFactory.getKeyManagers(), null, null); - return sslContext; + return EncryptionFixture.buildSslContextFromKeyStore(getKeystoreLocation(), keystoreFileName); } - private static KeyManagerFactory getKeyManagerFactory(final String keystoreFileName) - throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException, UnrecoverableKeyException { - KeyStore ks = KeyStore.getInstance("PKCS12"); - try (FileInputStream fis = new FileInputStream(getKeystoreLocation() + File.separator + keystoreFileName)) { - ks.load(fis, KEYSTORE_PASSWORD.toCharArray()); - } - KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance( - KeyManagerFactory.getDefaultAlgorithm()); - keyManagerFactory.init(ks, KEYSTORE_PASSWORD.toCharArray()); - return keyManagerFactory; + private static KeyManagerFactory getKeyManagerFactory(final String keystoreFileName) throws Exception { + return EncryptionFixture.getKeyManagerFactory(getKeystoreLocation(), keystoreFileName); } private static String getKeystoreLocation() { diff --git a/driver-sync/src/test/functional/com/mongodb/fixture/EncryptionFixture.java b/driver-sync/src/test/functional/com/mongodb/fixture/EncryptionFixture.java index f6edb9a14ed..7530bf3d7c0 100644 --- a/driver-sync/src/test/functional/com/mongodb/fixture/EncryptionFixture.java +++ b/driver-sync/src/test/functional/com/mongodb/fixture/EncryptionFixture.java @@ -18,6 +18,11 @@ package com.mongodb.fixture; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import java.io.File; +import java.io.FileInputStream; +import java.security.KeyStore; import java.util.HashMap; import java.util.Map; @@ -73,6 +78,23 @@ public static Map> getKmsProviders(final KmsProvider }}; } + public static KeyManagerFactory getKeyManagerFactory(final String keystoreLocation, final String keystoreFileName) throws Exception { + KeyStore ks = KeyStore.getInstance("PKCS12"); + try (FileInputStream fis = new FileInputStream(keystoreLocation + File.separator + keystoreFileName)) { + ks.load(fis, "test".toCharArray()); + } + KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(ks, "test".toCharArray()); + return keyManagerFactory; + } + + public static SSLContext buildSslContextFromKeyStore(final String keystoreLocation, final String keystoreFileName) throws Exception { + KeyManagerFactory keyManagerFactory = getKeyManagerFactory(keystoreLocation, keystoreFileName); + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(keyManagerFactory.getKeyManagers(), null, null); + return sslContext; + } + public enum KmsProviderType { LOCAL, AWS,