diff --git a/.evergreen/run-kms-tls-tests.sh b/.evergreen/run-kms-tls-tests.sh
index df3a38c0eec..bdc716fc86f 100755
--- a/.evergreen/run-kms-tls-tests.sh
+++ b/.evergreen/run-kms-tls-tests.sh
@@ -17,6 +17,12 @@ echo "Running KMS TLS tests"
cp ${JAVA_HOME}/lib/security/cacerts mongo-truststore
${JAVA_HOME}/bin/keytool -importcert -trustcacerts -file ${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem -keystore mongo-truststore -storepass changeit -storetype JKS -noprompt
+# Create keystore from server.pem to emulate KMS server in tests.
+openssl pkcs12 -export \
+ -in ${DRIVERS_TOOLS}/.evergreen/x509gen/server.pem \
+ -out server.p12 \
+ -password pass:test
+
export GRADLE_EXTRA_VARS="-Pssl.enabled=true -Pssl.trustStoreType=jks -Pssl.trustStore=`pwd`/mongo-truststore -Pssl.trustStorePassword=changeit"
export KMS_TLS_ERROR_TYPE=${KMS_TLS_ERROR_TYPE}
@@ -24,12 +30,14 @@ export KMS_TLS_ERROR_TYPE=${KMS_TLS_ERROR_TYPE}
./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \
-Dorg.mongodb.test.kms.tls.error.type=${KMS_TLS_ERROR_TYPE} \
+ -Dorg.mongodb.test.kms.keystore.location="$(pwd)" \
driver-sync:cleanTest driver-sync:test --tests ClientSideEncryptionKmsTlsTest
first=$?
echo $first
./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \
-Dorg.mongodb.test.kms.tls.error.type=${KMS_TLS_ERROR_TYPE} \
+ -Dorg.mongodb.test.kms.keystore.location="$(pwd)" \
driver-reactive-streams:cleanTest driver-reactive-streams:test --tests ClientSideEncryptionKmsTlsTest
second=$?
echo $second
diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java
index b82dd590618..67ebf421c9c 100644
--- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java
+++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java
@@ -16,6 +16,7 @@
package com.mongodb.reactivestreams.client.internal.crypt;
+import com.mongodb.MongoException;
import com.mongodb.MongoOperationTimeoutException;
import com.mongodb.MongoSocketException;
import com.mongodb.MongoSocketReadTimeoutException;
@@ -131,6 +132,11 @@ private void streamRead(final Stream stream, final MongoKeyDecryptor keyDecrypto
@Override
public void completed(final Integer integer, final Void aVoid) {
+ if (integer == -1) {
+ sink.error(new MongoException(
+ "Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider()));
+ return;
+ }
buffer.flip();
try {
keyDecryptor.feed(buffer.asNIO());
diff --git a/driver-sync/src/main/com/mongodb/client/internal/Crypt.java b/driver-sync/src/main/com/mongodb/client/internal/Crypt.java
index ae7a75ae626..67fac13770c 100644
--- a/driver-sync/src/main/com/mongodb/client/internal/Crypt.java
+++ b/driver-sync/src/main/com/mongodb/client/internal/Crypt.java
@@ -369,6 +369,9 @@ private void decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Ti
while (bytesNeeded > 0) {
byte[] bytes = new byte[bytesNeeded];
int bytesRead = inputStream.read(bytes, 0, bytes.length);
+ if (bytesRead == -1) {
+ throw new MongoException("Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider());
+ }
keyDecryptor.feed(ByteBuffer.wrap(bytes, 0, bytesRead));
bytesNeeded = keyDecryptor.bytesNeeded();
}
@@ -376,7 +379,7 @@ private void decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Ti
}
private MongoException wrapInMongoException(final Throwable t) {
- if (t instanceof MongoException) {
+ if (t instanceof MongoClientException) {
return (MongoException) t;
} else {
return new MongoClientException("Exception in encryption library: " + t.getMessage(), t);
diff --git a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsTlsTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsTlsTest.java
index 6e0b5957dea..fb8b6682590 100644
--- a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsTlsTest.java
+++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsTlsTest.java
@@ -20,14 +20,19 @@
import com.mongodb.MongoClientException;
import com.mongodb.client.model.vault.DataKeyOptions;
import com.mongodb.client.vault.ClientEncryption;
+import com.mongodb.fixture.EncryptionFixture;
import com.mongodb.lang.NonNull;
import com.mongodb.lang.Nullable;
import org.bson.BsonDocument;
import org.junit.jupiter.api.Test;
import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLServerSocket;
+import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
+import java.io.IOException;
+import java.net.Socket;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
@@ -35,16 +40,20 @@
import java.security.cert.X509Certificate;
import java.util.HashMap;
import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
import static com.mongodb.ClusterFixture.getEnv;
import static com.mongodb.ClusterFixture.hasEncryptionTestsEnabled;
import static com.mongodb.client.Fixture.getMongoClientSettings;
import static java.util.Objects.requireNonNull;
+import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
+
public abstract class AbstractClientSideEncryptionKmsTlsTest {
private static final String SYSTEM_PROPERTY_KEY = "org.mongodb.test.kms.tls.error.type";
@@ -128,7 +137,7 @@ public void testInvalidKmsCertificate() {
* See
* 11. KMS TLS Options Tests.
*/
- @Test()
+ @Test
public void testThatCustomSslContextIsUsed() {
assumeTrue(hasEncryptionTestsEnabled());
@@ -165,34 +174,67 @@ public void testThatCustomSslContextIsUsed() {
}
}
+ /**
+ * Not a prose spec test. However, it is additional test case for better coverage.
+ */
+ @Test
+ public void testUnexpectedEndOfStreamFromKmsProvider() throws Exception {
+ int kmsPort = 5555;
+ ClientEncryptionSettings clientEncryptionSettings = ClientEncryptionSettings.builder()
+ .keyVaultMongoClientSettings(getMongoClientSettings())
+ .keyVaultNamespace("keyvault.datakeys")
+ .kmsProviders(new HashMap>() {{
+ put("kmip", new HashMap() {{
+ put("endpoint", "localhost:" + kmsPort);
+ }});
+ }})
+ .build();
+
+ Thread serverThread = null;
+ try (ClientEncryption clientEncryption = getClientEncryption(clientEncryptionSettings)) {
+ serverThread = startKmsServerSimulatingEof(EncryptionFixture.buildSslContextFromKeyStore(
+ System.getProperty("org.mongodb.test.kms.keystore.location"),
+ "server.p12"), kmsPort);
+
+ MongoClientException mongoException = assertThrows(MongoClientException.class,
+ () -> clientEncryption.createDataKey("kmip", new DataKeyOptions()));
+ assertEquals("Exception in encryption library: Unexpected end of stream from KMS provider kmip",
+ mongoException.getMessage());
+ } finally {
+ if (serverThread != null) {
+ serverThread.interrupt();
+ }
+ }
+ }
+
private HashMap> getKmsProviders() {
return new HashMap>() {{
- put("aws", new HashMap() {{
+ put("aws", new HashMap() {{
put("accessKeyId", getEnv("AWS_ACCESS_KEY_ID"));
put("secretAccessKey", getEnv("AWS_SECRET_ACCESS_KEY"));
}});
- put("aws:named", new HashMap() {{
+ put("aws:named", new HashMap() {{
put("accessKeyId", getEnv("AWS_ACCESS_KEY_ID"));
put("secretAccessKey", getEnv("AWS_SECRET_ACCESS_KEY"));
}});
- put("azure", new HashMap() {{
+ put("azure", new HashMap() {{
put("tenantId", getEnv("AZURE_TENANT_ID"));
put("clientId", getEnv("AZURE_CLIENT_ID"));
put("clientSecret", getEnv("AZURE_CLIENT_SECRET"));
put("identityPlatformEndpoint", "login.microsoftonline.com:443");
}});
- put("azure:named", new HashMap() {{
+ put("azure:named", new HashMap() {{
put("tenantId", getEnv("AZURE_TENANT_ID"));
put("clientId", getEnv("AZURE_CLIENT_ID"));
put("clientSecret", getEnv("AZURE_CLIENT_SECRET"));
put("identityPlatformEndpoint", "login.microsoftonline.com:443");
}});
- put("gcp", new HashMap() {{
+ put("gcp", new HashMap() {{
put("email", getEnv("GCP_EMAIL"));
put("privateKey", getEnv("GCP_PRIVATE_KEY"));
put("endpoint", "oauth2.googleapis.com:443");
}});
- put("gcp:named", new HashMap() {{
+ put("gcp:named", new HashMap() {{
put("email", getEnv("GCP_EMAIL"));
put("privateKey", getEnv("GCP_PRIVATE_KEY"));
put("endpoint", "oauth2.googleapis.com:443");
@@ -257,5 +299,29 @@ public void checkServerTrusted(final X509Certificate[] certs, final String authT
throw new RuntimeException(e);
}
}
+
+ private Thread startKmsServerSimulatingEof(final SSLContext sslContext, final int kmsPort)
+ throws Exception {
+ CompletableFuture confirmListening = new CompletableFuture<>();
+ Thread serverThread = new Thread(() -> {
+ try {
+ SSLServerSocketFactory serverSocketFactory = sslContext.getServerSocketFactory();
+ try (SSLServerSocket sslServerSocket = (SSLServerSocket) serverSocketFactory.createServerSocket(kmsPort)) {
+ sslServerSocket.setNeedClientAuth(false);
+ confirmListening.complete(null);
+ try (Socket accept = sslServerSocket.accept()) {
+ accept.setSoTimeout(10000);
+ accept.getInputStream().read();
+ }
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }, "KMIP-EOF-Fake-Server");
+ serverThread.setDaemon(true);
+ serverThread.start();
+ confirmListening.get(TimeUnit.SECONDS.toMillis(10), TimeUnit.MILLISECONDS);
+ return serverThread;
+ }
}
diff --git a/driver-sync/src/test/functional/com/mongodb/client/auth/AbstractX509AuthenticationTest.java b/driver-sync/src/test/functional/com/mongodb/client/auth/AbstractX509AuthenticationTest.java
index 0d003210f3d..e325e9a23f8 100644
--- a/driver-sync/src/test/functional/com/mongodb/client/auth/AbstractX509AuthenticationTest.java
+++ b/driver-sync/src/test/functional/com/mongodb/client/auth/AbstractX509AuthenticationTest.java
@@ -22,6 +22,7 @@
import com.mongodb.client.Fixture;
import com.mongodb.client.MongoClient;
import com.mongodb.connection.NettyTransportSettings;
+import com.mongodb.fixture.EncryptionFixture;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProvider;
import org.junit.jupiter.api.extension.ConditionEvaluationResult;
@@ -34,14 +35,6 @@
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.security.KeyStore;
-import java.security.KeyStoreException;
-import java.security.NoSuchAlgorithmException;
-import java.security.UnrecoverableKeyException;
-import java.security.cert.CertificateException;
import java.util.stream.Stream;
import static com.mongodb.AuthenticationMechanism.MONGODB_X509;
@@ -52,7 +45,6 @@
@ExtendWith(AbstractX509AuthenticationTest.X509AuthenticationPropertyCondition.class)
public abstract class AbstractX509AuthenticationTest {
- private static final String KEYSTORE_PASSWORD = "test";
protected abstract MongoClient createMongoClient(MongoClientSettings mongoClientSettings);
private static Stream shouldAuthenticateWithClientCertificate() throws Exception {
@@ -128,22 +120,11 @@ private static Stream getArgumentForKeystore(final String keystoreFil
}
private static SSLContext buildSslContextFromKeyStore(final String keystoreFileName) throws Exception {
- KeyManagerFactory keyManagerFactory = getKeyManagerFactory(keystoreFileName);
- SSLContext sslContext = SSLContext.getInstance("TLS");
- sslContext.init(keyManagerFactory.getKeyManagers(), null, null);
- return sslContext;
+ return EncryptionFixture.buildSslContextFromKeyStore(getKeystoreLocation(), keystoreFileName);
}
- private static KeyManagerFactory getKeyManagerFactory(final String keystoreFileName)
- throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException, UnrecoverableKeyException {
- KeyStore ks = KeyStore.getInstance("PKCS12");
- try (FileInputStream fis = new FileInputStream(getKeystoreLocation() + File.separator + keystoreFileName)) {
- ks.load(fis, KEYSTORE_PASSWORD.toCharArray());
- }
- KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(
- KeyManagerFactory.getDefaultAlgorithm());
- keyManagerFactory.init(ks, KEYSTORE_PASSWORD.toCharArray());
- return keyManagerFactory;
+ private static KeyManagerFactory getKeyManagerFactory(final String keystoreFileName) throws Exception {
+ return EncryptionFixture.getKeyManagerFactory(getKeystoreLocation(), keystoreFileName);
}
private static String getKeystoreLocation() {
diff --git a/driver-sync/src/test/functional/com/mongodb/fixture/EncryptionFixture.java b/driver-sync/src/test/functional/com/mongodb/fixture/EncryptionFixture.java
index f6edb9a14ed..7530bf3d7c0 100644
--- a/driver-sync/src/test/functional/com/mongodb/fixture/EncryptionFixture.java
+++ b/driver-sync/src/test/functional/com/mongodb/fixture/EncryptionFixture.java
@@ -18,6 +18,11 @@
package com.mongodb.fixture;
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import java.io.File;
+import java.io.FileInputStream;
+import java.security.KeyStore;
import java.util.HashMap;
import java.util.Map;
@@ -73,6 +78,23 @@ public static Map> getKmsProviders(final KmsProvider
}};
}
+ public static KeyManagerFactory getKeyManagerFactory(final String keystoreLocation, final String keystoreFileName) throws Exception {
+ KeyStore ks = KeyStore.getInstance("PKCS12");
+ try (FileInputStream fis = new FileInputStream(keystoreLocation + File.separator + keystoreFileName)) {
+ ks.load(fis, "test".toCharArray());
+ }
+ KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+ keyManagerFactory.init(ks, "test".toCharArray());
+ return keyManagerFactory;
+ }
+
+ public static SSLContext buildSslContextFromKeyStore(final String keystoreLocation, final String keystoreFileName) throws Exception {
+ KeyManagerFactory keyManagerFactory = getKeyManagerFactory(keystoreLocation, keystoreFileName);
+ SSLContext sslContext = SSLContext.getInstance("TLS");
+ sslContext.init(keyManagerFactory.getKeyManagers(), null, null);
+ return sslContext;
+ }
+
public enum KmsProviderType {
LOCAL,
AWS,