Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion typesense/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ trybuild = "1.0.42"
# native-only dev deps
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
tokio = { workspace = true}
tokio-rustls = "0.26"
rcgen = "0.14"
wiremock = "0.6"

# wasm test deps
Expand All @@ -64,4 +66,4 @@ required-features = ["derive"]

[[test]]
name = "client"
path = "tests/client/mod.rs"
path = "tests/client/mod.rs"
22 changes: 16 additions & 6 deletions typesense/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@
/// - **healthcheck_interval**: 60 seconds.
/// - **retry_policy**: Exponential backoff with a maximum of 3 retries. (disabled on WASM)
/// - **connection_timeout**: 5 seconds. (disabled on WASM)
/// - **additional_root_certificates**: None. (not available on WASM)
#[builder]
pub fn new(
/// The Typesense API key used for authentication.
Expand All @@ -231,10 +232,15 @@
healthcheck_interval: Duration,
#[builder(default = ExponentialBackoff::builder().build_with_max_retries(3))]
/// The retry policy for transient network errors on a *single* node.
retry_policy: ExponentialBackoff,

Check warning on line 235 in typesense/src/client/mod.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, wasm32-unknown-unknown)

unused variable: `retry_policy`

Check warning on line 235 in typesense/src/client/mod.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, wasm32-unknown-unknown)

unused variable: `retry_policy`

Check warning on line 235 in typesense/src/client/mod.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, wasm32-unknown-unknown)

unused variable: `retry_policy`
#[builder(default = Duration::from_secs(5))]
/// The timeout for each individual network request.
connection_timeout: Duration,

Check warning on line 238 in typesense/src/client/mod.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, wasm32-unknown-unknown)

unused variable: `connection_timeout`

Check warning on line 238 in typesense/src/client/mod.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, wasm32-unknown-unknown)

unused variable: `connection_timeout`

Check warning on line 238 in typesense/src/client/mod.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, wasm32-unknown-unknown)

unused variable: `connection_timeout`

#[cfg(not(target_arch = "wasm32"))]
#[builder(default = vec![])]
/// The list of custom headers to add to each request.
additional_root_certificates: Vec<reqwest::Certificate>,
) -> Result<Self, &'static str> {
let is_nearest_node_set = nearest_node.is_some();

Expand All @@ -248,12 +254,16 @@
.expect("Failed to build reqwest client");

#[cfg(not(target_arch = "wasm32"))]
let http_client = ReqwestMiddlewareClientBuilder::new(
reqwest::Client::builder()
.timeout(connection_timeout)
.build()
.expect("Failed to build reqwest client"),
)
let http_client = ReqwestMiddlewareClientBuilder::new({
let builder = reqwest::Client::builder().timeout(connection_timeout);
let builder = additional_root_certificates
.iter()
.fold(builder, |builder, certificate| {
builder.add_root_certificate(certificate.clone())
});

builder.build().expect("Failed to build reqwest client")
})
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();

Expand Down
2 changes: 2 additions & 0 deletions typesense/tests/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ mod operations_test;
mod presets_test;
mod stemming_dictionaries_test;
mod stopwords_test;
#[cfg(not(target_arch = "wasm32"))]
mod tls_certificate_test;

use std::time::Duration;
use typesense::{Client, ExponentialBackoff};
Expand Down
106 changes: 106 additions & 0 deletions typesense/tests/client/tls_certificate_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use std::{
net::{IpAddr, Ipv4Addr},
sync::Arc,
time::Duration,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt as _},
net::TcpListener,
};
use tokio_rustls::{
TlsAcceptor,
rustls::{
self, ServerConfig,
pki_types::{CertificateDer, PrivateKeyDer},
},
};
use typesense::ExponentialBackoff;

#[tokio::test]
async fn test_tls_certificate() {
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.expect("Failed to install crypto provider");

let api_key = "xxx-api-key";

// generate a self-signed key pair and build TLS config out of it
let (cert, key) = generate_self_signed_cert();
let tls_config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert.clone()], key)
.expect("failed to build TLS config");

let localhost = IpAddr::V4(Ipv4Addr::LOCALHOST);
let listener = TcpListener::bind((localhost, 0))
.await
.expect("Failed to bind to address");
let server_addr = listener.local_addr().expect("Failed to get local address");

// spawn a handler which handles one /health request over a TLS connection
let handler = tokio::spawn(mock_node_handler(listener, tls_config, api_key));

let client_cert = reqwest::Certificate::from_der(&cert)
.expect("Failed to convert certificate to Certificate");
let client = typesense::Client::builder()
.nodes(vec![format!("https://localhost:{}", server_addr.port())])
.api_key(api_key)
.additional_root_certificates(vec![client_cert])
.healthcheck_interval(Duration::from_secs(9001)) // we'll do a healthcheck manually
.retry_policy(ExponentialBackoff::builder().build_with_max_retries(0)) // no retries
.connection_timeout(Duration::from_secs(1)) // short
.build()
.expect("Failed to create Typesense client");

// request /health
client
.operations()
.health()
.await
.expect("Failed to get collection health");

handler.await.expect("Failed to join handler");
}

fn generate_self_signed_cert() -> (CertificateDer<'static>, PrivateKeyDer<'static>) {
let pair = rcgen::generate_simple_self_signed(["localhost".into()])
.expect("Failed to generate self-signed certificate");
let cert = pair.cert.der().clone();
let signing_key = pair.signing_key.serialize_der();
let signing_key = PrivateKeyDer::try_from(signing_key)
.expect("Failed to convert signing key to PrivateKeyDer");
(cert, signing_key)
}

async fn mock_node_handler(listener: TcpListener, tls_config: ServerConfig, api_key: &'static str) {
let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config));
let (stream, _addr) = listener
.accept()
.await
.expect("Failed to accept connection");
let mut stream = tls_acceptor
.accept(stream)
.await
.expect("Failed to accept TLS connection");

let mut buf = vec![0u8; 1024];
stream
.read(&mut buf[..])
.await
.expect("Failed to read request");
let request = String::from_utf8(buf).expect("Failed to parse request as UTF-8");
assert!(request.contains("/health"));
assert!(request.contains(api_key));

// mock a /health response
let response = r#"HTTP/1.1 200 OK\r\n\
Content-Type: application/json;\r\n\
Connection: close\r\n

{"ok": true}"#;
stream
.write_all(&response.as_bytes())
.await
.expect("Failed to write to stream");
stream.shutdown().await.expect("Failed to shutdown stream");
}
Loading