Skip to content

[RISCV] Unify vsetvli compatibility logic in forward and backwards passes #71657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 56 additions & 61 deletions llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,17 @@ DemandedFields getDemanded(const MachineInstr &MI,
return Res;
}

static MachineInstr *isADDIX0(Register Reg, const MachineRegisterInfo &MRI) {
if (Reg == RISCV::X0)
return nullptr;
if (MachineInstr *MI = MRI.getVRegDef(Reg);
MI && MI->getOpcode() == RISCV::ADDI && MI->getOperand(1).isReg() &&
MI->getOperand(2).isImm() && MI->getOperand(1).getReg() == RISCV::X0 &&
MI->getOperand(2).getImm() != 0)
return MI;
return nullptr;
}

/// Defines the abstract state with which the forward dataflow models the
/// values of the VL and VTYPE registers after insertion.
class VSETVLIInfo {
Expand All @@ -431,6 +442,7 @@ class VSETVLIInfo {
Uninitialized,
AVLIsReg,
AVLIsImm,
PreserveVL, // vsetvli x0, x0
Unknown,
} State = Uninitialized;

Expand Down Expand Up @@ -466,6 +478,8 @@ class VSETVLIInfo {
State = AVLIsImm;
}

void setPreserveVL() { State = PreserveVL; }

bool hasAVLImm() const { return State == AVLIsImm; }
bool hasAVLReg() const { return State == AVLIsReg; }
Register getAVLReg() const {
Expand All @@ -486,11 +500,7 @@ class VSETVLIInfo {
if (hasAVLReg()) {
if (getAVLReg() == RISCV::X0)
return true;
if (MachineInstr *MI = MRI.getVRegDef(getAVLReg());
MI && MI->getOpcode() == RISCV::ADDI &&
MI->getOperand(1).isReg() && MI->getOperand(2).isImm() &&
MI->getOperand(1).getReg() == RISCV::X0 &&
MI->getOperand(2).getImm() != 0)
if (isADDIX0(getAVLReg(), MRI))
return true;
return false;
}
Expand Down Expand Up @@ -579,8 +589,11 @@ class VSETVLIInfo {
// Determine whether the vector instructions requirements represented by
// Require are compatible with the previous vsetvli instruction represented
// by this. MI is the instruction whose requirements we're considering.
// The instruction represented by Require should come after this, unless
// OrderReversed is true.
bool isCompatible(const DemandedFields &Used, const VSETVLIInfo &Require,
const MachineRegisterInfo &MRI) const {
const MachineRegisterInfo &MRI,
bool OrderReversed = false) const {
assert(isValid() && Require.isValid() &&
"Can't compare invalid VSETVLIInfos");
assert(!Require.SEWLMULRatioOnly &&
Expand All @@ -593,11 +606,15 @@ class VSETVLIInfo {
if (SEWLMULRatioOnly)
return false;

if (Used.VLAny && !hasSameAVL(Require))
return false;
// If the VL will be preserved, then we don't need to check the AVL.
const uint8_t EndState = OrderReversed ? State : Require.State;
if (EndState != PreserveVL) {
if (Used.VLAny && !hasSameAVL(Require))
return false;

if (Used.VLZeroness && !hasEquallyZeroAVL(Require, MRI))
return false;
if (Used.VLZeroness && !hasEquallyZeroAVL(Require, MRI))
return false;
}

return hasCompatibleVTYPE(Used, Require);
}
Expand Down Expand Up @@ -849,9 +866,11 @@ static VSETVLIInfo getInfoForVSETVLI(const MachineInstr &MI) {
assert(MI.getOpcode() == RISCV::PseudoVSETVLI ||
MI.getOpcode() == RISCV::PseudoVSETVLIX0);
Register AVLReg = MI.getOperand(1).getReg();
assert((AVLReg != RISCV::X0 || MI.getOperand(0).getReg() != RISCV::X0) &&
"Can't handle X0, X0 vsetvli yet");
NewInfo.setAVLReg(AVLReg);

if (AVLReg == RISCV::X0 && MI.getOperand(0).getReg() == RISCV::X0)
NewInfo.setPreserveVL();
else
NewInfo.setAVLReg(AVLReg);
}
NewInfo.setVTYPE(MI.getOperand(2).getImm());

Expand Down Expand Up @@ -1426,52 +1445,9 @@ static void doUnion(DemandedFields &A, DemandedFields B) {
A.MaskPolicy |= B.MaskPolicy;
}

static bool isNonZeroAVL(const MachineOperand &MO) {
if (MO.isReg())
return RISCV::X0 == MO.getReg();
assert(MO.isImm());
return 0 != MO.getImm();
}

// Return true if we can mutate PrevMI to match MI without changing any the
// fields which would be observed.
static bool canMutatePriorConfig(const MachineInstr &PrevMI,
const MachineInstr &MI,
const DemandedFields &Used) {
// If the VL values aren't equal, return false if either a) the former is
// demanded, or b) we can't rewrite the former to be the later for
// implementation reasons.
if (!isVLPreservingConfig(MI)) {
if (Used.VLAny)
return false;

// We don't bother to handle the equally zero case here as it's largely
// uninteresting.
if (Used.VLZeroness) {
if (isVLPreservingConfig(PrevMI))
return false;
if (!isNonZeroAVL(MI.getOperand(1)) ||
!isNonZeroAVL(PrevMI.getOperand(1)))
return false;
}

// TODO: Track whether the register is defined between
// PrevMI and MI.
if (MI.getOperand(1).isReg() &&
RISCV::X0 != MI.getOperand(1).getReg())
return false;
}

if (!PrevMI.getOperand(2).isImm() || !MI.getOperand(2).isImm())
return false;

auto PriorVType = PrevMI.getOperand(2).getImm();
auto VType = MI.getOperand(2).getImm();
return areCompatibleVTYPEs(PriorVType, VType, Used);
}

void RISCVInsertVSETVLI::doLocalPostpass(MachineBasicBlock &MBB) {
MachineInstr *NextMI = nullptr;
VSETVLIInfo NextInfo;
// We can have arbitrary code in successors, so VL and VTYPE
// must be considered demanded.
DemandedFields Used;
Expand All @@ -1482,6 +1458,7 @@ void RISCVInsertVSETVLI::doLocalPostpass(MachineBasicBlock &MBB) {

if (!isVectorConfigInstr(MI)) {
doUnion(Used, getDemanded(MI, MRI, ST));
transferAfter(NextInfo, MI);
continue;
}

Expand All @@ -1495,14 +1472,31 @@ void RISCVInsertVSETVLI::doLocalPostpass(MachineBasicBlock &MBB) {
ToDelete.push_back(&MI);
// Leave NextMI unchanged
continue;
} else if (canMutatePriorConfig(MI, *NextMI, Used)) {
} else if (NextInfo.isCompatible(Used, getInfoForVSETVLI(MI), *MRI,
true)) {
if (!isVLPreservingConfig(*NextMI)) {
MI.getOperand(0).setReg(NextMI->getOperand(0).getReg());
MI.getOperand(0).setIsDead(false);

MachineOperand &AVL = MI.getOperand(1);
// If the old AVL was only used by MI, it's dead.
if (AVL.isReg() && AVL.getReg().isVirtual() &&
MRI->hasOneNonDBGUse(AVL.getReg()))
MRI->getVRegDef(AVL.getReg())->eraseFromParent();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it could have been pulled out into a separate patch but it has no effect on main, because we currently don't consider these cases as compatible


if (NextMI->getOperand(1).isImm())
MI.getOperand(1).ChangeToImmediate(NextMI->getOperand(1).getImm());
else
MI.getOperand(1).ChangeToRegister(NextMI->getOperand(1).getReg(), false);
AVL.ChangeToImmediate(NextMI->getOperand(1).getImm());
else {
// NextMI may have an AVL (addi x0, imm) whilst MI might have a
// different non-zero AVL. But the AVLs may be considered
// compatible. So hoist it up to MI in case it's not already
// dominated by it. See hasNonZeroAVL.
if (MachineInstr *ADDI =
isADDIX0(NextMI->getOperand(1).getReg(), *MRI))
ADDI->moveBefore(&MI);

AVL.ChangeToRegister(NextMI->getOperand(1).getReg(), false);
}
MI.setDesc(NextMI->getDesc());
}
MI.getOperand(2).setImm(NextMI->getOperand(2).getImm());
Expand All @@ -1511,6 +1505,7 @@ void RISCVInsertVSETVLI::doLocalPostpass(MachineBasicBlock &MBB) {
}
}
NextMI = &MI;
NextInfo = getInfoForVSETVLI(MI);
Used = getDemanded(MI, MRI, ST);
}

Expand Down
10 changes: 3 additions & 7 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ define <32 x i32> @insertelt_v32i32_0(<32 x i32> %a, i32 %y) {
define <32 x i32> @insertelt_v32i32_4(<32 x i32> %a, i32 %y) {
; CHECK-LABEL: insertelt_v32i32_4:
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 32
; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
; CHECK-NEXT: vmv.s.x v16, a0
; CHECK-NEXT: vsetivli zero, 5, e32, m2, tu, ma
; CHECK-NEXT: vmv.s.x v16, a0
; CHECK-NEXT: vslideup.vi v8, v16, 4
; CHECK-NEXT: ret
%b = insertelement <32 x i32> %a, i32 %y, i32 4
Expand All @@ -65,9 +63,8 @@ define <32 x i32> @insertelt_v32i32_31(<32 x i32> %a, i32 %y) {
; CHECK-LABEL: insertelt_v32i32_31:
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 32
; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
; CHECK-NEXT: vmv.s.x v16, a0
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; CHECK-NEXT: vmv.s.x v16, a0
; CHECK-NEXT: vslideup.vi v8, v16, 31
; CHECK-NEXT: ret
%b = insertelement <32 x i32> %a, i32 %y, i32 31
Expand Down Expand Up @@ -103,9 +100,8 @@ define <64 x i32> @insertelt_v64i32_63(<64 x i32> %a, i32 %y) {
; CHECK-LABEL: insertelt_v64i32_63:
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 32
; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
; CHECK-NEXT: vmv.s.x v24, a0
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; CHECK-NEXT: vmv.s.x v24, a0
; CHECK-NEXT: vslideup.vi v16, v24, 31
; CHECK-NEXT: ret
%b = insertelement <64 x i32> %a, i32 %y, i32 63
Expand Down
35 changes: 8 additions & 27 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
Original file line number Diff line number Diff line change
Expand Up @@ -12418,12 +12418,11 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: andi a2, a1, 1
; RV64ZVE32F-NEXT: beqz a2, .LBB98_2
; RV64ZVE32F-NEXT: # %bb.1: # %cond.load
; RV64ZVE32F-NEXT: vsetvli zero, zero, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: li a2, 32
; RV64ZVE32F-NEXT: vsetvli zero, a2, e8, mf4, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v8
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
; RV64ZVE32F-NEXT: li a3, 32
; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, tu, ma
; RV64ZVE32F-NEXT: vmv.s.x v10, a2
; RV64ZVE32F-NEXT: .LBB98_2: # %else
; RV64ZVE32F-NEXT: andi a2, a1, 2
Expand Down Expand Up @@ -12452,14 +12451,11 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: andi a2, a1, 16
; RV64ZVE32F-NEXT: beqz a2, .LBB98_8
; RV64ZVE32F-NEXT: .LBB98_7: # %cond.load10
; RV64ZVE32F-NEXT: vsetivli zero, 1, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vsetivli zero, 5, e8, m1, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v13
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
; RV64ZVE32F-NEXT: li a3, 32
; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vmv.s.x v12, a2
; RV64ZVE32F-NEXT: vsetivli zero, 5, e8, m1, tu, ma
; RV64ZVE32F-NEXT: vslideup.vi v10, v12, 4
; RV64ZVE32F-NEXT: .LBB98_8: # %else11
; RV64ZVE32F-NEXT: vsetivli zero, 8, e8, m1, ta, ma
Expand Down Expand Up @@ -12592,14 +12588,11 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: slli a2, a1, 43
; RV64ZVE32F-NEXT: bgez a2, .LBB98_32
; RV64ZVE32F-NEXT: .LBB98_31: # %cond.load58
; RV64ZVE32F-NEXT: vsetivli zero, 1, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vsetivli zero, 21, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v9
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
; RV64ZVE32F-NEXT: li a3, 32
; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vmv.s.x v12, a2
; RV64ZVE32F-NEXT: vsetivli zero, 21, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vslideup.vi v10, v12, 20
; RV64ZVE32F-NEXT: .LBB98_32: # %else59
; RV64ZVE32F-NEXT: vsetivli zero, 8, e8, m1, ta, ma
Expand Down Expand Up @@ -12742,14 +12735,11 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: andi a2, a1, 256
; RV64ZVE32F-NEXT: beqz a2, .LBB98_13
; RV64ZVE32F-NEXT: .LBB98_53: # %cond.load22
; RV64ZVE32F-NEXT: vsetivli zero, 1, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vsetivli zero, 9, e8, m1, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v12
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
; RV64ZVE32F-NEXT: li a3, 32
; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vmv.s.x v13, a2
; RV64ZVE32F-NEXT: vsetivli zero, 9, e8, m1, tu, ma
; RV64ZVE32F-NEXT: vslideup.vi v10, v13, 8
; RV64ZVE32F-NEXT: andi a2, a1, 512
; RV64ZVE32F-NEXT: bnez a2, .LBB98_14
Expand Down Expand Up @@ -12777,14 +12767,11 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: slli a2, a1, 47
; RV64ZVE32F-NEXT: bgez a2, .LBB98_26
; RV64ZVE32F-NEXT: .LBB98_56: # %cond.load46
; RV64ZVE32F-NEXT: vsetivli zero, 1, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vsetivli zero, 17, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v8
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
; RV64ZVE32F-NEXT: li a3, 32
; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vmv.s.x v12, a2
; RV64ZVE32F-NEXT: vsetivli zero, 17, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vslideup.vi v10, v12, 16
; RV64ZVE32F-NEXT: slli a2, a1, 46
; RV64ZVE32F-NEXT: bltz a2, .LBB98_27
Expand Down Expand Up @@ -12835,14 +12822,11 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: slli a2, a1, 39
; RV64ZVE32F-NEXT: bgez a2, .LBB98_37
; RV64ZVE32F-NEXT: .LBB98_61: # %cond.load70
; RV64ZVE32F-NEXT: vsetivli zero, 1, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vsetivli zero, 25, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v8
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
; RV64ZVE32F-NEXT: li a3, 32
; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vmv.s.x v12, a2
; RV64ZVE32F-NEXT: vsetivli zero, 25, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vslideup.vi v10, v12, 24
; RV64ZVE32F-NEXT: slli a2, a1, 38
; RV64ZVE32F-NEXT: bltz a2, .LBB98_38
Expand Down Expand Up @@ -12870,14 +12854,11 @@ define <32 x i8> @mgather_baseidx_v32i8(ptr %base, <32 x i8> %idxs, <32 x i1> %m
; RV64ZVE32F-NEXT: slli a2, a1, 35
; RV64ZVE32F-NEXT: bgez a2, .LBB98_42
; RV64ZVE32F-NEXT: .LBB98_64: # %cond.load82
; RV64ZVE32F-NEXT: vsetivli zero, 1, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vsetivli zero, 29, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vmv.x.s a2, v9
; RV64ZVE32F-NEXT: add a2, a0, a2
; RV64ZVE32F-NEXT: lbu a2, 0(a2)
; RV64ZVE32F-NEXT: li a3, 32
; RV64ZVE32F-NEXT: vsetvli zero, a3, e8, mf4, ta, ma
; RV64ZVE32F-NEXT: vmv.s.x v12, a2
; RV64ZVE32F-NEXT: vsetivli zero, 29, e8, m2, tu, ma
; RV64ZVE32F-NEXT: vslideup.vi v10, v12, 28
; RV64ZVE32F-NEXT: slli a2, a1, 34
; RV64ZVE32F-NEXT: bltz a2, .LBB98_43
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,9 @@ entry:
define double @test17(i64 %avl, <vscale x 1 x double> %a, <vscale x 1 x double> %b) nounwind {
; CHECK-LABEL: test17:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli a0, a0, e64, m1, ta, ma
; CHECK-NEXT: vfmv.f.s fa5, v8
; CHECK-NEXT: vsetvli a0, a0, e32, mf2, ta, ma
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case can be far more optimized I think. Only one vsetvli zero, a0, e64, m1, ta, ma is needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it need to be vsetvli a0, a0, e64, m1, ta, ma since the resulting VL is used?

; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: vfmv.f.s fa5, v8
; CHECK-NEXT: vfadd.vv v8, v8, v9
; CHECK-NEXT: vfmv.f.s fa4, v8
; CHECK-NEXT: fadd.d fa0, fa5, fa4
Expand Down