Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .evergreen/run-kms-tls-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,27 @@ echo "Running KMS TLS tests"
cp ${JAVA_HOME}/lib/security/cacerts mongo-truststore
${JAVA_HOME}/bin/keytool -importcert -trustcacerts -file ${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem -keystore mongo-truststore -storepass changeit -storetype JKS -noprompt

# Create keystore from server.pem to emulate KMS server in tests.
openssl pkcs12 -export \
-in ${DRIVERS_TOOLS}/.evergreen/x509gen/server.pem \
-out server.p12 \
-password pass:test

export GRADLE_EXTRA_VARS="-Pssl.enabled=true -Pssl.trustStoreType=jks -Pssl.trustStore=`pwd`/mongo-truststore -Pssl.trustStorePassword=changeit"
export KMS_TLS_ERROR_TYPE=${KMS_TLS_ERROR_TYPE}

./gradlew -version

./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \
-Dorg.mongodb.test.kms.tls.error.type=${KMS_TLS_ERROR_TYPE} \
-Dorg.mongodb.test.kms.keystore.location="$(pwd)" \
driver-sync:cleanTest driver-sync:test --tests ClientSideEncryptionKmsTlsTest
first=$?
echo $first

./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \
-Dorg.mongodb.test.kms.tls.error.type=${KMS_TLS_ERROR_TYPE} \
-Dorg.mongodb.test.kms.keystore.location="$(pwd)" \
driver-reactive-streams:cleanTest driver-reactive-streams:test --tests ClientSideEncryptionKmsTlsTest
second=$?
echo $second
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.mongodb.reactivestreams.client.internal.crypt;

import com.mongodb.MongoException;
import com.mongodb.MongoOperationTimeoutException;
import com.mongodb.MongoSocketException;
import com.mongodb.MongoSocketReadTimeoutException;
Expand Down Expand Up @@ -131,6 +132,11 @@ private void streamRead(final Stream stream, final MongoKeyDecryptor keyDecrypto

@Override
public void completed(final Integer integer, final Void aVoid) {
if (integer == -1) {
sink.error(new MongoException(
"Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider()));
return;
}
buffer.flip();
try {
keyDecryptor.feed(buffer.asNIO());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,14 +369,17 @@ private void decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Ti
while (bytesNeeded > 0) {
byte[] bytes = new byte[bytesNeeded];
int bytesRead = inputStream.read(bytes, 0, bytes.length);
if (bytesRead == -1) {
throw new MongoException("Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider());
}
keyDecryptor.feed(ByteBuffer.wrap(bytes, 0, bytesRead));
bytesNeeded = keyDecryptor.bytesNeeded();
}
}
}

private MongoException wrapInMongoException(final Throwable t) {
if (t instanceof MongoException) {
if (t instanceof MongoClientException) {
return (MongoException) t;
} else {
return new MongoClientException("Exception in encryption library: " + t.getMessage(), t);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,43 @@

import com.mongodb.ClientEncryptionSettings;
import com.mongodb.MongoClientException;
import com.mongodb.MongoException;
import com.mongodb.client.model.vault.DataKeyOptions;
import com.mongodb.client.vault.ClientEncryption;
import com.mongodb.fixture.EncryptionFixture;
import com.mongodb.lang.NonNull;
import com.mongodb.lang.Nullable;
import org.bson.BsonDocument;
import org.junit.jupiter.api.Test;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
import java.net.Socket;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateExpiredException;
import java.security.cert.X509Certificate;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

import static com.mongodb.ClusterFixture.getEnv;
import static com.mongodb.ClusterFixture.hasEncryptionTestsEnabled;
import static com.mongodb.client.Fixture.getMongoClientSettings;
import static java.util.Objects.requireNonNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

public abstract class AbstractClientSideEncryptionKmsTlsTest {

private static final String SYSTEM_PROPERTY_KEY = "org.mongodb.test.kms.tls.error.type";
Expand Down Expand Up @@ -128,7 +138,7 @@ public void testInvalidKmsCertificate() {
* See <a href="https://github.com/mongodb/specifications/tree/master/source/client-side-encryption/tests#11-kms-tls-options-tests">
* 11. KMS TLS Options Tests</a>.
*/
@Test()
@Test
public void testThatCustomSslContextIsUsed() {
assumeTrue(hasEncryptionTestsEnabled());

Expand Down Expand Up @@ -165,34 +175,67 @@ public void testThatCustomSslContextIsUsed() {
}
}

/**
* Not a prose spec test. However, it is additional test case for better coverage.
*/
@Test
public void testUnexpectedEndOfStreamFromKmsProvider() {
int kmsPort = 5555;
ClientEncryptionSettings clientEncryptionSettings = ClientEncryptionSettings.builder()
.keyVaultMongoClientSettings(getMongoClientSettings())
.keyVaultNamespace("keyvault.datakeys")
.kmsProviders(new HashMap<String, Map<String, Object>>() {{
put("kmip", new HashMap<String, Object>() {{
put("endpoint", "localhost:" + kmsPort);
}});
}})
.build();

Thread serverThread = null;
try (ClientEncryption clientEncryption = getClientEncryption(clientEncryptionSettings)) {
serverThread = startKmsServerSimulatingEof(EncryptionFixture.buildSslContextFromKeyStore(
System.getProperty("org.mongodb.test.kms.keystore.location"),
"server.p12"), kmsPort);

MongoException mongoException = assertThrows(MongoClientException.class,
() -> clientEncryption.createDataKey("kmip", new DataKeyOptions()));
assertEquals("Exception in encryption library: Unexpected end of stream from KMS provider kmip",
mongoException.getMessage());
} catch (Throwable e) {
if (serverThread != null) {
serverThread.interrupt();
}
}
}

private HashMap<String, Map<String, Object>> getKmsProviders() {
return new HashMap<String, Map<String, Object>>() {{
put("aws", new HashMap<String, Object>() {{
put("aws", new HashMap<String, Object>() {{
put("accessKeyId", getEnv("AWS_ACCESS_KEY_ID"));
put("secretAccessKey", getEnv("AWS_SECRET_ACCESS_KEY"));
}});
put("aws:named", new HashMap<String, Object>() {{
put("aws:named", new HashMap<String, Object>() {{
put("accessKeyId", getEnv("AWS_ACCESS_KEY_ID"));
put("secretAccessKey", getEnv("AWS_SECRET_ACCESS_KEY"));
}});
put("azure", new HashMap<String, Object>() {{
put("azure", new HashMap<String, Object>() {{
put("tenantId", getEnv("AZURE_TENANT_ID"));
put("clientId", getEnv("AZURE_CLIENT_ID"));
put("clientSecret", getEnv("AZURE_CLIENT_SECRET"));
put("identityPlatformEndpoint", "login.microsoftonline.com:443");
}});
put("azure:named", new HashMap<String, Object>() {{
put("azure:named", new HashMap<String, Object>() {{
put("tenantId", getEnv("AZURE_TENANT_ID"));
put("clientId", getEnv("AZURE_CLIENT_ID"));
put("clientSecret", getEnv("AZURE_CLIENT_SECRET"));
put("identityPlatformEndpoint", "login.microsoftonline.com:443");
}});
put("gcp", new HashMap<String, Object>() {{
put("gcp", new HashMap<String, Object>() {{
put("email", getEnv("GCP_EMAIL"));
put("privateKey", getEnv("GCP_PRIVATE_KEY"));
put("endpoint", "oauth2.googleapis.com:443");
}});
put("gcp:named", new HashMap<String, Object>() {{
put("gcp:named", new HashMap<String, Object>() {{
put("email", getEnv("GCP_EMAIL"));
put("privateKey", getEnv("GCP_PRIVATE_KEY"));
put("endpoint", "oauth2.googleapis.com:443");
Expand Down Expand Up @@ -257,5 +300,30 @@ public void checkServerTrusted(final X509Certificate[] certs, final String authT
throw new RuntimeException(e);
}
}

private Thread startKmsServerSimulatingEof(final SSLContext sslContext, final int kmsPort)
throws Exception {
CompletableFuture<Void> confirmListening = new CompletableFuture<>();
Thread serverThread = new Thread(() -> {
try {
SSLServerSocketFactory serverSocketFactory = sslContext.getServerSocketFactory();
SSLServerSocket sslServerSocket = (SSLServerSocket) serverSocketFactory.createServerSocket(kmsPort);
sslServerSocket.setNeedClientAuth(false);
confirmListening.complete(null);

Socket accept = sslServerSocket.accept();
accept.setSoTimeout(10000);
accept.getInputStream().read();
accept.close();
sslServerSocket.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}, "KMIP-EOF-Fake-Server");
serverThread.setDaemon(true);
serverThread.start();
confirmListening.get(TimeUnit.SECONDS.toMillis(10), TimeUnit.MILLISECONDS);
return serverThread;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.mongodb.client.Fixture;
import com.mongodb.client.MongoClient;
import com.mongodb.connection.NettyTransportSettings;
import com.mongodb.fixture.EncryptionFixture;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProvider;
import org.junit.jupiter.api.extension.ConditionEvaluationResult;
Expand All @@ -34,14 +35,6 @@

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.util.stream.Stream;

import static com.mongodb.AuthenticationMechanism.MONGODB_X509;
Expand All @@ -52,7 +45,6 @@
@ExtendWith(AbstractX509AuthenticationTest.X509AuthenticationPropertyCondition.class)
public abstract class AbstractX509AuthenticationTest {

private static final String KEYSTORE_PASSWORD = "test";
protected abstract MongoClient createMongoClient(MongoClientSettings mongoClientSettings);

private static Stream<Arguments> shouldAuthenticateWithClientCertificate() throws Exception {
Expand Down Expand Up @@ -128,22 +120,11 @@ private static Stream<Arguments> getArgumentForKeystore(final String keystoreFil
}

private static SSLContext buildSslContextFromKeyStore(final String keystoreFileName) throws Exception {
KeyManagerFactory keyManagerFactory = getKeyManagerFactory(keystoreFileName);
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(keyManagerFactory.getKeyManagers(), null, null);
return sslContext;
return EncryptionFixture.buildSslContextFromKeyStore(getKeystoreLocation(), keystoreFileName);
}

private static KeyManagerFactory getKeyManagerFactory(final String keystoreFileName)
throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException, UnrecoverableKeyException {
KeyStore ks = KeyStore.getInstance("PKCS12");
try (FileInputStream fis = new FileInputStream(getKeystoreLocation() + File.separator + keystoreFileName)) {
ks.load(fis, KEYSTORE_PASSWORD.toCharArray());
}
KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(
KeyManagerFactory.getDefaultAlgorithm());
keyManagerFactory.init(ks, KEYSTORE_PASSWORD.toCharArray());
return keyManagerFactory;
private static KeyManagerFactory getKeyManagerFactory(final String keystoreFileName) throws Exception {
return EncryptionFixture.getKeyManagerFactory(getKeystoreLocation(), keystoreFileName);
}

private static String getKeystoreLocation() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

package com.mongodb.fixture;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import java.io.File;
import java.io.FileInputStream;
import java.security.KeyStore;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -73,6 +78,23 @@ public static Map<String, Map<String, Object>> getKmsProviders(final KmsProvider
}};
}

public static KeyManagerFactory getKeyManagerFactory(final String keystoreLocation, final String keystoreFileName) throws Exception {
KeyStore ks = KeyStore.getInstance("PKCS12");
try (FileInputStream fis = new FileInputStream(keystoreLocation + File.separator + keystoreFileName)) {
ks.load(fis, "test".toCharArray());
}
KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
keyManagerFactory.init(ks, "test".toCharArray());
return keyManagerFactory;
}

public static SSLContext buildSslContextFromKeyStore(final String keystoreLocation, final String keystoreFileName) throws Exception {
KeyManagerFactory keyManagerFactory = getKeyManagerFactory(keystoreLocation, keystoreFileName);
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(keyManagerFactory.getKeyManagers(), null, null);
return sslContext;
}

public enum KmsProviderType {
LOCAL,
AWS,
Expand Down