diff --git a/Framework/Core/include/Framework/ASoA.h b/Framework/Core/include/Framework/ASoA.h index 8af872a64176d..65fd12b3e6df3 100644 --- a/Framework/Core/include/Framework/ASoA.h +++ b/Framework/Core/include/Framework/ASoA.h @@ -1389,76 +1389,69 @@ consteval static bool relatedBySortedIndex() namespace o2::framework { -template -struct PresliceBase { - constexpr static bool sorted = SORTED; + +struct PreslicePolicyBase { + const std::string binding; + StringPair bindingKey; + + bool isMissing() const; + StringPair const& getBindingKey() const; +}; + +struct PreslicePolicySorted : public PreslicePolicyBase { + void updateSliceInfo(SliceInfoPtr&& si); + + SliceInfoPtr sliceInfo; + std::shared_ptr getSliceFor(int value, std::shared_ptr const& input, uint64_t& offset) const; +}; + +struct PreslicePolicyGeneral : public PreslicePolicyBase { + void updateSliceInfo(SliceInfoUnsortedPtr&& si); + + SliceInfoUnsortedPtr sliceInfo; + gsl::span getSliceFor(int value) const; +}; + +template +struct PresliceBase : public Policy { constexpr static bool optional = OPT; using target_t = T; const std::string binding; PresliceBase(expressions::BindingNode index_) - : binding{o2::soa::getLabelFromTypeForKey(index_.name)}, - bindingKey{binding, index_.name} {} - - void updateSliceInfo(std::conditional_t&& si) + : Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey(std::string{index_.name})}, std::make_pair(o2::soa::getLabelFromTypeForKey(std::string{index_.name}), std::string{index_.name})}, {}} { - sliceInfo = si; } std::shared_ptr getSliceFor(int value, std::shared_ptr const& input, uint64_t& offset) const { if constexpr (OPT) { - if (isMissing()) { + if (Policy::isMissing()) { return nullptr; } } - if constexpr (SORTED) { - auto [offset_, count] = sliceInfo.getSliceFor(value); - auto output = input->Slice(offset_, count); - offset = static_cast(offset_); - return output; - } else { - static_assert(SORTED, "Wrong method called for unsorted cache"); - } + return Policy::getSliceFor(value, input, offset); } gsl::span getSliceFor(int value) const { if constexpr (OPT) { - if (isMissing()) { + if (Policy::isMissing()) { return {}; } } - if constexpr (!SORTED) { - return sliceInfo.getSliceFor(value); - } else { - static_assert(!SORTED, "Wrong method called for sorted cache"); - } + return Policy::getSliceFor(value); } - - bool isMissing() const - { - return binding == "[MISSING]"; - } - - StringPair const& getBindingKey() const - { - return bindingKey; - } - - std::conditional_t sliceInfo; - - StringPair bindingKey; }; template -using PresliceUnsorted = PresliceBase; +using PresliceUnsorted = PresliceBase; template -using PresliceUnsortedOptional = PresliceBase; +using PresliceUnsortedOptional = PresliceBase; template -using Preslice = PresliceBase; +using Preslice = PresliceBase; template -using PresliceOptional = PresliceBase; +using PresliceOptional = PresliceBase; } // namespace o2::framework @@ -1497,96 +1490,84 @@ static consteval auto extractBindings(framework::pack) SelectionVector selectionToVector(gandiva::Selection const& sel); -template -auto doSliceBy(T const* table, o2::framework::PresliceBase const& container, int value) +template + requires std::same_as && (o2::soa::is_binding_compatible_v()) +auto doSliceBy(T const* table, o2::framework::PresliceBase const& container, int value) { - if constexpr (o2::soa::is_binding_compatible_v()) { - if constexpr (OPT) { - if (container.isMissing()) { - missingOptionalPreslice(getLabelFromType>().data(), container.bindingKey.second.c_str()); - } - } - if constexpr (SORTED) { - uint64_t offset = 0; - auto out = container.getSliceFor(value, table->asArrowTable(), offset); - auto t = typename T::self_t({out}, offset); - table->copyIndexBindings(t); - t.bindInternalIndicesTo(table); - return t; - } else { - auto selection = container.getSliceFor(value); - if constexpr (soa::is_filtered_table) { - auto t = soa::Filtered({table->asArrowTable()}, selection); - table->copyIndexBindings(t); - t.bindInternalIndicesTo(table); - t.intersectWithSelection(table->getSelectedRows()); // intersect filters - return t; - } else { - auto t = soa::Filtered({table->asArrowTable()}, selection); - table->copyIndexBindings(t); - t.bindInternalIndicesTo(table); - return t; - } + if constexpr (OPT) { + if (container.isMissing()) { + missingOptionalPreslice(getLabelFromType>().data(), container.bindingKey.second.c_str()); } - } else { - if constexpr (SORTED) { - static_assert(o2::framework::always_static_assert_v, "Wrong Preslice<> entry used: incompatible type"); - } else { - static_assert(o2::framework::always_static_assert_v, "Wrong PresliceUnsorted<> entry used: incompatible type"); + } + uint64_t offset = 0; + auto out = container.getSliceFor(value, table->asArrowTable(), offset); + auto t = typename T::self_t({out}, offset); + table->copyIndexBindings(t); + t.bindInternalIndicesTo(table); + return t; +} + +template +auto doSliceByHelper(T const* table, gsl::span const& selection) +{ + auto t = soa::Filtered({table->asArrowTable()}, selection); + table->copyIndexBindings(t); + t.bindInternalIndicesTo(table); + t.intersectWithSelection(table->getSelectedRows()); // intersect filters + return t; +} + +template + requires(!soa::is_filtered_table) +auto doSliceByHelper(T const* table, gsl::span const& selection) +{ + auto t = soa::Filtered({table->asArrowTable()}, selection); + table->copyIndexBindings(t); + t.bindInternalIndicesTo(table); + return t; +} + +template + requires std::same_as && (o2::soa::is_binding_compatible_v()) +auto doSliceBy(T const* table, o2::framework::PresliceBase const& container, int value) +{ + if constexpr (OPT) { + if (container.isMissing()) { + missingOptionalPreslice(getLabelFromType>().data(), container.bindingKey.second.c_str()); } } + auto selection = container.getSliceFor(value); + return doSliceByHelper(table, selection); } -template +SelectionVector sliceSelection(gsl::span const& mSelectedRows, int64_t nrows, uint64_t offset); + +template auto prepareFilteredSlice(T const* table, std::shared_ptr slice, uint64_t offset) { if (offset >= static_cast(table->tableSize())) { - if constexpr (soa::is_filtered_table) { - Filtered fresult{{{slice}}, SelectionVector{}, 0}; - table->copyIndexBindings(fresult); - return fresult; - } else { - typename T::self_t fresult{{{slice}}, SelectionVector{}, 0}; - table->copyIndexBindings(fresult); - return fresult; - } - } - auto start = offset; - auto end = start + slice->num_rows(); - auto mSelectedRows = table->getSelectedRows(); - auto start_iterator = std::lower_bound(mSelectedRows.begin(), mSelectedRows.end(), start); - auto stop_iterator = std::lower_bound(start_iterator, mSelectedRows.end(), end); - SelectionVector slicedSelection{start_iterator, stop_iterator}; - std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(), - [&start](int64_t idx) { - return idx - static_cast(start); - }); - if constexpr (soa::is_filtered_table) { - Filtered fresult{{{slice}}, std::move(slicedSelection), start}; - table->copyIndexBindings(fresult); - return fresult; - } else { - typename T::self_t fresult{{{slice}}, std::move(slicedSelection), start}; + Filtered fresult{{{slice}}, SelectionVector{}, 0}; table->copyIndexBindings(fresult); return fresult; } + auto slicedSelection = sliceSelection(table->getSelectedRows(), slice->num_rows(), offset); + Filtered fresult{{{slice}}, std::move(slicedSelection), offset}; + table->copyIndexBindings(fresult); + return fresult; } -template -auto doFilteredSliceBy(T const* table, o2::framework::PresliceBase const& container, int value) +template + requires(o2::soa::is_binding_compatible_v()) +auto doFilteredSliceBy(T const* table, o2::framework::PresliceBase const& container, int value) { - if constexpr (o2::soa::is_binding_compatible_v()) { - if constexpr (OPT) { - if (container.isMissing()) { - missingOptionalPreslice(getLabelFromType().data(), container.bindingKey.second.c_str()); - } + if constexpr (OPT) { + if (container.isMissing()) { + missingOptionalPreslice(getLabelFromType().data(), container.bindingKey.second.c_str()); } - uint64_t offset = 0; - auto slice = container.getSliceFor(value, table->asArrowTable(), offset); - return prepareFilteredSlice(table, slice, offset); - } else { - static_assert(o2::framework::always_static_assert_v, "Wrong Preslice<> entry used: incompatible type"); } + uint64_t offset = 0; + auto slice = container.getSliceFor(value, table->asArrowTable(), offset); + return prepareFilteredSlice(table, slice, offset); } template @@ -2099,8 +2080,8 @@ class Table return doSliceByCachedUnsorted(this, node, value, cache); } - template - auto sliceBy(o2::framework::PresliceBase const& container, int value) const + template + auto sliceBy(o2::framework::PresliceBase const& container, int value) const { return doSliceBy(this, container, value); } @@ -3201,8 +3182,8 @@ struct JoinFull : Table, D, o2::aod::Hash<"JOIN"_h>, Ts. return doSliceByCachedUnsorted(this, node, value, cache); } - template - auto sliceBy(o2::framework::PresliceBase const& container, int value) const + template + auto sliceBy(o2::framework::PresliceBase const& container, int value) const { return doSliceBy(this, container, value); } @@ -3463,14 +3444,16 @@ class FilteredBase : public T return doSliceByCachedUnsorted(this, node, value, cache); } - template - auto sliceBy(o2::framework::PresliceBase const& container, int value) const + template + auto sliceBy(o2::framework::PresliceBase const& container, int value) const { - if constexpr (SORTED) { - return doFilteredSliceBy(this, container, value); - } else { - return doSliceBy(this, container, value); - } + return doFilteredSliceBy(this, container, value); + } + + template + auto sliceBy(o2::framework::PresliceBase const& container, int value) const + { + return doSliceBy(this, container, value); } auto select(framework::expressions::Filter const& f) const @@ -3697,14 +3680,16 @@ class Filtered : public FilteredBase return doSliceByCachedUnsorted(this, node, value, cache); } - template - auto sliceBy(o2::framework::PresliceBase const& container, int value) const + template + auto sliceBy(o2::framework::PresliceBase const& container, int value) const { - if constexpr (SORTED) { - return doFilteredSliceBy(this, container, value); - } else { - return doSliceBy(this, container, value); - } + return doFilteredSliceBy(this, container, value); + } + + template + auto sliceBy(o2::framework::PresliceBase const& container, int value) const + { + return doSliceBy(this, container, value); } auto select(framework::expressions::Filter const& f) const @@ -3864,14 +3849,16 @@ class Filtered> : public FilteredBase return doSliceByCachedUnsorted(this, node, value, cache); } - template - auto sliceBy(o2::framework::PresliceBase const& container, int value) const + template + auto sliceBy(o2::framework::PresliceBase const& container, int value) const { - if constexpr (SORTED) { - return doFilteredSliceBy(this, container, value); - } else { - return doSliceBy(this, container, value); - } + return doFilteredSliceBy(this, container, value); + } + + template + auto sliceBy(o2::framework::PresliceBase const& container, int value) const + { + return doSliceBy(this, container, value); } private: diff --git a/Framework/Core/include/Framework/AnalysisHelpers.h b/Framework/Core/include/Framework/AnalysisHelpers.h index d84c9714b2f30..bb7e5e14aaa75 100644 --- a/Framework/Core/include/Framework/AnalysisHelpers.h +++ b/Framework/Core/include/Framework/AnalysisHelpers.h @@ -652,8 +652,8 @@ struct Partition { return mFiltered->sliceByCachedUnsorted(node, value, cache); } - template - [[nodiscard]] auto sliceBy(o2::framework::PresliceBase const& container, int value) const + template + [[nodiscard]] auto sliceBy(o2::framework::PresliceBase const& container, int value) const { return mFiltered->sliceBy(container, value); } diff --git a/Framework/Core/include/Framework/AnalysisManagers.h b/Framework/Core/include/Framework/AnalysisManagers.h index e0dd21708e841..30ebf1799b227 100644 --- a/Framework/Core/include/Framework/AnalysisManagers.h +++ b/Framework/Core/include/Framework/AnalysisManagers.h @@ -645,44 +645,60 @@ struct PresliceManager { } }; -template -struct PresliceManager> { - static bool registerCache(PresliceBase& container, std::vector& bsks, std::vector& bsksU) +template +struct PresliceManager> { + static bool registerCache(PresliceBase& container, std::vector& bsks, std::vector&) + requires std::same_as { if constexpr (OPT) { if (container.binding == "[MISSING]") { return true; } } - if constexpr (SORTED) { - auto locate = std::find_if(bsks.begin(), bsks.end(), [&](auto const& entry) { return (entry.first == container.bindingKey.first) && (entry.second == container.bindingKey.second); }); - if (locate == bsks.end()) { - bsks.emplace_back(container.getBindingKey()); - } - return true; - } else { - auto locate = std::find_if(bsksU.begin(), bsksU.end(), [&](auto const& entry) { return (entry.first == container.bindingKey.first) && (entry.second == container.bindingKey.second); }); - if (locate == bsksU.end()) { - bsksU.emplace_back(container.getBindingKey()); + auto locate = std::find_if(bsks.begin(), bsks.end(), [&](auto const& entry) { return (entry.first == container.bindingKey.first) && (entry.second == container.bindingKey.second); }); + if (locate == bsks.end()) { + bsks.emplace_back(container.getBindingKey()); + } + return true; + } + + static bool registerCache(PresliceBase& container, std::vector&, std::vector& bsksU) + requires std::same_as + { + if constexpr (OPT) { + if (container.binding == "[MISSING]") { + return true; } - return true; } + auto locate = std::find_if(bsksU.begin(), bsksU.end(), [&](auto const& entry) { return (entry.first == container.bindingKey.first) && (entry.second == container.bindingKey.second); }); + if (locate == bsksU.end()) { + bsksU.emplace_back(container.getBindingKey()); + } + return true; } - static bool updateSliceInfo(PresliceBase& container, ArrowTableSlicingCache& cache) + static bool updateSliceInfo(PresliceBase& container, ArrowTableSlicingCache& cache) + requires std::same_as { if constexpr (OPT) { if (container.binding == "[MISSING]") { return true; } } - if constexpr (SORTED) { - container.updateSliceInfo(cache.getCacheFor(container.getBindingKey())); - return true; - } else { - container.updateSliceInfo(cache.getCacheUnsortedFor(container.getBindingKey())); - return true; + container.updateSliceInfo(cache.getCacheFor(container.getBindingKey())); + return true; + } + + static bool updateSliceInfo(PresliceBase& container, ArrowTableSlicingCache& cache) + requires std::same_as + { + if constexpr (OPT) { + if (container.binding == "[MISSING]") { + return true; + } } + container.updateSliceInfo(cache.getCacheUnsortedFor(container.getBindingKey())); + return true; } }; } // namespace o2::framework diff --git a/Framework/Core/src/ASoA.cxx b/Framework/Core/src/ASoA.cxx index a37d0f33891e7..276a592d87895 100644 --- a/Framework/Core/src/ASoA.cxx +++ b/Framework/Core/src/ASoA.cxx @@ -50,6 +50,20 @@ SelectionVector selectionToVector(gandiva::Selection const& sel) return rows; } +SelectionVector sliceSelection(gsl::span const& mSelectedRows, int64_t nrows, uint64_t offset) +{ + auto start = offset; + auto end = start + nrows; + auto start_iterator = std::lower_bound(mSelectedRows.begin(), mSelectedRows.end(), start); + auto stop_iterator = std::lower_bound(start_iterator, mSelectedRows.end(), end); + SelectionVector slicedSelection{start_iterator, stop_iterator}; + std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(), + [&start](int64_t idx) { + return idx - static_cast(start); + }); + return slicedSelection; +} + std::shared_ptr ArrowHelpers::joinTables(std::vector>&& tables) { if (tables.size() == 1) { @@ -177,4 +191,37 @@ std::string strToUpper(std::string&& str) std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) { return std::toupper(c); }); return str; } + +bool PreslicePolicyBase::isMissing() const +{ + return binding == "[MISSING]"; +} + +StringPair const& PreslicePolicyBase::getBindingKey() const +{ + return bindingKey; +} + +void PreslicePolicySorted::updateSliceInfo(SliceInfoPtr&& si) +{ + sliceInfo = si; +} + +void PreslicePolicyGeneral::updateSliceInfo(SliceInfoUnsortedPtr&& si) +{ + sliceInfo = si; +} + +std::shared_ptr PreslicePolicySorted::getSliceFor(int value, std::shared_ptr const& input, uint64_t& offset) const +{ + auto [offset_, count] = this->sliceInfo.getSliceFor(value); + auto output = input->Slice(offset_, count); + offset = static_cast(offset_); + return output; +} + +gsl::span PreslicePolicyGeneral::getSliceFor(int value) const +{ + return this->sliceInfo.getSliceFor(value); +} } // namespace o2::framework