@@ -477,6 +477,20 @@ static std::vector<std::pair<isl::id, TensorGroupsInfo>> sortTensorGroupMap(
477477 return groupLists;
478478}
479479
480+ /* Sorts the given vector of tensor groups in place following the number of
481+ * references in the group in decreasing order. This prioritize groups with
482+ * more references as they are more likely to benefit from promotion.
483+ */
484+ static void sortTensorGroups (TensorGroupsInfo& tensorGroups) {
485+ std::sort (
486+ tensorGroups.begin (),
487+ tensorGroups.end (),
488+ [](const std::unique_ptr<TensorReferenceGroup>& group1,
489+ const std::unique_ptr<TensorReferenceGroup>& group2) {
490+ return group1->referenceIds ().size () > group2->referenceIds ().size ();
491+ });
492+ }
493+
480494/*
481495 * Promote to shared memory in "scop" below "node". Use at most
482496 * "remainingMemory" bytes, and update the variable to reflect the amount of
@@ -512,15 +526,7 @@ void promoteToSharedBelow(
512526
513527 for (auto & tensorGroups : groupLists) {
514528 auto tensorId = tensorGroups.first ;
515- // Sort the reference groups to prioritize groups with more references as
516- // they are more likely to benefit from promotion.
517- std::sort (
518- tensorGroups.second .begin (),
519- tensorGroups.second .end (),
520- [](const std::unique_ptr<TensorReferenceGroup>& group1,
521- const std::unique_ptr<TensorReferenceGroup>& group2) {
522- return group1->referenceIds ().size () > group2->referenceIds ().size ();
523- });
529+ sortTensorGroups (tensorGroups.second );
524530
525531 for (auto & group : tensorGroups.second ) {
526532 auto sizes = group->approximationSizes ();
0 commit comments