Skip to content

[AMDGPU] Optionally Use GCNRPTrackers during scheduling #93090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/GCNIterativeScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ void GCNIterativeScheduler::scheduleLegacyMaxOccupancy(
LLVM_DEBUG(dbgs() << "Scheduling using default scheduler, "
"target occupancy = "
<< TgtOcc << '\n');
GCNMaxOccupancySchedStrategy LStrgy(Context);
GCNMaxOccupancySchedStrategy LStrgy(Context, /*IsLegacyScheduler=*/true);
unsigned FinalOccupancy = std::min(Occ, MFI->getOccupancy());

for (int I = 0; I < NumPasses; ++I) {
Expand Down
199 changes: 176 additions & 23 deletions llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,63 @@ collectVirtualRegUses(SmallVectorImpl<RegisterMaskPair> &RegMaskPairs,
}
}

/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
static LaneBitmask getLanesWithProperty(
const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
LaneBitmask SafeDefault,
function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) {
if (RegUnit.isVirtual()) {
const LiveInterval &LI = LIS.getInterval(RegUnit);
LaneBitmask Result;
if (TrackLaneMasks && LI.hasSubRanges()) {
for (const LiveInterval::SubRange &SR : LI.subranges()) {
if (Property(SR, Pos))
Result |= SR.LaneMask;
}
} else if (Property(LI, Pos)) {
Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit)
: LaneBitmask::getAll();
}

return Result;
}

const LiveRange *LR = LIS.getCachedRegUnit(RegUnit);
if (LR == nullptr)
return SafeDefault;
return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone();
}

/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
/// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}.
/// The query starts with a lane bitmask which gets lanes/bits removed for every
/// use we find.
static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
const MachineRegisterInfo &MRI,
const SIRegisterInfo *TRI,
const LiveIntervals *LIS,
bool Upward = false) {
for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) {
if (MO.isUndef())
continue;
Comment on lines +338 to +339
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No isUndef check, it is redundant with the readsReg case for uses and broken of the def of subregister case

const MachineInstr *MI = MO.getParent();
SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot();
bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx)
: (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx);
if (!InRange)
continue;

unsigned SubRegIdx = MO.getSubReg();
LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubRegIdx);
LastUseMask &= ~UseMask;
if (LastUseMask.none())
return LaneBitmask::getNone();
}
return LastUseMask;
}

///////////////////////////////////////////////////////////////////////////////
// GCNRPTracker

Expand Down Expand Up @@ -354,17 +411,28 @@ void GCNRPTracker::reset(const MachineInstr &MI,
MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
}

////////////////////////////////////////////////////////////////////////////////
// GCNUpwardRPTracker

void GCNUpwardRPTracker::reset(const MachineRegisterInfo &MRI_,
const LiveRegSet &LiveRegs_) {
void GCNRPTracker::reset(const MachineRegisterInfo &MRI_,
const LiveRegSet &LiveRegs_) {
MRI = &MRI_;
LiveRegs = LiveRegs_;
LastTrackedMI = nullptr;
MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_);
}

/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
LaneBitmask GCNRPTracker::getLastUsedLanes(Register RegUnit,
SlotIndex Pos) const {
return getLanesWithProperty(
LIS, *MRI, true, RegUnit, Pos.getBaseIndex(), LaneBitmask::getNone(),
[](const LiveRange &LR, SlotIndex Pos) {
const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
return S != nullptr && S->end == Pos.getRegSlot();
});
}

////////////////////////////////////////////////////////////////////////////////
// GCNUpwardRPTracker

void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
assert(MRI && "call reset first");

Expand Down Expand Up @@ -441,25 +509,37 @@ bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
return true;
}

