diff --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp index 94d93390d0916..db8dd45ec0441 100644 --- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp +++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp @@ -58,6 +58,11 @@ static cl::opt "Wave Limited (amdgpu-limit-wave-threshold)."), cl::init(false)); +static cl::opt GCNTrackers( + "amdgpu-use-amdgpu-trackers", cl::Hidden, + cl::desc("Use the AMDGPU specific RPTrackers during scheduling"), + cl::init(false)); + const unsigned ScheduleMetrics::ScaleFactor = 100; GCNSchedStrategy::GCNSchedStrategy(const MachineSchedContext *C) @@ -484,7 +489,8 @@ GCNScheduleDAGMILive::GCNScheduleDAGMILive( MachineSchedContext *C, std::unique_ptr S) : ScheduleDAGMILive(C, std::move(S)), ST(MF.getSubtarget()), MFI(*MF.getInfo()), - StartingOccupancy(MFI.getOccupancy()), MinOccupancy(StartingOccupancy) { + StartingOccupancy(MFI.getOccupancy()), MinOccupancy(StartingOccupancy), + RegionLiveOuts(this, /*IsLiveOut=*/true) { LLVM_DEBUG(dbgs() << "Starting occupancy is " << StartingOccupancy << ".\n"); if (RelaxedOcc) { @@ -526,6 +532,14 @@ GCNScheduleDAGMILive::getRealRegPressure(unsigned RegionIdx) const { return RPTracker.moveMaxPressure(); } +static MachineInstr *getLastMIForRegion(MachineBasicBlock::iterator RegionBegin, + MachineBasicBlock::iterator RegionEnd) { + auto REnd = RegionEnd == RegionBegin->getParent()->end() + ? std::prev(RegionEnd) + : RegionEnd; + return &*skipDebugInstructionsBackward(REnd, RegionBegin); +} + void GCNScheduleDAGMILive::computeBlockPressure(unsigned RegionIdx, const MachineBasicBlock *MBB) { GCNDownwardRPTracker RPTracker(*LIS); @@ -600,20 +614,45 @@ void GCNScheduleDAGMILive::computeBlockPressure(unsigned RegionIdx, } DenseMap -GCNScheduleDAGMILive::getBBLiveInMap() const { +GCNScheduleDAGMILive::getRegionLiveInMap() const { assert(!Regions.empty()); - std::vector BBStarters; - BBStarters.reserve(Regions.size()); + std::vector RegionFirstMIs; + RegionFirstMIs.reserve(Regions.size()); auto I = Regions.rbegin(), E = Regions.rend(); auto *BB = I->first->getParent(); do { auto *MI = &*skipDebugInstructionsForward(I->first, I->second); - BBStarters.push_back(MI); + RegionFirstMIs.push_back(MI); do { ++I; } while (I != E && I->first->getParent() == BB); } while (I != E); - return getLiveRegMap(BBStarters, false /*After*/, *LIS); + return getLiveRegMap(RegionFirstMIs, /*After=*/false, *LIS); +} + +DenseMap +GCNScheduleDAGMILive::getRegionLiveOutMap() const { + assert(!Regions.empty()); + std::vector RegionLastMIs; + RegionLastMIs.reserve(Regions.size()); + for (auto &[RegionBegin, RegionEnd] : reverse(Regions)) + RegionLastMIs.push_back(getLastMIForRegion(RegionBegin, RegionEnd)); + + return getLiveRegMap(RegionLastMIs, /*After=*/true, *LIS); +} + +void RegionPressureMap::buildLiveRegMap() { + IdxToInstruction.clear(); + + BBLiveRegMap = + IsLiveOut ? DAG->getRegionLiveOutMap() : DAG->getRegionLiveInMap(); + for (unsigned I = 0; I < DAG->Regions.size(); I++) { + MachineInstr *RegionKey = + IsLiveOut + ? getLastMIForRegion(DAG->Regions[I].first, DAG->Regions[I].second) + : &*DAG->Regions[I].first; + IdxToInstruction[I] = RegionKey; + } } void GCNScheduleDAGMILive::finalizeSchedule() { @@ -639,8 +678,11 @@ void GCNScheduleDAGMILive::finalizeSchedule() { void GCNScheduleDAGMILive::runSchedStages() { LLVM_DEBUG(dbgs() << "All regions recorded, starting actual scheduling.\n"); - if (!Regions.empty()) - BBLiveInMap = getBBLiveInMap(); + if (!Regions.empty()) { + BBLiveInMap = getRegionLiveInMap(); + if (GCNTrackers) + RegionLiveOuts.buildLiveRegMap(); + } GCNSchedStrategy &S = static_cast(*SchedImpl); while (S.advanceStage()) { diff --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h index 2084aae4128ff..a71f3fd0dd469 100644 --- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h +++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h @@ -163,6 +163,32 @@ inline raw_ostream &operator<<(raw_ostream &OS, const ScheduleMetrics &Sm) { return OS; } +class GCNScheduleDAGMILive; +class RegionPressureMap { + GCNScheduleDAGMILive *DAG; + // The live in/out pressure as indexed by the first or last MI in the region + // before scheduling. + DenseMap BBLiveRegMap; + // The mapping of RegionIDx to key instruction + DenseMap IdxToInstruction; + // Whether we are calculating LiveOuts or LiveIns + bool IsLiveOut; + +public: + RegionPressureMap() {} + RegionPressureMap(GCNScheduleDAGMILive *GCNDAG, bool LiveOut) + : DAG(GCNDAG), IsLiveOut(LiveOut) {} + // Build the Instr->LiveReg and RegionIdx->Instr maps + void buildLiveRegMap(); + + // Retrieve the LiveReg for a given RegionIdx + GCNRPTracker::LiveRegSet &getLiveRegsForRegionIdx(unsigned RegionIdx) { + assert(IdxToInstruction.find(RegionIdx) != IdxToInstruction.end()); + MachineInstr *Key = IdxToInstruction[RegionIdx]; + return BBLiveRegMap[Key]; + } +}; + class GCNScheduleDAGMILive final : public ScheduleDAGMILive { friend class GCNSchedStage; friend class OccInitialScheduleStage; @@ -170,6 +196,7 @@ class GCNScheduleDAGMILive final : public ScheduleDAGMILive { friend class ClusteredLowOccStage; friend class PreRARematStage; friend class ILPInitialScheduleStage; + friend class RegionPressureMap; const GCNSubtarget &ST; @@ -211,9 +238,22 @@ class GCNScheduleDAGMILive final : public ScheduleDAGMILive { // Temporary basic block live-in cache. DenseMap MBBLiveIns; + // The map of the initial first region instruction to region live in registers DenseMap BBLiveInMap; - DenseMap getBBLiveInMap() const; + // Calculate the map of the initial first region instruction to region live in + // registers + DenseMap getRegionLiveInMap() const; + + // Calculate the map of the initial last region instruction to region live out + // registers + DenseMap + getRegionLiveOutMap() const; + + // The live out registers per region. These are internally stored as a map of + // the initial last region instruction to region live out registers, but can + // be retreived with the regionIdx by calls to getLiveRegsForRegionIdx. + RegionPressureMap RegionLiveOuts; // Return current region pressure. GCNRegPressure getRealRegPressure(unsigned RegionIdx) const;