Skip to content

Commit e5d02dd

Browse files
authored
[CK][AMDGPU] Verify dominance when rewriting spills to registers (#641)
1 parent 4b33e3d commit e5d02dd

File tree

1 file changed

+88
-4
lines changed

1 file changed

+88
-4
lines changed

llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "llvm/CodeGen/LiveIntervals.h"
3131
#include "llvm/CodeGen/LiveRegMatrix.h"
3232
#include "llvm/CodeGen/LiveStacks.h"
33+
#include "llvm/CodeGen/MachineDominators.h"
3334
#include "llvm/CodeGen/MachineFrameInfo.h"
3435
#include "llvm/CodeGen/MachineFunctionPass.h"
3536
#include "llvm/CodeGen/VirtRegMap.h"
@@ -58,6 +59,7 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
5859
LiveIntervals &LIS;
5960
LiveStacks &LSS;
6061
const RegisterClassInfo &RegClassInfo;
62+
MachineDominatorTree &MDT;
6163

6264
bool attemptReassignmentsToAGPR(SmallSetVector<Register, 4> &InterferingRegs,
6365
MCPhysReg PrefPhysReg) const;
@@ -66,10 +68,11 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
6668
AMDGPURewriteAGPRCopyMFMAImpl(MachineFunction &MF, VirtRegMap &VRM,
6769
LiveRegMatrix &LRM, LiveIntervals &LIS,
6870
LiveStacks &LSS,
69-
const RegisterClassInfo &RegClassInfo)
71+
const RegisterClassInfo &RegClassInfo,
72+
MachineDominatorTree &MDT)
7073
: MF(MF), ST(MF.getSubtarget<GCNSubtarget>()), TII(*ST.getInstrInfo()),
7174
TRI(*ST.getRegisterInfo()), MRI(MF.getRegInfo()), VRM(VRM), LRM(LRM),
72-
LIS(LIS), LSS(LSS), RegClassInfo(RegClassInfo) {}
75+
LIS(LIS), LSS(LSS), RegClassInfo(RegClassInfo), MDT(MDT) {}
7376

7477
bool isRewriteCandidate(const MachineInstr &MI) const {
7578
return TII.isMAI(MI) && AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode()) != -1;
@@ -515,6 +518,82 @@ void AMDGPURewriteAGPRCopyMFMAImpl::eliminateSpillsOfReassignedVGPRs() const {
515518
if (SpillReferences == SpillSlotReferences.end())
516519
continue;
517520

521+
// For each spill reload, every path from entry to the reload must pass
522+
// through at least one spill store to the same stack slot.
523+
SmallVector<MachineInstr *, 4> Stores, Loads;
524+
Stores.reserve(SpillReferences->second.size());
525+
Loads.reserve(SpillReferences->second.size());
526+
for (MachineInstr *MI : SpillReferences->second) {
527+
if (MI->mayStore())
528+
Stores.push_back(MI);
529+
else if (MI->mayLoad())
530+
Loads.push_back(MI);
531+
}
532+
533+
SmallPtrSet<MachineBasicBlock *, 4> StoreBlocks;
534+
for (MachineInstr *S : Stores)
535+
if (MDT.isReachableFromEntry(S->getParent()))
536+
StoreBlocks.insert(S->getParent());
537+
538+
if (StoreBlocks.empty()) {
539+
LLVM_DEBUG(dbgs() << "Skipping " << printReg(Slot, &TRI)
540+
<< ": no reachable stores\n");
541+
continue;
542+
}
543+
544+
// Compute blocks reachable from entry without passing through a store
545+
// block.
546+
SmallPtrSet<MachineBasicBlock *, 16> StoreFreeReachable;
547+
SmallVector<MachineBasicBlock *, 16> Worklist;
548+
549+
MachineBasicBlock &EntryMBB = MF.front();
550+
Worklist.push_back(&EntryMBB);
551+
StoreFreeReachable.insert(&EntryMBB);
552+
553+
while (!Worklist.empty()) {
554+
MachineBasicBlock *MBB = Worklist.pop_back_val();
555+
if (StoreBlocks.contains(MBB))
556+
continue;
557+
558+
for (MachineBasicBlock *Succ : MBB->successors()) {
559+
if (StoreFreeReachable.insert(Succ).second)
560+
Worklist.push_back(Succ);
561+
}
562+
}
563+
564+
auto IsLoadJointlyDominatedByStores = [&](MachineInstr *LoadMI) -> bool {
565+
MachineBasicBlock *LoadMBB = LoadMI->getParent();
566+
if (!MDT.isReachableFromEntry(LoadMBB))
567+
return true;
568+
569+
// Check if every path passed through a store block.
570+
if (!StoreFreeReachable.contains(LoadMBB))
571+
return true;
572+
573+
// Otherwise, there exists a path to this block that has not seen any
574+
// store yet. We must ensure that within this block there is a store to
575+
// this slot before the load.
576+
for (MachineInstr &MI : *LoadMBB) {
577+
if (&MI == LoadMI)
578+
break;
579+
if (MI.mayStore()) {
580+
for (MachineOperand &MO : MI.operands()) {
581+
if (MO.isFI() && MO.getIndex() == Slot)
582+
return true;
583+
}
584+
}
585+
}
586+
587+
return false;
588+
};
589+
590+
if (!llvm::all_of(Loads, IsLoadJointlyDominatedByStores)) {
591+
LLVM_DEBUG(
592+
dbgs() << "Skipping " << printReg(Slot, &TRI)
593+
<< ": some reachable load not jointly dominated by stores\n");
594+
continue;
595+
}
596+
518597
const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot);
519598

