diff --git a/conanfile.py b/conanfile.py index a6840562..17d303b6 100644 --- a/conanfile.py +++ b/conanfile.py @@ -43,7 +43,7 @@ def configure(self): self.options.rm_safe("fPIC") def requirements(self): - self.requires("sparrow/1.0.0") + self.requires("sparrow/1.2.0", options={"json_reader": True}) self.requires(f"flatbuffers/{self._flatbuffers_version}") self.requires("lz4/1.9.4") self.requires("zstd/1.5.7") diff --git a/include/sparrow_ipc/deserialize_decimal_array.hpp b/include/sparrow_ipc/deserialize_decimal_array.hpp new file mode 100644 index 00000000..9bed6b32 --- /dev/null +++ b/include/sparrow_ipc/deserialize_decimal_array.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include + +#include +#include + +#include "Message_generated.h" +#include "sparrow_ipc/arrow_interface/arrow_array.hpp" +#include "sparrow_ipc/arrow_interface/arrow_schema.hpp" +#include "sparrow_ipc/deserialize_utils.hpp" + +namespace sparrow_ipc +{ + template + [[nodiscard]] sparrow::decimal_array deserialize_non_owning_decimal( + const org::apache::arrow::flatbuf::RecordBatch& record_batch, + std::span body, + std::string_view name, + const std::optional>& metadata, + size_t& buffer_index, + int32_t scale, + int32_t precision + ) + { + constexpr std::size_t sizeof_decimal = sizeof(typename T::integer_type); + std::string format_str = "d:" + std::to_string(precision) + "," + std::to_string(scale); + if constexpr (sizeof_decimal != 16) // We don't need to specify the size for 128-bit + // decimals + { + format_str += "," + std::to_string(sizeof_decimal * 8); + } + + ArrowSchema schema = make_non_owning_arrow_schema( + format_str, + name.data(), + metadata, + std::nullopt, + 0, + nullptr, + nullptr + ); + const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count( + record_batch, + body + ); + + const auto buffer_metadata = record_batch.buffers()->Get(buffer_index++); + if ((body.size() < (buffer_metadata->offset() + buffer_metadata->length()))) + { + throw std::runtime_error("Data buffer exceeds body size"); + } + auto buffer_ptr = const_cast(body.data() + buffer_metadata->offset()); + std::vector buffers = {bitmap_ptr, buffer_ptr}; + ArrowArray array = make_non_owning_arrow_array( + record_batch.length(), + null_count, + 0, + std::move(buffers), + 0, + nullptr, + nullptr + ); + sparrow::arrow_proxy ap{std::move(array), std::move(schema)}; + return sparrow::decimal_array(std::move(ap)); + } +} \ No newline at end of file diff --git a/include/sparrow_ipc/utils.hpp b/include/sparrow_ipc/utils.hpp index 63f1fb89..75f96b70 100644 --- a/include/sparrow_ipc/utils.hpp +++ b/include/sparrow_ipc/utils.hpp @@ -2,7 +2,10 @@ #include #include +#include #include +#include +#include #include @@ -13,6 +16,39 @@ namespace sparrow_ipc::utils // Aligns a value to the next multiple of 8, as required by the Arrow IPC format for message bodies SPARROW_IPC_API size_t align_to_8(const size_t n); + /** + * @brief Extracts words after ':' separated by ',' from a string. + * + * This function finds the position of ':' in the input string and then + * splits the remaining part by ',' to extract individual words. + * + * @param str Input string to parse (e.g., "prefix:word1,word2,word3") + * @return std::vector Vector of string views containing the extracted words + * Returns an empty vector if ':' is not found or if there are no words after it + * + * @example + * extract_words_after_colon("d:128,10") returns {"128", "10"} + * extract_words_after_colon("w:256") returns {"256"} + * extract_words_after_colon("no_colon") returns {} + */ + SPARROW_IPC_API std::vector extract_words_after_colon(std::string_view str); + + /** + * @brief Parse a string_view to int32_t using std::from_chars. + * + * This function converts a string view to a 32-bit integer using std::from_chars + * for efficient parsing. + * + * @param str The string view to parse + * @return std::optional The parsed integer value, or std::nullopt if parsing fails + * + * @example + * parse_to_int32("123") returns std::optional(123) + * parse_to_int32("abc") returns std::nullopt + * parse_to_int32("") returns std::nullopt + */ + SPARROW_IPC_API std::optional parse_to_int32(std::string_view str); + /** * @brief Checks if all record batches in a collection have consistent structure. * diff --git a/src/deserialize.cpp b/src/deserialize.cpp index 92063de1..43a1e1e7 100644 --- a/src/deserialize.cpp +++ b/src/deserialize.cpp @@ -2,6 +2,7 @@ #include +#include "sparrow_ipc/deserialize_decimal_array.hpp" #include "sparrow_ipc/deserialize_fixedsizebinary_array.hpp" #include "sparrow_ipc/deserialize_primitive_array.hpp" #include "sparrow_ipc/deserialize_variable_size_binary_array.hpp" @@ -205,6 +206,69 @@ namespace sparrow_ipc ) ); break; + case org::apache::arrow::flatbuf::Type::Decimal: + { + const auto decimal_field = field->type_as_Decimal(); + const auto scale = decimal_field->scale(); + const auto precision = decimal_field->precision(); + if (decimal_field->bitWidth() == 32) + { + arrays.emplace_back( + deserialize_non_owning_decimal>( + record_batch, + encapsulated_message.body(), + name, + metadata, + buffer_index, + scale, + precision + ) + ); + } + else if (decimal_field->bitWidth() == 64) + { + arrays.emplace_back( + deserialize_non_owning_decimal>( + record_batch, + encapsulated_message.body(), + name, + metadata, + buffer_index, + scale, + precision + ) + ); + } + else if (decimal_field->bitWidth() == 128) + { + arrays.emplace_back( + deserialize_non_owning_decimal>( + record_batch, + encapsulated_message.body(), + name, + metadata, + buffer_index, + scale, + precision + ) + ); + } + else if (decimal_field->bitWidth() == 256) + { + arrays.emplace_back( + deserialize_non_owning_decimal>( + record_batch, + encapsulated_message.body(), + name, + metadata, + buffer_index, + scale, + precision + ) + ); + } + break; + } default: throw std::runtime_error("Unsupported type."); } diff --git a/src/utils.cpp b/src/utils.cpp index 73db1369..f6ce8d71 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -1,36 +1,100 @@ #include "sparrow_ipc/utils.hpp" #include +#include +#include +#include namespace sparrow_ipc::utils { - std::optional parse_format(std::string_view format_str, std::string_view sep) + namespace { - // Find the position of the delimiter - const auto sep_pos = format_str.find(sep); - if (sep_pos == std::string_view::npos) + // Parse the format string + // The format string is expected to be "w:size", "+w:size", "d:precision,scale", etc + std::optional parse_format(std::string_view format_str, std::string_view sep) { - return std::nullopt; - } + // Find the position of the delimiter + const auto sep_pos = format_str.find(sep); + if (sep_pos == std::string_view::npos) + { + return std::nullopt; + } - std::string_view substr_str(format_str.data() + sep_pos + 1, format_str.size() - sep_pos - 1); + std::string_view substr_str(format_str.data() + sep_pos + 1, format_str.size() - sep_pos - 1); - int32_t substr_size = 0; - const auto [ptr, ec] = std::from_chars( - substr_str.data(), - substr_str.data() + substr_str.size(), - substr_size - ); + int32_t substr_size = 0; + const auto [ptr, ec] = std::from_chars( + substr_str.data(), + substr_str.data() + substr_str.size(), + substr_size + ); - if (ec != std::errc() || ptr != substr_str.data() + substr_str.size()) - { - return std::nullopt; + if (ec != std::errc() || ptr != substr_str.data() + substr_str.size()) + { + return std::nullopt; + } + return substr_size; } - return substr_size; } - size_t align_to_8(const size_t n) + namespace utils { - return (n + 7) & -8; + int64_t align_to_8(const int64_t n) + { + return (n + 7) & -8; + } + + std::vector extract_words_after_colon(std::string_view str) + { + std::vector result; + + // Find the position of ':' + const auto colon_pos = str.find(':'); + if (colon_pos == std::string_view::npos) + { + return result; // Return empty vector if ':' not found + } + + // Get the substring after ':' + std::string_view remaining = str.substr(colon_pos + 1); + + // If nothing after ':', return empty vector + if (remaining.empty()) + { + return result; + } + + // Split by ',' + size_t start = 0; + size_t comma_pos = remaining.find(','); + + while (comma_pos != std::string_view::npos) + { + result.push_back(remaining.substr(start, comma_pos - start)); + start = comma_pos + 1; + comma_pos = remaining.find(',', start); + } + + // Add the last word (or the only word if no comma was found) + result.push_back(remaining.substr(start)); + + return result; + } + + std::optional parse_to_int32(std::string_view str) + { + int32_t value = 0; + const auto [ptr, ec] = std::from_chars( + str.data(), + str.data() + str.size(), + value + ); + + if (ec != std::errc() || ptr != str.data() + str.size()) + { + return std::nullopt; + } + return value; + } } } diff --git a/tests/test_de_serialization_with_files.cpp b/tests/test_de_serialization_with_files.cpp index 8cb74e8f..ad0149b4 100644 --- a/tests/test_de_serialization_with_files.cpp +++ b/tests/test_de_serialization_with_files.cpp @@ -43,6 +43,10 @@ const std::vector files_paths_to_test_with_lz4_compressio const std::vector files_paths_to_test_with_zstd_compression = { tests_resources_files_path_with_compression / "generated_zstd", tests_resources_files_path_with_compression/ "generated_uncompressible_zstd", + tests_resources_files_path / "generated_decimal32", + tests_resources_files_path / "generated_decimal64", + tests_resources_files_path / "generated_decimal", + tests_resources_files_path / "generated_decimal256", }; size_t get_number_of_batches(const std::filesystem::path& json_path) diff --git a/tests/test_utils.cpp b/tests/test_utils.cpp index 0619d68a..caf9b091 100644 --- a/tests/test_utils.cpp +++ b/tests/test_utils.cpp @@ -15,4 +15,134 @@ namespace sparrow_ipc CHECK_EQ(utils::align_to_8(15), 16); CHECK_EQ(utils::align_to_8(16), 16); } + + TEST_CASE("extract_words_after_colon") + { + SUBCASE("Basic case with multiple words") + { + auto result = utils::extract_words_after_colon("d:128,10"); + REQUIRE_EQ(result.size(), 2); + CHECK_EQ(result[0], "128"); + CHECK_EQ(result[1], "10"); + } + + SUBCASE("Single word after colon") + { + auto result = utils::extract_words_after_colon("w:256"); + REQUIRE_EQ(result.size(), 1); + CHECK_EQ(result[0], "256"); + } + + SUBCASE("Three words") + { + auto result = utils::extract_words_after_colon("d:10,5,128"); + REQUIRE_EQ(result.size(), 3); + CHECK_EQ(result[0], "10"); + CHECK_EQ(result[1], "5"); + CHECK_EQ(result[2], "128"); + } + + SUBCASE("No colon in string") + { + auto result = utils::extract_words_after_colon("no_colon"); + CHECK_EQ(result.size(), 0); + } + + SUBCASE("Colon at end") + { + auto result = utils::extract_words_after_colon("prefix:"); + CHECK_EQ(result.size(), 0); + } + + SUBCASE("Empty string") + { + auto result = utils::extract_words_after_colon(""); + CHECK_EQ(result.size(), 0); + } + + SUBCASE("Only colon and comma") + { + auto result = utils::extract_words_after_colon(":,"); + REQUIRE_EQ(result.size(), 2); + CHECK_EQ(result[0], ""); + CHECK_EQ(result[1], ""); + } + + SUBCASE("Complex prefix") + { + auto result = utils::extract_words_after_colon("prefix:word1,word2,word3"); + REQUIRE_EQ(result.size(), 3); + CHECK_EQ(result[0], "word1"); + CHECK_EQ(result[1], "word2"); + CHECK_EQ(result[2], "word3"); + } + } + + TEST_CASE("parse_to_int32") + { + SUBCASE("Valid positive integer") + { + auto result = utils::parse_to_int32("123"); + REQUIRE(result.has_value()); + CHECK_EQ(result.value(), 123); + } + + SUBCASE("Valid negative integer") + { + auto result = utils::parse_to_int32("-456"); + REQUIRE(result.has_value()); + CHECK_EQ(result.value(), -456); + } + + SUBCASE("Zero") + { + auto result = utils::parse_to_int32("0"); + REQUIRE(result.has_value()); + CHECK_EQ(result.value(), 0); + } + + SUBCASE("Large valid number") + { + auto result = utils::parse_to_int32("2147483647"); // INT32_MAX + REQUIRE(result.has_value()); + CHECK_EQ(result.value(), 2147483647); + } + + SUBCASE("Invalid - not a number") + { + auto result = utils::parse_to_int32("abc"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - empty string") + { + auto result = utils::parse_to_int32(""); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - partial number with text") + { + auto result = utils::parse_to_int32("123abc"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - text with number") + { + auto result = utils::parse_to_int32("abc123"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - just a sign") + { + auto result = utils::parse_to_int32("-"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Valid with leading zeros") + { + auto result = utils::parse_to_int32("00123"); + REQUIRE(result.has_value()); + CHECK_EQ(result.value(), 123); + } + } }