Skip to content

Commit 089b389

Browse files
committed
[SME] Stop RA from coalescing COPY instructions that transcend beyond smstart/smstop.
This patch introduces a 'COALESCER_BARRIER' which is a pseudo node that expands to a 'nop', but which stops the register allocator from coalescing a COPY node when its use/def crosses a SMSTART or SMSTOP instruction. For example: %0:fpr64 = COPY killed $d0 undef %2.dsub:zpr = COPY %0 // <- Do not coalesce this COPY ADJCALLSTACKDOWN 0, 0 MSRpstatesvcrImm1 1, 0, csr_aarch64_smstartstop, implicit-def dead $d0 $d0 = COPY killed %0 BL @use_f64, csr_aarch64_aapcs If the COPY would be coalesced, that would lead to: $d0 = COPY killed %0 being replaced by: $d0 = COPY killed %2.dsub which means the whole ZPR reg would be live upto the call, causing the MSRpstatesvcrImm1 (smstop) to spill/reload the ZPR register: str q0, [sp] // 16-byte Folded Spill smstop sm ldr z0, [sp] // 16-byte Folded Reload bl use_f64 which would be incorrect for two reasons: 1. The program may load more data than it has allocated. 2. If there are other SVE objects on the stack, the compiler might use the 'mul vl' addressing modes to access the spill location. By disabling the coalescing, we get the desired results: str d0, [sp, #8] // 8-byte Folded Spill smstop sm ldr d0, [sp, #8] // 8-byte Folded Reload bl use_f64
1 parent 44035cc commit 089b389

11 files changed

+519
-367
lines changed

llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,12 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
15281528
NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated.
15291529
return true;
15301530
}
1531+
case AArch64::COALESCER_BARRIER_FPR16:
1532+
case AArch64::COALESCER_BARRIER_FPR32:
1533+
case AArch64::COALESCER_BARRIER_FPR64:
1534+
case AArch64::COALESCER_BARRIER_FPR128:
1535+
MI.eraseFromParent();
1536+
return true;
15311537
case AArch64::LD1B_2Z_IMM_PSEUDO:
15321538
return expandMultiVecPseudo(
15331539
MBB, MBBI, AArch64::ZPR2RegClass, AArch64::ZPR2StridedRegClass,

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,6 +2338,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
23382338
switch ((AArch64ISD::NodeType)Opcode) {
23392339
case AArch64ISD::FIRST_NUMBER:
23402340
break;
2341+
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
23412342
MAKE_CASE(AArch64ISD::SMSTART)
23422343
MAKE_CASE(AArch64ISD::SMSTOP)
23432344
MAKE_CASE(AArch64ISD::RESTORE_ZA)
@@ -7090,13 +7091,18 @@ void AArch64TargetLowering::saveVarArgRegisters(CCState &CCInfo,
70907091
}
70917092
}
70927093

7094+
static bool isPassedInFPR(EVT VT) {
7095+
return VT.isFixedLengthVector() ||
7096+
(VT.isFloatingPoint() && !VT.isScalableVector());
7097+
}
7098+
70937099
/// LowerCallResult - Lower the result values of a call into the
70947100
/// appropriate copies out of appropriate physical registers.
70957101
SDValue AArch64TargetLowering::LowerCallResult(
70967102
SDValue Chain, SDValue InGlue, CallingConv::ID CallConv, bool isVarArg,
70977103
const SmallVectorImpl<CCValAssign> &RVLocs, const SDLoc &DL,
70987104
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals, bool isThisReturn,
7099-
SDValue ThisVal) const {
7105+
SDValue ThisVal, bool RequiresSMChange) const {
71007106
DenseMap<unsigned, SDValue> CopiedRegs;
71017107
// Copy all of the result registers out of their specified physreg.
71027108
for (unsigned i = 0; i != RVLocs.size(); ++i) {
@@ -7141,6 +7147,10 @@ SDValue AArch64TargetLowering::LowerCallResult(
71417147
break;
71427148
}
71437149

7150+
if (RequiresSMChange && isPassedInFPR(VA.getValVT()))
7151+
Val = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL, Val.getValueType(),
7152+
Val);
7153+
71447154
InVals.push_back(Val);
71457155
}
71467156

@@ -7829,6 +7839,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
78297839
return ArgReg.Reg == VA.getLocReg();
78307840
});
78317841
} else {
7842+
// Add an extra level of indirection for streaming mode changes by
7843+
// using a pseudo copy node that cannot be rematerialised between a
7844+
// smstart/smstop and the call by the simple register coalescer.
7845+
if (RequiresSMChange && isPassedInFPR(Arg.getValueType()))
7846+
Arg = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL,
7847+
Arg.getValueType(), Arg);
78327848
RegsToPass.emplace_back(VA.getLocReg(), Arg);
78337849
RegsUsed.insert(VA.getLocReg());
78347850
const TargetOptions &Options = DAG.getTarget().Options;
@@ -8063,7 +8079,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
80638079
// return.
80648080
SDValue Result = LowerCallResult(Chain, InGlue, CallConv, IsVarArg, RVLocs,
80658081
DL, DAG, InVals, IsThisReturn,
8066-
IsThisReturn ? OutVals[0] : SDValue());
8082+
IsThisReturn ? OutVals[0] : SDValue(),
8083+
RequiresSMChange.has_value());
80678084

