Skip to content

Commit 276193f

Browse files
committed
[AMDGPU] Optionally Use AMDGPU RPTrackers during scheduling
1 parent 5cb6b15 commit 276193f

12 files changed

+1672
-79
lines changed

llvm/lib/Target/AMDGPU/GCNIterativeScheduler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ void GCNIterativeScheduler::scheduleLegacyMaxOccupancy(
480480
LLVM_DEBUG(dbgs() << "Scheduling using default scheduler, "
481481
"target occupancy = "
482482
<< TgtOcc << '\n');
483-
GCNMaxOccupancySchedStrategy LStrgy(Context);
483+
GCNMaxOccupancySchedStrategy LStrgy(Context, /*IsLegacyScheduler=*/true);
484484
unsigned FinalOccupancy = std::min(Occ, MFI->getOccupancy());
485485

486486
for (int I = 0; I < NumPasses; ++I) {

llvm/lib/Target/AMDGPU/GCNRegPressure.cpp

Lines changed: 176 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,63 @@ collectVirtualRegUses(SmallVectorImpl<RegisterMaskPair> &RegMaskPairs,
296296
}
297297
}
298298

299+
/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
300+
static LaneBitmask getLanesWithProperty(
301+
const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
302+
bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
303+
LaneBitmask SafeDefault,
304+
function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) {
305+
if (RegUnit.isVirtual()) {
306+
const LiveInterval &LI = LIS.getInterval(RegUnit);
307+
LaneBitmask Result;
308+
if (TrackLaneMasks && LI.hasSubRanges()) {
309+
for (const LiveInterval::SubRange &SR : LI.subranges()) {
310+
if (Property(SR, Pos))
311+
Result |= SR.LaneMask;
312+
}
313+
} else if (Property(LI, Pos)) {
314+
Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit)
315+
: LaneBitmask::getAll();
316+
}
317+
318+
return Result;
319+
}
320+
321+
const LiveRange *LR = LIS.getCachedRegUnit(RegUnit);
322+
if (LR == nullptr)
323+
return SafeDefault;
324+
return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone();
325+
}
326+
327+
/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
328+
/// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}.
329+
/// The query starts with a lane bitmask which gets lanes/bits removed for every
330+
/// use we find.
331+
static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
332+
SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
333+
const MachineRegisterInfo &MRI,
334+
const SIRegisterInfo *TRI,
335+
const LiveIntervals *LIS,
336+
bool Upward = false) {
337+
for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) {
338+
if (MO.isUndef())
339+
continue;
340+
const MachineInstr *MI = MO.getParent();
341+
SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot();
342+
bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx)
343+
: (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx);
344+
if (!InRange)
345+
continue;
346+
347+
unsigned SubRegIdx = MO.getSubReg();
348+
LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubRegIdx);
349+
LastUseMask &= ~UseMask;
350+
if (LastUseMask.none())
351+
return LaneBitmask::getNone();
352+
}
353+
return LastUseMask;
354+
}
355+
299356
///////////////////////////////////////////////////////////////////////////////
300357
// GCNRPTracker
301358

@@ -354,17 +411,28 @@ void GCNRPTracker::reset(const MachineInstr &MI,
354411
MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
355412
}
356413

357-
////////////////////////////////////////////////////////////////////////////////
358-
// GCNUpwardRPTracker
359-
360-
void GCNUpwardRPTracker::reset(const MachineRegisterInfo &MRI_,
361-
const LiveRegSet &LiveRegs_) {
414+
void GCNRPTracker::reset(const MachineRegisterInfo &MRI_,
415+
const LiveRegSet &LiveRegs_) {
362416
MRI = &MRI_;
363417
LiveRegs = LiveRegs_;
364418
LastTrackedMI = nullptr;
365419
MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_);
366420
}
367421

422+
/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
423+
LaneBitmask GCNRPTracker::getLastUsedLanes(Register RegUnit,
424+
SlotIndex Pos) const {
425+
return getLanesWithProperty(
426+
LIS, *MRI, true, RegUnit, Pos.getBaseIndex(), LaneBitmask::getNone(),
427+
[](const LiveRange &LR, SlotIndex Pos) {
428+
const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
429+
return S != nullptr && S->end == Pos.getRegSlot();
430+
});
431+
}
432+
433+
////////////////////////////////////////////////////////////////////////////////
434+
// GCNUpwardRPTracker
435+
368436
void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
369437
assert(MRI && "call reset first");
370438

@@ -441,25 +509,37 @@ bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
441509
return true;
442510
}
443511

