diff --git a/include/pysa/branching/branching.hpp b/include/pysa/branching/branching.hpp index 7743d0b..e8a53c3 100644 --- a/include/pysa/branching/branching.hpp +++ b/include/pysa/branching/branching.hpp @@ -40,6 +40,10 @@ specific language governing permissions and limitations under the License. namespace pysa::branching { +enum struct BranchResult{ + BRANCHOK, // default exit status + BRANCHEXIT // terminate all threads +}; /** * @brief Split a collection of branches into two. * Branches is a container that supports efficient .front() and .pop_front() @@ -96,7 +100,7 @@ void branching_impl(const Function &fn, Branches &branches, // Define core auto core_ = [fn = fn, &branches](std::size_t idx, auto &&stop) { - fn(branches[idx], stop); + return fn(branches[idx], stop); }; // Initialize threads @@ -133,12 +137,30 @@ void branching_impl(const Function &fn, Branches &branches, return std::tuple{min_, max_}; }; + bool exit_all = false; // Avoid a race condition where count_n_branches_() may be 0 temporarily at // start std::this_thread::sleep_for(sleep_time); // Keep going if there are still branches or the stop signal is off while (count_n_branches_() && !*stop) { + for (auto& t: threads_){ + if(t.is_ready()){ + if(t.get() == BranchResult::BRANCHEXIT){ + exit_all = true; + } + } + } + if(exit_all){ +#ifndef NDEBUG + std::cerr << "# Brancher exit " << std::endl; +#endif + for (auto& t: threads_){ + if(t.is_running()) + t.stop(); + } + return; + } // Propagate branches between two threads if (const auto [ei_, ni_] = balance_indexes_(); ei_ && ni_) { const auto e_idx_ = ei_.value(); diff --git a/include/pysa/dpll/dpll.hpp b/include/pysa/dpll/dpll.hpp index 882070e..547661b 100644 --- a/include/pysa/dpll/dpll.hpp +++ b/include/pysa/dpll/dpll.hpp @@ -23,8 +23,8 @@ specific language governing permissions and limitations under the License. namespace pysa::branching { -template -void DPLL_(Branches &&branches, Collect &&collect, ConstStopPtr stop) { +template +BranchResult DPLL_(Branches &&branches, Collect &&collect, ConstStopPtr stop) { // While there are still branches ... while (std::size(branches) && !*stop) { // Get last branch (depth first) @@ -45,11 +45,18 @@ void DPLL_(Branches &&branches, Collect &&collect, ConstStopPtr stop) { branches.splice(std::end(branches), branch_.branch()); // Collect - collect(std::move(branch_)); + if constexpr (exit_on_first){ + if(collect(std::move(branch_))){ + return BranchResult::BRANCHEXIT; + } + } else { + collect(std::move(branch_)); + } } + return BranchResult::BRANCHOK; } -template auto DPLL(Branches &&branches, Collect &&collect, Args &&...args) { /* @@ -59,7 +66,7 @@ auto DPLL(Branches &&branches, Collect &&collect, Args &&...args) { // Get brancher return branching( [collect](auto &&branches, auto &&stop) { - DPLL_(branches, collect, stop); + return DPLL_(branches, collect, stop); }, std::forward(branches), std::forward(args)...); } diff --git a/tests/test_branching.cpp b/tests/test_branching.cpp index 88d0744..0940f78 100644 --- a/tests/test_branching.cpp +++ b/tests/test_branching.cpp @@ -53,9 +53,10 @@ int main() { { // Use one thread pysa::branching::TestBranching(28, 1, true); - + pysa::branching::TestBranching(28, 1, true); // Use number of threads provided by the implementation pysa::branching::TestBranching(30, 0, true); + pysa::branching::TestBranching(30, 0, true); } #ifdef USE_MPI MPI_Barrier(mpi_comm_world); diff --git a/tests/test_branching.hpp b/tests/test_branching.hpp index 5580f6d..14062fa 100644 --- a/tests/test_branching.hpp +++ b/tests/test_branching.hpp @@ -62,6 +62,7 @@ struct Branch { using Branches = std::list; +template void TestBranching(const std::size_t n, const std::size_t n_threads = 0, const bool verbose = false) { // How to collect the branches @@ -72,11 +73,14 @@ void TestBranching(const std::size_t n, const std::size_t n_threads = 0, if (CheckBranch(branch)) { const std::scoped_lock lock_(mutex_); collected_.push_back(branch.state); + return true; + } else { + return false; } }; // Get branches - auto brancher_ = DPLL(Branches{Branch{n, 0, 0}}, collect_, n_threads); + auto brancher_ = DPLL(Branches{Branch{n, 0, 0}}, collect_, n_threads); // Start brancher auto it_ = std::chrono::high_resolution_clock::now(); @@ -98,18 +102,20 @@ void TestBranching(const std::size_t n, const std::size_t n_threads = 0, .count() << std::endl; - // Sort collected numbers - std::sort(std::begin(collected_), std::end(collected_)); + if constexpr (!stop_on_first){ + // Sort collected numbers + std::sort(std::begin(collected_), std::end(collected_)); - // Get head - auto head_ = std::cbegin(collected_); + // Get head + auto head_ = std::cbegin(collected_); - // Check results - for (std::size_t i_ = 0, end_ = std::size_t{1} << n; i_ < end_; ++i_) - if (CheckBranch(Branch{n, i_, 0})) assert(*head_++ == i_); + // Check results + for (std::size_t i_ = 0, end_ = std::size_t{1} << n; i_ < end_; ++i_) + if (CheckBranch(Branch{n, i_, 0})) assert(*head_++ == i_); - // All numbers should have been checked at this point - assert(head_ == std::cend(collected_)); + // All numbers should have been checked at this point + assert(head_ == std::cend(collected_)); + } } #ifdef USE_MPI