80688085
if (!Ins.empty())
80698086
InGlue = Result.getValue(Result->getNumValues() - 1);

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ enum NodeType : unsigned {
5858

5959
CALL_BTI, // Function call followed by a BTI instruction.
6060

61+
COALESCER_BARRIER,
62+
6163
SMSTART,
6264
SMSTOP,
6365
RESTORE_ZA,
@@ -1021,7 +1023,7 @@ class AArch64TargetLowering : public TargetLowering {
10211023
const SmallVectorImpl<CCValAssign> &RVLocs,
10221024
const SDLoc &DL, SelectionDAG &DAG,
10231025
SmallVectorImpl<SDValue> &InVals, bool isThisReturn,
1024-
SDValue ThisVal) const;
1026+
SDValue ThisVal, bool RequiresSMChange) const;
10251027

10261028
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
10271029
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,8 @@ bool AArch64RegisterInfo::shouldCoalesce(
10121012
MachineInstr *MI, const TargetRegisterClass *SrcRC, unsigned SubReg,
10131013
const TargetRegisterClass *DstRC, unsigned DstSubReg,
10141014
const TargetRegisterClass *NewRC, LiveIntervals &LIS) const {
1015+
MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
1016+
10151017
if (MI->isCopy() &&
10161018
((DstRC->getID() == AArch64::GPR64RegClassID) ||
10171019
(DstRC->getID() == AArch64::GPR64commonRegClassID)) &&
@@ -1020,5 +1022,38 @@ bool AArch64RegisterInfo::shouldCoalesce(
10201022
// which implements a 32 to 64 bit zero extension
10211023
// which relies on the upper 32 bits being zeroed.
10221024
return false;
1025+
1026+
auto IsCoalescerBarrier = [](const MachineInstr &MI) {
1027+
switch (MI.getOpcode()) {
1028+
case AArch64::COALESCER_BARRIER_FPR16:
1029+
case AArch64::COALESCER_BARRIER_FPR32:
1030+
case AArch64::COALESCER_BARRIER_FPR64:
1031+
case AArch64::COALESCER_BARRIER_FPR128:
1032+
return true;
1033+
default:
1034+
return false;
1035+
}
1036+
};
1037+
1038+
// For calls that temporarily have to toggle streaming mode as part of the
1039+
// call-sequence, we need to be more careful when coalescing copy instructions
1040+
// so that we don't end up coalescing the NEON/FP result or argument register
1041+
// with a whole Z-register, such that after coalescing the register allocator
1042+
// will try to spill/reload the entire Z register.
1043+
//
1044+
// We do this by checking if the node has any defs/uses that are COALESCER_BARRIER
1045+
// pseudos. These are 'nops' in practice, but they exist to instruct the
1046+
// coalescer to avoid coalescing the copy.
1047+
if (MI->isCopy() && SubReg != DstSubReg &&
1048+
(AArch64::ZPRRegClass.hasSubClassEq(DstRC) ||
1049+
AArch64::ZPRRegClass.hasSubClassEq(SrcRC))) {
1050+
unsigned SrcReg = MI->getOperand(1).getReg();
1051+
if (any_of(MRI.def_instructions(SrcReg), IsCoalescerBarrier))
1052+
return false;
1053+
unsigned DstReg = MI->getOperand(0).getReg();
1054+
if (any_of(MRI.use_nodbg_instructions(DstReg), IsCoalescerBarrier))
1055+
return false;
1056+
}
1057+
10231058
return true;
10241059
}

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
2222
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
2323
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
2424
SDNPOptInGlue]>;
25+
def AArch64CoalescerBarrier
26+
: SDNode<"AArch64ISD::COALESCER_BARRIER", SDTypeProfile<1, 1, []>, []>;
2527

2628
//===----------------------------------------------------------------------===//
2729
// Instruction naming conventions.
@@ -183,6 +185,26 @@ def : Pat<(int_aarch64_sme_set_tpidr2 i64:$val),
183185
(MSR 0xde85, GPR64:$val)>;
184186
def : Pat<(i64 (int_aarch64_sme_get_tpidr2)),
185187
(MRS 0xde85)>;
188+
189+
multiclass CoalescerBarrierPseudo<RegisterClass rc, list<ValueType> vts> {
190+
def NAME : Pseudo<(outs rc:$dst), (ins rc:$idx), []>, Sched<[]> {
191+
let Constraints = "$dst = $idx";
192+
}
193+
foreach vt = vts in {
194+
def : Pat<(vt (AArch64CoalescerBarrier (vt rc:$idx))),
195+
(!cast<Instruction>(NAME) rc:$idx)>;
196+
}
197+
}
198+
199+
multiclass CoalescerBarriers {
200+
defm _FPR16 : CoalescerBarrierPseudo<FPR16, [f16]>;
201+
defm _FPR32 : CoalescerBarrierPseudo<FPR32, [f32]>;
202+
defm _FPR64 : CoalescerBarrierPseudo<FPR64, [f64, v8i8, v4i16, v2i32, v1i64, v4f16, v2f32, v1f64, v4bf16]>;
203+
defm _FPR128 : CoalescerBarrierPseudo<FPR128, [f128, v16i8, v8i16, v4i32, v2i64, v8f16, v4f32, v2f64, v8bf16]>;
204+
}
205+
206+
defm COALESCER_BARRIER : CoalescerBarriers;
207+
186208
} // End let Predicates = [HasSME]
187209