520599
LLVM_DEBUG(dbgs() << "Trying to eliminate " << printReg(Slot, &TRI)
@@ -603,11 +682,13 @@ class AMDGPURewriteAGPRCopyMFMALegacy : public MachineFunctionPass {
603682
AU.addRequired<VirtRegMapWrapperLegacy>();
604683
AU.addRequired<LiveRegMatrixWrapperLegacy>();
605684
AU.addRequired<LiveStacksWrapperLegacy>();
685+
AU.addRequired<MachineDominatorTreeWrapperPass>();
606686

607687
AU.addPreserved<LiveIntervalsWrapperPass>();
608688
AU.addPreserved<VirtRegMapWrapperLegacy>();
609689
AU.addPreserved<LiveRegMatrixWrapperLegacy>();
610690
AU.addPreserved<LiveStacksWrapperLegacy>();
691+
AU.addPreserved<MachineDominatorTreeWrapperPass>();
611692

612693
AU.setPreservesAll();
613694
MachineFunctionPass::getAnalysisUsage(AU);
@@ -622,6 +703,7 @@ INITIALIZE_PASS_DEPENDENCY(LiveIntervalsWrapperPass)
622703
INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)
623704
INITIALIZE_PASS_DEPENDENCY(LiveRegMatrixWrapperLegacy)
624705
INITIALIZE_PASS_DEPENDENCY(LiveStacksWrapperLegacy)
706+
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
625707
INITIALIZE_PASS_END(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
626708
"AMDGPU Rewrite AGPR-Copy-MFMA", false, false)
627709

@@ -641,7 +723,8 @@ bool AMDGPURewriteAGPRCopyMFMALegacy::runOnMachineFunction(
641723
auto &LRM = getAnalysis<LiveRegMatrixWrapperLegacy>().getLRM();
642724
auto &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
643725
auto &LSS = getAnalysis<LiveStacksWrapperLegacy>().getLS();
644-
AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo);
726+
auto &MDT = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
727+
AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo, MDT);
645728
return Impl.run(MF);
646729
}
647730

@@ -652,10 +735,11 @@ AMDGPURewriteAGPRCopyMFMAPass::run(MachineFunction &MF,
652735
LiveRegMatrix &LRM = MFAM.getResult<LiveRegMatrixAnalysis>(MF);
653736
LiveIntervals &LIS = MFAM.getResult<LiveIntervalsAnalysis>(MF);
654737
LiveStacks &LSS = MFAM.getResult<LiveStacksAnalysis>(MF);
738+
MachineDominatorTree &MDT = MFAM.getResult<MachineDominatorTreeAnalysis>(MF);
655739
RegisterClassInfo RegClassInfo;
656740
RegClassInfo.runOnMachineFunction(MF);
657741

658-
AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo);
742+
AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo, MDT);
659743
if (!Impl.run(MF))
660744
return PreservedAnalyses::all();
661745
auto PA = getMachineFunctionPassPreservedAnalyses();

0 commit comments

Comments
 (0)