diff --git a/Framework/AnalysisSupport/src/TTreePlugin.cxx b/Framework/AnalysisSupport/src/TTreePlugin.cxx index e84a053d58d60..abc08526815cc 100644 --- a/Framework/AnalysisSupport/src/TTreePlugin.cxx +++ b/Framework/AnalysisSupport/src/TTreePlugin.cxx @@ -14,6 +14,8 @@ #include "Framework/Signpost.h" #include "Framework/Endian.h" #include +#include +#include #include #include #include @@ -23,6 +25,8 @@ #include #include #include +#include +#include O2_DECLARE_DYNAMIC_LOG(root_arrow_fs); @@ -91,6 +95,7 @@ arrow::Result SingleTreeFileSystem::GetFileInfo(std::string return result; } +// A fragment which holds a tree class TTreeFileFragment : public arrow::dataset::FileFragment { public: @@ -101,6 +106,13 @@ class TTreeFileFragment : public arrow::dataset::FileFragment : FileFragment(std::move(source), std::move(format), std::move(partition_expression), std::move(physical_schema)) { } + + std::unique_ptr& GetTree() + { + auto topFs = std::dynamic_pointer_cast(source().filesystem()); + auto treeFs = std::dynamic_pointer_cast(topFs->GetSubFilesystem(source())); + return treeFs->GetTree(source()); + } }; class TTreeFileFormat : public arrow::dataset::FileFormat @@ -158,9 +170,9 @@ class TTreeFileFormat : public arrow::dataset::FileFormat class TTreeOutputStream : public arrow::io::OutputStream { public: - // Using a pointer means that the tree itself is owned by another + // Using a pointer means that the tree itself is owned by another // class - TTreeOutputStream(TTree *, std::string branchPrefix); + TTreeOutputStream(TTree*, std::string branchPrefix); arrow::Status Close() override; @@ -245,33 +257,70 @@ struct TTreeObjectReadingImplementation : public RootArrowFactoryPlugin { } }; +struct BranchFieldMapping { + int mainBranchIdx; + int vlaIdx; + int datasetFieldIdx; +}; + arrow::Result TTreeFileFormat::ScanBatchesAsync( const std::shared_ptr& options, const std::shared_ptr& fragment) const { - // Get the fragment as a TTreeFragment. This might be PART of a TTree. - auto treeFragment = std::dynamic_pointer_cast(fragment); // This is the schema we want to read auto dataset_schema = options->dataset_schema; - auto generator = [pool = options->pool, treeFragment, dataset_schema, &totalCompressedSize = mTotCompressedSize, + auto generator = [pool = options->pool, fragment, dataset_schema, &totalCompressedSize = mTotCompressedSize, &totalUncompressedSize = mTotUncompressedSize]() -> arrow::Future> { - auto schema = treeFragment->format()->Inspect(treeFragment->source()); - std::vector> columns; std::vector> fields = dataset_schema->fields(); - auto physical_schema = *treeFragment->ReadPhysicalSchema(); + auto physical_schema = *fragment->ReadPhysicalSchema(); + + auto fs = std::dynamic_pointer_cast(fragment->source().filesystem()); + // Actually get the TTree from the ROOT file. + auto treeFs = std::dynamic_pointer_cast(fs->GetSubFilesystem(fragment->source())); + + if (dataset_schema->num_fields() > physical_schema->num_fields()) { + throw runtime_error_f("One TTree must have all the fields requested in a table"); + } + + // Register physical fields into the cache + std::vector mappings; + + for (int fi = 0; fi < dataset_schema->num_fields(); ++fi) { + auto dataset_field = dataset_schema->field(fi); + int physicalFieldIdx = physical_schema->GetFieldIndex(dataset_field->name()); + + if (physicalFieldIdx < 0) { + throw runtime_error_f("Cannot find physical field associated to %s", dataset_field->name().c_str()); + } + if (physicalFieldIdx > 1 && physical_schema->field(physicalFieldIdx - 1)->name().ends_with("_size")) { + mappings.push_back({physicalFieldIdx, physicalFieldIdx - 1, fi}); + } else { + mappings.push_back({physicalFieldIdx, -1, fi}); + } + } + + auto& tree = treeFs->GetTree(fragment->source()); + tree->SetCacheSize(25000000); + auto branches = tree->GetListOfBranches(); + for (auto& mapping : mappings) { + tree->AddBranchToCache((TBranch*)branches->At(mapping.mainBranchIdx), false); + if (mapping.vlaIdx != -1) { + tree->AddBranchToCache((TBranch*)branches->At(mapping.vlaIdx), false); + } + } + tree->StopCacheLearningPhase(); static TBufferFile buffer{TBuffer::EMode::kWrite, 4 * 1024 * 1024}; - auto containerFS = std::dynamic_pointer_cast(treeFragment->source().filesystem()); - auto fs = std::dynamic_pointer_cast(containerFS->GetSubFilesystem(treeFragment->source())); int64_t rows = -1; - auto& tree = fs->GetTree(treeFragment->source()); - for (auto& field : fields) { + for (size_t mi = 0; mi < mappings.size(); ++mi) { + BranchFieldMapping mapping = mappings[mi]; // The field actually on disk - auto physicalField = physical_schema->GetFieldByName(field->name()); - TBranch* branch = tree->GetBranch(physicalField->name().c_str()); + auto datasetField = dataset_schema->field(mapping.datasetFieldIdx); + auto physicalField = physical_schema->field(mapping.mainBranchIdx); + auto* branch = (TBranch*)branches->At(mapping.mainBranchIdx); assert(branch); buffer.Reset(); auto totalEntries = branch->GetEntries(); @@ -284,12 +333,12 @@ arrow::Result TTreeFileFormat::ScanBatchesAsync( arrow::Status status; int readEntries = 0; std::shared_ptr array; - auto listType = std::dynamic_pointer_cast(physicalField->type()); - if (physicalField->type() == arrow::boolean() || - (listType && physicalField->type()->field(0)->type() == arrow::boolean())) { + auto listType = std::dynamic_pointer_cast(datasetField->type()); + if (datasetField->type() == arrow::boolean() || + (listType && datasetField->type()->field(0)->type() == arrow::boolean())) { if (listType) { std::unique_ptr builder = nullptr; - auto status = arrow::MakeBuilder(pool, physicalField->type()->field(0)->type(), &builder); + auto status = arrow::MakeBuilder(pool, datasetField->type()->field(0)->type(), &builder); if (!status.ok()) { throw runtime_error("Cannot create value builder"); } @@ -316,7 +365,7 @@ arrow::Result TTreeFileFormat::ScanBatchesAsync( } } else if (listType == nullptr) { std::unique_ptr builder = nullptr; - auto status = arrow::MakeBuilder(pool, physicalField->type(), &builder); + auto status = arrow::MakeBuilder(pool, datasetField->type(), &builder); if (!status.ok()) { throw runtime_error("Cannot create builder"); } @@ -340,8 +389,6 @@ arrow::Result TTreeFileFormat::ScanBatchesAsync( } } } else { - // other types: use serialized read to build arrays directly. - auto typeSize = physicalField->type()->byte_width(); // This is needed for branches which have not been persisted. auto bytes = branch->GetTotBytes(); auto branchSize = bytes ? bytes : 1000000; @@ -349,7 +396,7 @@ arrow::Result TTreeFileFormat::ScanBatchesAsync( if (!result.ok()) { throw runtime_error("Cannot allocate values buffer"); } - std::shared_ptr arrowValuesBuffer = std::move(result).ValueUnsafe(); + std::shared_ptr arrowValuesBuffer = result.MoveValueUnsafe(); auto ptr = arrowValuesBuffer->mutable_data(); if (ptr == nullptr) { throw runtime_error("Invalid buffer"); @@ -363,23 +410,14 @@ arrow::Result TTreeFileFormat::ScanBatchesAsync( std::span offsets; int size = 0; uint32_t totalSize = 0; - TBranch* mSizeBranch = nullptr; - int64_t listSize = 1; - if (auto fixedSizeList = std::dynamic_pointer_cast(physicalField->type())) { - listSize = fixedSizeList->list_size(); - typeSize = fixedSizeList->field(0)->type()->byte_width(); - } else if (auto vlaListType = std::dynamic_pointer_cast(physicalField->type())) { - listSize = -1; - typeSize = vlaListType->field(0)->type()->byte_width(); - } - if (listSize == -1) { - mSizeBranch = branch->GetTree()->GetBranch((std::string{branch->GetName()} + "_size").c_str()); + if (mapping.vlaIdx != -1) { + auto* mSizeBranch = (TBranch*)branches->At(mapping.vlaIdx); offsetBuffer = std::make_unique(TBuffer::EMode::kWrite, 4 * 1024 * 1024); result = arrow::AllocateResizableBuffer((totalEntries + 1) * (int64_t)sizeof(int), pool); if (!result.ok()) { throw runtime_error("Cannot allocate offset buffer"); } - arrowOffsetBuffer = std::move(result).ValueUnsafe(); + arrowOffsetBuffer = result.MoveValueUnsafe(); unsigned char* ptrOffset = arrowOffsetBuffer->mutable_data(); auto* tPtrOffset = reinterpret_cast(ptrOffset); offsets = std::span{tPtrOffset, tPtrOffset + totalEntries + 1}; @@ -398,9 +436,19 @@ arrow::Result TTreeFileFormat::ScanBatchesAsync( readEntries = 0; } + int typeSize = physicalField->type()->byte_width(); + int64_t listSize = 1; + if (auto fixedSizeList = std::dynamic_pointer_cast(datasetField->type())) { + listSize = fixedSizeList->list_size(); + typeSize = physicalField->type()->field(0)->type()->byte_width(); + } else if (mapping.vlaIdx != -1) { + typeSize = physicalField->type()->field(0)->type()->byte_width(); + listSize = -1; + } + while (readEntries < totalEntries) { auto readLast = branch->GetBulkRead().GetEntriesSerialized(readEntries, buffer); - if (listSize == -1) { + if (mapping.vlaIdx != -1) { size = offsets[readEntries + readLast] - offsets[readEntries]; } else { size = readLast * listSize; @@ -412,18 +460,15 @@ arrow::Result TTreeFileFormat::ScanBatchesAsync( if (listSize >= 1) { totalSize = readEntries * listSize; } - std::shared_ptr varray; - switch (listSize) { - case -1: - varray = std::make_shared(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer); - array = std::make_shared(physicalField->type(), readEntries, arrowOffsetBuffer, varray); - break; - case 1: - array = std::make_shared(physicalField->type(), readEntries, arrowValuesBuffer); - break; - default: - varray = std::make_shared(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer); - array = std::make_shared(physicalField->type(), readEntries, varray); + if (listSize == 1) { + array = std::make_shared(datasetField->type(), readEntries, arrowValuesBuffer); + } else { + auto varray = std::make_shared(datasetField->type()->field(0)->type(), totalSize, arrowValuesBuffer); + if (mapping.vlaIdx != -1) { + array = std::make_shared(datasetField->type(), readEntries, arrowOffsetBuffer, varray); + } else { + array = std::make_shared(datasetField->type(), readEntries, varray); + } } } @@ -534,9 +579,12 @@ auto arrowTypeFromROOT(EDataType type, int size) } } +// This is a datatype for branches which implies +struct RootTransientIndexType : arrow::ExtensionType { +}; + arrow::Result> TTreeFileFormat::Inspect(const arrow::dataset::FileSource& source) const { - arrow::Schema schema{{}}; auto fs = std::dynamic_pointer_cast(source.filesystem()); // Actually get the TTree from the ROOT file. auto treeFs = std::dynamic_pointer_cast(fs->GetSubFilesystem(source)); @@ -548,51 +596,37 @@ arrow::Result> TTreeFileFormat::Inspect(const arr auto branches = tree->GetListOfBranches(); auto n = branches->GetEntries(); - std::vector branchInfos; + std::vector> fields; + + bool prevIsSize = false; for (auto i = 0; i < n; ++i) { auto branch = static_cast(branches->At(i)); - auto name = std::string{branch->GetName()}; - auto pos = name.find("_size"); - if (pos != std::string::npos) { - name.erase(pos); - branchInfos.emplace_back(BranchInfo{name, (TBranch*)nullptr, true}); + std::string name = branch->GetName(); + if (prevIsSize && fields.back()->name() != name + "_size") { + throw runtime_error_f("Unexpected layout for VLA container %s.", branch->GetName()); + } + + if (name.ends_with("_size")) { + fields.emplace_back(std::make_shared(name, arrow::int32())); + prevIsSize = true; } else { - auto lookup = std::find_if(branchInfos.begin(), branchInfos.end(), [&](BranchInfo const& bi) { - return bi.name == name; - }); - if (lookup == branchInfos.end()) { - branchInfos.emplace_back(BranchInfo{name, branch, false}); + static TClass* cls; + EDataType type; + branch->GetExpectedType(cls, type); + + if (prevIsSize) { + fields.emplace_back(std::make_shared(name, arrowTypeFromROOT(type, -1))); } else { - lookup->ptr = branch; + auto listSize = static_cast(branch->GetListOfLeaves()->At(0))->GetLenStatic(); + fields.emplace_back(std::make_shared(name, arrowTypeFromROOT(type, listSize))); } + prevIsSize = false; } } - std::vector> fields; - tree->SetCacheSize(25000000); - for (auto& bi : branchInfos) { - static TClass* cls; - EDataType type; - bi.ptr->GetExpectedType(cls, type); - auto listSize = -1; - if (!bi.mVLA) { - listSize = static_cast(bi.ptr->GetListOfLeaves()->At(0))->GetLenStatic(); - } - auto field = std::make_shared(bi.ptr->GetName(), arrowTypeFromROOT(type, listSize)); - fields.push_back(field); - - tree->AddBranchToCache(bi.ptr); - if (strncmp(bi.ptr->GetName(), "fIndexArray", strlen("fIndexArray")) == 0) { - std::string sizeBranchName = bi.ptr->GetName(); - sizeBranchName += "_size"; - auto* sizeBranch = (TBranch*)tree->GetBranch(sizeBranchName.c_str()); - if (sizeBranch) { - tree->AddBranchToCache(sizeBranch); - } - } + if (fields.back()->name().ends_with("_size")) { + throw runtime_error_f("Missing values for VLA indices %s.", fields.back()->name().c_str()); } - tree->StopCacheLearningPhase(); - return std::make_shared(fields); } @@ -601,9 +635,8 @@ arrow::Result> TTreeFileFormat::Ma arrow::dataset::FileSource source, arrow::compute::Expression partition_expression, std::shared_ptr physical_schema) { - std::shared_ptr format = std::make_shared(mTotCompressedSize, mTotUncompressedSize); - auto fragment = std::make_shared(std::move(source), std::move(format), + auto fragment = std::make_shared(std::move(source), std::dynamic_pointer_cast(shared_from_this()), std::move(partition_expression), std::move(physical_schema)); return std::dynamic_pointer_cast(fragment); diff --git a/Framework/Core/test/test_Root2ArrowTable.cxx b/Framework/Core/test/test_Root2ArrowTable.cxx index 04a8d91303f0e..ebc854d1d6dc0 100644 --- a/Framework/Core/test/test_Root2ArrowTable.cxx +++ b/Framework/Core/test/test_Root2ArrowTable.cxx @@ -384,6 +384,24 @@ bool validateSchema(std::shared_ptr schema) return true; } +bool validatePhysicalSchema(std::shared_ptr schema) +{ + REQUIRE(schema->num_fields() == 12); + REQUIRE(schema->field(0)->type()->id() == arrow::float32()->id()); + REQUIRE(schema->field(1)->type()->id() == arrow::float32()->id()); + REQUIRE(schema->field(2)->type()->id() == arrow::float32()->id()); + REQUIRE(schema->field(3)->type()->id() == arrow::float64()->id()); + REQUIRE(schema->field(4)->type()->id() == arrow::int32()->id()); + REQUIRE(schema->field(5)->type()->id() == arrow::fixed_size_list(arrow::float32(), 3)->id()); + REQUIRE(schema->field(6)->type()->id() == arrow::fixed_size_list(arrow::int32(), 2)->id()); + REQUIRE(schema->field(7)->type()->id() == arrow::boolean()->id()); + REQUIRE(schema->field(8)->type()->id() == arrow::fixed_size_list(arrow::boolean(), 2)->id()); + REQUIRE(schema->field(9)->type()->id() == arrow::int32()->id()); + REQUIRE(schema->field(10)->type()->id() == arrow::list(arrow::int32())->id()); + REQUIRE(schema->field(11)->type()->id() == arrow::int8()->id()); + return true; +} + TEST_CASE("RootTree2Dataset") { using namespace o2::framework; @@ -502,12 +520,22 @@ TEST_CASE("RootTree2Dataset") arrow::dataset::FileSource source("DF_2/tracks", fs); REQUIRE(format->IsSupported(source) == true); - auto schemaOpt = format->Inspect(source); - REQUIRE(schemaOpt.ok()); - auto schema = *schemaOpt; + auto physicalSchema = format->Inspect(source); + REQUIRE(physicalSchema.ok()); + REQUIRE(validatePhysicalSchema(*physicalSchema)); + // Create the dataset schema rather than using the physical one + std::vector> fields; + for (auto& field : (*(physicalSchema))->fields()) { + if (field->name().ends_with("_size")) { + continue; + } + fields.push_back(field); + } + std::shared_ptr schema = std::make_shared(fields); + validateSchema(schema); - auto fragment = format->MakeFragment(source, {}, schema); + auto fragment = format->MakeFragment(source, {}, *physicalSchema); REQUIRE(fragment.ok()); auto options = std::make_shared(); options->dataset_schema = schema; @@ -545,12 +573,22 @@ TEST_CASE("RootTree2Dataset") auto schemaOptWritten = format->Inspect(source); REQUIRE(schemaOptWritten.ok()); auto schemaWritten = *schemaOptWritten; - REQUIRE(validateSchema(schemaWritten)); - auto fragmentWritten = format->MakeFragment(source, {}, schema); + REQUIRE(validatePhysicalSchema(schemaWritten)); + std::vector> fields; + for (auto& field : schemaWritten->fields()) { + if (field->name().ends_with("_size")) { + continue; + } + fields.push_back(field); + } + std::shared_ptr schema = std::make_shared(fields); + REQUIRE(validateSchema(schema)); + + auto fragmentWritten = format->MakeFragment(source, {}, *physicalSchema); REQUIRE(fragmentWritten.ok()); auto optionsWritten = std::make_shared(); - options->dataset_schema = schemaWritten; + options->dataset_schema = schema; auto scannerWritten = format->ScanBatchesAsync(optionsWritten, *fragment); REQUIRE(scannerWritten.ok()); auto batchesWritten = (*scanner)();