diff --git a/Framework/Core/include/Framework/ArrowTableSlicingCache.h b/Framework/Core/include/Framework/ArrowTableSlicingCache.h index 292a67023fc5e..41d6b33e48476 100644 --- a/Framework/Core/include/Framework/ArrowTableSlicingCache.h +++ b/Framework/Core/include/Framework/ArrowTableSlicingCache.h @@ -21,8 +21,8 @@ namespace o2::framework using ListVector = std::vector>; struct SliceInfoPtr { - gsl::span values; - gsl::span counts; + gsl::span offsets; + gsl::span sizes; std::pair getSliceFor(int value) const; }; @@ -66,6 +66,8 @@ struct ArrowTableSlicingCache { Cache bindingsKeys; std::vector>> values; std::vector>> counts; + std::vector> offsets; + std::vector> sizes; Cache bindingsKeysUnsorted; std::vector> valuesUnsorted; diff --git a/Framework/Core/include/Framework/GroupSlicer.h b/Framework/Core/include/Framework/GroupSlicer.h index b8436314b057e..112bf7e147ff0 100644 --- a/Framework/Core/include/Framework/GroupSlicer.h +++ b/Framework/Core/include/Framework/GroupSlicer.h @@ -246,9 +246,7 @@ struct GroupSlicer { pos = position; } // optimized split - auto oc = sliceInfos[index].getSliceFor(pos); - uint64_t offset = oc.first; - auto count = oc.second; + auto [offset, count] = sliceInfos[index].getSliceFor(pos); auto groupedElementsTable = originalTable.rawSlice(offset, offset + count - 1); groupedElementsTable.bindInternalIndicesTo(&originalTable); return groupedElementsTable; diff --git a/Framework/Core/src/ArrowTableSlicingCache.cxx b/Framework/Core/src/ArrowTableSlicingCache.cxx index e001e293c4733..26bb9bcee80eb 100644 --- a/Framework/Core/src/ArrowTableSlicingCache.cxx +++ b/Framework/Core/src/ArrowTableSlicingCache.cxx @@ -32,28 +32,14 @@ void updatePairList(Cache& list, std::string const& binding, std::string const& std::pair SliceInfoPtr::getSliceFor(int value) const { int64_t offset = 0; - if (values.empty()) { + if (offsets.empty()) { return {offset, 0}; } - int64_t p = static_cast(values.size()) - 1; - while (values[p] < 0) { - --p; - if (p < 0) { - return {offset, 0}; - } - } - - if (value > values[p]) { + if ((size_t)value >= offsets.size()) { return {offset, 0}; } - for (auto i = 0U; i < values.size(); ++i) { - if (values[i] == value) { - return {offset, counts[i]}; - } - offset += counts[i]; - } - return {offset, 0}; + return {offsets[value], sizes[value]}; } gsl::span SliceInfoUnsortedPtr::getSliceFor(int value) const @@ -84,6 +70,8 @@ ArrowTableSlicingCache::ArrowTableSlicingCache(Cache&& bsks, Cache&& bsksUnsorte { values.resize(bindingsKeys.size()); counts.resize(bindingsKeys.size()); + offsets.resize(bindingsKeys.size()); + sizes.resize(bindingsKeys.size()); valuesUnsorted.resize(bindingsKeysUnsorted.size()); groups.resize(bindingsKeysUnsorted.size()); @@ -97,6 +85,10 @@ void ArrowTableSlicingCache::setCaches(Cache&& bsks, Cache&& bsksUnsorted) values.resize(bindingsKeys.size()); counts.clear(); counts.resize(bindingsKeys.size()); + offsets.clear(); + offsets.resize(bindingsKeys.size()); + sizes.clear(); + sizes.resize(bindingsKeys.size()); valuesUnsorted.clear(); valuesUnsorted.resize(bindingsKeysUnsorted.size()); groups.clear(); @@ -105,9 +97,11 @@ void ArrowTableSlicingCache::setCaches(Cache&& bsks, Cache&& bsksUnsorted) arrow::Status ArrowTableSlicingCache::updateCacheEntry(int pos, std::shared_ptr const& table) { + values[pos].reset(); + counts[pos].reset(); + offsets[pos].clear(); + sizes[pos].clear(); if (table->num_rows() == 0) { - values[pos].reset(); - counts[pos].reset(); return arrow::Status::OK(); } auto& [b, k, e] = bindingsKeys[pos]; @@ -125,6 +119,31 @@ arrow::Status ArrowTableSlicingCache::updateCacheEntry(int pos, std::shared_ptr< counts[pos].reset(); values[pos] = std::make_shared>(pair.field(0)->data()); counts[pos] = std::make_shared>(pair.field(1)->data()); + + int maxValue = -1; + for (auto i = values[pos]->length() - 1; i >= 0; --i) { + if (values[pos]->Value(i) < 0) { + continue; + } else { + maxValue = values[pos]->Value(i); + break; + } + } + + offsets[pos].resize(maxValue + 1); + sizes[pos].resize(maxValue + 1); + std::fill(offsets[pos].begin(), offsets[pos].end(), 0); + std::fill(sizes[pos].begin(), sizes[pos].end(), 0); + int64_t offset = 0; + for (auto i = 0U; i < values[pos]->length(); ++i) { + auto value = values[pos]->Value(i); + auto count = counts[pos]->Value(i); + if (value >= 0) { + offsets[pos][value] = offset; + sizes[pos][value] = count; + } + offset += count; + } return arrow::Status::OK(); } @@ -221,14 +240,14 @@ SliceInfoPtr ArrowTableSlicingCache::getCacheForPos(int pos) const { if (values[pos] == nullptr && counts[pos] == nullptr) { return { - {}, - {} // + {}, // + {} // }; } return { - {reinterpret_cast(values[pos]->values()->data()), static_cast(values[pos]->length())}, - {reinterpret_cast(counts[pos]->values()->data()), static_cast(counts[pos]->length())} // + gsl::span{offsets[pos].data(), offsets[pos].size()}, // + gsl::span(sizes[pos].data(), sizes[pos].size()) // }; } diff --git a/Framework/Core/test/test_GroupSlicer.cxx b/Framework/Core/test/test_GroupSlicer.cxx index 091c21eeae229..2f21d7dd17975 100644 --- a/Framework/Core/test/test_GroupSlicer.cxx +++ b/Framework/Core/test/test_GroupSlicer.cxx @@ -245,8 +245,8 @@ TEST_CASE("GroupSlicerMismatchedGroups") if (i == 3 || i == 10 || i == 12 || i == 16 || i == 19) { continue; } - for (auto j = 0.f; j < 5; j += 0.5f) { - trksWriter(0, i, 0.5f * j); + for (auto j = 0; j < 10; ++j) { + trksWriter(0, i, 0.5f * (j / 2.)); } } auto trkTable = builderT.finalize(); @@ -260,21 +260,19 @@ TEST_CASE("GroupSlicerMismatchedGroups") auto s = slices.updateCacheEntry(0, trkTable); o2::framework::GroupSlicer g(e, tt, slices); - auto count = 0; for (auto& slice : g) { auto as = slice.associatedTables(); auto gg = slice.groupingElement(); - REQUIRE(gg.globalIndex() == count); + REQUIRE(gg.globalIndex() == (int64_t)slice.position); auto trks = std::get(as); - if (count == 3 || count == 10 || count == 12 || count == 16 || count == 19) { + if (slice.position == 3 || slice.position == 10 || slice.position == 12 || slice.position == 16 || slice.position == 19) { REQUIRE(trks.size() == 0); } else { REQUIRE(trks.size() == 10); } for (auto& trk : trks) { - REQUIRE(trk.eventId() == count); + REQUIRE(trk.eventId() == (int64_t)slice.position); } - ++count; } } @@ -299,8 +297,8 @@ TEST_CASE("GroupSlicerMismatchedUnassignedGroups") ++skip; continue; } - for (auto j = 0.f; j < 5; j += 0.5f) { - trksWriter(0, i, 0.5f * j); + for (auto j = 0; j < 10; ++j) { + trksWriter(0, i, 0.5f * (j / 2.)); } } for (auto i = 0; i < 5; ++i) { @@ -510,7 +508,7 @@ TEST_CASE("GroupSlicerMismatchedUnsortedFilteredGroupsWithSelfIndex") { TableBuilder builderE; auto evtsWriter = builderE.cursor(); - for (auto i = 0; i < 20; ++i) { + for (auto i = 0; i < 10; ++i) { evtsWriter(0, i, 0.5f * i, 2.f * i, 3.f * i); } auto evtTable = builderE.finalize(); @@ -523,7 +521,6 @@ TEST_CASE("GroupSlicerMismatchedUnsortedFilteredGroupsWithSelfIndex") std::uniform_int_distribution<> distrib(0, 99); for (auto i = 0; i < 100; ++i) { - filler[0] = distrib(gen); filler[1] = distrib(gen); if (filler[0] > filler[1]) { @@ -541,7 +538,6 @@ TEST_CASE("GroupSlicerMismatchedUnsortedFilteredGroupsWithSelfIndex") auto thingsTable = builderT.finalize(); aod::Events e{evtTable}; - // aod::Parts p{partsTable}; aod::Things t{thingsTable}; using FilteredParts = soa::Filtered; auto size = distrib(gen);