diff --git a/Framework/Core/include/Framework/CompletionPolicyHelpers.h b/Framework/Core/include/Framework/CompletionPolicyHelpers.h index 547add44560ea..aa336d040d30d 100644 --- a/Framework/Core/include/Framework/CompletionPolicyHelpers.h +++ b/Framework/Core/include/Framework/CompletionPolicyHelpers.h @@ -43,6 +43,11 @@ struct CompletionPolicyHelpers { /// When any of the parts of the record have been received, consume them. static CompletionPolicy consumeWhenAny(const char* name, CompletionPolicy::Matcher matcher); + +#if __has_include() + /// When any of the parts which has arrived has a refcount of 1. + static CompletionPolicy consumeWhenAnyZeroCount(const char* name, CompletionPolicy::Matcher matcher); +#endif /// Default matcher applies for all devices static CompletionPolicy consumeWhenAny(CompletionPolicy::Matcher matcher = [](auto const&) -> bool { return true; }) { diff --git a/Framework/Core/include/Framework/InputSpan.h b/Framework/Core/include/Framework/InputSpan.h index c435276c7134f..cf8c8acda6796 100644 --- a/Framework/Core/include/Framework/InputSpan.h +++ b/Framework/Core/include/Framework/InputSpan.h @@ -46,7 +46,7 @@ class InputSpan /// index and the buffer associated. /// @nofPartsGetter is the getter for the number of parts associated with an index /// @a size is the number of elements in the span. - InputSpan(std::function getter, std::function nofPartsGetter, size_t size); + InputSpan(std::function getter, std::function nofPartsGetter, std::function refCountGetter, size_t size); /// @a i-th element of the InputSpan [[nodiscard]] DataRef get(size_t i, size_t partidx = 0) const @@ -66,6 +66,18 @@ class InputSpan return mNofPartsGetter(i); } + // Get the refcount for a given part + [[nodiscard]] int getRefCount(size_t i) const + { + if (i >= mSize) { + return 0; + } + if (!mRefCountGetter) { + return -1; + } + return mRefCountGetter(i); + } + /// Number of elements in the InputSpan [[nodiscard]] size_t size() const { @@ -236,6 +248,7 @@ class InputSpan private: std::function mGetter; std::function mNofPartsGetter; + std::function mRefCountGetter; size_t mSize; }; diff --git a/Framework/Core/src/CompletionPolicy.cxx b/Framework/Core/src/CompletionPolicy.cxx index 9d92fd07e6f5a..ec8997e32c5db 100644 --- a/Framework/Core/src/CompletionPolicy.cxx +++ b/Framework/Core/src/CompletionPolicy.cxx @@ -26,7 +26,11 @@ std::vector { return { CompletionPolicyHelpers::consumeWhenAllOrdered("internal-dpl-aod-writer"), +#if __has_include() + CompletionPolicyHelpers::consumeWhenAnyZeroCount("internal-dpl-injected-dummy-sink", [](DeviceSpec const& s) { return s.name.find("internal-dpl-injected-dummy-sink") != std::string::npos; }), +#else CompletionPolicyHelpers::consumeWhenAny("internal-dpl-injected-dummy-sink", [](DeviceSpec const& s) { return s.name.find("internal-dpl-injected-dummy-sink") != std::string::npos; }), +#endif CompletionPolicyHelpers::consumeWhenAll()}; } diff --git a/Framework/Core/src/CompletionPolicyHelpers.cxx b/Framework/Core/src/CompletionPolicyHelpers.cxx index 9dd895a6fed6d..e682f9a7c7dd6 100644 --- a/Framework/Core/src/CompletionPolicyHelpers.cxx +++ b/Framework/Core/src/CompletionPolicyHelpers.cxx @@ -19,6 +19,9 @@ #include "Framework/TimingInfo.h" #include "DecongestionService.h" #include "Framework/Signpost.h" +#if __has_include() +#include +#endif #include #include @@ -249,6 +252,21 @@ CompletionPolicy CompletionPolicyHelpers::consumeExistingWhenAny(const char* nam }}; } +#if __has_include() +CompletionPolicy CompletionPolicyHelpers::consumeWhenAnyZeroCount(const char* name, CompletionPolicy::Matcher matcher) +{ + auto callback = [](InputSpan const& inputs, std::vector const&, ServiceRegistryRef& ref) -> CompletionPolicy::CompletionOp { + for (size_t i = 0; i < inputs.size(); ++i) { + if (inputs.get(i).header != nullptr && inputs.getRefCount(i) == 1) { + return CompletionPolicy::CompletionOp::Consume; + } + } + return CompletionPolicy::CompletionOp::Wait; + }; + return CompletionPolicy{name, matcher, callback, false}; +} +#endif + CompletionPolicy CompletionPolicyHelpers::consumeWhenAny(const char* name, CompletionPolicy::Matcher matcher) { auto callback = [](InputSpan const& inputs, std::vector const&, ServiceRegistryRef& ref) -> CompletionPolicy::CompletionOp { diff --git a/Framework/Core/src/DataProcessingDevice.cxx b/Framework/Core/src/DataProcessingDevice.cxx index 7f42805cfdb1e..ae25d8d3a915c 100644 --- a/Framework/Core/src/DataProcessingDevice.cxx +++ b/Framework/Core/src/DataProcessingDevice.cxx @@ -57,6 +57,9 @@ #include #include #include +#if __has_include() +#include +#endif #include #include #include @@ -1214,12 +1217,14 @@ void DataProcessingDevice::fillContext(DataProcessorContext& context, DeviceCont if (forwarded.matcher.lifetime != Lifetime::Condition) { onlyConditions = false; } +#if !__has_include() if (strncmp(DataSpecUtils::asConcreteOrigin(forwarded.matcher).str, "AOD", 3) == 0) { context.canForwardEarly = false; overriddenEarlyForward = true; LOG(detail) << "Cannot forward early because of AOD input: " << DataSpecUtils::describe(forwarded.matcher); break; } +#endif if (DataSpecUtils::partialMatch(forwarded.matcher, o2::header::DataDescription{"RAWDATA"}) && mProcessingPolicies.earlyForward == EarlyForwardPolicy::NORAW) { context.canForwardEarly = false; overriddenEarlyForward = true; @@ -2230,7 +2235,15 @@ bool DataProcessingDevice::tryDispatchComputation(ServiceRegistryRef ref, std::v auto nofPartsGetter = [¤tSetOfInputs](size_t i) -> size_t { return currentSetOfInputs[i].getNumberOfPairs(); }; - return InputSpan{getter, nofPartsGetter, currentSetOfInputs.size()}; +#if __has_include() + auto refCountGetter = [¤tSetOfInputs](size_t idx) -> int { + auto& header = static_cast(*currentSetOfInputs[idx].header(0)); + return header.GetRefCount(); + }; +#else + std::function refCountGetter = nullptr; +#endif + return InputSpan{getter, nofPartsGetter, refCountGetter, currentSetOfInputs.size()}; }; auto markInputsAsDone = [ref](TimesliceSlot slot) -> void { diff --git a/Framework/Core/src/DataRelayer.cxx b/Framework/Core/src/DataRelayer.cxx index 385d9a6c50c4a..f30866dc0aa1b 100644 --- a/Framework/Core/src/DataRelayer.cxx +++ b/Framework/Core/src/DataRelayer.cxx @@ -44,6 +44,10 @@ #include #include +#include +#if __has_include() +#include +#endif #include #include #include @@ -209,7 +213,15 @@ DataRelayer::ActivityStats DataRelayer::processDanglingInputs(std::vector(partial.size())}; +#if __has_include() + auto refCountGetter = [&partial](size_t idx) -> int { + auto& header = static_cast(*partial[idx].header(0)); + return header.GetRefCount(); + }; +#else + std::function refCountGetter = nullptr; +#endif + InputSpan span{getter, nPartsGetter, refCountGetter, static_cast(partial.size())}; // Setup the input span if (expirator.checker(services, timestamp.value, span) == false) { @@ -755,7 +767,15 @@ void DataRelayer::getReadyToProcess(std::vector& comp auto nPartsGetter = [&partial](size_t idx) { return partial[idx].size(); }; - InputSpan span{getter, nPartsGetter, static_cast(partial.size())}; +#if __has_include() + auto refCountGetter = [&partial](size_t idx) -> int { + auto& header = static_cast(*partial[idx].header(0)); + return header.GetRefCount(); + }; +#else + std::function refCountGetter = nullptr; +#endif + InputSpan span{getter, nPartsGetter, refCountGetter, static_cast(partial.size())}; CompletionPolicy::CompletionOp action = mCompletionPolicy.callbackFull(span, mInputs, mContext); auto& variables = mTimesliceIndex.getVariablesForSlot(slot); diff --git a/Framework/Core/src/InputSpan.cxx b/Framework/Core/src/InputSpan.cxx index 510b55cd0b9b9..d1dffc85602a5 100644 --- a/Framework/Core/src/InputSpan.cxx +++ b/Framework/Core/src/InputSpan.cxx @@ -29,8 +29,11 @@ InputSpan::InputSpan(std::function getter, size_t size) { } -InputSpan::InputSpan(std::function getter, std::function nofPartsGetter, size_t size) - : mGetter{getter}, mNofPartsGetter{nofPartsGetter}, mSize{size} +InputSpan::InputSpan(std::function getter, + std::function nofPartsGetter, + std::function refCountGetter, + size_t size) + : mGetter{getter}, mNofPartsGetter{nofPartsGetter}, mRefCountGetter(refCountGetter), mSize{size} { } diff --git a/Framework/Core/test/test_InputRecordWalker.cxx b/Framework/Core/test/test_InputRecordWalker.cxx index 5b9004a1a9366..9af3c0dd2dbe2 100644 --- a/Framework/Core/test/test_InputRecordWalker.cxx +++ b/Framework/Core/test/test_InputRecordWalker.cxx @@ -42,7 +42,7 @@ struct DataSet { auto payload = static_cast(this->messages[i].second.at(2 * part + 1)->data()); return DataRef{nullptr, header, payload}; }, - [this](size_t i) { return i < this->messages.size() ? messages[i].second.size() / 2 : 0; }, this->messages.size()}, + [this](size_t i) { return i < this->messages.size() ? messages[i].second.size() / 2 : 0; }, nullptr, this->messages.size()}, record{schema, span, registry}, values{std::move(v)} { diff --git a/Framework/Core/test/test_InputSpan.cxx b/Framework/Core/test/test_InputSpan.cxx index 0622ad898d249..c5682aea80b6c 100644 --- a/Framework/Core/test/test_InputSpan.cxx +++ b/Framework/Core/test/test_InputSpan.cxx @@ -37,7 +37,7 @@ TEST_CASE("TestInputSpan") return inputs[i].size() / 2; }; - InputSpan span{getter, nPartsGetter, inputs.size()}; + InputSpan span{getter, nPartsGetter, nullptr, inputs.size()}; REQUIRE(span.size() == inputs.size()); routeNo = 0; for (; routeNo < span.size(); ++routeNo) { diff --git a/Framework/Utils/test/RawPageTestData.h b/Framework/Utils/test/RawPageTestData.h index 684fc4d0cf8a3..a6b800f7cba32 100644 --- a/Framework/Utils/test/RawPageTestData.h +++ b/Framework/Utils/test/RawPageTestData.h @@ -47,7 +47,9 @@ struct DataSet { auto payload = static_cast(this->messages[i].at(2 * part + 1)->data()); return DataRef{nullptr, header, payload}; }, - [this](size_t i) { return i < this->messages.size() ? messages[i].size() / 2 : 0; }, this->messages.size()}, + [this](size_t i) { return i < this->messages.size() ? messages[i].size() / 2 : 0; }, + nullptr, + this->messages.size()}, record{schema, span, registry}, values{std::move(v)} { @@ -63,5 +65,5 @@ struct DataSet { using AmendRawDataHeader = std::function; DataSet createData(std::vector const& inputspecs, std::vector const& dataheaders, AmendRawDataHeader amendRdh = nullptr); -} // namespace o2::framework +} // namespace o2::framework::test #endif // FRAMEWORK_UTILS_RAWPAGETESTDATA_H