@@ -296,6 +296,63 @@ collectVirtualRegUses(SmallVectorImpl<RegisterMaskPair> &RegMaskPairs,
296
296
}
297
297
}
298
298
299
+ // / Mostly copy/paste from CodeGen/RegisterPressure.cpp
300
+ static LaneBitmask getLanesWithProperty (
301
+ const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
302
+ bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
303
+ LaneBitmask SafeDefault,
304
+ function_ref<bool (const LiveRange &LR, SlotIndex Pos)> Property) {
305
+ if (RegUnit.isVirtual ()) {
306
+ const LiveInterval &LI = LIS.getInterval (RegUnit);
307
+ LaneBitmask Result;
308
+ if (TrackLaneMasks && LI.hasSubRanges ()) {
309
+ for (const LiveInterval::SubRange &SR : LI.subranges ()) {
310
+ if (Property (SR, Pos))
311
+ Result |= SR.LaneMask ;
312
+ }
313
+ } else if (Property (LI, Pos)) {
314
+ Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg (RegUnit)
315
+ : LaneBitmask::getAll ();
316
+ }
317
+
318
+ return Result;
319
+ }
320
+
321
+ const LiveRange *LR = LIS.getCachedRegUnit (RegUnit);
322
+ if (LR == nullptr )
323
+ return SafeDefault;
324
+ return Property (*LR, Pos) ? LaneBitmask::getAll () : LaneBitmask::getNone ();
325
+ }
326
+
327
+ // / Mostly copy/paste from CodeGen/RegisterPressure.cpp
328
+ // / Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}.
329
+ // / The query starts with a lane bitmask which gets lanes/bits removed for every
330
+ // / use we find.
331
+ static LaneBitmask findUseBetween (unsigned Reg, LaneBitmask LastUseMask,
332
+ SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
333
+ const MachineRegisterInfo &MRI,
334
+ const SIRegisterInfo *TRI,
335
+ const LiveIntervals *LIS,
336
+ bool Upward = false ) {
337
+ for (const MachineOperand &MO : MRI.use_nodbg_operands (Reg)) {
338
+ if (MO.isUndef ())
339
+ continue ;
340
+ const MachineInstr *MI = MO.getParent ();
341
+ SlotIndex InstSlot = LIS->getInstructionIndex (*MI).getRegSlot ();
342
+ bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx)
343
+ : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx);
344
+ if (!InRange)
345
+ continue ;
346
+
347
+ unsigned SubRegIdx = MO.getSubReg ();
348
+ LaneBitmask UseMask = TRI->getSubRegIndexLaneMask (SubRegIdx);
349
+ LastUseMask &= ~UseMask;
350
+ if (LastUseMask.none ())
351
+ return LaneBitmask::getNone ();
352
+ }
353
+ return LastUseMask;
354
+ }
355
+
299
356
// /////////////////////////////////////////////////////////////////////////////
300
357
// GCNRPTracker
301
358
@@ -354,17 +411,28 @@ void GCNRPTracker::reset(const MachineInstr &MI,
354
411
MaxPressure = CurPressure = getRegPressure (*MRI, LiveRegs);
355
412
}
356
413
357
- // //////////////////////////////////////////////////////////////////////////////
358
- // GCNUpwardRPTracker
359
-
360
- void GCNUpwardRPTracker::reset (const MachineRegisterInfo &MRI_,
361
- const LiveRegSet &LiveRegs_) {
414
+ void GCNRPTracker::reset (const MachineRegisterInfo &MRI_,
415
+ const LiveRegSet &LiveRegs_) {
362
416
MRI = &MRI_;
363
417
LiveRegs = LiveRegs_;
364
418
LastTrackedMI = nullptr ;
365
419
MaxPressure = CurPressure = getRegPressure (MRI_, LiveRegs_);
366
420
}
367
421
422
+ // / Mostly copy/paste from CodeGen/RegisterPressure.cpp
423
+ LaneBitmask GCNRPTracker::getLastUsedLanes (Register RegUnit,
424
+ SlotIndex Pos) const {
425
+ return getLanesWithProperty (
426
+ LIS, *MRI, true , RegUnit, Pos.getBaseIndex (), LaneBitmask::getNone (),
427
+ [](const LiveRange &LR, SlotIndex Pos) {
428
+ const LiveRange::Segment *S = LR.getSegmentContaining (Pos);
429
+ return S != nullptr && S->end == Pos.getRegSlot ();
430
+ });
431
+ }
432
+
433
+ // //////////////////////////////////////////////////////////////////////////////
434
+ // GCNUpwardRPTracker
435
+
368
436
void GCNUpwardRPTracker::recede (const MachineInstr &MI) {
369
437
assert (MRI && " call reset first" );
370
438
@@ -441,25 +509,37 @@ bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
441
509
return true ;
442
510
}
443
511
444
- bool GCNDownwardRPTracker::advanceBeforeNext () {
512
+ bool GCNDownwardRPTracker::advanceBeforeNext (MachineInstr *MI,
513
+ bool UseInternalIterator) {
445
514
assert (MRI && " call reset first" );
446
- if (!LastTrackedMI)
447
- return NextMI == MBBEnd;
448
-
449
- assert (NextMI == MBBEnd || !NextMI->isDebugInstr ());
515
+ SlotIndex SI;
516
+ const MachineInstr *CurrMI;
517
+ if (UseInternalIterator) {
518
+ if (!LastTrackedMI)
519
+ return NextMI == MBBEnd;
520
+
521
+ assert (NextMI == MBBEnd || !NextMI->isDebugInstr ());
522
+ CurrMI = LastTrackedMI;
523
+
524
+ SI = NextMI == MBBEnd
525
+ ? LIS.getInstructionIndex (*LastTrackedMI).getDeadSlot ()
526
+ : LIS.getInstructionIndex (*NextMI).getBaseIndex ();
527
+ } else { // ! UseInternalIterator
528
+ SI = LIS.getInstructionIndex (*MI).getBaseIndex ();
529
+ CurrMI = MI;
530
+ }
450
531
451
- SlotIndex SI = NextMI == MBBEnd
452
- ? LIS.getInstructionIndex (*LastTrackedMI).getDeadSlot ()
453
- : LIS.getInstructionIndex (*NextMI).getBaseIndex ();
454
532
assert (SI.isValid ());
455
533
456
534
// Remove dead registers or mask bits.
457
535
SmallSet<Register, 8 > SeenRegs;
458
- for (auto &MO : LastTrackedMI ->operands ()) {
536
+ for (auto &MO : CurrMI ->operands ()) {
459
537
if (!MO.isReg () || !MO.getReg ().isVirtual ())
460
538
continue ;
461
539
if (MO.isUse () && !MO.readsReg ())
462
540
continue ;
541
+ if (!UseInternalIterator && MO.isDef ())
542
+ continue ;
463
543
if (!SeenRegs.insert (MO.getReg ()).second )
464
544
continue ;
465
545
const LiveInterval &LI = LIS.getInterval (MO.getReg ());
@@ -492,15 +572,22 @@ bool GCNDownwardRPTracker::advanceBeforeNext() {
492
572
493
573
LastTrackedMI = nullptr ;
494
574
495
- return NextMI == MBBEnd;
575
+ return UseInternalIterator && ( NextMI == MBBEnd) ;
496
576
}
497
577
498
- void GCNDownwardRPTracker::advanceToNext () {
499
- LastTrackedMI = &*NextMI++;
500
- NextMI = skipDebugInstructionsForward (NextMI, MBBEnd);
578
+ void GCNDownwardRPTracker::advanceToNext (MachineInstr *MI,
579
+ bool UseInternalIterator) {
580
+ if (UseInternalIterator) {
581
+ LastTrackedMI = &*NextMI++;
582
+ NextMI = skipDebugInstructionsForward (NextMI, MBBEnd);
583
+ } else {
584
+ LastTrackedMI = MI;
585
+ }
586
+
587
+ const MachineInstr *CurrMI = LastTrackedMI;
501
588
502
589
// Add new registers or mask bits.
503
- for (const auto &MO : LastTrackedMI ->all_defs ()) {
590
+ for (const auto &MO : CurrMI ->all_defs ()) {
504
591
Register Reg = MO.getReg ();
505
592
if (!Reg.isVirtual ())
506
593
continue ;
@@ -513,11 +600,16 @@ void GCNDownwardRPTracker::advanceToNext() {
513
600
MaxPressure = max (MaxPressure, CurPressure);
514
601
}
515
602
516
- bool GCNDownwardRPTracker::advance () {
517
- if (NextMI == MBBEnd)
603
+ bool GCNDownwardRPTracker::advance (MachineInstr *MI, bool UseInternalIterator ) {
604
+ if (UseInternalIterator && NextMI == MBBEnd)
518
605
return false ;
519
- advanceBeforeNext ();
520
- advanceToNext ();
606
+
607
+ advanceBeforeNext (MI, UseInternalIterator);
608
+ advanceToNext (MI, UseInternalIterator);
609
+ if (!UseInternalIterator) {
610
+ // We must remove any dead def lanes from the current RP
611
+ advanceBeforeNext (MI, true );
612
+ }
521
613
return true ;
522
614
}
523
615
@@ -559,6 +651,67 @@ Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
559
651
});
560
652
}
561
653
654
+ GCNRegPressure
655
+ GCNDownwardRPTracker::bumpDownwardPressure (const MachineInstr *MI,
656
+ const SIRegisterInfo *TRI) const {
657
+ assert (!MI->isDebugOrPseudoInstr () && " Expect a nondebug instruction." );
658
+
659
+ SlotIndex SlotIdx;
660
+ SlotIdx = LIS.getInstructionIndex (*MI).getRegSlot ();
661
+
662
+ // Account for register pressure similar to RegPressureTracker::recede().
663
+ RegisterOperands RegOpers;
664
+ RegOpers.collect (*MI, *TRI, *MRI, true , /* IgnoreDead=*/ false );
665
+ RegOpers.adjustLaneLiveness (LIS, *MRI, SlotIdx);
666
+ GCNRegPressure TempPressure = CurPressure;
667
+
668
+ for (const RegisterMaskPair &Use : RegOpers.Uses ) {
669
+ Register Reg = Use.RegUnit ;
670
+ if (!Reg.isVirtual ())
671
+ continue ;
672
+ LaneBitmask LastUseMask = getLastUsedLanes (Reg, SlotIdx);
673
+ if (LastUseMask.none ())
674
+ continue ;
675
+ // The LastUseMask is queried from the liveness information of instruction
676
+ // which may be further down the schedule. Some lanes may actually not be
677
+ // last uses for the current position.
678
+ // FIXME: allow the caller to pass in the list of vreg uses that remain
679
+ // to be bottom-scheduled to avoid searching uses at each query.
680
+ SlotIndex CurrIdx;
681
+ const MachineBasicBlock *MBB = MI->getParent ();
682
+ MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward (
683
+ LastTrackedMI ? LastTrackedMI : MBB->begin (), MBB->end ());
684
+ if (IdxPos == MBB->end ()) {
685
+ CurrIdx = LIS.getMBBEndIdx (MBB);
686
+ } else {
687
+ CurrIdx = LIS.getInstructionIndex (*IdxPos).getRegSlot ();
688
+ }
689
+
690
+ LastUseMask =
691
+ findUseBetween (Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS);
692
+ if (LastUseMask.none ())
693
+ continue ;
694
+
695
+ LaneBitmask LiveMask =
696
+ LiveRegs.contains (Reg) ? LiveRegs.at (Reg) : LaneBitmask (0 );
697
+ LaneBitmask NewMask = LiveMask & ~LastUseMask;
698
+ TempPressure.inc (Reg, LiveMask, NewMask, *MRI);
699
+ }
700
+
701
+ // Generate liveness for defs.
702
+ for (const RegisterMaskPair &Def : RegOpers.Defs ) {
703
+ Register Reg = Def.RegUnit ;
704
+ if (!Reg.isVirtual ())
705
+ continue ;
706
+ LaneBitmask LiveMask =
707
+ LiveRegs.contains (Reg) ? LiveRegs.at (Reg) : LaneBitmask (0 );
708
+ LaneBitmask NewMask = LiveMask | Def.LaneMask ;
709
+ TempPressure.inc (Reg, LiveMask, NewMask, *MRI);
710
+ }
711
+
712
+ return TempPressure;
713
+ }
714
+
562
715
bool GCNUpwardRPTracker::isValid () const {
563
716
const auto &SI = LIS.getInstructionIndex (*LastTrackedMI).getBaseIndex ();
564
717
const auto LISLR = llvm::getLiveRegs (SI, LIS, *MRI);
0 commit comments