diff --git a/DataFormats/Headers/include/Headers/DataHeader.h b/DataFormats/Headers/include/Headers/DataHeader.h index b44f41c5d3cb3..e43dbcbd09f81 100644 --- a/DataFormats/Headers/include/Headers/DataHeader.h +++ b/DataFormats/Headers/include/Headers/DataHeader.h @@ -373,7 +373,8 @@ struct BaseHeader { uint32_t flags; struct { uint32_t flagsNextHeader : 1, // do we have a next header after this one? - flagsReserved : 15, // reserved for future use + flagsReserved : 14, // reserved for future use. MUST be filled with 0s. + flagsDisabled : 1, // header should be ignored if this is 1 flagsDerivedHeader : 16; // reserved for usage by the derived header }; }; @@ -467,7 +468,9 @@ auto get(const std::byte* buffer, size_t /*len*/ = 0) // otherwise, we keep the code related to the exception outside the header file. // Note: Can not check on size because the O2 data model requires variable size headers // to be supported. - if (current->sanityCheck(HeaderValueType::sVersion)) { + if (current->sanityCheck(HeaderValueType::sVersion) && current->flagsDisabled == 0) { + // If the first header matches and it's enabled, we return it + // otherwise we look for more. return reinterpret_cast(current); } } @@ -475,7 +478,10 @@ auto get(const std::byte* buffer, size_t /*len*/ = 0) while ((current = current->next())) { prev = current; if (current->description == HeaderValueType::sHeaderType) { - if (current->sanityCheck(HeaderValueType::sVersion)) { + // This is needed to allow disabling some headers from being picked up + // even if they are matching. This is handy to have a quick + // way to disable a sub headers without having to drop them. + if (current->sanityCheck(HeaderValueType::sVersion) && current->flagsDisabled == 0) { return reinterpret_cast(current); } } diff --git a/Framework/Core/CMakeLists.txt b/Framework/Core/CMakeLists.txt index ce8fbb0dc55f7..fe8a91eaa0449 100644 --- a/Framework/Core/CMakeLists.txt +++ b/Framework/Core/CMakeLists.txt @@ -224,6 +224,7 @@ add_executable(o2-test-framework-core test/test_FairMQOptionsRetriever.cxx test/test_FairMQResizableBuffer.cxx test/test_FairMQ.cxx + test/test_ForwardInputs.cxx test/test_FrameworkDataFlowToDDS.cxx test/test_FrameworkDataFlowToO2Control.cxx test/test_Graphviz.cxx diff --git a/Framework/Core/include/Framework/DataProcessingHelpers.h b/Framework/Core/include/Framework/DataProcessingHelpers.h index d8d8b7caf9d0a..122b53976c035 100644 --- a/Framework/Core/include/Framework/DataProcessingHelpers.h +++ b/Framework/Core/include/Framework/DataProcessingHelpers.h @@ -12,6 +12,10 @@ #define O2_FRAMEWORK_DATAPROCESSINGHELPERS_H_ #include +#include "Framework/TimesliceSlot.h" +#include "Framework/TimesliceIndex.h" +#include +#include namespace o2::framework { @@ -23,6 +27,9 @@ struct OutputChannelSpec; struct OutputChannelState; struct ProcessingPolicies; struct DeviceSpec; +struct FairMQDeviceProxy; +struct MessageSet; +struct ChannelIndex; enum struct StreamingState; enum struct TransitionHandlingState; @@ -45,7 +52,9 @@ struct DataProcessingHelpers { static bool hasOnlyGenerated(DeviceSpec const& spec); /// starts the EoS timers and returns the new TransitionHandlingState in case as new state is requested static TransitionHandlingState updateStateTransition(ServiceRegistryRef const& ref, ProcessingPolicies const& policies); + /// Helper to route messages for forwarding + static std::vector routeForwardedMessages(FairMQDeviceProxy& proxy, TimesliceSlot slot, std::vector& currentSetOfInputs, + TimesliceIndex::OldestOutputInfo oldestTimeslice, bool copy, bool consume = true); }; - } // namespace o2::framework #endif // O2_FRAMEWORK_DATAPROCESSINGHELPERS_H_ diff --git a/Framework/Core/src/DataProcessingDevice.cxx b/Framework/Core/src/DataProcessingDevice.cxx index 3b430378dc0b0..ec03aaf97d078 100644 --- a/Framework/Core/src/DataProcessingDevice.cxx +++ b/Framework/Core/src/DataProcessingDevice.cxx @@ -550,76 +550,6 @@ void on_signal_callback(uv_signal_t* handle, int signum) O2_SIGNPOST_END(device, sid, "signal_state", "Done processing signals."); } -static auto toBeForwardedHeader = [](void* header) -> bool { - // If is now possible that the record is not complete when - // we forward it, because of a custom completion policy. - // this means that we need to skip the empty entries in the - // record for being forwarded. - if (header == nullptr) { - return false; - } - auto sih = o2::header::get(header); - if (sih) { - return false; - } - - auto dih = o2::header::get(header); - if (dih) { - return false; - } - - auto dh = o2::header::get(header); - if (!dh) { - return false; - } - auto dph = o2::header::get(header); - if (!dph) { - return false; - } - return true; -}; - -static auto toBeforwardedMessageSet = [](std::vector& cachedForwardingChoices, - FairMQDeviceProxy& proxy, - std::unique_ptr& header, - std::unique_ptr& payload, - size_t total, - bool consume) { - if (header.get() == nullptr) { - // Missing an header is not an error anymore. - // it simply means that we did not receive the - // given input, but we were asked to - // consume existing, so we skip it. - return false; - } - if (payload.get() == nullptr && consume == true) { - // If the payload is not there, it means we already - // processed it with ConsumeExisiting. Therefore we - // need to do something only if this is the last consume. - header.reset(nullptr); - return false; - } - - auto fdph = o2::header::get(header->GetData()); - if (fdph == nullptr) { - LOG(error) << "Data is missing DataProcessingHeader"; - return false; - } - auto fdh = o2::header::get(header->GetData()); - if (fdh == nullptr) { - LOG(error) << "Data is missing DataHeader"; - return false; - } - - // We need to find the forward route only for the first - // part of a split payload. All the others will use the same. - // but always check if we have a sequence of multiple payloads - if (fdh->splitPayloadIndex == 0 || fdh->splitPayloadParts <= 1 || total > 1) { - proxy.getMatchingForwardChannelIndexes(cachedForwardingChoices, *fdh, fdph->startTime); - } - return cachedForwardingChoices.empty() == false; -}; - struct DecongestionContext { ServiceRegistryRef ref; TimesliceIndex::OldestOutputInfo oldestTimeslice; @@ -660,67 +590,9 @@ auto decongestionCallbackLate = [](AsyncTask& task, size_t aid) -> void { static auto forwardInputs = [](ServiceRegistryRef registry, TimesliceSlot slot, std::vector& currentSetOfInputs, TimesliceIndex::OldestOutputInfo oldestTimeslice, bool copy, bool consume = true) { auto& proxy = registry.get(); - // we collect all messages per forward in a map and send them together - std::vector forwardedParts; - forwardedParts.resize(proxy.getNumForwards()); - std::vector cachedForwardingChoices{}; - O2_SIGNPOST_ID_GENERATE(sid, forwarding); - O2_SIGNPOST_START(forwarding, sid, "forwardInputs", "Starting forwarding for slot %zu with oldestTimeslice %zu %{public}s%{public}s%{public}s", - slot.index, oldestTimeslice.timeslice.value, copy ? "with copy" : "", copy && consume ? " and " : "", consume ? "with consume" : ""); - - for (size_t ii = 0, ie = currentSetOfInputs.size(); ii < ie; ++ii) { - auto& messageSet = currentSetOfInputs[ii]; - // In case the messageSet is empty, there is nothing to be done. - if (messageSet.size() == 0) { - continue; - } - if (!toBeForwardedHeader(messageSet.header(0)->GetData())) { - continue; - } - cachedForwardingChoices.clear(); - - for (size_t pi = 0; pi < currentSetOfInputs[ii].size(); ++pi) { - auto& messageSet = currentSetOfInputs[ii]; - auto& header = messageSet.header(pi); - auto& payload = messageSet.payload(pi); - auto total = messageSet.getNumberOfPayloads(pi); + auto forwardedParts = DataProcessingHelpers::routeForwardedMessages(proxy, slot, currentSetOfInputs, oldestTimeslice, copy); - if (!toBeforwardedMessageSet(cachedForwardingChoices, proxy, header, payload, total, consume)) { - continue; - } - - // In case of more than one forward route, we need to copy the message. - // This will eventually use the same mamory if running with the same backend. - if (cachedForwardingChoices.size() > 1) { - copy = true; - } - auto* dh = o2::header::get(header->GetData()); - auto* dph = o2::header::get(header->GetData()); - - if (copy) { - for (auto& cachedForwardingChoice : cachedForwardingChoices) { - auto&& newHeader = header->GetTransport()->CreateMessage(); - O2_SIGNPOST_EVENT_EMIT(forwarding, sid, "forwardInputs", "Forwarding a copy of %{public}s to route %d.", - fmt::format("{}/{}/{}@timeslice:{} tfCounter:{}", dh->dataOrigin, dh->dataDescription, dh->subSpecification, dph->startTime, dh->tfCounter).c_str(), cachedForwardingChoice.value); - newHeader->Copy(*header); - forwardedParts[cachedForwardingChoice.value].AddPart(std::move(newHeader)); - - for (size_t payloadIndex = 0; payloadIndex < messageSet.getNumberOfPayloads(pi); ++payloadIndex) { - auto&& newPayload = header->GetTransport()->CreateMessage(); - newPayload->Copy(*messageSet.payload(pi, payloadIndex)); - forwardedParts[cachedForwardingChoice.value].AddPart(std::move(newPayload)); - } - } - } else { - O2_SIGNPOST_EVENT_EMIT(forwarding, sid, "forwardInputs", "Forwarding %{public}s to route %d.", - fmt::format("{}/{}/{}@timeslice:{} tfCounter:{}", dh->dataOrigin, dh->dataDescription, dh->subSpecification, dph->startTime, dh->tfCounter).c_str(), cachedForwardingChoices.back().value); - forwardedParts[cachedForwardingChoices.back().value].AddPart(std::move(messageSet.header(pi))); - for (size_t payloadIndex = 0; payloadIndex < messageSet.getNumberOfPayloads(pi); ++payloadIndex) { - forwardedParts[cachedForwardingChoices.back().value].AddPart(std::move(messageSet.payload(pi, payloadIndex))); - } - } - } - } + O2_SIGNPOST_ID_GENERATE(sid, forwarding); O2_SIGNPOST_EVENT_EMIT(forwarding, sid, "forwardInputs", "Forwarding %zu messages", forwardedParts.size()); for (int fi = 0; fi < proxy.getNumForwardChannels(); fi++) { if (forwardedParts[fi].Size() == 0) { diff --git a/Framework/Core/src/DataProcessingHelpers.cxx b/Framework/Core/src/DataProcessingHelpers.cxx index e144f426372b1..02d5a6c03845e 100644 --- a/Framework/Core/src/DataProcessingHelpers.cxx +++ b/Framework/Core/src/DataProcessingHelpers.cxx @@ -16,6 +16,7 @@ #include "MemoryResources/MemoryResources.h" #include "Framework/FairMQDeviceProxy.h" #include "Headers/DataHeader.h" +#include "Headers/DataHeaderHelpers.h" #include "Headers/Stack.h" #include "Framework/Logger.h" #include "Framework/SendingPolicy.h" @@ -31,6 +32,8 @@ #include "Framework/ControlService.h" #include "Framework/DataProcessingContext.h" #include "Framework/DeviceStateEnums.h" +#include "Headers/DataHeader.h" +#include "Framework/DataProcessingHeader.h" #include #include @@ -41,6 +44,7 @@ O2_DECLARE_DYNAMIC_LOG(device); // Stream which keeps track of the calibration lifetime logic O2_DECLARE_DYNAMIC_LOG(calibration); +O2_DECLARE_DYNAMIC_LOG(forwarding); namespace o2::framework { @@ -217,4 +221,129 @@ TransitionHandlingState DataProcessingHelpers::updateStateTransition(ServiceRegi } } +auto DataProcessingHelpers::routeForwardedMessages(FairMQDeviceProxy& proxy, TimesliceSlot slot, + std::vector& currentSetOfInputs, TimesliceIndex::OldestOutputInfo oldestTimeslice, + const bool copyByDefault, bool consume) -> std::vector +{ + // we collect all messages per forward in a map and send them together + std::vector forwardedParts; + forwardedParts.resize(proxy.getNumForwards()); + std::vector forwardingChoices{}; + O2_SIGNPOST_ID_GENERATE(sid, forwarding); + O2_SIGNPOST_START(forwarding, sid, "forwardInputs", "Starting forwarding for slot %zu with oldestTimeslice %zu %{public}s%{public}s%{public}s", + slot.index, oldestTimeslice.timeslice.value, copyByDefault ? "with copy" : "", copyByDefault && consume ? " and " : "", consume ? "with consume" : ""); + + for (size_t ii = 0, ie = currentSetOfInputs.size(); ii < ie; ++ii) { + auto& messageSet = currentSetOfInputs[ii]; + + for (size_t pi = 0; pi < messageSet.size(); ++pi) { + auto& header = messageSet.header(pi); + + // If is now possible that the record is not complete when + // we forward it, because of a custom completion policy. + // this means that we need to skip the empty entries in the + // record for being forwarded. + if (header->GetData() == nullptr) { + continue; + } + + auto dph = o2::header::get(header->GetData()); + auto dh = o2::header::get(header->GetData()); + + if (dph == nullptr || dh == nullptr) { + // Complain only if this is not an out-of-band message + auto dih = o2::header::get(header->GetData()); + auto sih = o2::header::get(header->GetData()); + if (dih == nullptr || sih == nullptr) { + LOGP(error, "Data is missing {}{}{}", + dph ? "DataProcessingHeader" : "", dph || dh ? "and" : "", dh ? "DataHeader" : ""); + } + continue; + } + + auto& payload = messageSet.payload(pi); + + if (payload.get() == nullptr && consume == true) { + // If the payload is not there, it means we already + // processed it with ConsumeExisiting. Therefore we + // need to do something only if this is the last consume. + header.reset(nullptr); + continue; + } + + // We need to find the forward route only for the first + // part of a split payload. All the others will use the same. + // Therefore, we reset and recompute the forwarding choice: + // + // - If this is the first payload of a [header0][payload0][header0][payload1] sequence, + // which is actually always created and handled together + // - If the message is not a multipart (splitPayloadParts 0) or has only one part + // - If it's a message of the kind [header0][payload1][payload2][payload3]... and therefore + // we will already use the same choice in the for loop below. + if (dh->splitPayloadIndex == 0 || dh->splitPayloadParts <= 1 || messageSet.getNumberOfPayloads(pi) > 0) { + forwardingChoices.clear(); + proxy.getMatchingForwardChannelIndexes(forwardingChoices, *dh, dph->startTime); + } + + if (forwardingChoices.empty()) { + // Nothing to forward go to the next messageset + continue; + } + + // In case of more than one forward route, we need to copy the message. + // This will eventually use the same memory if running with the same backend. + if (copyByDefault || forwardingChoices.size()) { + for (auto& choice : forwardingChoices) { + auto&& newHeader = header->GetTransport()->CreateMessage(); + O2_SIGNPOST_EVENT_EMIT(forwarding, sid, "forwardInputs", "Forwarding a copy of %{public}s to route %d.", + fmt::format("{}/{}/{}@timeslice:{} tfCounter:{}", dh->dataOrigin, dh->dataDescription, dh->subSpecification, dph->startTime, dh->tfCounter).c_str(), choice.value); + newHeader->Copy(*header); + auto dih = o2::header::get(newHeader->GetData()); + if (dih) { + const_cast(dih)->flagsDisabled = 1; + } + auto sih = o2::header::get(newHeader->GetData()); + if (sih) { + const_cast(sih)->flagsDisabled = 1; + } + forwardedParts[choice.value].AddPart(std::move(newHeader)); + + for (size_t payloadIndex = 0; payloadIndex < messageSet.getNumberOfPayloads(pi); ++payloadIndex) { + auto&& newPayload = header->GetTransport()->CreateMessage(); + newPayload->Copy(*messageSet.payload(pi, payloadIndex)); + forwardedParts[choice.value].AddPart(std::move(newPayload)); + } + } + } else { + O2_SIGNPOST_EVENT_EMIT(forwarding, sid, "forwardInputs", "Forwarding %{public}s to route %d.", + fmt::format("{}/{}/{}@timeslice:{} tfCounter:{}", dh->dataOrigin, dh->dataDescription, dh->subSpecification, dph->startTime, dh->tfCounter).c_str(), forwardingChoices.back().value); + auto dih = o2::header::get(messageSet.header(pi)->GetData()); + auto sih = o2::header::get(messageSet.header(pi)->GetData()); + // We need to copy the header if it has extra timeframe accounting + // information attached to it, so that we can disable it without having + // a race condition in shared memory. + if (dih || sih) { + auto&& newHeader = header->GetTransport()->CreateMessage(); + newHeader->Copy(*header); + auto dih = o2::header::get(newHeader->GetData()); + if (dih) { + const_cast(dih)->flagsDisabled = 1; + } + auto sih = o2::header::get(newHeader->GetData()); + if (sih) { + const_cast(sih)->flagsDisabled = 1; + } + forwardedParts[forwardingChoices.back().value].AddPart(std::move(newHeader)); + } else { + forwardedParts[forwardingChoices.back().value].AddPart(std::move(messageSet.header(pi))); + } + for (size_t payloadIndex = 0; payloadIndex < messageSet.getNumberOfPayloads(pi); ++payloadIndex) { + forwardedParts[forwardingChoices.back().value].AddPart(std::move(messageSet.payload(pi, payloadIndex))); + } + } + } + } + return forwardedParts; +}; + } // namespace o2::framework diff --git a/Framework/Core/test/test_ForwardInputs.cxx b/Framework/Core/test/test_ForwardInputs.cxx new file mode 100644 index 0000000000000..cf9f933bbe176 --- /dev/null +++ b/Framework/Core/test/test_ForwardInputs.cxx @@ -0,0 +1,688 @@ +// Copyright 2019-2025 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +#include +#include "Headers/DataHeader.h" +#include "Framework/DataProcessingHeader.h" +#include "Framework/DataProcessingHelpers.h" +#include "Framework/SourceInfoHeader.h" +#include "Framework/DomainInfoHeader.h" +#include "Framework/ServiceRegistry.h" +#include "Framework/ServiceRegistryRef.h" +#include "Framework/Signpost.h" +#include "Framework/MessageSet.h" +#include "Framework/FairMQDeviceProxy.h" +#include "Headers/Stack.h" +#include "MemoryResources/MemoryResources.h" +#include +#include +#include + +O2_DECLARE_DYNAMIC_LOG(forwarding); +using namespace o2::framework; + +TEST_CASE("ForwardInputsEmpty") +{ + o2::header::DataHeader dh; + dh.dataDescription = "CLUSTERS"; + dh.dataOrigin = "TPC"; + dh.subSpecification = 0; + dh.splitPayloadIndex = 0; + dh.splitPayloadParts = 1; + + o2::framework::DataProcessingHeader dph{0, 1}; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + + bool consume = true; + bool copyByDefault = true; + FairMQDeviceProxy proxy; + + TimesliceIndex::OldestOutputInfo oldestTimeslice{.timeslice = {1}}; + std::vector currentSetOfInputs; + TimesliceSlot slot{0}; + + auto result = o2::framework::DataProcessingHelpers::routeForwardedMessages(proxy, slot, currentSetOfInputs, oldestTimeslice, copyByDefault, consume); + REQUIRE(result.empty()); +} + +TEST_CASE("ForwardInputsSingleMessageSingleRoute") +{ + o2::header::DataHeader dh; + dh.dataOrigin = "TST"; + dh.dataDescription = "A"; + dh.subSpecification = 0; + dh.splitPayloadIndex = 0; + dh.splitPayloadParts = 1; + + o2::framework::DataProcessingHeader dph{0, 1}; + std::vector channels{ + fair::mq::Channel("from_A_to_B")}; + + bool consume = true; + bool copyByDefault = true; + FairMQDeviceProxy proxy; + std::vector routes{ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding", ConcreteDataMatcher{"TST", "A", 0}}, + .channel = "from_A_to_B", + .policy = nullptr, + }}; + + auto findChannelByName = [&channels](std::string const& channelName) -> fair::mq::Channel& { + for (auto& channel : channels) { + if (channel.GetName() == channelName) { + return channel; + } + } + throw std::runtime_error("Channel not found"); + }; + + proxy.bind({}, {}, routes, findChannelByName, nullptr); + + TimesliceIndex::OldestOutputInfo oldestTimeslice{.timeslice = {0}}; + std::vector currentSetOfInputs; + MessageSet messageSet; + + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); + messageSet.add(PartRef{std::move(header), std::move(payload)}); + REQUIRE(messageSet.size() == 1); + currentSetOfInputs.emplace_back(std::move(messageSet)); + + TimesliceSlot slot{0}; + + auto result = o2::framework::DataProcessingHelpers::routeForwardedMessages(proxy, slot, currentSetOfInputs, oldestTimeslice, copyByDefault, consume); + REQUIRE(result.size() == 1); // One route + REQUIRE(result[0].Size() == 2); // Two messages for that route +} + +TEST_CASE("ForwardInputsSingleMessageSingleRouteAtEOS") +{ + o2::header::DataHeader dh; + dh.dataOrigin = "TST"; + dh.dataDescription = "A"; + dh.subSpecification = 0; + dh.splitPayloadIndex = 0; + dh.splitPayloadParts = 1; + + o2::framework::DataProcessingHeader dph{0, 1}; + + o2::framework::SourceInfoHeader sih{}; + + std::vector channels{ + fair::mq::Channel("from_A_to_B")}; + + bool consume = true; + bool copyByDefault = true; + FairMQDeviceProxy proxy; + std::vector routes{ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding", ConcreteDataMatcher{"TST", "A", 0}}, + .channel = "from_A_to_B", + .policy = nullptr, + }}; + + auto findChannelByName = [&channels](std::string const& channelName) -> fair::mq::Channel& { + for (auto& channel : channels) { + if (channel.GetName() == channelName) { + return channel; + } + } + throw std::runtime_error("Channel not found"); + }; + + proxy.bind({}, {}, routes, findChannelByName, nullptr); + + TimesliceIndex::OldestOutputInfo oldestTimeslice{.timeslice = {0}}; + std::vector currentSetOfInputs; + MessageSet messageSet; + + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph, sih}); + REQUIRE(o2::header::get(header->GetData())); + messageSet.add(PartRef{std::move(header), std::move(payload)}); + REQUIRE(messageSet.size() == 1); + currentSetOfInputs.emplace_back(std::move(messageSet)); + + TimesliceSlot slot{0}; + + auto result = o2::framework::DataProcessingHelpers::routeForwardedMessages(proxy, slot, currentSetOfInputs, oldestTimeslice, copyByDefault, consume); + REQUIRE(result.size() == 1); // One route + REQUIRE(result[0].Size() == 2); + // Correct behavior below: + // REQUIRE(result[0].Size() == 2); + // REQUIRE(o2::header::get(result[0].At(0)->GetData()) == nullptr); +} + +TEST_CASE("ForwardInputsSingleMessageSingleRouteWithOldestPossible") +{ + o2::header::DataHeader dh; + dh.dataOrigin = "TST"; + dh.dataDescription = "A"; + dh.subSpecification = 0; + dh.splitPayloadIndex = 0; + dh.splitPayloadParts = 1; + + o2::framework::DataProcessingHeader dph{0, 1}; + + o2::framework::DomainInfoHeader dih{}; + + std::vector channels{ + fair::mq::Channel("from_A_to_B")}; + + bool consume = true; + bool copyByDefault = true; + FairMQDeviceProxy proxy; + std::vector routes{ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding", ConcreteDataMatcher{"TST", "A", 0}}, + .channel = "from_A_to_B", + .policy = nullptr, + }}; + + auto findChannelByName = [&channels](std::string const& channelName) -> fair::mq::Channel& { + for (auto& channel : channels) { + if (channel.GetName() == channelName) { + return channel; + } + } + throw std::runtime_error("Channel not found"); + }; + + proxy.bind({}, {}, routes, findChannelByName, nullptr); + + TimesliceIndex::OldestOutputInfo oldestTimeslice{.timeslice = {0}}; + std::vector currentSetOfInputs; + MessageSet messageSet; + + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph, dih}); + REQUIRE(o2::header::get(header->GetData())); + messageSet.add(PartRef{std::move(header), std::move(payload)}); + REQUIRE(messageSet.size() == 1); + currentSetOfInputs.emplace_back(std::move(messageSet)); + + TimesliceSlot slot{0}; + + auto result = o2::framework::DataProcessingHelpers::routeForwardedMessages(proxy, slot, currentSetOfInputs, oldestTimeslice, copyByDefault, consume); + REQUIRE(result.size() == 1); // One route + REQUIRE(result[0].Size() == 2); + REQUIRE(o2::header::get(result[0].At(0)->GetData()) == nullptr); // it should not have the end of stream +} + +TEST_CASE("ForwardInputsSingleMessageMultipleRoutes") +{ + o2::header::DataHeader dh; + dh.dataOrigin = "TST"; + dh.dataDescription = "A"; + dh.subSpecification = 0; + dh.splitPayloadIndex = 0; + dh.splitPayloadParts = 1; + + o2::framework::DataProcessingHeader dph{0, 1}; + + std::vector channels{ + fair::mq::Channel("from_A_to_B"), + fair::mq::Channel("from_A_to_C"), + }; + + bool consume = true; + bool copyByDefault = true; + FairMQDeviceProxy proxy; + std::vector routes{ + ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding", ConcreteDataMatcher{"TST", "A", 0}}, + .channel = "from_A_to_B", + .policy = nullptr, + }, + ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding2", ConcreteDataMatcher{"TST", "A", 0}}, + .channel = "from_A_to_C", + .policy = nullptr, + }}; + + auto findChannelByName = [&channels](std::string const& channelName) -> fair::mq::Channel& { + for (auto& channel : channels) { + if (channel.GetName() == channelName) { + return channel; + } + } + throw std::runtime_error("Channel not found"); + }; + + proxy.bind({}, {}, routes, findChannelByName, nullptr); + + TimesliceIndex::OldestOutputInfo oldestTimeslice{.timeslice = {0}}; + std::vector currentSetOfInputs; + MessageSet messageSet; + + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); + messageSet.add(PartRef{std::move(header), std::move(payload)}); + REQUIRE(messageSet.size() == 1); + currentSetOfInputs.emplace_back(std::move(messageSet)); + + TimesliceSlot slot{0}; + + auto result = o2::framework::DataProcessingHelpers::routeForwardedMessages(proxy, slot, currentSetOfInputs, oldestTimeslice, copyByDefault, consume); + REQUIRE(result.size() == 2); // Two routes + REQUIRE(result[0].Size() == 2); // Two messages per route + REQUIRE(result[1].Size() == 0); // Only the first DPL matched channel matters +} + +TEST_CASE("ForwardInputsSingleMessageMultipleRoutesExternals") +{ + o2::header::DataHeader dh; + dh.dataOrigin = "TST"; + dh.dataDescription = "A"; + dh.subSpecification = 0; + dh.splitPayloadIndex = 0; + dh.splitPayloadParts = 1; + + o2::framework::DataProcessingHeader dph{0, 1}; + + std::vector channels{ + fair::mq::Channel("external"), + fair::mq::Channel("from_A_to_C"), + }; + + bool consume = true; + bool copyByDefault = true; + FairMQDeviceProxy proxy; + std::vector routes{ + ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding", ConcreteDataMatcher{"TST", "A", 0}}, + .channel = "external", + .policy = nullptr, + }, + ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding2", ConcreteDataMatcher{"TST", "A", 0}}, + .channel = "from_A_to_C", + .policy = nullptr, + }}; + + auto findChannelByName = [&channels](std::string const& channelName) -> fair::mq::Channel& { + for (auto& channel : channels) { + if (channel.GetName() == channelName) { + return channel; + } + } + throw std::runtime_error("Channel not found"); + }; + + proxy.bind({}, {}, routes, findChannelByName, nullptr); + + TimesliceIndex::OldestOutputInfo oldestTimeslice{.timeslice = {0}}; + std::vector currentSetOfInputs; + MessageSet messageSet; + + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); + messageSet.add(PartRef{std::move(header), std::move(payload)}); + REQUIRE(messageSet.size() == 1); + currentSetOfInputs.emplace_back(std::move(messageSet)); + + TimesliceSlot slot{0}; + + auto result = o2::framework::DataProcessingHelpers::routeForwardedMessages(proxy, slot, currentSetOfInputs, oldestTimeslice, copyByDefault, consume); + REQUIRE(result.size() == 2); // Two routes + REQUIRE(result[0].Size() == 2); // With external matching channels, we need to copy and then forward + REQUIRE(result[1].Size() == 2); // +} + +TEST_CASE("ForwardInputsMultiMessageMultipleRoutes") +{ + o2::header::DataHeader dh1; + dh1.dataOrigin = "TST"; + dh1.dataDescription = "A"; + dh1.subSpecification = 0; + dh1.splitPayloadIndex = 0; + dh1.splitPayloadParts = 1; + + o2::header::DataHeader dh2; + dh2.dataOrigin = "TST"; + dh2.dataDescription = "B"; + dh2.subSpecification = 0; + dh2.splitPayloadIndex = 0; + dh2.splitPayloadParts = 1; + + o2::framework::DataProcessingHeader dph{0, 1}; + + std::vector channels{ + fair::mq::Channel("from_A_to_B"), + fair::mq::Channel("from_A_to_C"), + }; + + bool consume = true; + bool copyByDefault = true; + FairMQDeviceProxy proxy; + std::vector routes{ + ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding", ConcreteDataMatcher{"TST", "A", 0}}, + .channel = "from_A_to_B", + .policy = nullptr, + }, + ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding2", ConcreteDataMatcher{"TST", "B", 0}}, + .channel = "from_A_to_C", + .policy = nullptr, + }}; + + auto findChannelByName = [&channels](std::string const& channelName) -> fair::mq::Channel& { + for (auto& channel : channels) { + if (channel.GetName() == channelName) { + return channel; + } + } + throw std::runtime_error("Channel not found"); + }; + + proxy.bind({}, {}, routes, findChannelByName, nullptr); + + TimesliceIndex::OldestOutputInfo oldestTimeslice{.timeslice = {0}}; + std::vector currentSetOfInputs; + + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload1(transport->CreateMessage()); + fair::mq::MessagePtr payload2(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + auto header1 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh1, dph}); + MessageSet messageSet1; + messageSet1.add(PartRef{std::move(header1), std::move(payload1)}); + REQUIRE(messageSet1.size() == 1); + + auto header2 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh2, dph}); + MessageSet messageSet2; + messageSet2.add(PartRef{std::move(header2), std::move(payload2)}); + REQUIRE(messageSet2.size() == 1); + currentSetOfInputs.emplace_back(std::move(messageSet1)); + currentSetOfInputs.emplace_back(std::move(messageSet2)); + REQUIRE(currentSetOfInputs.size() == 2); + + TimesliceSlot slot{0}; + + auto result = o2::framework::DataProcessingHelpers::routeForwardedMessages(proxy, slot, currentSetOfInputs, oldestTimeslice, copyByDefault, consume); + REQUIRE(result.size() == 2); // Two routes + REQUIRE(result[0].Size() == 2); // + REQUIRE(result[1].Size() == 2); // +} + +TEST_CASE("ForwardInputsSingleMessageMultipleRoutesOnlyOneMatches") +{ + o2::header::DataHeader dh; + dh.dataOrigin = "TST"; + dh.dataDescription = "A"; + dh.subSpecification = 0; + dh.splitPayloadIndex = 0; + dh.splitPayloadParts = 1; + + o2::framework::DataProcessingHeader dph{0, 1}; + + std::vector channels{ + fair::mq::Channel("from_A_to_B"), + fair::mq::Channel("from_A_to_C"), + }; + + bool consume = true; + bool copyByDefault = true; + FairMQDeviceProxy proxy; + std::vector routes{ + ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding", ConcreteDataMatcher{"TST", "B", 0}}, + .channel = "from_A_to_B", + .policy = nullptr, + }, + ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding", ConcreteDataMatcher{"TST", "A", 0}}, + .channel = "from_A_to_C", + .policy = nullptr, + }}; + + auto findChannelByName = [&channels](std::string const& channelName) -> fair::mq::Channel& { + for (auto& channel : channels) { + if (channel.GetName() == channelName) { + return channel; + } + } + throw std::runtime_error("Channel not found"); + }; + + proxy.bind({}, {}, routes, findChannelByName, nullptr); + + TimesliceIndex::OldestOutputInfo oldestTimeslice{.timeslice = {0}}; + std::vector currentSetOfInputs; + MessageSet messageSet; + + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); + messageSet.add(PartRef{std::move(header), std::move(payload)}); + REQUIRE(messageSet.size() == 1); + currentSetOfInputs.emplace_back(std::move(messageSet)); + + TimesliceSlot slot{0}; + + auto result = o2::framework::DataProcessingHelpers::routeForwardedMessages(proxy, slot, currentSetOfInputs, oldestTimeslice, copyByDefault, consume); + REQUIRE(result.size() == 2); // Two routes + REQUIRE(result[0].Size() == 0); // Two messages per route + REQUIRE(result[1].Size() == 2); // Two messages per route +} + +TEST_CASE("ForwardInputsSplitPayload") +{ + o2::header::DataHeader dh; + dh.dataOrigin = "TST"; + dh.dataDescription = "A"; + dh.subSpecification = 0; + dh.splitPayloadIndex = 0; + dh.splitPayloadParts = 2; + + o2::header::DataHeader dh2; + dh2.dataOrigin = "TST"; + dh2.dataDescription = "B"; + dh2.subSpecification = 0; + dh2.splitPayloadIndex = 0; + dh2.splitPayloadParts = 1; + + o2::framework::DataProcessingHeader dph{0, 1}; + + std::vector channels{ + fair::mq::Channel("from_A_to_B"), + fair::mq::Channel("from_A_to_C"), + }; + + bool consume = true; + bool copyByDefault = true; + FairMQDeviceProxy proxy; + std::vector routes{ + ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding", ConcreteDataMatcher{"TST", "B", 0}}, + .channel = "from_A_to_B", + .policy = nullptr, + }, + ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding", ConcreteDataMatcher{"TST", "A", 0}}, + .channel = "from_A_to_C", + .policy = nullptr, + }}; + + auto findChannelByName = [&channels](std::string const& channelName) -> fair::mq::Channel& { + for (auto& channel : channels) { + if (channel.GetName() == channelName) { + return channel; + } + } + throw std::runtime_error("Channel not found"); + }; + + proxy.bind({}, {}, routes, findChannelByName, nullptr); + + TimesliceIndex::OldestOutputInfo oldestTimeslice{.timeslice = {0}}; + std::vector currentSetOfInputs; + MessageSet messageSet; + + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload1(transport->CreateMessage()); + fair::mq::MessagePtr payload2(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); + std::vector> messages; + messages.push_back(std::move(header)); + messages.push_back(std::move(payload1)); + messages.push_back(std::move(payload2)); + auto fillMessages = [&messages](size_t t) -> fair::mq::MessagePtr { + return std::move(messages[t]); + }; + messageSet.add(fillMessages, 3); + auto header2 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh2, dph}); + PartRef part{std::move(header2), transport->CreateMessage()}; + messageSet.add(std::move(part)); + + REQUIRE(messageSet.size() == 2); + currentSetOfInputs.emplace_back(std::move(messageSet)); + + TimesliceSlot slot{0}; + + auto result = o2::framework::DataProcessingHelpers::routeForwardedMessages(proxy, slot, currentSetOfInputs, oldestTimeslice, copyByDefault, consume); + REQUIRE(result.size() == 2); // Two routes + CHECK(result[0].Size() == 2); // No messages on this route + CHECK(result[1].Size() == 3); +} + +TEST_CASE("ForwardInputEOSSingleRoute") +{ + o2::framework::SourceInfoHeader sih{}; + + std::vector channels{ + fair::mq::Channel("from_A_to_B")}; + + bool consume = true; + bool copyByDefault = true; + FairMQDeviceProxy proxy; + std::vector routes{ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding", ConcreteDataMatcher{"TST", "A", 0}}, + .channel = "from_A_to_B", + .policy = nullptr, + }}; + + auto findChannelByName = [&channels](std::string const& channelName) -> fair::mq::Channel& { + for (auto& channel : channels) { + if (channel.GetName() == channelName) { + return channel; + } + } + throw std::runtime_error("Channel not found"); + }; + + proxy.bind({}, {}, routes, findChannelByName, nullptr); + + TimesliceIndex::OldestOutputInfo oldestTimeslice{.timeslice = {0}}; + std::vector currentSetOfInputs; + MessageSet messageSet; + + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, sih}); + messageSet.add(PartRef{std::move(header), std::move(payload)}); + REQUIRE(messageSet.size() == 1); + currentSetOfInputs.emplace_back(std::move(messageSet)); + + TimesliceSlot slot{0}; + + auto result = o2::framework::DataProcessingHelpers::routeForwardedMessages(proxy, slot, currentSetOfInputs, oldestTimeslice, copyByDefault, consume); + REQUIRE(result.size() == 1); // One route + REQUIRE(result[0].Size() == 0); // Oldest possible timeframe should not be forwarded +} + +TEST_CASE("ForwardInputOldestPossibleSingleRoute") +{ + o2::framework::DomainInfoHeader dih{}; + + std::vector channels{ + fair::mq::Channel("from_A_to_B")}; + + bool consume = true; + bool copyByDefault = true; + FairMQDeviceProxy proxy; + std::vector routes{ForwardRoute{ + .timeslice = 0, + .maxTimeslices = 1, + .matcher = {"binding", ConcreteDataMatcher{"TST", "A", 0}}, + .channel = "from_A_to_B", + .policy = nullptr, + }}; + + auto findChannelByName = [&channels](std::string const& channelName) -> fair::mq::Channel& { + for (auto& channel : channels) { + if (channel.GetName() == channelName) { + return channel; + } + } + throw std::runtime_error("Channel not found"); + }; + + proxy.bind({}, {}, routes, findChannelByName, nullptr); + + TimesliceIndex::OldestOutputInfo oldestTimeslice{.timeslice = {0}}; + std::vector currentSetOfInputs; + MessageSet messageSet; + + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dih}); + messageSet.add(PartRef{std::move(header), std::move(payload)}); + REQUIRE(messageSet.size() == 1); + currentSetOfInputs.emplace_back(std::move(messageSet)); + + TimesliceSlot slot{0}; + + auto result = o2::framework::DataProcessingHelpers::routeForwardedMessages(proxy, slot, currentSetOfInputs, oldestTimeslice, copyByDefault, consume); + REQUIRE(result.size() == 1); // One route + REQUIRE(result[0].Size() == 0); // Oldest possible timeframe should not be forwarded +}