Skip to content

release/18.x: [RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices (#82506) #82572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3192,15 +3192,16 @@ static std::optional<uint64_t> getExactInteger(const APFloat &APF,
// Note that this method will also match potentially unappealing index
// sequences, like <i32 0, i32 50939494>, however it is left to the caller to
// determine whether this is worth generating code for.
static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op,
unsigned EltSizeInBits) {
unsigned NumElts = Op.getNumOperands();
assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unexpected BUILD_VECTOR");
bool IsInteger = Op.getValueType().isInteger();

std::optional<unsigned> SeqStepDenom;
std::optional<int64_t> SeqStepNum, SeqAddend;
std::optional<std::pair<uint64_t, unsigned>> PrevElt;
unsigned EltSizeInBits = Op.getValueType().getScalarSizeInBits();
assert(EltSizeInBits >= Op.getValueType().getScalarSizeInBits());
for (unsigned Idx = 0; Idx < NumElts; Idx++) {
// Assume undef elements match the sequence; we just have to be careful
// when interpolating across them.
Expand All @@ -3213,14 +3214,14 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
if (!isa<ConstantSDNode>(Op.getOperand(Idx)))
return std::nullopt;
Val = Op.getConstantOperandVal(Idx) &
maskTrailingOnes<uint64_t>(EltSizeInBits);
maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
} else {
// The BUILD_VECTOR must be all constants.
if (!isa<ConstantFPSDNode>(Op.getOperand(Idx)))
return std::nullopt;
if (auto ExactInteger = getExactInteger(
cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
EltSizeInBits))
Op.getScalarValueSizeInBits()))
Val = *ExactInteger;
else
return std::nullopt;
Expand Down Expand Up @@ -3276,11 +3277,11 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
uint64_t Val;
if (IsInteger) {
Val = Op.getConstantOperandVal(Idx) &
maskTrailingOnes<uint64_t>(EltSizeInBits);
maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
} else {
Val = *getExactInteger(
cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
EltSizeInBits);
Op.getScalarValueSizeInBits());
}
uint64_t ExpectedVal =
(int64_t)(Idx * (uint64_t)*SeqStepNum) / *SeqStepDenom;
Expand Down Expand Up @@ -3550,7 +3551,7 @@ static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG,
// Try and match index sequences, which we can lower to the vid instruction
// with optional modifications. An all-undef vector is matched by
// getSplatValue, above.
if (auto SimpleVID = isSimpleVIDSequence(Op)) {
if (auto SimpleVID = isSimpleVIDSequence(Op, Op.getScalarValueSizeInBits())) {
int64_t StepNumerator = SimpleVID->StepNumerator;
unsigned StepDenominator = SimpleVID->StepDenominator;
int64_t Addend = SimpleVID->Addend;
Expand Down Expand Up @@ -15562,7 +15563,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,

if (Index.getOpcode() == ISD::BUILD_VECTOR &&
MGN->getExtensionType() == ISD::NON_EXTLOAD && isTypeLegal(VT)) {
if (std::optional<VIDSequence> SimpleVID = isSimpleVIDSequence(Index);
// The sequence will be XLenVT, not the type of Index. Tell
// isSimpleVIDSequence this so we avoid overflow.
if (std::optional<VIDSequence> SimpleVID =
isSimpleVIDSequence(Index, Subtarget.getXLen());
SimpleVID && SimpleVID->StepDenominator == 1) {
const int64_t StepNumerator = SimpleVID->StepNumerator;
const int64_t Addend = SimpleVID->Addend;
Expand Down
43 changes: 43 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
Original file line number Diff line number Diff line change
Expand Up @@ -15086,5 +15086,48 @@ define <32 x i64> @mgather_strided_split(ptr %base) {
ret <32 x i64> %x
}

define <4 x i32> @masked_gather_widen_sew_negative_stride(ptr %base) {
; RV32V-LABEL: masked_gather_widen_sew_negative_stride:
; RV32V: # %bb.0:
; RV32V-NEXT: addi a0, a0, 136
; RV32V-NEXT: li a1, -136
; RV32V-NEXT: vsetivli zero, 2, e64, m1, ta, ma
; RV32V-NEXT: vlse64.v v8, (a0), a1
; RV32V-NEXT: ret
;
; RV64V-LABEL: masked_gather_widen_sew_negative_stride:
; RV64V: # %bb.0:
; RV64V-NEXT: addi a0, a0, 136
; RV64V-NEXT: li a1, -136
; RV64V-NEXT: vsetivli zero, 2, e64, m1, ta, ma
; RV64V-NEXT: vlse64.v v8, (a0), a1
; RV64V-NEXT: ret
;
; RV32ZVE32F-LABEL: masked_gather_widen_sew_negative_stride:
; RV32ZVE32F: # %bb.0:
; RV32ZVE32F-NEXT: lui a1, 16393
; RV32ZVE32F-NEXT: addi a1, a1, -888
; RV32ZVE32F-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; RV32ZVE32F-NEXT: vmv.s.x v9, a1
; RV32ZVE32F-NEXT: vluxei8.v v8, (a0), v9
; RV32ZVE32F-NEXT: ret
;
; RV64ZVE32F-LABEL: masked_gather_widen_sew_negative_stride:
; RV64ZVE32F: # %bb.0:
; RV64ZVE32F-NEXT: addi a1, a0, 136
; RV64ZVE32F-NEXT: lw a2, 140(a0)
; RV64ZVE32F-NEXT: lw a3, 0(a0)
; RV64ZVE32F-NEXT: lw a0, 4(a0)
; RV64ZVE32F-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; RV64ZVE32F-NEXT: vlse32.v v8, (a1), zero
; RV64ZVE32F-NEXT: vslide1down.vx v8, v8, a2
; RV64ZVE32F-NEXT: vslide1down.vx v8, v8, a3
; RV64ZVE32F-NEXT: vslide1down.vx v8, v8, a0
; RV64ZVE32F-NEXT: ret
%ptrs = getelementptr i32, ptr %base, <4 x i64> <i64 34, i64 35, i64 0, i64 1>
%x = call <4 x i32> @llvm.masked.gather.v4i32.v32p0(<4 x ptr> %ptrs, i32 8, <4 x i1> shufflevector(<4 x i1> insertelement(<4 x i1> poison, i1 true, i32 0), <4 x i1> poison, <4 x i32> zeroinitializer), <4 x i32> poison)
ret <4 x i32> %x
}

;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; RV64: {{.*}}