From 76daad31c5adab5d3b85db9e043ad51868bc6b8f Mon Sep 17 00:00:00 2001 From: Artem Konovalov Date: Fri, 21 Nov 2025 14:42:48 -0300 Subject: [PATCH 1/5] compression --- .github/workflows/sqlx.yml | 24 +- Cargo.lock | 68 +++++ Cargo.toml | 5 +- README.md | 2 + sqlx-mysql/Cargo.toml | 5 + sqlx-mysql/src/connection/compression.rs | 245 ++++++++++++++++++ sqlx-mysql/src/connection/establish.rs | 8 +- sqlx-mysql/src/connection/mod.rs | 1 + sqlx-mysql/src/connection/stream.rs | 51 +++- sqlx-mysql/src/connection/tls.rs | 7 +- sqlx-mysql/src/lib.rs | 2 +- sqlx-mysql/src/options/mod.rs | 124 +++++++++ sqlx-mysql/src/options/parse.rs | 51 +++- sqlx-mysql/src/protocol/compressed_packet.rs | 108 ++++++++ .../protocol/connect/handshake_response.rs | 15 +- sqlx-mysql/src/protocol/mod.rs | 4 + tests/mysql/mysql.rs | 145 ++++++++++- tests/x.py | 8 +- 18 files changed, 827 insertions(+), 46 deletions(-) create mode 100644 sqlx-mysql/src/connection/compression.rs create mode 100644 sqlx-mysql/src/protocol/compressed_packet.rs diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index b2f81b75ad..58d449a128 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -343,7 +343,7 @@ jobs: - uses: Swatinem/rust-cache@v2 - - run: cargo build --features mysql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + - run: cargo build --features mysql,mysql-compression,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} - run: docker compose -f tests/docker-compose.yml run -d -p 3306:3306 --name mysql_${{ matrix.mysql }} mysql_${{ matrix.mysql }} - run: sleep 60 @@ -354,7 +354,7 @@ jobs: - run: > cargo test --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx?ssl-mode=disabled SQLX_OFFLINE_DIR: .sqlx @@ -365,7 +365,7 @@ jobs: cargo test --test mysql-test-attr --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx?ssl-mode=disabled SQLX_OFFLINE_DIR: .sqlx @@ -376,7 +376,7 @@ jobs: run: > cargo test --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE_DIR: .sqlx @@ -390,7 +390,7 @@ jobs: cargo build --no-default-features --test mysql-macros - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: SQLX_OFFLINE: true SQLX_OFFLINE_DIR: .sqlx @@ -402,7 +402,7 @@ jobs: cargo test --no-default-features --test mysql-macros - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE: true @@ -421,7 +421,7 @@ jobs: run: > cargo test --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-ompression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} @@ -444,7 +444,7 @@ jobs: - uses: Swatinem/rust-cache@v2 - - run: cargo build --features mysql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + - run: cargo build --features mysql,mysql-compression,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} - run: docker compose -f tests/docker-compose.yml run -d -p 3306:3306 --name mariadb_${{ matrix.mariadb }} mariadb_${{ matrix.mariadb }} - run: sleep 30 @@ -455,7 +455,7 @@ jobs: - run: > cargo test --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE_DIR: .sqlx @@ -466,7 +466,7 @@ jobs: cargo test --test mysql-test-attr --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE_DIR: .sqlx @@ -491,7 +491,7 @@ jobs: cargo test --no-default-features --test mysql-macros - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE: true @@ -510,7 +510,7 @@ jobs: run: > cargo test --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt RUSTFLAGS: --cfg mariadb="${{ matrix.mariadb }}" diff --git a/Cargo.lock b/Cargo.lock index 78e40f0c12..cc33848e1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1020,6 +1020,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "criterion" version = "0.5.1" @@ -1373,6 +1382,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "flate2" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" +dependencies = [ + "crc32fast", + "libz-sys", + "miniz_oxide", +] + [[package]] name = "float-cmp" version = "0.9.0" @@ -2160,6 +2180,17 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libz-sys" +version = "1.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15d118bbf3771060e7311cc7bb0545b01d08a8b4a7de949198dec1fa0ca1c0f7" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -2266,6 +2297,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", + "simd-adler32", ] [[package]] @@ -3431,6 +3463,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + [[package]] name = "simdutf8" version = "0.1.5" @@ -3903,6 +3941,7 @@ dependencies = [ "digest", "dotenvy", "either", + "flate2", "futures-channel", "futures-core", "futures-io", @@ -3931,6 +3970,7 @@ dependencies = [ "tracing", "uuid", "whoami", + "zstd", ] [[package]] @@ -5290,3 +5330,31 @@ dependencies = [ "quote", "syn 2.0.104", ] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 00d5d656c1..6d5ec3cc4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -161,6 +161,9 @@ uuid = ["sqlx-core/uuid", "sqlx-macros?/uuid", "sqlx-mysql?/uuid", "sqlx-postgre regexp = ["sqlx-sqlite?/regexp"] bstr = ["sqlx-core/bstr"] +# compression +mysql-compression = ["sqlx-mysql/compression"] + [workspace.dependencies] # Core Crates sqlx-core = { version = "=0.9.0-alpha.1", path = "sqlx-core" } @@ -359,7 +362,7 @@ required-features = ["sqlite"] [[test]] name = "mysql" path = "tests/mysql/mysql.rs" -required-features = ["mysql"] +required-features = ["mysql", "compression"] [[test]] name = "mysql-types" diff --git a/README.md b/README.md index f1e53cdced..2700b9aef9 100644 --- a/README.md +++ b/README.md @@ -177,6 +177,8 @@ be removed in the future. - `mysql`: Add support for the MySQL/MariaDB database server. +- `mysql-compression`: Add compression support for MySQL/MariaDB database server. + - `mssql`: Add support for the MSSQL database server. - `sqlite`: Add support for the self-contained [SQLite](https://sqlite.org/) database engine with SQLite bundled and statically-linked. diff --git a/sqlx-mysql/Cargo.toml b/sqlx-mysql/Cargo.toml index ee9512b61e..d9eb8eea64 100644 --- a/sqlx-mysql/Cargo.toml +++ b/sqlx-mysql/Cargo.toml @@ -14,6 +14,7 @@ json = ["sqlx-core/json", "serde"] any = ["sqlx-core/any"] offline = ["sqlx-core/offline", "serde/derive"] migrate = ["sqlx-core/migrate"] +compression = ["zstd", "flate2"] # Type Integration features bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal"] @@ -67,6 +68,10 @@ stringprep = "0.1.2" tracing = { version = "0.1.37", features = ["log"] } whoami = { version = "1.2.1", default-features = false } +# Compression +zstd = { version = "0.13.3", optional = true, default-features = false, features = ["zdict_builder"] } +flate2 = { version = "1.1.5", optional = true, default-features = false, features = ["rust_backend", "zlib"] } + dotenvy.workspace = true thiserror.workspace = true diff --git a/sqlx-mysql/src/connection/compression.rs b/sqlx-mysql/src/connection/compression.rs new file mode 100644 index 0000000000..2fdac04874 --- /dev/null +++ b/sqlx-mysql/src/connection/compression.rs @@ -0,0 +1,245 @@ +use crate::protocol::Capabilities; +#[cfg(feature = "compression")] +use crate::Compression; +use crate::CompressionConfig; +#[cfg(feature = "compression")] +use compressed_stream::CompressedStream; +use sqlx_core::io::{ProtocolDecode, ProtocolEncode}; +use sqlx_core::net::{BufferedSocket, Socket}; +use sqlx_core::Error; + +pub(crate) struct CompressionMySqlStream> { + stream: CompressionStream, + pub(crate) socket: BufferedSocket, +} + +impl CompressionMySqlStream { + pub(crate) fn not_compressed(socket: BufferedSocket) -> Self { + let stream = CompressionStream::NotCompressed(NoCompressionStream {}); + Self { stream, socket } + } + + #[cfg(feature = "compression")] + fn compressed(socket: BufferedSocket, compression: CompressionConfig) -> Self { + let stream = CompressionStream::Compressed(CompressedStream::new(compression)); + Self { stream, socket } + } + + pub(crate) fn create( + socket: BufferedSocket, + #[cfg_attr(not(feature = "compression"), allow(unused_variables))] + capabilities: &Capabilities, + compression: Option, + ) -> Self { + match compression { + #[cfg(feature = "compression")] + Some(c) if c.is_supported(&capabilities) => { + CompressionMySqlStream::compressed(socket, c) + } + _ => CompressionMySqlStream::not_compressed(socket), + } + } + + pub(crate) fn boxed(self) -> CompressionMySqlStream> { + CompressionMySqlStream { + socket: self.socket.boxed(), + stream: self.stream, + } + } + + pub(crate) async fn read_with<'de, T, C>( + &mut self, + byte_len: usize, + context: C, + ) -> Result + where + T: ProtocolDecode<'de, C>, + { + match self.stream { + CompressionStream::NotCompressed(ref mut s) => { + s.read_with(byte_len, context, &mut self.socket).await + } + #[cfg(feature = "compression")] + CompressionStream::Compressed(ref mut s) => { + s.read_with(byte_len, context, &mut self.socket).await + } + } + } + + pub(crate) fn write_with<'en, 'stream, T>( + &mut self, + value: T, + context: (Capabilities, &'stream mut u8), + ) -> Result<(), Error> + where + T: ProtocolEncode<'en, (Capabilities, &'stream mut u8)>, + { + match self.stream { + CompressionStream::NotCompressed(ref mut s) => { + s.write_with(value, context, &mut self.socket) + } + #[cfg(feature = "compression")] + CompressionStream::Compressed(ref mut s) => { + s.write_with(value, context, &mut self.socket) + } + } + } +} + +enum CompressionStream { + NotCompressed(NoCompressionStream), + #[cfg(feature = "compression")] + Compressed(CompressedStream), +} + +struct NoCompressionStream {} +impl NoCompressionStream { + async fn read_with<'de, T, C, S: Socket>( + &mut self, + byte_len: usize, + context: C, + buffered_socket: &mut BufferedSocket, + ) -> Result + where + T: ProtocolDecode<'de, C>, + { + buffered_socket.read_with(byte_len, context).await + } + + fn write_with<'en, 'stream, T, C, S: Socket>( + &mut self, + packet: T, + context: C, + buffered_socket: &mut BufferedSocket, + ) -> Result<(), Error> + where + T: ProtocolEncode<'en, C>, + { + buffered_socket.write_with(packet, context) + } +} + +#[cfg(feature = "compression")] +mod compressed_stream { + use crate::protocol::{CompressedPacket, CompressedPacketContext}; + use crate::CompressionConfig; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + use sqlx_core::io::{ProtocolDecode, ProtocolEncode}; + use sqlx_core::net::{BufferedSocket, Socket}; + use sqlx_core::Error; + use std::cmp::min; + + pub(crate) struct CompressedStream { + compression: CompressionConfig, + sequence_id: u8, + last_read_packet: Option, + } + + impl CompressedStream { + pub(crate) fn new(compression: CompressionConfig) -> Self { + Self { + sequence_id: 0, + last_read_packet: None, + compression, + } + } + + async fn receive_packet( + &mut self, + buffered_socket: &mut BufferedSocket, + ) -> Result { + let mut header: Bytes = buffered_socket.read(7).await?; + #[allow(clippy::cast_possible_truncation)] + let compressed_payload_length = header.get_uint_le(3) as usize; + let sequence_id = header.get_u8(); + let uncompressed_payload_length = header.get_uint_le(3); + + self.sequence_id = sequence_id.wrapping_add(1); + + let packet = if uncompressed_payload_length > 0 { + let compressed_context = CompressedPacketContext { + nested_context: (), + sequence_id: &mut self.sequence_id, + compression: self.compression, + }; + let compressed_payload: CompressedPacket = buffered_socket + .read_with(compressed_payload_length, compressed_context) + .await?; + + compressed_payload.0 + } else { + let uncompressed_payload: Bytes = buffered_socket + .read_with(compressed_payload_length, ()) + .await?; + + uncompressed_payload + }; + + Ok(packet) + } + + pub(crate) async fn read_with<'de, T, C, S: Socket>( + &mut self, + byte_len: usize, + context: C, + buffered_socket: &mut BufferedSocket, + ) -> Result + where + T: ProtocolDecode<'de, C>, + { + let mut result_buffer = BytesMut::with_capacity(byte_len); + while result_buffer.len() != byte_len { + let current_packet = match self.last_read_packet.as_mut() { + None => { + let received_packet = self.receive_packet(buffered_socket).await?; + self.last_read_packet = Some(received_packet); + self.last_read_packet.as_mut().unwrap() + } + Some(p) => p, + }; + + let remaining_bytes_count = byte_len.saturating_sub(result_buffer.len()); + let available_bytes_count = min(current_packet.len(), remaining_bytes_count); + let chunk = current_packet.split_to(available_bytes_count); + result_buffer.put_slice(chunk.chunk()); + + if current_packet.is_empty() { + self.last_read_packet = None + } + } + + T::decode_with(result_buffer.freeze(), context) + } + + pub(crate) fn write_with<'en, T, C, S: Socket>( + &mut self, + packet: T, + context: C, + buffered_socket: &mut BufferedSocket, + ) -> Result<(), Error> + where + T: ProtocolEncode<'en, C>, + { + self.sequence_id = 0; + let compressed_packet = CompressedPacket(packet); + buffered_socket.write_with( + compressed_packet, + CompressedPacketContext { + nested_context: context, + sequence_id: &mut self.sequence_id, + compression: self.compression, + }, + ) + } + } +} + +#[cfg(feature = "compression")] +impl CompressionConfig { + fn is_supported(&self, capabilities: &Capabilities) -> bool { + match self.0 { + Compression::Zlib => capabilities.contains(Capabilities::COMPRESS), + Compression::Zstd => capabilities.contains(Capabilities::ZSTD_COMPRESSION_ALGORITHM), + } + } +} diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index f61654d876..1ca62c4571 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -1,6 +1,3 @@ -use bytes::buf::Buf; -use bytes::Bytes; - use crate::common::StatementCache; use crate::connection::{tls, MySqlConnectionInner, MySqlStream, MAX_PACKET_SIZE}; use crate::error::Error; @@ -10,6 +7,8 @@ use crate::protocol::connect::{ }; use crate::protocol::Capabilities; use crate::{MySqlConnectOptions, MySqlConnection, MySqlSslMode}; +use bytes::buf::Buf; +use bytes::Bytes; impl MySqlConnection { pub(crate) async fn establish(options: &MySqlConnectOptions) -> Result { @@ -112,6 +111,7 @@ impl<'a> DoHandshake<'a> { database: options.database.as_deref(), auth_plugin: plugin, auth_response: auth_response.as_deref(), + compression: options.compression, })?; stream.flush().await?; @@ -121,7 +121,7 @@ impl<'a> DoHandshake<'a> { match packet[0] { 0x00 => { let _ok = packet.ok()?; - + stream = stream.maybe_enable_compression(options); break; } diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index 569ad32722..8d4a69db34 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -16,6 +16,7 @@ use crate::transaction::Transaction; use crate::{MySql, MySqlConnectOptions}; mod auth; +mod compression; mod establish; mod executor; mod stream; diff --git a/sqlx-mysql/src/connection/stream.rs b/sqlx-mysql/src/connection/stream.rs index ff931b2f46..7f72a85cd7 100644 --- a/sqlx-mysql/src/connection/stream.rs +++ b/sqlx-mysql/src/connection/stream.rs @@ -1,19 +1,21 @@ use std::collections::VecDeque; use std::ops::{Deref, DerefMut}; -use bytes::{Buf, Bytes, BytesMut}; - +use crate::connection::compression::CompressionMySqlStream; use crate::error::Error; use crate::io::MySqlBufExt; use crate::io::{ProtocolDecode, ProtocolEncode}; use crate::net::{BufferedSocket, Socket}; +#[cfg(feature = "compression")] +use crate::options::Compression; use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status}; use crate::protocol::{Capabilities, Packet}; use crate::{MySqlConnectOptions, MySqlDatabaseError}; +use bytes::{Buf, Bytes, BytesMut}; pub struct MySqlStream> { // Wrapping the socket in `Box` allows us to unsize in-place. - pub(crate) socket: BufferedSocket, + pub(crate) compression_stream: CompressionMySqlStream, pub(crate) server_version: (u16, u16, u16), pub(super) capabilities: Capabilities, pub(crate) sequence_id: u8, @@ -49,19 +51,27 @@ impl MySqlStream { capabilities |= Capabilities::CONNECT_WITH_DB; } + #[cfg(feature = "compression")] + if let Some(compression) = options.compression { + match compression.0 { + Compression::Zlib => capabilities |= Capabilities::COMPRESS, + Compression::Zstd => capabilities |= Capabilities::ZSTD_COMPRESSION_ALGORITHM, + } + } + Self { waiting: VecDeque::new(), capabilities, server_version: (0, 0, 0), sequence_id: 0, - socket: BufferedSocket::new(socket), + compression_stream: CompressionMySqlStream::not_compressed(BufferedSocket::new(socket)), is_tls: false, } } pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { - if !self.socket.write_buffer().is_empty() { - self.socket.flush().await?; + if !self.write_buffer().is_empty() { + self.flush().await?; } while !self.waiting.is_empty() { @@ -112,7 +122,7 @@ impl MySqlStream { where T: ProtocolEncode<'en, Capabilities>, { - self.socket + self.compression_stream .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id)) } @@ -120,7 +130,7 @@ impl MySqlStream { // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html // https://mariadb.com/kb/en/library/0-packet/#standard-packet - let mut header: Bytes = self.socket.read(4).await?; + let mut header: Bytes = self.compression_stream.read_with(4, ()).await?; // cannot overflow #[allow(clippy::cast_possible_truncation)] @@ -129,9 +139,7 @@ impl MySqlStream { self.sequence_id = sequence_id.wrapping_add(1); - let payload: Bytes = self.socket.read(packet_size).await?; - - // TODO: packet compression + let payload: Bytes = self.compression_stream.read_with(packet_size, ()).await?; Ok(payload) } @@ -207,7 +215,22 @@ impl MySqlStream { pub fn boxed_socket(self) -> MySqlStream { MySqlStream { - socket: self.socket.boxed(), + compression_stream: self.compression_stream.boxed(), + server_version: self.server_version, + capabilities: self.capabilities, + sequence_id: self.sequence_id, + waiting: self.waiting, + is_tls: self.is_tls, + } + } + + pub fn maybe_enable_compression(self, options: &MySqlConnectOptions) -> Self { + MySqlStream { + compression_stream: CompressionMySqlStream::create( + self.compression_stream.socket, + &self.capabilities, + options.compression, + ), server_version: self.server_version, capabilities: self.capabilities, sequence_id: self.sequence_id, @@ -221,12 +244,12 @@ impl Deref for MySqlStream { type Target = BufferedSocket; fn deref(&self) -> &Self::Target { - &self.socket + &self.compression_stream.socket } } impl DerefMut for MySqlStream { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.socket + &mut self.compression_stream.socket } } diff --git a/sqlx-mysql/src/connection/tls.rs b/sqlx-mysql/src/connection/tls.rs index 9034fbd63a..b363b19c32 100644 --- a/sqlx-mysql/src/connection/tls.rs +++ b/sqlx-mysql/src/connection/tls.rs @@ -1,3 +1,4 @@ +use crate::connection::compression::CompressionMySqlStream; use crate::connection::{MySqlStream, Waiting}; use crate::error::Error; use crate::net::tls::TlsConfig; @@ -74,7 +75,7 @@ pub(super) async fn maybe_upgrade( stream.flush().await?; tls::handshake( - stream.socket.into_inner(), + stream.compression_stream.socket.into_inner(), tls_config, MapStream { server_version: stream.server_version, @@ -91,7 +92,9 @@ impl WithSocket for MapStream { async fn with_socket(self, socket: S) -> Self::Output { MySqlStream { - socket: BufferedSocket::new(Box::new(socket)), + compression_stream: CompressionMySqlStream::not_compressed(BufferedSocket::new( + Box::new(socket), + )), server_version: self.server_version, capabilities: self.capabilities, sequence_id: self.sequence_id, diff --git a/sqlx-mysql/src/lib.rs b/sqlx-mysql/src/lib.rs index 7aa14256f3..da4b7ae715 100644 --- a/sqlx-mysql/src/lib.rs +++ b/sqlx-mysql/src/lib.rs @@ -42,7 +42,7 @@ pub use column::MySqlColumn; pub use connection::MySqlConnection; pub use database::MySql; pub use error::MySqlDatabaseError; -pub use options::{MySqlConnectOptions, MySqlSslMode}; +pub use options::{Compression, CompressionConfig, MySqlConnectOptions, MySqlSslMode}; pub use query_result::MySqlQueryResult; pub use row::MySqlRow; pub use statement::MySqlStatement; diff --git a/sqlx-mysql/src/options/mod.rs b/sqlx-mysql/src/options/mod.rs index 421bfb700e..6f1cd61ef1 100644 --- a/sqlx-mysql/src/options/mod.rs +++ b/sqlx-mysql/src/options/mod.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "compression")] +use sqlx_core::Error; use std::path::{Path, PathBuf}; mod connect; @@ -80,6 +82,93 @@ pub struct MySqlConnectOptions { pub(crate) no_engine_substitution: bool, pub(crate) timezone: Option, pub(crate) set_names: bool, + pub(crate) compression: Option, +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct CompressionConfig( + pub(crate) Compression, + #[cfg_attr(not(feature = "compression"), allow(dead_code))] pub(crate) u8, +); + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum Compression { + #[cfg(feature = "compression")] + Zlib, + #[cfg(feature = "compression")] + Zstd, +} + +#[cfg(feature = "compression")] +impl Compression { + /// Selects a default compression level optimized for both encoding speed and output size. + pub fn default(self) -> CompressionConfig { + match self { + Compression::Zlib => CompressionConfig(self, 5), + Compression::Zstd => CompressionConfig(self, 11), + } + } + + /// Optimize for the best speed of encoding. + pub fn fast(self) -> CompressionConfig { + CompressionConfig(self, 1) + } + + /// Optimize for the size of data being encoded. + pub fn best(self) -> CompressionConfig { + match self { + Compression::Zlib => CompressionConfig(self, 9), + Compression::Zstd => CompressionConfig(self, 22), + } + } + + /// Sets the compression level for the current algorithm. + /// + /// Each compression method supports its own valid range of levels: + /// + /// - **Zstd:** `1` to `22` + /// - **Zlib:** `1` to `9` + /// + /// If the provided level is valid for the selected algorithm, a new + /// [`CompressionConfig`] is returned. + /// If the level is out of range, an [`Error::Configuration`] is returned. + /// + /// # Returns + /// + /// - `Ok(CompressionConfig)` if the level is valid + /// - `Err(Error)` if the level is invalid + /// + /// # Examples + /// + /// ```rust + /// # use sqlx_mysql::Compression; + /// + /// let ok = Compression::Zstd.level(5); + /// assert!(ok.is_ok()); + /// + /// let bad = Compression::Zlib.level(42); + /// assert!(bad.is_err()); + /// ``` + pub fn level(self, value: u8) -> Result { + let range = match self { + Compression::Zstd => 1..=22, + Compression::Zlib => 1..=9, + }; + + range + .contains(&value) + .then_some(CompressionConfig(self, value)) + .ok_or_else(|| { + Error::Configuration( + format!( + "Illegal compression level for {self:?}: expected {}..={}, got {value}", + range.start(), + range.end() + ) + .into(), + ) + }) + } } impl Default for MySqlConnectOptions { @@ -111,6 +200,7 @@ impl MySqlConnectOptions { no_engine_substitution: true, timezone: Some(String::from("+00:00")), set_names: true, + compression: None, } } @@ -414,6 +504,24 @@ impl MySqlConnectOptions { self.set_names = flag_val; self } + + /// Sets the compression mode for the connection. + /// + /// Data is uncompressed by default. + /// Ensure that the server supports the selected compression algorithm; + /// if it does not, the client will fall back to uncompressed mode. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_mysql::{MySqlConnectOptions, Compression}; + /// let options = MySqlConnectOptions::new() + /// .compression(Compression::Zlib.fast()); + /// ``` + pub fn compression(mut self, compression: CompressionConfig) -> Self { + self.compression = Some(compression); + self + } } impl MySqlConnectOptions { @@ -526,4 +634,20 @@ impl MySqlConnectOptions { pub fn get_collation(&self) -> Option<&str> { self.collation.as_deref() } + + /// Get compression + /// + /// # Example + /// + /// ```rust + /// #![cfg(feature = "compression")] + /// # use sqlx_mysql::{Compression, CompressionConfig, MySqlConnectOptions}; + /// let options = MySqlConnectOptions::new() + /// .compression(Compression::Zlib.fast()); + /// + /// assert!(options.get_compression().is_some()); + /// ``` + pub fn get_compression(&self) -> Option { + self.compression + } } diff --git a/sqlx-mysql/src/options/parse.rs b/sqlx-mysql/src/options/parse.rs index e31ddc46d4..68ccabfd19 100644 --- a/sqlx-mysql/src/options/parse.rs +++ b/sqlx-mysql/src/options/parse.rs @@ -1,11 +1,11 @@ -use std::str::FromStr; - +use super::MySqlConnectOptions; +use crate::error::Error; +#[cfg(feature = "compression")] +use crate::Compression; +use crate::MySqlSslMode; use percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; use sqlx_core::Url; - -use crate::{error::Error, MySqlSslMode}; - -use super::MySqlConnectOptions; +use std::str::FromStr; impl MySqlConnectOptions { pub(crate) fn parse_from_url(url: &Url) -> Result { @@ -80,6 +80,29 @@ impl MySqlConnectOptions { options = options.timezone(Some(value.to_string())); } + #[cfg(feature = "compression")] + "compression" => { + let (algorithm, level) = value.split_once(":").ok_or_else(|| { + Error::Configuration( + format!( + "Invalid compression parameter. Expected algorithm:level, but got '{}'", + value + ) + .into(), + ) + })?; + let compression = match algorithm { + "zlib" => Ok(Compression::Zlib), + "zstd" => Ok(Compression::Zstd), + _ => Err(Error::Configuration( + format!("Unknown compression algorithm: {}", algorithm).into(), + )), + }?; + let compression_config = + compression.level(level.parse().map_err(Error::config)?)?; + options = options.compression(compression_config); + } + _ => {} } } @@ -197,3 +220,19 @@ fn it_parses_timezone() { .unwrap(); assert_eq!(opts.timezone.as_deref(), Some("+08:00")); } + +#[test] +#[cfg(feature = "compression")] +fn it_parses_compression() { + let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?compression=zstd:10" + .parse() + .unwrap(); + + assert_eq!(opts.compression, Compression::Zstd.level(10).ok()); + + let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?compression=zlib:2" + .parse() + .unwrap(); + + assert_eq!(opts.compression, Compression::Zlib.level(2).ok()); +} diff --git a/sqlx-mysql/src/protocol/compressed_packet.rs b/sqlx-mysql/src/protocol/compressed_packet.rs new file mode 100644 index 0000000000..0dbc3d36cd --- /dev/null +++ b/sqlx-mysql/src/protocol/compressed_packet.rs @@ -0,0 +1,108 @@ +use crate::error::Error; +use crate::io::ProtocolEncode; +use crate::options::Compression; +use crate::CompressionConfig; +use bytes::{BufMut, Bytes}; +use flate2::read::ZlibDecoder; +use flate2::{write::ZlibEncoder, Compression as ZlibCompression}; +use sqlx_core::io::ProtocolDecode; +use std::io::{Cursor, Read, Write}; + +#[derive(Debug)] +pub(crate) struct CompressedPacket(pub(crate) T); + +pub(crate) struct CompressedPacketContext<'cs, C> { + pub(crate) nested_context: C, + pub(crate) sequence_id: &'cs mut u8, + pub(crate) compression: CompressionConfig, +} + +impl<'en, 'compressed_stream, T, C> + ProtocolEncode<'en, CompressedPacketContext<'compressed_stream, C>> for CompressedPacket +where + T: ProtocolEncode<'en, C>, +{ + fn encode_with( + &self, + buf: &mut Vec, + context: CompressedPacketContext<'compressed_stream, C>, + ) -> Result<(), Error> { + let mut uncompressed_payload = Vec::with_capacity(0xFF_FF_FF); + self.0 + .encode_with(&mut uncompressed_payload, context.nested_context)?; + + let mut chunks = uncompressed_payload.chunks(0xFF_FF_FF); + for chunk in chunks.by_ref() { + add_packet(buf, *context.sequence_id, &context.compression, chunk)?; + *context.sequence_id = context.sequence_id.wrapping_add(1); + } + + Ok(()) + } +} + +fn add_packet( + buf: &mut Vec, + sequence_id: u8, + compression: &CompressionConfig, + uncompressed_chunk: &[u8], +) -> Result<(), Error> { + let offset = buf.len(); + buf.extend_from_slice(&[0; 7]); + + let compressed_payload_length = compress(compression, uncompressed_chunk, buf)?; + + let mut header = Vec::with_capacity(7); + header.put_uint_le(compressed_payload_length as u64, 3); + header.put_u8(sequence_id); + header.put_uint_le(uncompressed_chunk.len() as u64, 3); + buf[offset..offset + 7].copy_from_slice(&header); + + Ok(()) +} + +impl<'compressed_stream, C> ProtocolDecode<'_, CompressedPacketContext<'compressed_stream, C>> + for CompressedPacket +{ + fn decode_with( + buf: Bytes, + context: CompressedPacketContext<'compressed_stream, C>, + ) -> Result { + decompress(&context.compression, buf.as_ref()).map(|d| CompressedPacket(Bytes::from(d))) + } +} + +fn compress( + compression: &CompressionConfig, + input: &[u8], + output: &mut Vec, +) -> Result { + let offset = output.len(); + let mut cursor = Cursor::new(output); + cursor.set_position(offset as u64); + + let cursor = match compression { + CompressionConfig(Compression::Zlib, level) => { + let mut encoder = ZlibEncoder::new(cursor, ZlibCompression::new(*level as u32)); + let _ = encoder.write(input)?; + encoder.finish()? + } + CompressionConfig(Compression::Zstd, level) => { + zstd::stream::copy_encode(input, &mut cursor, *level as i32)?; + cursor + } + }; + + Ok(cursor.get_ref().len().saturating_sub(offset)) +} + +fn decompress(compression: &CompressionConfig, bytes: &[u8]) -> Result, Error> { + match compression.0 { + Compression::Zlib => { + let mut out = Vec::with_capacity(bytes.len() * 2); + ZlibDecoder::new(bytes).read_to_end(&mut out)?; + Ok(out) + } + Compression::Zstd => Ok(zstd::stream::decode_all(bytes)?), + } +} diff --git a/sqlx-mysql/src/protocol/connect/handshake_response.rs b/sqlx-mysql/src/protocol/connect/handshake_response.rs index 6911419d98..c5d1bcc3d9 100644 --- a/sqlx-mysql/src/protocol/connect/handshake_response.rs +++ b/sqlx-mysql/src/protocol/connect/handshake_response.rs @@ -1,9 +1,11 @@ use crate::io::MySqlBufMutExt; use crate::io::{BufMutExt, ProtocolEncode}; +#[cfg(feature = "compression")] +use crate::options::Compression; use crate::protocol::auth::AuthPlugin; use crate::protocol::connect::ssl_request::SslRequest; use crate::protocol::Capabilities; - +use crate::CompressionConfig; // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse // https://mariadb.com/kb/en/connection/#client-handshake-response @@ -25,6 +27,10 @@ pub struct HandshakeResponse<'a> { /// Opaque authentication response pub auth_response: Option<&'a [u8]>, + + /// compression algorithm + #[cfg_attr(not(feature = "compression"), allow(dead_code))] + pub compression: Option, } impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> { @@ -77,6 +83,13 @@ impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> { } } + #[cfg(feature = "compression")] + if context.contains(Capabilities::ZSTD_COMPRESSION_ALGORITHM) { + if let Some(CompressionConfig(Compression::Zstd, level)) = self.compression { + buf.push(level) + } + } + Ok(()) } } diff --git a/sqlx-mysql/src/protocol/mod.rs b/sqlx-mysql/src/protocol/mod.rs index d1860f5c65..325ce456f4 100644 --- a/sqlx-mysql/src/protocol/mod.rs +++ b/sqlx-mysql/src/protocol/mod.rs @@ -1,5 +1,7 @@ pub(crate) mod auth; mod capabilities; +#[cfg(feature = "compression")] +mod compressed_packet; pub(crate) mod connect; mod packet; pub(crate) mod response; @@ -8,5 +10,7 @@ pub(crate) mod statement; pub(crate) mod text; pub(crate) use capabilities::Capabilities; +#[cfg(feature = "compression")] +pub(crate) use compressed_packet::{CompressedPacket, CompressedPacketContext}; pub(crate) use packet::Packet; pub(crate) use row::Row; diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index 5d6a5ef233..cc5d2b5eab 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -3,7 +3,7 @@ use futures_util::TryStreamExt; use sqlx::mysql::{MySql, MySqlConnection, MySqlPool, MySqlPoolOptions, MySqlRow}; use sqlx::{Column, Connection, Executor, Row, SqlSafeStr, Statement, TypeInfo}; use sqlx_core::connection::ConnectOptions; -use sqlx_mysql::MySqlConnectOptions; +use sqlx_mysql::{Compression, MySqlConnectOptions}; use sqlx_test::{new, setup_if_needed}; use std::env; use url::Url; @@ -39,6 +39,64 @@ async fn it_connects_without_password() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn it_connects_with_zlib_compression() -> anyhow::Result<()> { + let url = Url::parse(&env::var("DATABASE_URL").context("expected DATABASE_URL")?) + .context("error parsing DATABASE_URL")?; + let mut conn = MySqlConnectOptions::from_url(&url)? + .compression(Compression::Zlib.default()) + .connect() + .await?; + + let rows = sqlx::raw_sql(r#"SHOW SESSION STATUS LIKE 'Compression'"#) + .fetch_all(&mut conn) + .await?; + + let result = rows + .first() + .map(|r| r.try_get::(1).unwrap_or_default()) + .unwrap_or_default(); + + assert!(!rows.is_empty()); + assert_eq!(result, "ON"); + + Ok(()) +} + +#[sqlx_macros::test] +#[cfg(all( + not(any( + mariadb = "verylatest", + mariadb = "10_6", + mariadb = "10_11", + mariadb = "11_4", + mariadb = "11_8", + )), + feature = "mysql" +))] +async fn it_connects_with_zstd_compression() -> anyhow::Result<()> { + let url = Url::parse(&env::var("DATABASE_URL").context("expected DATABASE_URL")?) + .context("error parsing DATABASE_URL")?; + let mut conn = MySqlConnectOptions::from_url(&url)? + .compression(Compression::Zstd.default()) + .connect() + .await?; + + let rows = sqlx::raw_sql(r#"SHOW SESSION STATUS LIKE 'Compression'"#) + .fetch_all(&mut conn) + .await?; + + let result = rows + .first() + .map(|r| r.try_get::(1).unwrap_or_default()) + .unwrap_or_default(); + + assert!(!rows.is_empty()); + assert_eq!(result, "ON"); + + Ok(()) +} + #[sqlx_macros::test] async fn it_maths() -> anyhow::Result<()> { let mut conn = new::().await?; @@ -560,6 +618,91 @@ CREATE TEMPORARY TABLE large_table (data LONGBLOB); Ok(()) } +#[sqlx_macros::test] +#[cfg(all( + not(any( + mariadb = "verylatest", + mariadb = "10_6", + mariadb = "10_11", + mariadb = "11_4", + mariadb = "11_8", + )), + feature = "mysql" +))] +async fn it_can_handle_split_packets_with_zstd_compression() -> anyhow::Result<()> { + let url = Url::parse(&env::var("DATABASE_URL").context("expected DATABASE_URL")?) + .context("error parsing DATABASE_URL")?; + + let options = MySqlConnectOptions::from_url(&url)?.compression(Compression::Zstd.best()); + + // This will only take effect on new connections + options + .connect() + .await? + .execute("SET GLOBAL max_allowed_packet = 4294967297") + .await?; + + let mut conn = MySqlConnectOptions::from_url(&url)? + .compression(Compression::Zstd.best()) + .connect() + .await?; + conn.execute(r#" CREATE TEMPORARY TABLE large_table (data LONGBLOB);"#) + .await?; + + let data = vec![0x41; 0xFF_FF_FF * 2]; + + sqlx::query("INSERT INTO large_table (data) VALUES (?)") + .bind(&data) + .execute(&mut conn) + .await?; + + let ret: Vec = sqlx::query_scalar("SELECT * FROM large_table") + .fetch_one(&mut conn) + .await?; + + assert_eq!(ret, data); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_handle_split_packets_with_zlib_compression() -> anyhow::Result<()> { + let url = Url::parse(&env::var("DATABASE_URL").context("expected DATABASE_URL")?) + .context("error parsing DATABASE_URL")?; + + let options = MySqlConnectOptions::from_url(&url)?.compression(Compression::Zlib.best()); + + // This will only take effect on new connections + options + .connect() + .await? + .execute("SET GLOBAL max_allowed_packet = 4294967297") + .await?; + + let mut conn = MySqlConnectOptions::from_url(&url)? + .compression(Compression::Zstd.best()) + .connect() + .await?; + + conn.execute(r#"CREATE TEMPORARY TABLE large_table (data LONGBLOB);"#) + .await?; + + let data = vec![0x41; 0xFF_FF_FF * 2]; + + sqlx::query("INSERT INTO large_table (data) VALUES (?)") + .bind(&data) + .execute(&mut conn) + .await?; + + let ret: Vec = sqlx::query_scalar("SELECT * FROM large_table") + .fetch_one(&mut conn) + .await?; + + assert_eq!(ret, data); + + Ok(()) +} + #[sqlx_macros::test] async fn test_shrink_buffers() -> anyhow::Result<()> { // We don't really have a good way to test that `.shrink_buffers()` functions as expected diff --git a/tests/x.py b/tests/x.py index e1308f2fa4..7b01ce0f54 100755 --- a/tests/x.py +++ b/tests/x.py @@ -211,7 +211,7 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data # https://github.com/docker-library/mysql/issues/567 if not(version == "5_7" and tls == "rustls"): run( - f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + f"cargo test --no-default-features --features any,mysql,macros,mysql-compression,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mysql {version}", service=f"mysql_{version}", tag=f"mysql_{version}" if runtime == "async-std" else f"mysql_{version}_{runtime}", @@ -220,7 +220,7 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data ## +client-ssl if tls != "none" and not(version == "5_7" and tls == "rustls"): run( - f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + f"cargo test --no-default-features --features any,mysql,macros,mysql-compression,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mysql {version}_client_ssl no-password", database_url_args="sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt", service=f"mysql_{version}_client_ssl", @@ -233,7 +233,7 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data for version in ["verylatest", "10_11", "10_6", "10_5", "10_4"]: run( - f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + f"cargo test --no-default-features --features any,mysql,macros,mysql-compression,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mariadb {version}", service=f"mariadb_{version}", tag=f"mariadb_{version}" if runtime == "async-std" else f"mariadb_{version}_{runtime}", @@ -242,7 +242,7 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data ## +client-ssl if tls != "none": run( - f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + f"cargo test --no-default-features --features any,mysql,macros,mysql-compression,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mariadb {version}_client_ssl no-password", database_url_args="sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt", service=f"mariadb_{version}_client_ssl", From 12a719f8c081f874a598a053bc99e6cdadf512e1 Mon Sep 17 00:00:00 2001 From: Artem Konovalov Date: Fri, 21 Nov 2025 16:06:09 -0300 Subject: [PATCH 2/5] fix typo --- .github/workflows/sqlx.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 58d449a128..69a93897a0 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -421,7 +421,7 @@ jobs: run: > cargo test --no-default-features - --features any,mysql,macros,mysql-ompression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} From 5579f5ffcc762419d16d82c1b8ca7da8173020ba Mon Sep 17 00:00:00 2001 From: Artem Konovalov Date: Sat, 22 Nov 2025 08:20:09 -0300 Subject: [PATCH 3/5] add test --- sqlx-mysql/src/options/parse.rs | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/sqlx-mysql/src/options/parse.rs b/sqlx-mysql/src/options/parse.rs index 68ccabfd19..56cf7aa6be 100644 --- a/sqlx-mysql/src/options/parse.rs +++ b/sqlx-mysql/src/options/parse.rs @@ -1,8 +1,8 @@ use super::MySqlConnectOptions; use crate::error::Error; -#[cfg(feature = "compression")] -use crate::Compression; use crate::MySqlSslMode; +#[cfg(feature = "compression")] +use crate::{Compression, CompressionConfig}; use percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; use sqlx_core::Url; use std::str::FromStr; @@ -166,6 +166,15 @@ impl MySqlConnectOptions { .append_pair("socket", &socket.to_string_lossy()); } + #[cfg(feature = "compression")] + if let Some(compression_config) = &self.compression { + let value = match compression_config { + CompressionConfig(Compression::Zstd, level) => format!("zstd:{}", level), + CompressionConfig(Compression::Zlib, level) => format!("zlib:{}", level), + }; + url.query_pairs_mut().append_pair("compression", &value); + } + url } } @@ -208,6 +217,25 @@ fn it_returns_the_parsed_url() { assert_eq!(expected_url, opts.build_url()); } +#[test] +#[cfg(feature = "compression")] +fn it_returns_the_build_url_with_compression_param() { + let url = "mysql://username:p@ssw0rd@hostname:3306/database"; + let opts = MySqlConnectOptions::from_str(url) + .unwrap() + .compression(Compression::Zstd.fast()); + + let mut expected_url = Url::parse(url).unwrap(); + let mut query_string = String::new(); + // MySqlConnectOptions defaults + query_string += "ssl-mode=PREFERRED&charset=utf8mb4&statement-cache-capacity=100"; + query_string += "&compression=zstd%3A1"; + + expected_url.set_query(Some(&query_string)); + + assert_eq!(expected_url, opts.build_url()); +} + #[test] fn it_parses_timezone() { let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?timezone=%2B08:00" From 6fb8b2c154b1ad136c45e4580a499917960524e9 Mon Sep 17 00:00:00 2001 From: Artem Konovalov Date: Sat, 22 Nov 2025 08:26:05 -0300 Subject: [PATCH 4/5] remove code duplication from tests --- tests/mysql/mysql.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index cc5d2b5eab..0c05195ae5 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -642,10 +642,8 @@ async fn it_can_handle_split_packets_with_zstd_compression() -> anyhow::Result<( .execute("SET GLOBAL max_allowed_packet = 4294967297") .await?; - let mut conn = MySqlConnectOptions::from_url(&url)? - .compression(Compression::Zstd.best()) - .connect() - .await?; + let mut conn = options.await?; + conn.execute(r#" CREATE TEMPORARY TABLE large_table (data LONGBLOB);"#) .await?; @@ -679,10 +677,7 @@ async fn it_can_handle_split_packets_with_zlib_compression() -> anyhow::Result<( .execute("SET GLOBAL max_allowed_packet = 4294967297") .await?; - let mut conn = MySqlConnectOptions::from_url(&url)? - .compression(Compression::Zstd.best()) - .connect() - .await?; + let mut conn = options.await?; conn.execute(r#"CREATE TEMPORARY TABLE large_table (data LONGBLOB);"#) .await?; From da30e8dfced90f57a37b75270198395d421a02c2 Mon Sep 17 00:00:00 2001 From: Artem Konovalov Date: Sat, 22 Nov 2025 10:48:52 -0300 Subject: [PATCH 5/5] simplify logic of CompressionMySqlStream --- sqlx-mysql/src/connection/compression.rs | 39 +++--------------------- 1 file changed, 4 insertions(+), 35 deletions(-) diff --git a/sqlx-mysql/src/connection/compression.rs b/sqlx-mysql/src/connection/compression.rs index 2fdac04874..8e7c03f0be 100644 --- a/sqlx-mysql/src/connection/compression.rs +++ b/sqlx-mysql/src/connection/compression.rs @@ -15,7 +15,7 @@ pub(crate) struct CompressionMySqlStream> { impl CompressionMySqlStream { pub(crate) fn not_compressed(socket: BufferedSocket) -> Self { - let stream = CompressionStream::NotCompressed(NoCompressionStream {}); + let stream = CompressionStream::NotCompressed; Self { stream, socket } } @@ -56,9 +56,7 @@ impl CompressionMySqlStream { T: ProtocolDecode<'de, C>, { match self.stream { - CompressionStream::NotCompressed(ref mut s) => { - s.read_with(byte_len, context, &mut self.socket).await - } + CompressionStream::NotCompressed => self.socket.read_with(byte_len, context).await, #[cfg(feature = "compression")] CompressionStream::Compressed(ref mut s) => { s.read_with(byte_len, context, &mut self.socket).await @@ -75,9 +73,7 @@ impl CompressionMySqlStream { T: ProtocolEncode<'en, (Capabilities, &'stream mut u8)>, { match self.stream { - CompressionStream::NotCompressed(ref mut s) => { - s.write_with(value, context, &mut self.socket) - } + CompressionStream::NotCompressed => self.socket.write_with(value, context), #[cfg(feature = "compression")] CompressionStream::Compressed(ref mut s) => { s.write_with(value, context, &mut self.socket) @@ -87,38 +83,11 @@ impl CompressionMySqlStream { } enum CompressionStream { - NotCompressed(NoCompressionStream), + NotCompressed, #[cfg(feature = "compression")] Compressed(CompressedStream), } -struct NoCompressionStream {} -impl NoCompressionStream { - async fn read_with<'de, T, C, S: Socket>( - &mut self, - byte_len: usize, - context: C, - buffered_socket: &mut BufferedSocket, - ) -> Result - where - T: ProtocolDecode<'de, C>, - { - buffered_socket.read_with(byte_len, context).await - } - - fn write_with<'en, 'stream, T, C, S: Socket>( - &mut self, - packet: T, - context: C, - buffered_socket: &mut BufferedSocket, - ) -> Result<(), Error> - where - T: ProtocolEncode<'en, C>, - { - buffered_socket.write_with(packet, context) - } -} - #[cfg(feature = "compression")] mod compressed_stream { use crate::protocol::{CompressedPacket, CompressedPacketContext};