Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ pub struct SpeedTestCLIOptions {
pub completion: Option<Shell>,
}

impl Default for SpeedTestCLIOptions {
fn default() -> Self {
Self {
nr_tests: 10,
nr_latency_tests: 25,
max_payload_size: PayloadSize::M25,
output_format: OutputFormat::StdOut,
verbose: false,
ipv4: None,
ipv6: None,
disable_dynamic_max_payload_size: false,
download_only: false,
upload_only: false,
completion: None,
}
}
}

impl SpeedTestCLIOptions {
/// Returns whether download tests should be performed
pub fn should_download(&self) -> bool {
Expand Down
118 changes: 100 additions & 18 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,106 @@ fn main() {
if options.output_format == OutputFormat::StdOut {
println!("Starting Cloudflare speed test");
}
let client;

let client = match build_http_client(&options) {
Ok(client) => client,
Err(e) => {
eprintln!("Error: {}", e);
std::process::exit(1);
}
};

speed_test(client, options);
}

fn build_http_client(options: &SpeedTestCLIOptions) -> Result<reqwest::blocking::Client, String> {
let mut builder =
reqwest::blocking::Client::builder().timeout(std::time::Duration::from_secs(30));

if let Some(ref ip) = options.ipv4 {
client = reqwest::blocking::Client::builder()
.local_address(ip.parse::<IpAddr>().expect("Invalid IPv4 address"))
.timeout(std::time::Duration::from_secs(30))
.build();
let ip_addr = ip
.parse::<IpAddr>()
.map_err(|e| format!("Invalid IPv4 address '{}': {}", ip, e))?;
builder = builder.local_address(ip_addr);
} else if let Some(ref ip) = options.ipv6 {
client = reqwest::blocking::Client::builder()
.local_address(ip.parse::<IpAddr>().expect("Invalid IPv6 address"))
.timeout(std::time::Duration::from_secs(30))
.build();
} else {
client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build();
}
speed_test(
client.expect("Failed to initialize reqwest client"),
options,
);
let ip_addr = ip
.parse::<IpAddr>()
.map_err(|e| format!("Invalid IPv6 address '{}': {}", ip, e))?;
builder = builder.local_address(ip_addr);
}

builder
.build()
.map_err(|e| format!("Failed to initialize HTTP client: {}", e))
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_build_http_client_invalid_ipv4() {
let options = SpeedTestCLIOptions {
ipv4: Some("invalid-ip".to_string()),
ipv6: None,
..Default::default()
};

let result = build_http_client(&options);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.contains("Invalid IPv4 address"));
assert!(err.contains("invalid-ip"));
}

#[test]
fn test_build_http_client_invalid_ipv6() {
let options = SpeedTestCLIOptions {
ipv4: None,
ipv6: Some("invalid-ipv6".to_string()),
..Default::default()
};

let result = build_http_client(&options);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.contains("Invalid IPv6 address"));
assert!(err.contains("invalid-ipv6"));
}

#[test]
fn test_build_http_client_valid_ipv4() {
let options = SpeedTestCLIOptions {
ipv4: Some("127.0.0.1".to_string()),
ipv6: None,
..Default::default()
};

let result = build_http_client(&options);
assert!(result.is_ok());
}

#[test]
fn test_build_http_client_valid_ipv6() {
let options = SpeedTestCLIOptions {
ipv4: None,
ipv6: Some("::1".to_string()),
..Default::default()
};

let result = build_http_client(&options);
assert!(result.is_ok());
}

