diff --git a/Framework/Core/include/Framework/ASoA.h b/Framework/Core/include/Framework/ASoA.h index 10c1fc4ac3ceb..13560bd22c054 100644 --- a/Framework/Core/include/Framework/ASoA.h +++ b/Framework/Core/include/Framework/ASoA.h @@ -34,7 +34,6 @@ #include #include #include // IWYU pragma: export -#include namespace o2::framework { @@ -53,6 +52,12 @@ void dereferenceWithWrongType(const char* getter, const char* target); void missingFilterDeclaration(int hash, int ai); void notBoundTable(const char* tableName); void* extractCCDBPayload(char* payload, size_t size, TClass const* cl, const char* what); + +template +auto createFieldsFromColumns(framework::pack) +{ + return std::vector>{C::asArrowField()...}; +} } // namespace o2::soa namespace o2::soa @@ -248,6 +253,11 @@ struct TableMetadata { return -1; } } + + static std::shared_ptr getSchema() + { + return std::make_shared([](framework::pack&& p) { return o2::soa::createFieldsFromColumns(p); }(persistent_columns_t{})); + } }; template @@ -406,12 +416,6 @@ struct Binding { } }; -template -auto createFieldsFromColumns(framework::pack) -{ - return std::vector>{C::asArrowField()...}; -} - using SelectionVector = std::vector; template @@ -686,7 +690,7 @@ struct Column { static auto asArrowField() { - return std::make_shared(inherited_t::mLabel, framework::expressions::concreteArrowType(framework::expressions::selectArrowType())); + return std::make_shared(inherited_t::mLabel, soa::asArrowDataType()); } /// FIXME: rather than keeping this public we should have a protected diff --git a/Framework/Core/include/Framework/AnalysisHelpers.h b/Framework/Core/include/Framework/AnalysisHelpers.h index fa82151c6e756..660149b2154e1 100644 --- a/Framework/Core/include/Framework/AnalysisHelpers.h +++ b/Framework/Core/include/Framework/AnalysisHelpers.h @@ -29,7 +29,7 @@ namespace o2::framework { std::string serializeProjectors(std::vector& projectors); -std::string serializeSchema(std::shared_ptr& schema); +std::string serializeSchema(std::shared_ptr schema); } // namespace o2::framework namespace o2::soa @@ -44,6 +44,16 @@ constexpr auto tableRef2ConfigParamSpec() {"\"\""}}; } +template +constexpr auto tableRef2Schema() +{ + return o2::framework::ConfigParamSpec{ + std::string{"input-schema:"} + o2::aod::label(), + framework::VariantType::String, + framework::serializeSchema(o2::aod::MetadataTrait>::metadata::getSchema()), + {"\"\""}}; +} + namespace { template @@ -56,6 +66,16 @@ inline constexpr auto getSources() }.template operator()(); } +template +inline constexpr auto getSourceSchemas() +{ + return [] refs>() { + return [](std::index_sequence) { + return std::vector{soa::tableRef2Schema()...}; + }(std::make_index_sequence()); + }.template operator()(); +} + template inline constexpr auto getCCDBUrls() { @@ -73,11 +93,19 @@ template constexpr auto getInputMetadata() -> std::vector { std::vector inputMetadata; + auto inputSources = getSources(); std::sort(inputSources.begin(), inputSources.end(), [](framework::ConfigParamSpec const& a, framework::ConfigParamSpec const& b) { return a.name < b.name; }); auto last = std::unique(inputSources.begin(), inputSources.end(), [](framework::ConfigParamSpec const& a, framework::ConfigParamSpec const& b) { return a.name == b.name; }); inputSources.erase(last, inputSources.end()); inputMetadata.insert(inputMetadata.end(), inputSources.begin(), inputSources.end()); + + auto inputSchemas = getSourceSchemas(); + std::sort(inputSchemas.begin(), inputSchemas.end(), [](framework::ConfigParamSpec const& a, framework::ConfigParamSpec const& b) { return a.name < b.name; }); + last = std::unique(inputSchemas.begin(), inputSchemas.end(), [](framework::ConfigParamSpec const& a, framework::ConfigParamSpec const& b) { return a.name == b.name; }); + inputSchemas.erase(last, inputSchemas.end()); + inputMetadata.insert(inputMetadata.end(), inputSchemas.begin(), inputSchemas.end()); + return inputMetadata; } @@ -115,11 +143,8 @@ constexpr auto getExpressionMetadata() -> std::vector(o2::soa::createFieldsFromColumns(expression_pack_t{})); - auto json = framework::serializeProjectors(projectors); - return {framework::ConfigParamSpec{"projectors", framework::VariantType::String, json, {"\"\""}}, - framework::ConfigParamSpec{"schema", framework::VariantType::String, framework::serializeSchema(schema), {"\"\""}}}; + return {framework::ConfigParamSpec{"projectors", framework::VariantType::String, json, {"\"\""}}}; } template @@ -141,6 +166,9 @@ constexpr auto tableRef2InputSpec() metadata.insert(metadata.end(), ccdbMetadata.begin(), ccdbMetadata.end()); auto p = getExpressionMetadata>::metadata>(); metadata.insert(metadata.end(), p.begin(), p.end()); + if constexpr (!soa::with_ccdb_urls>::metadata>) { + metadata.emplace_back(framework::ConfigParamSpec{"schema", framework::VariantType::String, framework::serializeSchema(o2::aod::MetadataTrait>::metadata::getSchema()), {"\"\""}}); + } return framework::InputSpec{ o2::aod::label(), diff --git a/Framework/Core/include/Framework/ArrowTypes.h b/Framework/Core/include/Framework/ArrowTypes.h index 6fd70113fede7..2673472a81152 100644 --- a/Framework/Core/include/Framework/ArrowTypes.h +++ b/Framework/Core/include/Framework/ArrowTypes.h @@ -11,6 +11,7 @@ #ifndef O2_FRAMEWORK_ARROWTYPES_H #define O2_FRAMEWORK_ARROWTYPES_H +#include "Framework/Traits.h" #include "arrow/type_fwd.h" #include @@ -117,5 +118,54 @@ template using arrow_array_for_t = typename arrow_array_for::type; template using value_for_t = typename arrow_array_for::value_type; + +template +using array_element_t = std::decay_t()[0])>; + +template +std::shared_ptr asArrowDataType(int list_size = 1) +{ + auto typeGenerator = [](std::shared_ptr const& type, int list_size) -> std::shared_ptr { + switch (list_size) { + case -1: + return arrow::list(type); + case 1: + return std::move(type); + default: + return arrow::fixed_size_list(type, list_size); + } + }; + + if constexpr (std::is_arithmetic_v) { + if constexpr (std::same_as) { + return typeGenerator(arrow::boolean(), list_size); + } else if constexpr (std::same_as) { + return typeGenerator(arrow::uint8(), list_size); + } else if constexpr (std::same_as) { + return typeGenerator(arrow::uint16(), list_size); + } else if constexpr (std::same_as) { + return typeGenerator(arrow::uint32(), list_size); + } else if constexpr (std::same_as) { + return typeGenerator(arrow::uint64(), list_size); + } else if constexpr (std::same_as) { + return typeGenerator(arrow::int8(), list_size); + } else if constexpr (std::same_as) { + return typeGenerator(arrow::int16(), list_size); + } else if constexpr (std::same_as) { + return typeGenerator(arrow::int32(), list_size); + } else if constexpr (std::same_as) { + return typeGenerator(arrow::int64(), list_size); + } else if constexpr (std::same_as) { + return typeGenerator(arrow::float32(), list_size); + } else if constexpr (std::same_as) { + return typeGenerator(arrow::float64(), list_size); + } + } else if constexpr (std::is_bounded_array_v) { + return asArrowDataType>(std::extent_v); + } else if constexpr (o2::framework::is_specialization_v) { + return asArrowDataType(-1); + } + return nullptr; +} } // namespace o2::soa #endif // O2_FRAMEWORK_ARROWTYPES_H diff --git a/Framework/Core/src/AODReaderHelpers.cxx b/Framework/Core/src/AODReaderHelpers.cxx index 09ec16a93b087..cf019ee218f73 100644 --- a/Framework/Core/src/AODReaderHelpers.cxx +++ b/Framework/Core/src/AODReaderHelpers.cxx @@ -143,13 +143,14 @@ struct Maker { std::vector labels; std::vector> expressions; std::shared_ptr projector = nullptr; - std::shared_ptr schema; + std::shared_ptr schema = nullptr; + std::shared_ptr inputSchema = nullptr; header::DataOrigin origin; header::DataDescription description; header::DataHeader::SubSpecificationType version; - std::shared_ptr make(ProcessingContext& pc) + std::shared_ptr make(ProcessingContext& pc) const { std::vector> originals; for (auto const& label : labels) { @@ -159,15 +160,6 @@ struct Maker { if (fullTable->num_rows() == 0) { return arrow::Table::MakeEmpty(schema).ValueOrDie(); } - if (projector == nullptr) { - auto s = gandiva::Projector::Make( - fullTable->schema(), - expressions, - &projector); - if (!s.ok()) { - throw o2::framework::runtime_error_f("Failed to create projector: %s", s.ToString().c_str()); - } - } return spawnerHelper(fullTable, schema, binding.c_str(), schema->num_fields(), projector); } @@ -200,24 +192,21 @@ struct Spawnable { iws.clear(); iws.str(loc->defaultValue.get()); outputSchema = ArrowJSONHelpers::read(iws); + o2::framework::addLabelToSchema(outputSchema, binding.c_str()); + std::vector> schemas; for (auto& i : spec.metadata) { - if (i.name.starts_with("input:")) { - labels.emplace_back(i.name.substr(6)); + if (i.name.starts_with("input-schema:")) { + labels.emplace_back(i.name.substr(13)); + iws.clear(); + auto json = i.defaultValue.get(); + iws.str(json); + schemas.emplace_back(ArrowJSONHelpers::read(iws)); } } - std::vector> fields; - for (auto& p : projectors) { - expressions::walk(p.node.get(), - [&fields](expressions::Node* n) mutable { - if (n->self.index() == 1) { - auto& b = std::get(n->self); - if (std::find_if(fields.begin(), fields.end(), [&b](std::shared_ptr const& field) { return field->name() == b.name; }) == fields.end()) { - fields.emplace_back(std::make_shared(b.name, expressions::concreteArrowType(b.type))); - } - } - }); + for (auto& s : schemas) { + std::copy(s->fields().begin(), s->fields().end(), std::back_inserter(fields)); } inputSchema = std::make_shared(fields); @@ -233,20 +222,28 @@ struct Spawnable { } } - std::shared_ptr makeProjector() + std::shared_ptr makeProjector() const { - return expressions::createProjectorHelper(projectors.size(), projectors.data(), inputSchema, outputSchema->fields()); + std::shared_ptr p = nullptr; + auto s = gandiva::Projector::Make( + inputSchema, + expressions, + &p); + if (!s.ok()) { + throw o2::framework::runtime_error_f("Failed to create projector: %s", s.ToString().c_str()); + } + return p; } - Maker createMaker() + Maker createMaker() const { - o2::framework::addLabelToSchema(outputSchema, binding.c_str()); return { binding, labels, expressions, - nullptr, + makeProjector(), outputSchema, + inputSchema, origin, description, version}; diff --git a/Framework/Core/src/AnalysisHelpers.cxx b/Framework/Core/src/AnalysisHelpers.cxx index 4f78cc42f3f98..63923008f5a70 100644 --- a/Framework/Core/src/AnalysisHelpers.cxx +++ b/Framework/Core/src/AnalysisHelpers.cxx @@ -35,7 +35,7 @@ std::string serializeProjectors(std::vector& return osm.str(); } -std::string serializeSchema(std::shared_ptr& schema) +std::string serializeSchema(std::shared_ptr schema) { std::stringstream osm; ArrowJSONHelpers::write(osm, schema); diff --git a/Framework/Core/src/AnalysisSupportHelpers.cxx b/Framework/Core/src/AnalysisSupportHelpers.cxx index 7cfab22885671..b5c898faa515a 100644 --- a/Framework/Core/src/AnalysisSupportHelpers.cxx +++ b/Framework/Core/src/AnalysisSupportHelpers.cxx @@ -219,7 +219,6 @@ void AnalysisSupportHelpers::addMissingOutputsToAnalysisCCDBFetcher( // FIXME: good enough for now... for (auto& i : input.metadata) { if ((i.type == VariantType::String) && (i.name.find("input:") != std::string::npos)) { - auto value = i.defaultValue.get(); auto spec = DataSpecUtils::fromMetadataString(i.defaultValue.get()); auto j = std::find_if(publisher.inputs.begin(), publisher.inputs.end(), [&](auto x) { return x.binding == spec.binding; }); if (j == publisher.inputs.end()) { diff --git a/Framework/Core/src/ExpressionJSONHelpers.cxx b/Framework/Core/src/ExpressionJSONHelpers.cxx index 8d4907a721f7e..28685fecad468 100644 --- a/Framework/Core/src/ExpressionJSONHelpers.cxx +++ b/Framework/Core/src/ExpressionJSONHelpers.cxx @@ -637,6 +637,18 @@ void o2::framework::ExpressionJSONHelpers::write(std::ostream& o, std::vector arrowDataTypeFromId(atype::type type, int list_size = 1, atype::type element = atype::NA) +{ + switch (list_size) { + case -1: + return arrow::list(expressions::concreteArrowType(element)); + case 1: + return expressions::concreteArrowType(type); + default: + return arrow::fixed_size_list(expressions::concreteArrowType(element), list_size); + } +} + struct SchemaReader : public rapidjson::BaseReaderHandler, SchemaReader> { using Ch = rapidjson::UTF8<>::Ch; using SizeType = rapidjson::SizeType; @@ -658,6 +670,8 @@ struct SchemaReader : public rapidjson::BaseReaderHandler, Sch std::string name; atype::type type; + atype::type element; + int list_size = 1; SchemaReader() { @@ -706,6 +720,12 @@ struct SchemaReader : public rapidjson::BaseReaderHandler, Sch if (currentKey.compare("type") == 0) { return true; } + if (currentKey.compare("size") == 0) { + return true; + } + if (currentKey.compare("element") == 0) { + return true; + } } states.push(State::IN_ERROR); @@ -721,6 +741,9 @@ struct SchemaReader : public rapidjson::BaseReaderHandler, Sch if (states.top() == State::IN_LIST) { states.push(State::IN_FIELD); + list_size = 1; + element = atype::NA; + type = atype::NA; return true; } @@ -734,7 +757,7 @@ struct SchemaReader : public rapidjson::BaseReaderHandler, Sch if (states.top() == State::IN_FIELD) { states.pop(); // add a field - fields.emplace_back(std::make_shared(name, expressions::concreteArrowType(type))); + fields.emplace_back(std::make_shared(name, arrowDataTypeFromId(type, list_size, element))); return true; } @@ -754,6 +777,14 @@ struct SchemaReader : public rapidjson::BaseReaderHandler, Sch type = (atype::type)i; return true; } + if (currentKey.compare("element") == 0) { + element = (atype::type)i; + return true; + } + if (currentKey.compare("size") == 0) { + list_size = i; + return true; + } } states.push(State::IN_ERROR); @@ -777,6 +808,10 @@ struct SchemaReader : public rapidjson::BaseReaderHandler, Sch bool Int(int i) { debug << "Int(" << i << ")" << std::endl; + if (states.top() == State::IN_FIELD && currentKey.compare("size") == 0) { + list_size = i; + return true; + } return Uint(i); } }; @@ -791,7 +826,7 @@ std::shared_ptr o2::framework::ArrowJSONHelpers::read(std::istrea bool ok = reader.Parse(isw, sreader); if (!ok) { - throw framework::runtime_error_f("Cannot parse serialized Expression, error: %s at offset: %d", rapidjson::GetParseError_En(reader.GetParseErrorCode()), reader.GetErrorOffset()); + throw framework::runtime_error_f("Cannot parse serialized Schema, error: %s at offset: %d", rapidjson::GetParseError_En(reader.GetParseErrorCode()), reader.GetErrorOffset()); } return sreader.schema; } @@ -804,6 +839,20 @@ void writeSchema(rapidjson::Writer& w, arrow::Schema* w.StartObject(); w.Key("name"); w.String(f->name().c_str()); + auto fixedList = dynamic_cast(f->type().get()); + if (fixedList != nullptr) { + w.Key("size"); + w.Int(fixedList->list_size()); + w.Key("element"); + w.Int(fixedList->field(0)->type()->id()); + } + auto varList = dynamic_cast(f->type().get()); + if (varList != nullptr) { + w.Key("size"); + w.Int(-1); + w.Key("element"); + w.Int(varList->field(0)->type()->id()); + } w.Key("type"); w.Int(f->type()->id()); w.EndObject(); diff --git a/Framework/Core/test/test_Expressions.cxx b/Framework/Core/test/test_Expressions.cxx index 41be7d53d2276..b4a65fb0c7b48 100644 --- a/Framework/Core/test/test_Expressions.cxx +++ b/Framework/Core/test/test_Expressions.cxx @@ -454,4 +454,33 @@ TEST_CASE("TestExpressionSerialization") ism.str(osm.str()); auto newSchemap = ArrowJSONHelpers::read(ism); REQUIRE(schemap->ToString() == newSchemap->ToString()); + + osm.clear(); + osm.str(""); + ArrowJSONHelpers::write(osm, schemap1); + + ism.clear(); + ism.str(osm.str()); + auto newSchemap1 = ArrowJSONHelpers::read(ism); + REQUIRE(schemap1->ToString() == newSchemap1->ToString()); + + osm.clear(); + osm.str(""); + auto realisticSchema = std::make_shared(o2::soa::createFieldsFromColumns(o2::aod::MetadataTrait>::metadata::persistent_columns_t{})); + ArrowJSONHelpers::write(osm, realisticSchema); + + ism.clear(); + ism.str(osm.str()); + auto restoredSchema = ArrowJSONHelpers::read(ism); + REQUIRE(realisticSchema->ToString() == restoredSchema->ToString()); + + osm.clear(); + osm.str(""); + auto realisticSchema1 = std::make_shared(o2::soa::createFieldsFromColumns(o2::aod::MetadataTrait>::metadata::persistent_columns_t{})); + ArrowJSONHelpers::write(osm, realisticSchema1); + + ism.clear(); + ism.str(osm.str()); + auto restoredSchema1 = ArrowJSONHelpers::read(ism); + REQUIRE(realisticSchema1->ToString() == restoredSchema1->ToString()); }