diff --git a/typesense/Cargo.toml b/typesense/Cargo.toml index f86b5dd5..5f8fc248 100644 --- a/typesense/Cargo.toml +++ b/typesense/Cargo.toml @@ -48,6 +48,8 @@ trybuild = "1.0.42" # native-only dev deps [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] tokio = { workspace = true} +tokio-rustls = "0.26" +rcgen = "0.14" wiremock = "0.6" # wasm test deps @@ -64,4 +66,4 @@ required-features = ["derive"] [[test]] name = "client" -path = "tests/client/mod.rs" \ No newline at end of file +path = "tests/client/mod.rs" diff --git a/typesense/src/client/mod.rs b/typesense/src/client/mod.rs index f42cfb71..e6fe5576 100644 --- a/typesense/src/client/mod.rs +++ b/typesense/src/client/mod.rs @@ -210,6 +210,7 @@ impl Client { /// - **healthcheck_interval**: 60 seconds. /// - **retry_policy**: Exponential backoff with a maximum of 3 retries. (disabled on WASM) /// - **connection_timeout**: 5 seconds. (disabled on WASM) + /// - **additional_root_certificates**: None. (not available on WASM) #[builder] pub fn new( /// The Typesense API key used for authentication. @@ -235,6 +236,11 @@ impl Client { #[builder(default = Duration::from_secs(5))] /// The timeout for each individual network request. connection_timeout: Duration, + + #[cfg(not(target_arch = "wasm32"))] + #[builder(default = vec![])] + /// The list of custom headers to add to each request. + additional_root_certificates: Vec, ) -> Result { let is_nearest_node_set = nearest_node.is_some(); @@ -248,12 +254,16 @@ impl Client { .expect("Failed to build reqwest client"); #[cfg(not(target_arch = "wasm32"))] - let http_client = ReqwestMiddlewareClientBuilder::new( - reqwest::Client::builder() - .timeout(connection_timeout) - .build() - .expect("Failed to build reqwest client"), - ) + let http_client = ReqwestMiddlewareClientBuilder::new({ + let builder = reqwest::Client::builder().timeout(connection_timeout); + let builder = additional_root_certificates + .iter() + .fold(builder, |builder, certificate| { + builder.add_root_certificate(certificate.clone()) + }); + + builder.build().expect("Failed to build reqwest client") + }) .with(RetryTransientMiddleware::new_with_policy(retry_policy)) .build(); diff --git a/typesense/tests/client/mod.rs b/typesense/tests/client/mod.rs index 41882a93..5f557522 100644 --- a/typesense/tests/client/mod.rs +++ b/typesense/tests/client/mod.rs @@ -10,6 +10,8 @@ mod operations_test; mod presets_test; mod stemming_dictionaries_test; mod stopwords_test; +#[cfg(not(target_arch = "wasm32"))] +mod tls_certificate_test; use std::time::Duration; use typesense::{Client, ExponentialBackoff}; diff --git a/typesense/tests/client/tls_certificate_test.rs b/typesense/tests/client/tls_certificate_test.rs new file mode 100644 index 00000000..797d7d6e --- /dev/null +++ b/typesense/tests/client/tls_certificate_test.rs @@ -0,0 +1,106 @@ +use std::{ + net::{IpAddr, Ipv4Addr}, + sync::Arc, + time::Duration, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt as _}, + net::TcpListener, +}; +use tokio_rustls::{ + TlsAcceptor, + rustls::{ + self, ServerConfig, + pki_types::{CertificateDer, PrivateKeyDer}, + }, +}; +use typesense::ExponentialBackoff; + +#[tokio::test] +async fn test_tls_certificate() { + rustls::crypto::aws_lc_rs::default_provider() + .install_default() + .expect("Failed to install crypto provider"); + + let api_key = "xxx-api-key"; + + // generate a self-signed key pair and build TLS config out of it + let (cert, key) = generate_self_signed_cert(); + let tls_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert.clone()], key) + .expect("failed to build TLS config"); + + let localhost = IpAddr::V4(Ipv4Addr::LOCALHOST); + let listener = TcpListener::bind((localhost, 0)) + .await + .expect("Failed to bind to address"); + let server_addr = listener.local_addr().expect("Failed to get local address"); + + // spawn a handler which handles one /health request over a TLS connection + let handler = tokio::spawn(mock_node_handler(listener, tls_config, api_key)); + + let client_cert = reqwest::Certificate::from_der(&cert) + .expect("Failed to convert certificate to Certificate"); + let client = typesense::Client::builder() + .nodes(vec![format!("https://localhost:{}", server_addr.port())]) + .api_key(api_key) + .additional_root_certificates(vec![client_cert]) + .healthcheck_interval(Duration::from_secs(9001)) // we'll do a healthcheck manually + .retry_policy(ExponentialBackoff::builder().build_with_max_retries(0)) // no retries + .connection_timeout(Duration::from_secs(1)) // short + .build() + .expect("Failed to create Typesense client"); + + // request /health + client + .operations() + .health() + .await + .expect("Failed to get collection health"); + + handler.await.expect("Failed to join handler"); +} + +fn generate_self_signed_cert() -> (CertificateDer<'static>, PrivateKeyDer<'static>) { + let pair = rcgen::generate_simple_self_signed(["localhost".into()]) + .expect("Failed to generate self-signed certificate"); + let cert = pair.cert.der().clone(); + let signing_key = pair.signing_key.serialize_der(); + let signing_key = PrivateKeyDer::try_from(signing_key) + .expect("Failed to convert signing key to PrivateKeyDer"); + (cert, signing_key) +} + +async fn mock_node_handler(listener: TcpListener, tls_config: ServerConfig, api_key: &'static str) { + let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config)); + let (stream, _addr) = listener + .accept() + .await + .expect("Failed to accept connection"); + let mut stream = tls_acceptor + .accept(stream) + .await + .expect("Failed to accept TLS connection"); + + let mut buf = vec![0u8; 1024]; + stream + .read(&mut buf[..]) + .await + .expect("Failed to read request"); + let request = String::from_utf8(buf).expect("Failed to parse request as UTF-8"); + assert!(request.contains("/health")); + assert!(request.contains(api_key)); + + // mock a /health response + let response = r#"HTTP/1.1 200 OK\r\n\ +Content-Type: application/json;\r\n\ +Connection: close\r\n + +{"ok": true}"#; + stream + .write_all(&response.as_bytes()) + .await + .expect("Failed to write to stream"); + stream.shutdown().await.expect("Failed to shutdown stream"); +}