#[test]
fn test_build_http_client_no_ip() {
let options = SpeedTestCLIOptions {
ipv4: None,
ipv6: None,
..Default::default()
};

let result = build_http_client(&options);
assert!(result.is_ok());
}
}
192 changes: 160 additions & 32 deletions src/speedtest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,27 +182,61 @@ pub fn test_latency(client: &Client) -> f64 {
let req_builder = client.get(url);

let start = Instant::now();
let mut response = req_builder.send().expect("failed to get response");
let mut response = match req_builder.send() {
Ok(res) => res,
Err(e) => {
log::error!("Failed to get response for latency test: {}", e);
return 0.0;
}
};
let _status_code = response.status();
// Drain body to complete the request; ignore errors.
let _ = std::io::copy(&mut response, &mut std::io::sink());
let total_ms = start.elapsed().as_secs_f64() * 1_000.0;

let re = Regex::new(r"cfRequestDuration;dur=([\d.]+)").unwrap();
let server_timing = response
.headers()
.get("Server-Timing")
.expect("No Server-Timing in response header")
.to_str()
.unwrap();
let cf_req_duration: f64 = re
.captures(server_timing)
.unwrap()
.get(1)
.unwrap()
.as_str()
.parse()
.unwrap();
// Try to extract cfRequestDuration from Server-Timing header
let re = match Regex::new(r"cfRequestDuration;dur=([\d.]+)") {
Ok(re) => re,
Err(e) => {
log::error!("Failed to compile regex: {}", e);
return total_ms;
}
};

let server_timing = match response.headers().get("Server-Timing") {
Some(header_value) => match header_value.to_str() {
Ok(s) => s,
Err(e) => {
log::error!("Failed to convert Server-Timing header to string: {}", e);
return total_ms;
}
},
None => {
log::debug!("No Server-Timing header in response");
return total_ms;
}
};

let cf_req_duration: f64 = match re.captures(server_timing) {
Some(captures) => match captures.get(1) {
Some(dur_match) => match dur_match.as_str().parse::<f64>() {
Ok(parsed) => parsed,
Err(e) => {
log::error!("Failed to parse cfRequestDuration: {}", e);
return total_ms;
}
},
None => {
log::debug!("No cfRequestDuration found in Server-Timing header");
return total_ms;
}
},
None => {
log::debug!("Server-Timing header doesn't match expected format");
return total_ms;
}
};

let mut req_latency = total_ms - cf_req_duration;
log::debug!(
"latency debug: total_ms={total_ms:.3} cf_req_duration_ms={cf_req_duration:.3} req_latency_total={req_latency:.3} server_timing={server_timing}"
Expand Down Expand Up @@ -271,14 +305,19 @@ pub fn test_upload(client: &Client, payload_size_bytes: usize, output_format: Ou
let url = &format!("{BASE_URL}/{UPLOAD_URL}");
let payload: Vec<u8> = vec![1; payload_size_bytes];
let req_builder = client.post(url).body(payload);
let (mut response, status_code, mbits, duration) = {
let start = Instant::now();
let response = req_builder.send().expect("failed to get response");
let status_code = response.status();
let duration = start.elapsed();
let mbits = (payload_size_bytes as f64 * 8.0 / 1_000_000.0) / duration.as_secs_f64();
(response, status_code, mbits, duration)

let start = Instant::now();
let mut response = match req_builder.send() {
Ok(res) => res,
Err(e) => {
log::error!("Failed to send upload request: {}", e);
return 0.0;
}
};
let status_code = response.status();
let duration = start.elapsed();
let mbits = (payload_size_bytes as f64 * 8.0 / 1_000_000.0) / duration.as_secs_f64();

// Drain response after timing so we don't skew upload measurement.
let _ = std::io::copy(&mut response, &mut std::io::sink());
if output_format == OutputFormat::StdOut {
Expand All @@ -294,16 +333,21 @@ pub fn test_download(
) -> f64 {
let url = &format!("{BASE_URL}/{DOWNLOAD_URL}{payload_size_bytes}");
let req_builder = client.get(url);
let (status_code, mbits, duration) = {
let start = Instant::now();
let mut response = req_builder.send().expect("failed to get response");
let status_code = response.status();
// Stream the body to avoid buffering the full payload in memory.
let _ = std::io::copy(&mut response, &mut std::io::sink());
let duration = start.elapsed();
let mbits = (payload_size_bytes as f64 * 8.0 / 1_000_000.0) / duration.as_secs_f64();
(status_code, mbits, duration)

let start = Instant::now();
let mut response = match req_builder.send() {
Ok(res) => res,
Err(e) => {
log::error!("Failed to send download request: {}", e);
return 0.0;
}
};
let status_code = response.status();
// Stream the body to avoid buffering the full payload in memory.
let _ = std::io::copy(&mut response, &mut std::io::sink());
let duration = start.elapsed();
let mbits = (payload_size_bytes as f64 * 8.0 / 1_000_000.0) / duration.as_secs_f64();

if output_format == OutputFormat::StdOut {
print_current_speed(mbits, duration, status_code, payload_size_bytes);
}
Expand Down Expand Up @@ -506,4 +550,88 @@ mod tests {
let result = fetch_metadata(&client);
assert!(result.is_err());
}

#[test]
fn test_test_latency_with_mock_client() {
// This test verifies that test_latency handles errors gracefully
// We can't easily mock a failing client in this test setup,
// but we can verify the function doesn't panic
use std::time::Duration;

let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_millis(1))
.build()
.unwrap();

// This should either return a value or timeout gracefully
let result = test_latency(&client);
// The function should return some value (could be 0.0 if it fails)
assert!(result >= 0.0);
}

#[test]
fn test_test_upload_with_mock_client() {
// Test that test_upload handles errors gracefully
use std::time::Duration;

let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_millis(1))
.build()
.unwrap();

// This should either return a value or handle timeout gracefully
let result = test_upload(&client, 1000, OutputFormat::None);
// The function should return some value (could be 0.0 if it fails)
assert!(result >= 0.0);
}

#[test]
fn test_test_download_with_mock_client() {
// Test that test_download handles errors gracefully
use std::time::Duration;

let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_millis(1))
.build()
.unwrap();

// This should either return a value or handle timeout gracefully
let result = test_download(&client, 1000, OutputFormat::None);
// The function should return some value (could be 0.0 if it fails)
assert!(result >= 0.0);
}

#[test]
fn test_server_timing_header_parsing() {
// Test the Server-Timing header parsing logic
// We'll test the regex and parsing separately since we can't easily mock responses
use regex::Regex;

let re = Regex::new(r"cfRequestDuration;dur=([\d.]+)").unwrap();

// Test valid Server-Timing header
let valid_header = "cfRequestDuration;dur=12.34";
let captures = re.captures(valid_header).unwrap();
let dur_match = captures.get(1).unwrap();
let parsed = dur_match.as_str().parse::<f64>().unwrap();
assert_eq!(parsed, 12.34);

// Test header with multiple values
let multi_header = "cfRequestDuration;dur=56.78, other;dur=99.99";
let captures = re.captures(multi_header).unwrap();
let dur_match = captures.get(1).unwrap();
let parsed = dur_match.as_str().parse::<f64>().unwrap();
assert_eq!(parsed, 56.78);

// Test header without cfRequestDuration
let no_cf_header = "other;dur=99.99";
let captures = re.captures(no_cf_header);
assert!(captures.is_none());

// Test malformed duration - use a value that can't be parsed
let malformed_header = "cfRequestDuration;dur=not-a-number";
let captures = re.captures(malformed_header);
// This should not match the regex at all since it contains no digits
assert!(captures.is_none());
}
}
Loading