Skip to content

Commit 5558c55

Browse files
authored
Handle zstd compression (#63)
* Handle compression with zstd * Refactor * Add tests for ZSTD * Refactor compression params in tests * Use SOURCE_SUBDIR for zstd * Add aliases * Copy zstd dll for tests on windows * Add conditions on zstd targets * Add more conditions * Do the same for examples
1 parent 5ad8ebb commit 5558c55

14 files changed

+295
-119
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ target_link_libraries(sparrow-ipc
257257
flatbuffers::flatbuffers
258258
PRIVATE
259259
lz4::lz4
260+
zstd::libzstd
260261
)
261262

262263
# Ensure generated headers are available when building sparrow-ipc

cmake/external_dependencies.cmake

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,31 @@ if(NOT TARGET lz4::lz4)
123123
add_library(lz4::lz4 ALIAS lz4)
124124
endif()
125125

126+
find_package_or_fetch(
127+
PACKAGE_NAME zstd
128+
GIT_REPOSITORY https://github.com/facebook/zstd.git
129+
TAG v1.5.7
130+
SOURCE_SUBDIR build/cmake
131+
CMAKE_ARGS
132+
"ZSTD_BUILD_PROGRAMS=OFF"
133+
)
134+
135+
if(NOT TARGET zstd::libzstd)
136+
if(SPARROW_IPC_BUILD_SHARED)
137+
if(TARGET zstd::libzstd_shared) # Linux case
138+
add_library(zstd::libzstd ALIAS zstd::libzstd_shared)
139+
elseif(TARGET libzstd_shared) # Windows case
140+
add_library(zstd::libzstd ALIAS libzstd_shared)
141+
endif()
142+
else()
143+
if(TARGET zstd::libzstd_static) # Linux case
144+
add_library(zstd::libzstd ALIAS zstd::libzstd_static)
145+
elseif(TARGET libzstd_static) # Windows case
146+
add_library(zstd::libzstd ALIAS libzstd_static)
147+
endif()
148+
endif()
149+
endif()
150+
126151
if(SPARROW_IPC_BUILD_TESTS)
127152
find_package_or_fetch(
128153
PACKAGE_NAME doctest

conanfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def requirements(self):
4646
self.requires("sparrow/1.0.0")
4747
self.requires(f"flatbuffers/{self._flatbuffers_version}")
4848
self.requires("lz4/1.9.4")
49-
#self.requires("zstd/1.5.5")
49+
self.requires("zstd/1.5.7")
5050
if self.options.get_safe("build_tests"):
5151
self.test_requires("doctest/2.4.12")
5252

environment-dev.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies:
99
# Libraries dependencies
1010
- flatbuffers
1111
- lz4-c
12+
- zstd
1213
- nlohmann_json
1314
- sparrow-devel
1415
- sparrow-json-reader

examples/CMakeLists.txt

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,40 @@ add_dependencies(write_and_read_streams generate_flatbuffers_headers)
3131

3232
# Optional: Copy to build directory for easy execution
3333
if(WIN32)
34+
set(ZSTD_DLL_TARGET "")
35+
if(TARGET libzstd_shared AND SPARROW_IPC_BUILD_SHARED) # Building deps from src case: shared target without namespace
36+
set(ZSTD_DLL_TARGET libzstd_shared)
37+
elseif(TARGET libzstd_static AND NOT SPARROW_IPC_BUILD_SHARED) # Building deps from src case: static target without namespace
38+
set(ZSTD_DLL_TARGET libzstd_static)
39+
endif()
40+
3441
# On Windows, copy required DLLs
35-
add_custom_command(
36-
TARGET write_and_read_streams POST_BUILD
42+
set(DLL_COPY_COMMANDS "") # Initialize a list to hold all copy commands
43+
# Add unconditional copy commands
44+
list(APPEND DLL_COPY_COMMANDS
3745
COMMAND ${CMAKE_COMMAND} -E copy_if_different
3846
"$<TARGET_FILE:sparrow::sparrow>"
3947
"$<TARGET_FILE_DIR:write_and_read_streams>"
4048
COMMAND ${CMAKE_COMMAND} -E copy_if_different
4149
"$<TARGET_FILE:sparrow-ipc>"
4250
"$<TARGET_FILE_DIR:write_and_read_streams>"
43-
COMMENT "Copying sparrow and sparrow-ipc DLLs to example executable directory"
51+
)
52+
53+
# Conditionally add ZSTD copy command
54+
if(ZSTD_DLL_TARGET)
55+
list(APPEND DLL_COPY_COMMANDS
56+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
57+
"$<TARGET_FILE:${ZSTD_DLL_TARGET}>"
58+
"$<TARGET_FILE_DIR:write_and_read_streams>"
59+
)
60+
else()
61+
message(WARNING "ZSTD DLL will not be copied for examples.")
62+
endif()
63+
64+
add_custom_command(
65+
TARGET write_and_read_streams POST_BUILD
66+
${DLL_COPY_COMMANDS}
67+
COMMENT "Copying required DLLs to example executable directory"
4468
)
4569
endif()
4670

src/compression.cpp

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
#include <functional>
12
#include <stdexcept>
23

34
#include <lz4frame.h>
5+
#include <zstd.h>
46

57
#include "compression_impl.hpp"
68

@@ -15,7 +17,7 @@ namespace sparrow_ipc
1517
case CompressionType::LZ4_FRAME:
1618
return org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME;
1719
case CompressionType::ZSTD:
18-
throw std::invalid_argument("Compression using zstd is not supported yet.");
20+
return org::apache::arrow::flatbuf::CompressionType::ZSTD;
1921
default:
2022
throw std::invalid_argument("Unsupported compression type.");
2123
}
@@ -28,7 +30,7 @@ namespace sparrow_ipc
2830
case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME:
2931
return CompressionType::LZ4_FRAME;
3032
case org::apache::arrow::flatbuf::CompressionType::ZSTD:
31-
throw std::invalid_argument("Compression using zstd is not supported yet.");
33+
return CompressionType::ZSTD;
3234
default:
3335
throw std::invalid_argument("Unsupported compression type.");
3436
}
@@ -37,6 +39,9 @@ namespace sparrow_ipc
3739

3840
namespace
3941
{
42+
using compress_func = std::function<std::vector<uint8_t>(std::span<const uint8_t>)>;
43+
using decompress_func = std::function<std::vector<uint8_t>(std::span<const uint8_t>, int64_t)>;
44+
4045
std::vector<std::uint8_t> lz4_compress(std::span<const std::uint8_t> data)
4146
{
4247
const std::int64_t uncompressed_size = data.size();
@@ -58,7 +63,7 @@ namespace sparrow_ipc
5863
LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION);
5964
size_t compressed_size_in_out = data.size();
6065
size_t decompressed_size_in_out = decompressed_size;
61-
size_t result = LZ4F_decompress(dctx, decompressed_data.data(), &decompressed_size_in_out, data.data(), &compressed_size_in_out, nullptr);
66+
const size_t result = LZ4F_decompress(dctx, decompressed_data.data(), &decompressed_size_in_out, data.data(), &compressed_size_in_out, nullptr);
6267
if (LZ4F_isError(result) || (decompressed_size_in_out != (size_t)decompressed_size))
6368
{
6469
throw std::runtime_error("Failed to decompress data with LZ4 frame format");
@@ -67,6 +72,31 @@ namespace sparrow_ipc
6772
return decompressed_data;
6873
}
6974

75+
std::vector<std::uint8_t> zstd_compress(std::span<const std::uint8_t> data)
76+
{
77+
const std::int64_t uncompressed_size = data.size();
78+
const size_t max_compressed_size = ZSTD_compressBound(uncompressed_size);
79+
std::vector<std::uint8_t> compressed_data(max_compressed_size);
80+
const size_t compressed_size = ZSTD_compress(compressed_data.data(), max_compressed_size, data.data(), uncompressed_size, 1);
81+
if (ZSTD_isError(compressed_size))
82+
{
83+
throw std::runtime_error("Failed to compress data with ZSTD");
84+
}
85+
compressed_data.resize(compressed_size);
86+
return compressed_data;
87+
}
88+
89+
std::vector<std::uint8_t> zstd_decompress(std::span<const std::uint8_t> data, const std::int64_t decompressed_size)
90+
{
91+
std::vector<std::uint8_t> decompressed_data(decompressed_size);
92+
const size_t result = ZSTD_decompress(decompressed_data.data(), decompressed_size, data.data(), data.size());
93+
if (ZSTD_isError(result) || (result != (size_t)decompressed_size))
94+
{
95+
throw std::runtime_error("Failed to decompress data with ZSTD");
96+
}
97+
return decompressed_data;
98+
}
99+
70100
// TODO These functions could be moved to serialize_utils and deserialize_utils if preferred
71101
// as they are handling the header size
72102
std::vector<std::uint8_t> uncompressed_data_with_header(std::span<const std::uint8_t> data)
@@ -79,10 +109,10 @@ namespace sparrow_ipc
79109
return result;
80110
}
81111

82-
std::vector<std::uint8_t> lz4_compress_with_header(std::span<const std::uint8_t> data)
112+
std::vector<std::uint8_t> compress_with_header(std::span<const std::uint8_t> data, compress_func comp_func)
83113
{
84114
const std::int64_t original_size = data.size();
85-
auto compressed_body = lz4_compress(data);
115+
auto compressed_body = comp_func(data);
86116

87117
if (compressed_body.size() >= static_cast<size_t>(original_size))
88118
{
@@ -96,7 +126,7 @@ namespace sparrow_ipc
96126
return result;
97127
}
98128

99-
std::variant<std::vector<std::uint8_t>, std::span<const std::uint8_t>> lz4_decompress_with_header(std::span<const std::uint8_t> data)
129+
std::variant<std::vector<std::uint8_t>, std::span<const std::uint8_t>> decompress_with_header(std::span<const std::uint8_t> data, decompress_func decomp_func)
100130
{
101131
if (data.size() < details::CompressionHeaderSize)
102132
{
@@ -110,7 +140,7 @@ namespace sparrow_ipc
110140
return compressed_data;
111141
}
112142

113-
return lz4_decompress(compressed_data, decompressed_size);
143+
return decomp_func(compressed_data, decompressed_size);
114144
}
115145

116146
std::span<const uint8_t> get_body_from_uncompressed_data(std::span<const uint8_t> data)
@@ -129,11 +159,11 @@ namespace sparrow_ipc
129159
{
130160
case CompressionType::LZ4_FRAME:
131161
{
132-
return lz4_compress_with_header(data);
162+
return compress_with_header(data, lz4_compress);
133163
}
134164
case CompressionType::ZSTD:
135165
{
136-
throw std::invalid_argument("Compression using zstd is not supported yet.");
166+
return compress_with_header(data, zstd_compress);
137167
}
138168
default:
139169
return uncompressed_data_with_header(data);
@@ -151,11 +181,11 @@ namespace sparrow_ipc
151181
{
152182
case CompressionType::LZ4_FRAME:
153183
{
154-
return lz4_decompress_with_header(data);
184+
return decompress_with_header(data, lz4_decompress);
155185
}
156186
case CompressionType::ZSTD:
157187
{
158-
throw std::invalid_argument("Decompression using zstd is not supported yet.");
188+
return decompress_with_header(data, zstd_decompress);
159189
}
160190
default:
161191
{

tests/CMakeLists.txt

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,17 @@ target_link_libraries(${test_target}
3030

3131
if(WIN32)
3232
find_package(date) # For copying DLLs
33-
add_custom_command(
34-
TARGET ${test_target} POST_BUILD
33+
34+
set(ZSTD_DLL_TARGET "")
35+
if(TARGET libzstd_shared AND SPARROW_IPC_BUILD_SHARED) # Building deps from src case: shared target without namespace
36+
set(ZSTD_DLL_TARGET libzstd_shared)
37+
elseif(TARGET libzstd_static AND NOT SPARROW_IPC_BUILD_SHARED) # Building deps from src case: static target without namespace
38+
set(ZSTD_DLL_TARGET libzstd_static)
39+
endif()
40+
41+
set(DLL_COPY_COMMANDS "") # Initialize a list to hold all copy commands
42+
# Add unconditional copy commands
43+
list(APPEND DLL_COPY_COMMANDS
3544
COMMAND ${CMAKE_COMMAND} -E copy
3645
"$<TARGET_FILE:sparrow::sparrow>"
3746
"$<TARGET_FILE_DIR:${test_target}>"
@@ -44,7 +53,23 @@ if(WIN32)
4453
COMMAND ${CMAKE_COMMAND} -E copy
4554
"$<TARGET_FILE:date::date-tz>"
4655
"$<TARGET_FILE_DIR:${test_target}>"
47-
COMMENT "Copying sparrow and sparrow-ipc DLLs to executable directory"
56+
)
57+
58+
# Conditionally add ZSTD copy command
59+
if(ZSTD_DLL_TARGET)
60+
list(APPEND DLL_COPY_COMMANDS
61+
COMMAND ${CMAKE_COMMAND} -E copy
62+
"$<TARGET_FILE:${ZSTD_DLL_TARGET}>"
63+
"$<TARGET_FILE_DIR:${test_target}>"
64+
)
65+
else()
66+
message(WARNING "ZSTD DLL will not be copied for tests.")
67+
endif()
68+
69+
add_custom_command(
70+
TARGET ${test_target} POST_BUILD
71+
${DLL_COPY_COMMANDS}
72+
COMMENT "Copying required DLLs to executable directory"
4873
)
4974
endif()
5075

tests/include/sparrow_ipc_tests_helpers.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,37 @@
11
#pragma once
22

3+
#include <optional>
4+
35
#include <doctest/doctest.h>
46

57
#include <sparrow/record_batch.hpp>
68

9+
#include "sparrow_ipc/compression.hpp"
710

811
namespace sparrow_ipc
912
{
1013
namespace sp = sparrow;
1114

15+
struct Lz4Compression { static constexpr CompressionType type = CompressionType::LZ4_FRAME; };
16+
struct ZstdCompression { static constexpr CompressionType type = CompressionType::ZSTD; };
17+
18+
struct CompressionParams
19+
{
20+
std::optional<CompressionType> type;
21+
const char* name;
22+
};
23+
24+
inline constexpr std::array<CompressionParams, 3> compression_params = {{
25+
{ std::nullopt, "Uncompressed" },
26+
{ CompressionType::LZ4_FRAME, "LZ4" },
27+
{ CompressionType::ZSTD, "ZSTD" }
28+
}};
29+
30+
inline constexpr std::array<CompressionParams, 2> compression_only_params = {{
31+
{ CompressionType::LZ4_FRAME, "LZ4" },
32+
{ CompressionType::ZSTD, "ZSTD" }
33+
}};
34+
1235
template <typename T1, typename T2>
1336
void compare_metadata(const T1& arr1, const T2& arr2)
1437
{

tests/test_chunk_memory_serializer.cpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,6 @@ namespace sparrow_ipc
1818
SUBCASE("Valid record batch, with and without compression")
1919
{
2020
auto rb = create_compressible_test_record_batch();
21-
std::vector<std::vector<uint8_t>> chunks_compressed;
22-
chunked_memory_output_stream stream_compressed(chunks_compressed);
23-
24-
chunk_serializer serializer_compressed(stream_compressed, CompressionType::LZ4_FRAME);
25-
serializer_compressed << rb;
26-
27-
// After construction with single record batch, should have schema + record batch
28-
CHECK_EQ(chunks_compressed.size(), 2);
29-
CHECK_GT(chunks_compressed[0].size(), 0); // Schema message
30-
CHECK_GT(chunks_compressed[1].size(), 0); // Record batch message
31-
CHECK_GT(stream_compressed.size(), 0);
32-
3321
std::vector<std::vector<uint8_t>> chunks_uncompressed;
3422
chunked_memory_output_stream stream_uncompressed(chunks_uncompressed);
3523

@@ -41,11 +29,28 @@ namespace sparrow_ipc
4129
CHECK_GT(chunks_uncompressed[1].size(), 0); // Record batch message
4230
CHECK_GT(stream_uncompressed.size(), 0);
4331

44-
// Check that schema size is the same
45-
CHECK_EQ(chunks_compressed[0].size(), chunks_uncompressed[0].size());
46-
47-
// Check that compressed record batch is smaller
48-
CHECK_LT(chunks_compressed[1].size(), chunks_uncompressed[1].size());
32+
for (const auto& p : compression_only_params)
33+
{
34+
SUBCASE(p.name)
35+
{
36+
std::vector<std::vector<uint8_t>> chunks_compressed;
37+
chunked_memory_output_stream stream_compressed(chunks_compressed);
38+
chunk_serializer serializer_compressed(stream_compressed, p.type.value());
39+
serializer_compressed << rb;
40+
41+
// After construction with single record batch, should have schema + record batch
42+
CHECK_EQ(chunks_compressed.size(), 2);
43+
CHECK_GT(chunks_compressed[0].size(), 0); // Schema message
44+
CHECK_GT(chunks_compressed[1].size(), 0); // Record batch message
45+
CHECK_GT(stream_compressed.size(), 0);
46+
47+
// Check that schema size is the same
48+
CHECK_EQ(chunks_compressed[0].size(), chunks_uncompressed[0].size());
49+
50+
// Check that compressed record batch is smaller
51+
CHECK_LT(chunks_compressed[1].size(), chunks_uncompressed[1].size());
52+
}
53+
}
4954
}
5055

5156
SUBCASE("Empty record batch")

0 commit comments

Comments
 (0)