1+ #include < functional>
12#include < stdexcept>
23
34#include < lz4frame.h>
5+ #include < zstd.h>
46
57#include " compression_impl.hpp"
68
@@ -15,7 +17,7 @@ namespace sparrow_ipc
1517 case CompressionType::LZ4_FRAME:
1618 return org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME;
1719 case CompressionType::ZSTD:
18- throw std::invalid_argument ( " Compression using zstd is not supported yet. " ) ;
20+ return org::apache::arrow::flatbuf::CompressionType::ZSTD ;
1921 default :
2022 throw std::invalid_argument (" Unsupported compression type." );
2123 }
@@ -28,7 +30,7 @@ namespace sparrow_ipc
2830 case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME:
2931 return CompressionType::LZ4_FRAME;
3032 case org::apache::arrow::flatbuf::CompressionType::ZSTD:
31- throw std::invalid_argument ( " Compression using zstd is not supported yet. " ) ;
33+ return CompressionType::ZSTD ;
3234 default :
3335 throw std::invalid_argument (" Unsupported compression type." );
3436 }
@@ -37,6 +39,9 @@ namespace sparrow_ipc
3739
3840 namespace
3941 {
42+ using compress_func = std::function<std::vector<uint8_t >(std::span<const uint8_t >)>;
43+ using decompress_func = std::function<std::vector<uint8_t >(std::span<const uint8_t >, int64_t )>;
44+
4045 std::vector<std::uint8_t > lz4_compress (std::span<const std::uint8_t > data)
4146 {
4247 const std::int64_t uncompressed_size = data.size ();
@@ -58,7 +63,7 @@ namespace sparrow_ipc
5863 LZ4F_createDecompressionContext (&dctx, LZ4F_VERSION);
5964 size_t compressed_size_in_out = data.size ();
6065 size_t decompressed_size_in_out = decompressed_size;
61- size_t result = LZ4F_decompress (dctx, decompressed_data.data (), &decompressed_size_in_out, data.data (), &compressed_size_in_out, nullptr );
66+ const size_t result = LZ4F_decompress (dctx, decompressed_data.data (), &decompressed_size_in_out, data.data (), &compressed_size_in_out, nullptr );
6267 if (LZ4F_isError (result) || (decompressed_size_in_out != (size_t )decompressed_size))
6368 {
6469 throw std::runtime_error (" Failed to decompress data with LZ4 frame format" );
@@ -67,6 +72,31 @@ namespace sparrow_ipc
6772 return decompressed_data;
6873 }
6974
75+ std::vector<std::uint8_t > zstd_compress (std::span<const std::uint8_t > data)
76+ {
77+ const std::int64_t uncompressed_size = data.size ();
78+ const size_t max_compressed_size = ZSTD_compressBound (uncompressed_size);
79+ std::vector<std::uint8_t > compressed_data (max_compressed_size);
80+ const size_t compressed_size = ZSTD_compress (compressed_data.data (), max_compressed_size, data.data (), uncompressed_size, 1 );
81+ if (ZSTD_isError (compressed_size))
82+ {
83+ throw std::runtime_error (" Failed to compress data with ZSTD" );
84+ }
85+ compressed_data.resize (compressed_size);
86+ return compressed_data;
87+ }
88+
89+ std::vector<std::uint8_t > zstd_decompress (std::span<const std::uint8_t > data, const std::int64_t decompressed_size)
90+ {
91+ std::vector<std::uint8_t > decompressed_data (decompressed_size);
92+ const size_t result = ZSTD_decompress (decompressed_data.data (), decompressed_size, data.data (), data.size ());
93+ if (ZSTD_isError (result) || (result != (size_t )decompressed_size))
94+ {
95+ throw std::runtime_error (" Failed to decompress data with ZSTD" );
96+ }
97+ return decompressed_data;
98+ }
99+
70100 // TODO These functions could be moved to serialize_utils and deserialize_utils if preferred
71101 // as they are handling the header size
72102 std::vector<std::uint8_t > uncompressed_data_with_header (std::span<const std::uint8_t > data)
@@ -79,10 +109,10 @@ namespace sparrow_ipc
79109 return result;
80110 }
81111
82- std::vector<std::uint8_t > lz4_compress_with_header (std::span<const std::uint8_t > data)
112+ std::vector<std::uint8_t > compress_with_header (std::span<const std::uint8_t > data, compress_func comp_func )
83113 {
84114 const std::int64_t original_size = data.size ();
85- auto compressed_body = lz4_compress (data);
115+ auto compressed_body = comp_func (data);
86116
87117 if (compressed_body.size () >= static_cast <size_t >(original_size))
88118 {
@@ -96,7 +126,7 @@ namespace sparrow_ipc
96126 return result;
97127 }
98128
99- std::variant<std::vector<std::uint8_t >, std::span<const std::uint8_t >> lz4_decompress_with_header (std::span<const std::uint8_t > data)
129+ std::variant<std::vector<std::uint8_t >, std::span<const std::uint8_t >> decompress_with_header (std::span<const std::uint8_t > data, decompress_func decomp_func )
100130 {
101131 if (data.size () < details::CompressionHeaderSize)
102132 {
@@ -110,7 +140,7 @@ namespace sparrow_ipc
110140 return compressed_data;
111141 }
112142
113- return lz4_decompress (compressed_data, decompressed_size);
143+ return decomp_func (compressed_data, decompressed_size);
114144 }
115145
116146 std::span<const uint8_t > get_body_from_uncompressed_data (std::span<const uint8_t > data)
@@ -129,11 +159,11 @@ namespace sparrow_ipc
129159 {
130160 case CompressionType::LZ4_FRAME:
131161 {
132- return lz4_compress_with_header (data);
162+ return compress_with_header (data, lz4_compress );
133163 }
134164 case CompressionType::ZSTD:
135165 {
136- throw std::invalid_argument ( " Compression using zstd is not supported yet. " );
166+ return compress_with_header (data, zstd_compress );
137167 }
138168 default :
139169 return uncompressed_data_with_header (data);
@@ -151,11 +181,11 @@ namespace sparrow_ipc
151181 {
152182 case CompressionType::LZ4_FRAME:
153183 {
154- return lz4_decompress_with_header (data);
184+ return decompress_with_header (data, lz4_decompress );
155185 }
156186 case CompressionType::ZSTD:
157187 {
158- throw std::invalid_argument ( " Decompression using zstd is not supported yet. " );
188+ return decompress_with_header (data, zstd_decompress );
159189 }
160190 default :
161191 {
0 commit comments