Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .evergreen/run-kms-tls-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,27 @@ echo "Running KMS TLS tests"
cp ${JAVA_HOME}/lib/security/cacerts mongo-truststore
${JAVA_HOME}/bin/keytool -importcert -trustcacerts -file ${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem -keystore mongo-truststore -storepass changeit -storetype JKS -noprompt

# Create keystore from server.pem to emulate KMS server in tests.
openssl pkcs12 -export \
-in ${DRIVERS_TOOLS}/.evergreen/x509gen/server.pem \
-out server.p12 \
-password pass:test

export GRADLE_EXTRA_VARS="-Pssl.enabled=true -Pssl.trustStoreType=jks -Pssl.trustStore=`pwd`/mongo-truststore -Pssl.trustStorePassword=changeit"
export KMS_TLS_ERROR_TYPE=${KMS_TLS_ERROR_TYPE}

./gradlew -version

./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \
-Dorg.mongodb.test.kms.tls.error.type=${KMS_TLS_ERROR_TYPE} \
-Dorg.mongodb.test.kms.keystore.location="$(pwd)" \
driver-sync:cleanTest driver-sync:test --tests ClientSideEncryptionKmsTlsTest
first=$?
echo $first

./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \
-Dorg.mongodb.test.kms.tls.error.type=${KMS_TLS_ERROR_TYPE} \
-Dorg.mongodb.test.kms.keystore.location="$(pwd)" \
driver-reactive-streams:cleanTest driver-reactive-streams:test --tests ClientSideEncryptionKmsTlsTest
second=$?
echo $second
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.mongodb.reactivestreams.client.internal.crypt;

import com.mongodb.MongoException;
import com.mongodb.MongoOperationTimeoutException;
import com.mongodb.MongoSocketException;
import com.mongodb.MongoSocketReadTimeoutException;
Expand Down Expand Up @@ -131,6 +132,11 @@ private void streamRead(final Stream stream, final MongoKeyDecryptor keyDecrypto

@Override
public void completed(final Integer integer, final Void aVoid) {
if (integer == -1) {
sink.error(new MongoException(
"Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider()));
return;
}
buffer.flip();
try {
keyDecryptor.feed(buffer.asNIO());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,14 +369,17 @@ private void decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Ti
while (bytesNeeded > 0) {
byte[] bytes = new byte[bytesNeeded];
int bytesRead = inputStream.read(bytes, 0, bytes.length);
if (bytesRead == -1) {
throw new MongoException("Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider());
}
keyDecryptor.feed(ByteBuffer.wrap(bytes, 0, bytesRead));
bytesNeeded = keyDecryptor.bytesNeeded();
}
}
}

private MongoException wrapInMongoException(final Throwable t) {
if (t instanceof MongoException) {
if (t instanceof MongoClientException) {
return (MongoException) t;
} else {
return new MongoClientException("Exception in encryption library: " + t.getMessage(), t);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,40 @@
import com.mongodb.MongoClientException;
import com.mongodb.client.model.vault.DataKeyOptions;
import com.mongodb.client.vault.ClientEncryption;
import com.mongodb.fixture.EncryptionFixture;
import com.mongodb.lang.NonNull;
import com.mongodb.lang.Nullable;
import org.bson.BsonDocument;
import org.junit.jupiter.api.Test;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
import java.net.Socket;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateExpiredException;
import java.security.cert.X509Certificate;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

import static com.mongodb.ClusterFixture.getEnv;
import static com.mongodb.ClusterFixture.hasEncryptionTestsEnabled;
import static com.mongodb.client.Fixture.getMongoClientSettings;
import static java.util.Objects.requireNonNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

public abstract class AbstractClientSideEncryptionKmsTlsTest {

private static final String SYSTEM_PROPERTY_KEY = "org.mongodb.test.kms.tls.error.type";
Expand Down Expand Up @@ -128,7 +137,7 @@ public void testInvalidKmsCertificate() {
* See <a href="https://github.com/mongodb/specifications/tree/master/source/client-side-encryption/tests#11-kms-tls-options-tests">
* 11. KMS TLS Options Tests</a>.
*/
@Test()
@Test
public void testThatCustomSslContextIsUsed() {
assumeTrue(hasEncryptionTestsEnabled());

Expand Down Expand Up @@ -165,34 +174,67 @@ public void testThatCustomSslContextIsUsed() {
}
}

/**
* Not a prose spec test. However, it is additional test case for better coverage.
*/
@Test
public void testUnexpectedEndOfStreamFromKmsProvider() throws Exception {
int kmsPort = 5555;
ClientEncryptionSettings clientEncryptionSettings = ClientEncryptionSettings.builder()
.keyVaultMongoClientSettings(getMongoClientSettings())
.keyVaultNamespace("keyvault.datakeys")
.kmsProviders(new HashMap<String, Map<String, Object>>() {{
put("kmip", new HashMap<String, Object>() {{
put("endpoint", "localhost:" + kmsPort);
}});
}})
.build();

Thread serverThread = null;
try (ClientEncryption clientEncryption = getClientEncryption(clientEncryptionSettings)) {
serverThread = startKmsServerSimulatingEof(EncryptionFixture.buildSslContextFromKeyStore(
System.getProperty("org.mongodb.test.kms.keystore.location"),
"server.p12"), kmsPort);

MongoClientException mongoException = assertThrows(MongoClientException.class,
() -> clientEncryption.createDataKey("kmip", new DataKeyOptions()));
assertEquals("Exception in encryption library: Unexpected end of stream from KMS provider kmip",
mongoException.getMessage());
} finally {
if (serverThread != null) {
serverThread.interrupt();
}
}
}

private HashMap<String, Map<String, Object>> getKmsProviders() {
return new HashMap<String, Map<String, Object>>() {{
put("aws", new HashMap<String, Object>() {{
put("aws", new HashMap<String, Object>() {{
put("accessKeyId", getEnv("AWS_ACCESS_KEY_ID"));
put("secretAccessKey", getEnv("AWS_SECRET_ACCESS_KEY"));
}});
put("aws:named", new HashMap<String, Object>() {{
put("aws:named", new HashMap<String, Object>() {{
put("accessKeyId", getEnv("AWS_ACCESS_KEY_ID"));
put("secretAccessKey", getEnv("AWS_SECRET_ACCESS_KEY"));
}});
put("azure", new HashMap<String, Object>() {{
put("azure", new HashMap<String, Object>() {{
put("tenantId", getEnv("AZURE_TENANT_ID"));
put("clientId", getEnv("AZURE_CLIENT_ID"));
put("clientSecret", getEnv("AZURE_CLIENT_SECRET"));
put("identityPlatformEndpoint", "login.microsoftonline.com:443");
}});
put("azure:named", new HashMap<String, Object>() {{
put("azure:named", new HashMap<String, Object>() {{
put("tenantId", getEnv("AZURE_TENANT_ID"));
put("clientId", getEnv("AZURE_CLIENT_ID"));
put("clientSecret", getEnv("AZURE_CLIENT_SECRET"));
put("identityPlatformEndpoint", "login.microsoftonline.com:443");
}});
put("gcp", new HashMap<String, Object>() {{
put("gcp", new HashMap<String, Object>() {{
put("email", getEnv("GCP_EMAIL"));
put("privateKey", getEnv("GCP_PRIVATE_KEY"));
put("endpoint", "oauth2.googleapis.com:443");
}});
put("gcp:named", new HashMap<String, Object>() {{
put("gcp:named", new HashMap<String, Object>() {{
put("email", getEnv("GCP_EMAIL"));
put("privateKey", getEnv("GCP_PRIVATE_KEY"));
put("endpoint", "oauth2.googleapis.com:443");
Expand Down Expand Up @@ -257,5 +299,29 @@ public void checkServerTrusted(final X509Certificate[] certs, final String authT
throw new RuntimeException(e);
}
}

private Thread startKmsServerSimulatingEof(final SSLContext sslContext, final int kmsPort)
throws Exception {
CompletableFuture<Void> confirmListening = new CompletableFuture<>();
Thread serverThread = new Thread(() -> {
try {
SSLServerSocketFactory serverSocketFactory = sslContext.getServerSocketFactory();
try (SSLServerSocket sslServerSocket = (SSLServerSocket) serverSocketFactory.createServerSocket(kmsPort)) {
sslServerSocket.setNeedClientAuth(false);
confirmListening.complete(null);
try (Socket accept = sslServerSocket.accept()) {
accept.setSoTimeout(10000);
accept.getInputStream().read();
}
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}, "KMIP-EOF-Fake-Server");
serverThread.setDaemon(true);
serverThread.start();
confirmListening.get(TimeUnit.SECONDS.toMillis(10), TimeUnit.MILLISECONDS);
return serverThread;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.mongodb.client.Fixture;
import com.mongodb.client.MongoClient;
import com.mongodb.connection.NettyTransportSettings;
import com.mongodb.fixture.EncryptionFixture;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProvider;
import org.junit.jupiter.api.extension.ConditionEvaluationResult;
Expand All @@ -34,14 +35,6 @@

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.util.stream.Stream;

import static com.mongodb.AuthenticationMechanism.MONGODB_X509;
Expand All @@ -52,7 +45,6 @@
@ExtendWith(AbstractX509AuthenticationTest.X509AuthenticationPropertyCondition.class)
public abstract class AbstractX509AuthenticationTest {

private static final String KEYSTORE_PASSWORD = "test";
protected abstract MongoClient createMongoClient(MongoClientSettings mongoClientSettings);

private static Stream<Arguments> shouldAuthenticateWithClientCertificate() throws Exception {
Expand Down Expand Up @@ -128,22 +120,11 @@ private static Stream<Arguments> getArgumentForKeystore(final String keystoreFil
}

private static SSLContext buildSslContextFromKeyStore(final String keystoreFileName) throws Exception {
KeyManagerFactory keyManagerFactory = getKeyManagerFactory(keystoreFileName);
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(keyManagerFactory.getKeyManagers(), null, null);
return sslContext;
return EncryptionFixture.buildSslContextFromKeyStore(getKeystoreLocation(), keystoreFileName);
}

private static KeyManagerFactory getKeyManagerFactory(final String keystoreFileName)
throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException, UnrecoverableKeyException {
KeyStore ks = KeyStore.getInstance("PKCS12");
try (FileInputStream fis = new FileInputStream(getKeystoreLocation() + File.separator + keystoreFileName)) {
ks.load(fis, KEYSTORE_PASSWORD.toCharArray());
}
KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(
KeyManagerFactory.getDefaultAlgorithm());
keyManagerFactory.init(ks, KEYSTORE_PASSWORD.toCharArray());
return keyManagerFactory;
private static KeyManagerFactory getKeyManagerFactory(final String keystoreFileName) throws Exception {
return EncryptionFixture.getKeyManagerFactory(getKeystoreLocation(), keystoreFileName);
}

private static String getKeystoreLocation() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

package com.mongodb.fixture;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import java.io.File;
import java.io.FileInputStream;
import java.security.KeyStore;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -73,6 +78,23 @@ public static Map<String, Map<String, Object>> getKmsProviders(final KmsProvider
}};
}

public static KeyManagerFactory getKeyManagerFactory(final String keystoreLocation, final String keystoreFileName) throws Exception {
KeyStore ks = KeyStore.getInstance("PKCS12");
try (FileInputStream fis = new FileInputStream(keystoreLocation + File.separator + keystoreFileName)) {
ks.load(fis, "test".toCharArray());
}
KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
keyManagerFactory.init(ks, "test".toCharArray());
return keyManagerFactory;
}

public static SSLContext buildSslContextFromKeyStore(final String keystoreLocation, final String keystoreFileName) throws Exception {
KeyManagerFactory keyManagerFactory = getKeyManagerFactory(keystoreLocation, keystoreFileName);
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(keyManagerFactory.getKeyManagers(), null, null);
return sslContext;
}

public enum KmsProviderType {
LOCAL,
AWS,
Expand Down