bool GCNDownwardRPTracker::advanceBeforeNext() {
bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI,
bool UseInternalIterator) {
assert(MRI && "call reset first");
if (!LastTrackedMI)
return NextMI == MBBEnd;

assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
SlotIndex SI;
const MachineInstr *CurrMI;
if (UseInternalIterator) {
if (!LastTrackedMI)
return NextMI == MBBEnd;

assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
CurrMI = LastTrackedMI;

SI = NextMI == MBBEnd
? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
: LIS.getInstructionIndex(*NextMI).getBaseIndex();
} else { //! UseInternalIterator
SI = LIS.getInstructionIndex(*MI).getBaseIndex();
CurrMI = MI;
}

SlotIndex SI = NextMI == MBBEnd
? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
: LIS.getInstructionIndex(*NextMI).getBaseIndex();
assert(SI.isValid());

// Remove dead registers or mask bits.
SmallSet<Register, 8> SeenRegs;
for (auto &MO : LastTrackedMI->operands()) {
for (auto &MO : CurrMI->operands()) {
if (!MO.isReg() || !MO.getReg().isVirtual())
continue;
if (MO.isUse() && !MO.readsReg())
continue;
if (!UseInternalIterator && MO.isDef())
continue;
if (!SeenRegs.insert(MO.getReg()).second)
continue;
const LiveInterval &LI = LIS.getInterval(MO.getReg());
Expand Down Expand Up @@ -492,15 +572,22 @@ bool GCNDownwardRPTracker::advanceBeforeNext() {

LastTrackedMI = nullptr;

return NextMI == MBBEnd;
return UseInternalIterator && (NextMI == MBBEnd);
}

void GCNDownwardRPTracker::advanceToNext() {
LastTrackedMI = &*NextMI++;
NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI,
bool UseInternalIterator) {
if (UseInternalIterator) {
LastTrackedMI = &*NextMI++;
NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
} else {
LastTrackedMI = MI;
}

const MachineInstr *CurrMI = LastTrackedMI;

// Add new registers or mask bits.
for (const auto &MO : LastTrackedMI->all_defs()) {
for (const auto &MO : CurrMI->all_defs()) {
Register Reg = MO.getReg();
if (!Reg.isVirtual())
continue;
Expand All @@ -513,11 +600,16 @@ void GCNDownwardRPTracker::advanceToNext() {
MaxPressure = max(MaxPressure, CurPressure);
}

bool GCNDownwardRPTracker::advance() {
if (NextMI == MBBEnd)
bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) {
if (UseInternalIterator && NextMI == MBBEnd)
return false;
advanceBeforeNext();
advanceToNext();

advanceBeforeNext(MI, UseInternalIterator);
advanceToNext(MI, UseInternalIterator);
if (!UseInternalIterator) {
// We must remove any dead def lanes from the current RP
advanceBeforeNext(MI, true);
}
return true;
}

Expand Down Expand Up @@ -559,6 +651,67 @@ Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
});
}

GCNRegPressure
GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI,
const SIRegisterInfo *TRI) const {
assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction.");

SlotIndex SlotIdx;
SlotIdx = LIS.getInstructionIndex(*MI).getRegSlot();

// Account for register pressure similar to RegPressureTracker::recede().
RegisterOperands RegOpers;
RegOpers.collect(*MI, *TRI, *MRI, true, /*IgnoreDead=*/false);
RegOpers.adjustLaneLiveness(LIS, *MRI, SlotIdx);
GCNRegPressure TempPressure = CurPressure;

for (const RegisterMaskPair &Use : RegOpers.Uses) {
Register Reg = Use.RegUnit;
if (!Reg.isVirtual())
continue;
LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
if (LastUseMask.none())
continue;
// The LastUseMask is queried from the liveness information of instruction
// which may be further down the schedule. Some lanes may actually not be
// last uses for the current position.
// FIXME: allow the caller to pass in the list of vreg uses that remain
// to be bottom-scheduled to avoid searching uses at each query.
SlotIndex CurrIdx;
const MachineBasicBlock *MBB = MI->getParent();
MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward(
LastTrackedMI ? LastTrackedMI : MBB->begin(), MBB->end());
if (IdxPos == MBB->end()) {
CurrIdx = LIS.getMBBEndIdx(MBB);
} else {
CurrIdx = LIS.getInstructionIndex(*IdxPos).getRegSlot();
}

LastUseMask =
findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS);
if (LastUseMask.none())
continue;

LaneBitmask LiveMask =
LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0);
LaneBitmask NewMask = LiveMask & ~LastUseMask;
TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
}

// Generate liveness for defs.
for (const RegisterMaskPair &Def : RegOpers.Defs) {
Register Reg = Def.RegUnit;
if (!Reg.isVirtual())
continue;
LaneBitmask LiveMask =
LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0);
LaneBitmask NewMask = LiveMask | Def.LaneMask;
TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
}

return TempPressure;
}

bool GCNUpwardRPTracker::isValid() const {
const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
Expand Down
Loading
Loading