@@ -446,6 +446,37 @@ bool isInThreadMappedScope(
446446 return false ;
447447}
448448
449+ static std::vector<std::pair<isl::id, TensorGroupsInfo>> sortTensorGroupMap (
450+ TensorGroups&& groupMap) {
451+ // Prepare groups for sorting, to have specified order necessary for
452+ // reproducibility and tests.
453+ using TensorGroupList = std::pair<isl::id, TensorGroupsInfo>;
454+ std::vector<TensorGroupList> groupLists (
455+ std::make_move_iterator (groupMap.begin ()),
456+ std::make_move_iterator (groupMap.end ()));
457+
458+ // Computes the total number of references in all groups.
459+ auto refsCount = [](const TensorGroupsInfo& info) {
460+ size_t refs = 0 ;
461+ for (auto const & group : info) {
462+ refs += group->referenceIds ().size ();
463+ }
464+ return refs;
465+ };
466+
467+ // Sort by the total number of references, then by name. Because names are
468+ // guarenteed to be unique, the order is total.
469+ std::sort (
470+ groupLists.begin (),
471+ groupLists.end (),
472+ [refsCount](const TensorGroupList& l1, const TensorGroupList& l2) {
473+ auto r1 = refsCount (l1.second );
474+ auto r2 = refsCount (l2.second );
475+ return r1 == r2 ? l1.first .get_name () < l2.first .get_name () : r1 < r2;
476+ });
477+ return groupLists;
478+ }
479+
449480/*
450481 * Promote to shared memory in "scop" below "node". Use at most
451482 * "remainingMemory" bytes, and update the variable to reflect the amount of
@@ -474,37 +505,11 @@ void promoteToSharedBelow(
474505 auto partialSched = partialSchedule (root, node);
475506 auto mapping = collectMappingsTo<mapping::BlockId>(scop);
476507
477- auto groupMap = TensorReferenceGroup::accessedWithin (
478- partialSched.intersect_domain (mapping), scop.body );
508+ auto groupLists = sortTensorGroupMap ( TensorReferenceGroup::accessedWithin (
509+ partialSched.intersect_domain (mapping), scop.body )) ;
479510 // Pure affine schedule without (mapping) filters.
480511 auto partialSchedMupa = partialScheduleMupa (root, node);
481512
482- // Prepare groups for sorting, to have specified order necessary for
483- // reproducibility and tests.
484- using TensorGroupList = std::pair<isl::id, TensorGroupsInfo>;
485- std::vector<TensorGroupList> groupLists (
486- std::make_move_iterator (groupMap.begin ()),
487- std::make_move_iterator (groupMap.end ()));
488-
489- // Computes the total number of references in all groups.
490- auto refsCount = [](const TensorGroupsInfo& info) {
491- size_t refs = 0 ;
492- for (auto const & group : info) {
493- refs += group->referenceIds ().size ();
494- }
495- return refs;
496- };
497-
498- // Sort by the total number of references, then by name. Because names are
499- // guarenteed to be unique, the order is total.
500- std::sort (
501- groupLists.begin (),
502- groupLists.end (),
503- [refsCount](const TensorGroupList& l1, const TensorGroupList& l2) {
504- auto r1 = refsCount (l1.second );
505- auto r2 = refsCount (l2.second );
506- return r1 == r2 ? l1.first .get_name () < l2.first .get_name () : r1 < r2;
507- });
508513 for (auto & tensorGroups : groupLists) {
509514 auto tensorId = tensorGroups.first ;
510515 // Sort the reference groups to prioritize groups with more references as
0 commit comments