444-
bool GCNDownwardRPTracker::advanceBeforeNext() {
512+
bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI,
513+
bool UseInternalIterator) {
445514
assert(MRI && "call reset first");
446-
if (!LastTrackedMI)
447-
return NextMI == MBBEnd;
448-
449-
assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
515+
SlotIndex SI;
516+
const MachineInstr *CurrMI;
517+
if (UseInternalIterator) {
518+
if (!LastTrackedMI)
519+
return NextMI == MBBEnd;
520+
521+
assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
522+
CurrMI = LastTrackedMI;
523+
524+
SI = NextMI == MBBEnd
525+
? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
526+
: LIS.getInstructionIndex(*NextMI).getBaseIndex();
527+
} else { //! UseInternalIterator
528+
SI = LIS.getInstructionIndex(*MI).getBaseIndex();
529+
CurrMI = MI;
530+
}
450531

451-
SlotIndex SI = NextMI == MBBEnd
452-
? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
453-
: LIS.getInstructionIndex(*NextMI).getBaseIndex();
454532
assert(SI.isValid());
455533

456534
// Remove dead registers or mask bits.
457535
SmallSet<Register, 8> SeenRegs;
458-
for (auto &MO : LastTrackedMI->operands()) {
536+
for (auto &MO : CurrMI->operands()) {
459537
if (!MO.isReg() || !MO.getReg().isVirtual())
460538
continue;
461539
if (MO.isUse() && !MO.readsReg())
462540
continue;
541+
if (!UseInternalIterator && MO.isDef())
542+
continue;
463543
if (!SeenRegs.insert(MO.getReg()).second)
464544
continue;
465545
const LiveInterval &LI = LIS.getInterval(MO.getReg());
@@ -492,15 +572,22 @@ bool GCNDownwardRPTracker::advanceBeforeNext() {
492572

493573
LastTrackedMI = nullptr;
494574

495-
return NextMI == MBBEnd;
575+
return UseInternalIterator && (NextMI == MBBEnd);
496576
}
497577

498-
void GCNDownwardRPTracker::advanceToNext() {
499-
LastTrackedMI = &*NextMI++;
500-
NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
578+
void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI,
579+
bool UseInternalIterator) {
580+
if (UseInternalIterator) {
581+
LastTrackedMI = &*NextMI++;
582+
NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
583+
} else {
584+
LastTrackedMI = MI;
585+
}
586+
587+
const MachineInstr *CurrMI = LastTrackedMI;
501588

502589
// Add new registers or mask bits.
503-
for (const auto &MO : LastTrackedMI->all_defs()) {
590+
for (const auto &MO : CurrMI->all_defs()) {
504591
Register Reg = MO.getReg();
505592
if (!Reg.isVirtual())
506593
continue;
@@ -513,11 +600,16 @@ void GCNDownwardRPTracker::advanceToNext() {
513600
MaxPressure = max(MaxPressure, CurPressure);
514601
}
515602

516-
bool GCNDownwardRPTracker::advance() {
517-
if (NextMI == MBBEnd)
603+
bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) {
604+
if (UseInternalIterator && NextMI == MBBEnd)
518605
return false;
519-
advanceBeforeNext();
520-
advanceToNext();
606+
607+
advanceBeforeNext(MI, UseInternalIterator);
608+
advanceToNext(MI, UseInternalIterator);
609+
if (!UseInternalIterator) {
610+
// We must remove any dead def lanes from the current RP
611+
advanceBeforeNext(MI, true);
612+
}
521613
return true;
522614
}
523615

@@ -559,6 +651,67 @@ Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
559651
});
560652
}
561653

654+
GCNRegPressure
655+
GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI,
656+
const SIRegisterInfo *TRI) const {
657+
assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction.");
658+
659+
SlotIndex SlotIdx;
660+
SlotIdx = LIS.getInstructionIndex(*MI).getRegSlot();
661+
662+
// Account for register pressure similar to RegPressureTracker::recede().
663+
RegisterOperands RegOpers;
664+
RegOpers.collect(*MI, *TRI, *MRI, true, /*IgnoreDead=*/false);
665+
RegOpers.adjustLaneLiveness(LIS, *MRI, SlotIdx);
666+
GCNRegPressure TempPressure = CurPressure;
667+
668+
for (const RegisterMaskPair &Use : RegOpers.Uses) {
669+
Register Reg = Use.RegUnit;
670+
if (!Reg.isVirtual())
671+
continue;
672+
LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
673+
if (LastUseMask.none())
674+
continue;
675+
// The LastUseMask is queried from the liveness information of instruction
676+
// which may be further down the schedule. Some lanes may actually not be
677+
// last uses for the current position.
678+
// FIXME: allow the caller to pass in the list of vreg uses that remain
679+
// to be bottom-scheduled to avoid searching uses at each query.
680+
SlotIndex CurrIdx;
681+
const MachineBasicBlock *MBB = MI->getParent();
682+
MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward(
683+
LastTrackedMI ? LastTrackedMI : MBB->begin(), MBB->end());
684+
if (IdxPos == MBB->end()) {
685+
CurrIdx = LIS.getMBBEndIdx(MBB);
686+
} else {
687+
CurrIdx = LIS.getInstructionIndex(*IdxPos).getRegSlot();
688+
}
689+
690+
LastUseMask =
691+
findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS);
692+
if (LastUseMask.none())
693+
continue;
694+
695+
LaneBitmask LiveMask =
696+
LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0);
697+
LaneBitmask NewMask = LiveMask & ~LastUseMask;
698+
TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
699+
}
700+
701+
// Generate liveness for defs.
702+
for (const RegisterMaskPair &Def : RegOpers.Defs) {
703+
Register Reg = Def.RegUnit;
704+
if (!Reg.isVirtual())
705+
continue;
706+
LaneBitmask LiveMask =
707+
LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0);
708+
LaneBitmask NewMask = LiveMask | Def.LaneMask;
709+
TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
710+
}
711+
712+
return TempPressure;
713+
}
714+
562715
bool GCNUpwardRPTracker::isValid() const {
563716
const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
564717
const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);

0 commit comments

Comments
 (0)