Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ rmp-serde = { version = "1.3" }
rustls-pemfile = { version = "2.2" }
rustls-pki-types = { version = "1.12" }
serde = { version = "1.0" }
socket2 = { version = "0.6" }
thiserror = { version = "1.0" }
tokio = { version = "1.39", features = ["net", "rt"] }
tokio-rustls = { version = "0.26" }
Expand Down
4 changes: 4 additions & 0 deletions example-messagepack/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ async fn run_main() -> Result<(), Box<dyn std::error::Error>> {
.unwrap_or_else(|_| "0.0.0.0:9000".to_string())
.parse()?,
DemoRpcSocketService,
4 << 20,
1 << 20,
128,
64 << 10,
)
.await?;
server.set_max_queued_outbound_messages(512);
Expand Down
4 changes: 4 additions & 0 deletions example-proto-tls/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ async fn run_main() -> Result<(), Box<dyn std::error::Error>> {
DemoRpcSocketService {
tls_acceptor: Arc::new(server_config).into(),
},
4 << 20,
1 << 20,
128,
64 << 10,
)
.await?;
server.set_max_queued_outbound_messages(512);
Expand Down
4 changes: 4 additions & 0 deletions example-proto/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ async fn run_main() -> Result<(), Box<dyn std::error::Error>> {
.unwrap_or_else(|_| "0.0.0.0:9000".to_string())
.parse()?,
DemoRpcSocketService,
4 << 20,
1 << 20,
128,
64 << 10,
)
.await?;
server.set_max_queued_outbound_messages(512);
Expand Down
13 changes: 9 additions & 4 deletions protosocket-connection/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub struct Connection<Bindings: ConnectionBindings> {
receive_buffer_unread_index: usize,
receive_buffer: Vec<u8>,
max_buffer_length: usize,
buffer_allocation_increment: usize,
deserializer: Bindings::Deserializer,
serializer: Bindings::Serializer,
reactor: Bindings::Reactor,
Expand Down Expand Up @@ -125,6 +126,7 @@ where
deserializer: Bindings::Deserializer,
serializer: Bindings::Serializer,
max_buffer_length: usize,
buffer_allocation_increment: usize,
max_queued_send_messages: usize,
outbound_messages: mpsc::Receiver<<Bindings::Serializer as Serializer>::Message>,
reactor: Bindings::Reactor,
Expand All @@ -140,6 +142,7 @@ where
receive_buffer: Vec::new(),
max_buffer_length,
receive_buffer_unread_index: 0,
buffer_allocation_increment,
deserializer,
serializer,
reactor,
Expand All @@ -148,12 +151,14 @@ where

/// ensure buffer state and read from the inbound stream
fn poll_read_inbound(&mut self, context: &mut Context<'_>) -> ReadBufferState {
const BUFFER_INCREMENT: usize = 1 << 20;
if self.receive_buffer.len() < self.max_buffer_length
&& self.receive_buffer.len() - self.receive_buffer_unread_index < BUFFER_INCREMENT
&& self.receive_buffer.len() - self.receive_buffer_unread_index
< self.buffer_allocation_increment
{
self.receive_buffer
.resize(self.receive_buffer.len() + BUFFER_INCREMENT, 0);
self.receive_buffer.resize(
self.receive_buffer.len() + self.buffer_allocation_increment,
0,
);
}

if 0 < self.receive_buffer.len() - self.receive_buffer_unread_index {
Expand Down
3 changes: 3 additions & 0 deletions protosocket-prost/src/prost_client_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{ProstClientConnectionBindings, ProstSerializer};
#[derive(Debug, Clone)]
pub struct ClientRegistry<TConnector = TcpConnector> {
max_buffer_length: usize,
buffer_allocation_increment: usize,
max_queued_outbound_messages: usize,
runtime: tokio::runtime::Handle,
stream_connector: TConnector,
Expand Down Expand Up @@ -44,6 +45,7 @@ where
Self {
max_buffer_length: 4 * (1 << 20),
max_queued_outbound_messages: 256,
buffer_allocation_increment: 1 << 20,
runtime,
stream_connector: connector,
}
Expand Down Expand Up @@ -91,6 +93,7 @@ where
ProstSerializer::default(),
ProstSerializer::default(),
self.max_buffer_length,
self.buffer_allocation_increment,
self.max_queued_outbound_messages,
outbound_messages,
message_reactor,
Expand Down
1 change: 1 addition & 0 deletions protosocket-rpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ futures = { workspace = true }
k-lock = { workspace = true }
log = { workspace = true }
rustls-pki-types = { workspace = true }
socket2 = { workspace = true, features = ["all"] }
tokio = { workspace = true }
tokio-rustls = { workspace = true }
tokio-util = { workspace = true }
Expand Down
10 changes: 10 additions & 0 deletions protosocket-rpc/src/client/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ impl tokio_rustls::rustls::client::danger::ServerCertVerifier for DoNothingVerif
#[derive(Debug, Clone)]
pub struct Configuration<TStreamConnector> {
max_buffer_length: usize,
buffer_allocation_increment: usize,
max_queued_outbound_messages: usize,
stream_connector: TStreamConnector,
}
Expand All @@ -182,6 +183,7 @@ where
log::trace!("new client configuration");
Self {
max_buffer_length: 4 * (1 << 20), // 4 MiB
buffer_allocation_increment: 1 << 20,
max_queued_outbound_messages: 256,
stream_connector,
}
Expand All @@ -200,6 +202,13 @@ where
pub fn max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
self.max_queued_outbound_messages = max_queued_outbound_messages;
}

/// Amount of buffer to allocate at one time when buffer needs extension.
///
/// Default: 1MiB
pub fn buffer_allocation_increment(&mut self, buffer_allocation_increment: usize) {
self.buffer_allocation_increment = buffer_allocation_increment;
}
}

/// Connect a new protosocket rpc client to a server
Expand Down Expand Up @@ -247,6 +256,7 @@ where
Deserializer::default(),
Serializer::default(),
configuration.max_buffer_length,
configuration.buffer_allocation_increment,
configuration.max_queued_outbound_messages,
outbound_messages,
message_reactor,
Expand Down
33 changes: 30 additions & 3 deletions protosocket-rpc/src/server/socket_server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::ffi::c_int;
use std::future::Future;
use std::io::Error;
use std::pin::Pin;
Expand Down Expand Up @@ -27,6 +28,7 @@ where
socket_server: TSocketService,
listener: tokio::net::TcpListener,
max_buffer_length: usize,
buffer_allocation_increment: usize,
max_queued_outbound_messages: usize,
}

Expand All @@ -38,13 +40,36 @@ where
pub async fn new(
address: std::net::SocketAddr,
socket_server: TSocketService,
max_buffer_length: usize,
buffer_allocation_increment: usize,
max_queued_outbound_messages: usize,
listen_backlog: u32,
) -> crate::Result<Self> {
let listener = tokio::net::TcpListener::bind(address).await?;
let socket = socket2::Socket::new(
match address {
std::net::SocketAddr::V4(_) => socket2::Domain::IPV4,
std::net::SocketAddr::V6(_) => socket2::Domain::IPV6,
},
socket2::Type::STREAM,
None,
)?;

socket.set_nonblocking(true)?;
socket.set_tcp_nodelay(true)?;
socket.set_keepalive(true)?;
socket.set_reuse_port(true)?;
socket.set_reuse_address(true)?;

socket.bind(&address.into())?;
socket.listen(listen_backlog as c_int)?;

let listener = tokio::net::TcpListener::from_std(socket.into())?;
Ok(Self {
socket_server,
listener,
max_buffer_length: 16 * (2 << 20),
max_queued_outbound_messages: 128,
max_buffer_length,
buffer_allocation_increment,
max_queued_outbound_messages,
})
}

Expand Down Expand Up @@ -84,6 +109,7 @@ where
let serializer = self.socket_server.serializer();
let max_buffer_length = self.max_buffer_length;
let max_queued_outbound_messages = self.max_queued_outbound_messages;
let buffer_allocation_increment = self.buffer_allocation_increment;

let stream_future = self.socket_server.accept_stream(stream);

Expand All @@ -97,6 +123,7 @@ where
deserializer,
serializer,
max_buffer_length,
buffer_allocation_increment,
max_queued_outbound_messages,
outbound_messages_receiver,
submitter,
Expand Down
3 changes: 3 additions & 0 deletions protosocket-server/src/connection_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub struct ProtosocketServer<Connector: ServerConnector> {
connector: Connector,
listener: tokio::net::TcpListener,
max_buffer_length: usize,
buffer_allocation_increment: usize,
max_queued_outbound_messages: usize,
runtime: tokio::runtime::Handle,
}
Expand All @@ -78,6 +79,7 @@ impl<Connector: ServerConnector> ProtosocketServer<Connector> {
listener,
max_buffer_length: 16 * (2 << 20),
max_queued_outbound_messages: 128,
buffer_allocation_increment: 1 << 20,
runtime,
})
}
Expand Down Expand Up @@ -114,6 +116,7 @@ impl<Connector: ServerConnector> Future for ProtosocketServer<Connector> {
self.connector.deserializer(),
self.connector.serializer(),
self.max_buffer_length,
self.buffer_allocation_increment,
self.max_queued_outbound_messages,
outbound_messages,
reactor,
Expand Down
Loading