188210
// Pseudo to match to smstart/smstop. This expands:

llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ define double @nonstreaming_caller_streaming_callee(double %x) nounwind noinline
2323
; CHECK-FISEL-NEXT: bl streaming_callee
2424
; CHECK-FISEL-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
2525
; CHECK-FISEL-NEXT: smstop sm
26+
; CHECK-FISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
2627
; CHECK-FISEL-NEXT: adrp x8, .LCPI0_0
2728
; CHECK-FISEL-NEXT: ldr d0, [x8, :lo12:.LCPI0_0]
28-
; CHECK-FISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
2929
; CHECK-FISEL-NEXT: fadd d0, d1, d0
3030
; CHECK-FISEL-NEXT: ldr x30, [sp, #80] // 8-byte Folded Reload
3131
; CHECK-FISEL-NEXT: ldp d9, d8, [sp, #64] // 16-byte Folded Reload
@@ -49,9 +49,9 @@ define double @nonstreaming_caller_streaming_callee(double %x) nounwind noinline
4949
; CHECK-GISEL-NEXT: bl streaming_callee
5050
; CHECK-GISEL-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
5151
; CHECK-GISEL-NEXT: smstop sm
52+
; CHECK-GISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
5253
; CHECK-GISEL-NEXT: mov x8, #4631107791820423168 // =0x4045000000000000
5354
; CHECK-GISEL-NEXT: fmov d0, x8
54-
; CHECK-GISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
5555
; CHECK-GISEL-NEXT: fadd d0, d1, d0
5656
; CHECK-GISEL-NEXT: ldr x30, [sp, #80] // 8-byte Folded Reload
5757
; CHECK-GISEL-NEXT: ldp d9, d8, [sp, #64] // 16-byte Folded Reload
@@ -82,9 +82,9 @@ define double @streaming_caller_nonstreaming_callee(double %x) nounwind noinline
8282
; CHECK-COMMON-NEXT: bl normal_callee
8383
; CHECK-COMMON-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
8484
; CHECK-COMMON-NEXT: smstart sm
85+
; CHECK-COMMON-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
8586
; CHECK-COMMON-NEXT: mov x8, #4631107791820423168 // =0x4045000000000000
8687
; CHECK-COMMON-NEXT: fmov d0, x8
87-
; CHECK-COMMON-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
8888
; CHECK-COMMON-NEXT: fadd d0, d1, d0
8989
; CHECK-COMMON-NEXT: ldr x30, [sp, #80] // 8-byte Folded Reload
9090
; CHECK-COMMON-NEXT: ldp d9, d8, [sp, #64] // 16-byte Folded Reload
@@ -110,14 +110,16 @@ define double @locally_streaming_caller_normal_callee(double %x) nounwind noinli
110110
; CHECK-COMMON-NEXT: str x30, [sp, #96] // 8-byte Folded Spill
111111
; CHECK-COMMON-NEXT: str d0, [sp, #24] // 8-byte Folded Spill
112112
; CHECK-COMMON-NEXT: smstart sm
113+
; CHECK-COMMON-NEXT: ldr d0, [sp, #24] // 8-byte Folded Reload
114+
; CHECK-COMMON-NEXT: str d0, [sp, #24] // 8-byte Folded Spill
113115
; CHECK-COMMON-NEXT: smstop sm
114116
; CHECK-COMMON-NEXT: ldr d0, [sp, #24] // 8-byte Folded Reload
115117
; CHECK-COMMON-NEXT: bl normal_callee
116118
; CHECK-COMMON-NEXT: str d0, [sp, #16] // 8-byte Folded Spill
117119
; CHECK-COMMON-NEXT: smstart sm
120+
; CHECK-COMMON-NEXT: ldr d1, [sp, #16] // 8-byte Folded Reload
118121
; CHECK-COMMON-NEXT: mov x8, #4631107791820423168 // =0x4045000000000000
119122
; CHECK-COMMON-NEXT: fmov d0, x8
120-
; CHECK-COMMON-NEXT: ldr d1, [sp, #16] // 8-byte Folded Reload
121123
; CHECK-COMMON-NEXT: fadd d0, d1, d0
122124
; CHECK-COMMON-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
123125
; CHECK-COMMON-NEXT: smstop sm
@@ -329,9 +331,9 @@ define fp128 @f128_call_sm(fp128 %a, fp128 %b) "aarch64_pstate_sm_enabled" nounw
329331
; CHECK-COMMON-NEXT: stp d11, d10, [sp, #64] // 16-byte Folded Spill
330332
; CHECK-COMMON-NEXT: stp d9, d8, [sp, #80] // 16-byte Folded Spill
331333
; CHECK-COMMON-NEXT: str x30, [sp, #96] // 8-byte Folded Spill
332-
; CHECK-COMMON-NEXT: stp q0, q1, [sp] // 32-byte Folded Spill
334+
; CHECK-COMMON-NEXT: stp q1, q0, [sp] // 32-byte Folded Spill
333335
; CHECK-COMMON-NEXT: smstop sm
334-
; CHECK-COMMON-NEXT: ldp q0, q1, [sp] // 32-byte Folded Reload
336+
; CHECK-COMMON-NEXT: ldp q1, q0, [sp] // 32-byte Folded Reload
335337
; CHECK-COMMON-NEXT: bl __addtf3
336338
; CHECK-COMMON-NEXT: str q0, [sp, #16] // 16-byte Folded Spill
337339
; CHECK-COMMON-NEXT: smstart sm
@@ -390,9 +392,9 @@ define float @frem_call_sm(float %a, float %b) "aarch64_pstate_sm_enabled" nounw
390392
; CHECK-COMMON-NEXT: stp d11, d10, [sp, #48] // 16-byte Folded Spill
391393
; CHECK-COMMON-NEXT: stp d9, d8, [sp, #64] // 16-byte Folded Spill
392394
; CHECK-COMMON-NEXT: str x30, [sp, #80] // 8-byte Folded Spill
393-
; CHECK-COMMON-NEXT: stp s0, s1, [sp, #8] // 8-byte Folded Spill
395+
; CHECK-COMMON-NEXT: stp s1, s0, [sp, #8] // 8-byte Folded Spill
394396
; CHECK-COMMON-NEXT: smstop sm
395-
; CHECK-COMMON-NEXT: ldp s0, s1, [sp, #8] // 8-byte Folded Reload
397+
; CHECK-COMMON-NEXT: ldp s1, s0, [sp, #8] // 8-byte Folded Reload
396398
; CHECK-COMMON-NEXT: bl fmodf
397399
; CHECK-COMMON-NEXT: str s0, [sp, #12] // 4-byte Folded Spill
398400
; CHECK-COMMON-NEXT: smstart sm
@@ -420,7 +422,9 @@ define float @frem_call_sm_compat(float %a, float %b) "aarch64_pstate_sm_compati
420422
; CHECK-COMMON-NEXT: stp x30, x19, [sp, #80] // 16-byte Folded Spill
421423
; CHECK-COMMON-NEXT: stp s0, s1, [sp, #8] // 8-byte Folded Spill
422424
; CHECK-COMMON-NEXT: bl __arm_sme_state
425+
; CHECK-COMMON-NEXT: ldp s2, s0, [sp, #8] // 8-byte Folded Reload
423426
; CHECK-COMMON-NEXT: and x19, x0, #0x1
427+
; CHECK-COMMON-NEXT: stp s2, s0, [sp, #8] // 8-byte Folded Spill
424428
; CHECK-COMMON-NEXT: tbz w19, #0, .LBB12_2
425429
; CHECK-COMMON-NEXT: // %bb.1:
426430
; CHECK-COMMON-NEXT: smstop sm

0 commit comments

Comments
 (0)