Skip to content

Commit 7838505

Browse files
add tls variant
1 parent 1c1b6e1 commit 7838505

File tree

4 files changed

+74
-56
lines changed

4 files changed

+74
-56
lines changed

Cargo.lock

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/runtime/stream.rs

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ use std::{
66
time::Duration,
77
};
88

9-
use tokio::{io::AsyncWrite, net::TcpStream};
9+
use tokio::{
10+
io::{AsyncRead, AsyncWrite},
11+
net::TcpStream,
12+
};
1013

1114
use crate::{
1215
error::{Error, ErrorKind, Result},
@@ -33,7 +36,7 @@ pub(crate) enum AsyncStream {
3336
Tcp(TcpStream),
3437

3538
/// A TLS connection over TCP.
36-
Tls(TlsStream),
39+
Tls(TlsStream<TcpStream>),
3740

3841
/// A Unix domain socket connection.
3942
#[cfg(unix)]
@@ -42,6 +45,10 @@ pub(crate) enum AsyncStream {
4245
/// A connection to a SOCKS5 proxy.
4346
#[cfg(feature = "socks5-proxy")]
4447
Socks5(fast_socks5::client::Socks5Stream<TcpStream>),
48+
49+
/// A TLS connection to a SOCKS5 proxy.
50+
#[cfg(feature = "socks5-proxy")]
51+
Socks5Tls(TlsStream<fast_socks5::client::Socks5Stream<TcpStream>>),
4552
}
4653

4754
#[derive(Clone, Debug)]
@@ -54,8 +61,10 @@ pub(crate) struct Proxy {
5461
#[cfg(feature = "socks5-proxy")]
5562
impl Proxy {
5663
pub(crate) fn from_client_options(options: &crate::options::ClientOptions) -> Option<Self> {
64+
static DEFAULT_SOCKS5_PROXY_PORT: u16 = 1080;
65+
5766
let host = options.proxy_host.as_ref()?;
58-
let port = options.proxy_port.unwrap_or(1080);
67+
let port = options.proxy_port.unwrap_or(DEFAULT_SOCKS5_PROXY_PORT);
5968
let authentication = match (&options.proxy_username, &options.proxy_password) {
6069
(Some(username), Some(password)) => Some((username.clone(), password.clone())),
6170
// ClientOptions::validate will return an error if the username and password are not
@@ -70,37 +79,28 @@ impl Proxy {
7079

7180
async fn connect(
7281
&self,
73-
target_address: ServerAddress,
74-
connect_timeout: Duration,
75-
) -> Result<AsyncStream> {
82+
host: String,
83+
port: Option<u16>,
84+
) -> Result<fast_socks5::client::Socks5Stream<TcpStream>> {
85+
use crate::options::DEFAULT_PORT;
7686
use fast_socks5::{
7787
client::{Config, Socks5Stream},
7888
SocksError,
7989
};
8090

81-
let mut config = Config::default();
82-
config.set_connect_timeout(connect_timeout.as_secs());
83-
84-
let ServerAddress::Tcp { host, port } = target_address else {
85-
// this condition is checked in ClientOptions::validate
86-
return Err(Error::internal(format!(
87-
"attempted to connect to proxy server with non-TCP address"
88-
)));
89-
};
90-
let port = port.unwrap_or(27107);
91-
91+
let port = port.unwrap_or(DEFAULT_PORT);
9292
let stream = if let Some((username, password)) = self.authentication.as_ref() {
9393
Socks5Stream::connect_with_password(
9494
&self.address,
9595
host,
9696
port,
9797
username.clone(),
9898
password.clone(),
99-
config,
99+
Config::default(),
100100
)
101101
.await
102102
} else {
103-
Socks5Stream::connect(&self.address, host, port, config).await
103+
Socks5Stream::connect(&self.address, host, port, Config::default()).await
104104
}
105105
.map_err(|error| {
106106
if let SocksError::Io(io_error) = error {
@@ -111,7 +111,7 @@ impl Proxy {
111111
}
112112
}
113113
})?;
114-
Ok(AsyncStream::Socks5(stream))
114+
Ok(stream)
115115
}
116116
}
117117
impl AsyncStream {
@@ -121,14 +121,21 @@ impl AsyncStream {
121121
connect_timeout: Duration,
122122
#[cfg(feature = "socks5-proxy")] proxy: Option<&Proxy>,
123123
) -> Result<Self> {
124-
#[cfg(feature = "socks5-proxy")]
125-
if let Some(proxy) = proxy {
126-
return proxy.connect(address, connect_timeout).await;
127-
}
128-
129-
runtime::timeout(connect_timeout, async {
124+
let connect = async {
130125
match &address {
131-
ServerAddress::Tcp { host, .. } => {
126+
#[allow(unused)] // port is unused when socks5-proxy is not enabled
127+
ServerAddress::Tcp { host, port } => {
128+
#[cfg(feature = "socks5-proxy")]
129+
if let Some(proxy) = proxy {
130+
let inner = proxy.connect(host.clone(), port.clone()).await?;
131+
return match tls_cfg {
132+
Some(cfg) => {
133+
Ok(AsyncStream::Socks5Tls(tls_connect(host, inner, cfg).await?))
134+
}
135+
None => Ok(AsyncStream::Socks5(inner)),
136+
};
137+
}
138+
132139
let resolved: Vec<_> = runtime::resolve_address(&address).await?.collect();
133140
if resolved.is_empty() {
134141
return Err(ErrorKind::DnsResolve {
@@ -149,8 +156,9 @@ impl AsyncStream {
149156
tokio::net::UnixStream::connect(path.as_path()).await?,
150157
)),
151158
}
152-
})
153-
.await?
159+
};
160+
161+
runtime::timeout(connect_timeout, connect).await?
154162
}
155163
}
156164

@@ -247,22 +255,22 @@ fn interleave<T>(left: Vec<T>, right: Vec<T>) -> Vec<T> {
247255
out
248256
}
249257

250-
impl tokio::io::AsyncRead for AsyncStream {
258+
impl AsyncRead for AsyncStream {
251259
fn poll_read(
252260
mut self: Pin<&mut Self>,
253261
cx: &mut Context<'_>,
254262
buf: &mut tokio::io::ReadBuf<'_>,
255263
) -> Poll<std::io::Result<()>> {
256264
match self.deref_mut() {
257265
Self::Null => Poll::Ready(Ok(())),
258-
Self::Tcp(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf),
259-
Self::Tls(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf),
266+
Self::Tcp(ref mut inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf),
267+
Self::Tls(ref mut inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf),
260268
#[cfg(unix)]
261-
Self::Unix(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf),
269+
Self::Unix(ref mut inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf),
262270
#[cfg(feature = "socks5-proxy")]
263-
Self::Socks5(ref mut inner) => {
264-
tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf)
265-
}
271+
Self::Socks5(ref mut inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf),
272+
#[cfg(feature = "socks5-proxy")]
273+
Self::Socks5Tls(ref mut inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf),
266274
}
267275
}
268276
}
@@ -281,6 +289,8 @@ impl AsyncWrite for AsyncStream {
281289
Self::Unix(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf),
282290
#[cfg(feature = "socks5-proxy")]
283291
Self::Socks5(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf),
292+
#[cfg(feature = "socks5-proxy")]
293+
Self::Socks5Tls(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf),
284294
}
285295
}
286296

@@ -293,6 +303,8 @@ impl AsyncWrite for AsyncStream {
293303
Self::Unix(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx),
294304
#[cfg(feature = "socks5-proxy")]
295305
Self::Socks5(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx),
306+
#[cfg(feature = "socks5-proxy")]
307+
Self::Socks5Tls(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx),
296308
}
297309
}
298310

@@ -305,6 +317,8 @@ impl AsyncWrite for AsyncStream {
305317
Self::Unix(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
306318
#[cfg(feature = "socks5-proxy")]
307319
Self::Socks5(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
320+
#[cfg(feature = "socks5-proxy")]
321+
Self::Socks5Tls(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
308322
}
309323
}
310324

@@ -321,6 +335,8 @@ impl AsyncWrite for AsyncStream {
321335
Self::Unix(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
322336
#[cfg(feature = "socks5-proxy")]
323337
Self::Socks5(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
338+
#[cfg(feature = "socks5-proxy")]
339+
Self::Socks5Tls(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
324340
}
325341
}
326342

@@ -333,6 +349,8 @@ impl AsyncWrite for AsyncStream {
333349
Self::Unix(ref inner) => inner.is_write_vectored(),
334350
#[cfg(feature = "socks5-proxy")]
335351
Self::Socks5(ref inner) => inner.is_write_vectored(),
352+
#[cfg(feature = "socks5-proxy")]
353+
Self::Socks5Tls(ref inner) => inner.is_write_vectored(),
336354
}
337355
}
338356
}

src/runtime/tls_openssl.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ use openssl::{
44
error::ErrorStack,
55
ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode},
66
};
7-
use tokio::net::TcpStream;
7+
use tokio::io::{AsyncRead, AsyncWrite};
88
use tokio_openssl::SslStream;
99

1010
use crate::{
1111
client::options::TlsOptions,
1212
error::{Error, ErrorKind, Result},
1313
};
1414

15-
pub(super) type TlsStream = SslStream<TcpStream>;
15+
pub(super) type TlsStream<T> = SslStream<T>;
1616

1717
/// Configuration required to use TLS. Creating this is expensive, so its best to cache this value
1818
/// and reuse it for multiple connections.
@@ -40,11 +40,11 @@ impl TlsConfig {
4040
}
4141
}
4242

43-
pub(super) async fn tls_connect(
43+
pub(super) async fn tls_connect<T: AsyncRead + AsyncWrite + Unpin>(
4444
host: &str,
45-
tcp_stream: TcpStream,
45+
tcp_stream: T,
4646
cfg: &TlsConfig,
47-
) -> Result<TlsStream> {
47+
) -> Result<TlsStream<T>> {
4848
let mut stream = make_ssl_stream(host, tcp_stream, cfg).map_err(|err| {
4949
Error::from(ErrorKind::InvalidTlsConfig {
5050
message: err.to_string(),
@@ -120,11 +120,11 @@ fn make_openssl_connector(cfg: TlsOptions) -> Result<SslConnector> {
120120
Ok(builder.build())
121121
}
122122

123-
fn make_ssl_stream(
123+
fn make_ssl_stream<T: AsyncRead + AsyncWrite>(
124124
host: &str,
125-
tcp_stream: TcpStream,
125+
tcp_stream: T,
126126
cfg: &TlsConfig,
127-
) -> std::result::Result<SslStream<TcpStream>, ErrorStack> {
127+
) -> std::result::Result<SslStream<T>, ErrorStack> {
128128
let ssl = cfg
129129
.connector
130130
.configure()?

src/runtime/tls_rustls.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use rustls::{
1212
Error as TlsError,
1313
RootCertStore,
1414
};
15-
use tokio::net::TcpStream;
15+
use tokio::io::{AsyncRead, AsyncWrite};
1616
use tokio_rustls::TlsConnector;
1717
use webpki_roots::TLS_SERVER_ROOTS;
1818

@@ -21,7 +21,7 @@ use crate::{
2121
error::{ErrorKind, Result},
2222
};
2323

24-
pub(super) type TlsStream = tokio_rustls::client::TlsStream<TcpStream>;
24+
pub(super) type TlsStream<T> = tokio_rustls::client::TlsStream<T>;
2525

2626
/// Configuration required to use TLS. Creating this is expensive, so its best to cache this value
2727
/// and reuse it for multiple connections.
@@ -42,11 +42,11 @@ impl TlsConfig {
4242
}
4343
}
4444

45-
pub(super) async fn tls_connect(
45+
pub(super) async fn tls_connect<T: AsyncRead + AsyncWrite + Unpin>(
4646
host: &str,
47-
tcp_stream: TcpStream,
47+
tcp_stream: T,
4848
cfg: &TlsConfig,
49-
) -> Result<TlsStream> {
49+
) -> Result<TlsStream<T>> {
5050
let name = ServerName::try_from(host)
5151
.map_err(|e| ErrorKind::DnsResolve {
5252
message: format!("could not resolve {host:?}: {e}"),

0 commit comments

Comments
 (0)