Skip to content

Commit 0330d18

Browse files
committed
[RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices
This fixes the miscompile reported in #82430 by telling isSimpleVIDSequence to sign extending to XLen instead of the type of the indices, since the "sequence" of indices generated by a strided load will be at XLen. This was the simplest way I could think of of getting isSimpleVIDSequence to treat the indexes as if they were zero extended to XLenVT. Another way we could do this is by refactoring out the "get constant integers" part from isSimpleVIDSequence and handle them as APInts so we can separately zero extend it.
1 parent 2cd59bd commit 0330d18

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3240,15 +3240,16 @@ static std::optional<uint64_t> getExactInteger(const APFloat &APF,
32403240
// Note that this method will also match potentially unappealing index
32413241
// sequences, like <i32 0, i32 50939494>, however it is left to the caller to
32423242
// determine whether this is worth generating code for.
3243-
static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
3243+
static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op,
3244+
unsigned EltSizeInBits) {
32443245
unsigned NumElts = Op.getNumOperands();
32453246
assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unexpected BUILD_VECTOR");
32463247
bool IsInteger = Op.getValueType().isInteger();
32473248

32483249
std::optional<unsigned> SeqStepDenom;
32493250
std::optional<int64_t> SeqStepNum, SeqAddend;
32503251
std::optional<std::pair<uint64_t, unsigned>> PrevElt;
3251-
unsigned EltSizeInBits = Op.getValueType().getScalarSizeInBits();
3252+
assert(EltSizeInBits >= Op.getValueType().getScalarSizeInBits());
32523253
for (unsigned Idx = 0; Idx < NumElts; Idx++) {
32533254
// Assume undef elements match the sequence; we just have to be careful
32543255
// when interpolating across them.
@@ -3261,14 +3262,14 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
32613262
if (!isa<ConstantSDNode>(Op.getOperand(Idx)))
32623263
return std::nullopt;
32633264
Val = Op.getConstantOperandVal(Idx) &
3264-
maskTrailingOnes<uint64_t>(EltSizeInBits);
3265+
maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
32653266
} else {
32663267
// The BUILD_VECTOR must be all constants.
32673268
if (!isa<ConstantFPSDNode>(Op.getOperand(Idx)))
32683269
return std::nullopt;
32693270
if (auto ExactInteger = getExactInteger(
32703271
cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
3271-
EltSizeInBits))
3272+
Op.getScalarValueSizeInBits()))
32723273
Val = *ExactInteger;
32733274
else
32743275
return std::nullopt;
@@ -3324,11 +3325,11 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
33243325
uint64_t Val;
33253326
if (IsInteger) {
33263327
Val = Op.getConstantOperandVal(Idx) &
3327-
maskTrailingOnes<uint64_t>(EltSizeInBits);
3328+
maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
33283329
} else {
33293330
Val = *getExactInteger(
33303331
cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
3331-
EltSizeInBits);
3332+
Op.getScalarValueSizeInBits());
33323333
}
33333334
uint64_t ExpectedVal =
33343335
(int64_t)(Idx * (uint64_t)*SeqStepNum) / *SeqStepDenom;
@@ -3598,7 +3599,7 @@ static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG,
35983599
// Try and match index sequences, which we can lower to the vid instruction
35993600
// with optional modifications. An all-undef vector is matched by
36003601
// getSplatValue, above.
3601-
if (auto SimpleVID = isSimpleVIDSequence(Op)) {
3602+
if (auto SimpleVID = isSimpleVIDSequence(Op, Op.getScalarValueSizeInBits())) {
36023603
int64_t StepNumerator = SimpleVID->StepNumerator;
36033604
unsigned StepDenominator = SimpleVID->StepDenominator;
36043605
int64_t Addend = SimpleVID->Addend;
@@ -15978,7 +15979,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1597815979

1597915980
if (Index.getOpcode() == ISD::BUILD_VECTOR &&
1598015981
MGN->getExtensionType() == ISD::NON_EXTLOAD && isTypeLegal(VT)) {
15981-
if (std::optional<VIDSequence> SimpleVID = isSimpleVIDSequence(Index);
15982+
// The sequence will be XLenVT, not the type of Index. Tell
15983+
// isSimpleVIDSequence this so we avoid overflow.
15984+
if (std::optional<VIDSequence> SimpleVID =
15985+
isSimpleVIDSequence(Index, Subtarget.getXLen());
1598215986
SimpleVID && SimpleVID->StepDenominator == 1) {
1598315987
const int64_t StepNumerator = SimpleVID->StepNumerator;
1598415988
const int64_t Addend = SimpleVID->Addend;

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15086,22 +15086,18 @@ define <32 x i64> @mgather_strided_split(ptr %base) {
1508615086
ret <32 x i64> %x
1508715087
}
1508815088

15089-
; FIXME: This is a miscompile triggered by the mgather ->
15090-
; riscv.masked.strided.load combine. In order for it to trigger we need either a
15091-
; strided gather that RISCVGatherScatterLowering doesn't pick up, or a new
15092-
; strided gather generated by the widening sew combine.
1509315089
define <4 x i32> @masked_gather_widen_sew_negative_stride(ptr %base) {
1509415090
; RV32V-LABEL: masked_gather_widen_sew_negative_stride:
1509515091
; RV32V: # %bb.0:
15096-
; RV32V-NEXT: addi a0, a0, -128
15092+
; RV32V-NEXT: addi a0, a0, 128
1509715093
; RV32V-NEXT: li a1, -128
1509815094
; RV32V-NEXT: vsetivli zero, 2, e64, m1, ta, ma
1509915095
; RV32V-NEXT: vlse64.v v8, (a0), a1
1510015096
; RV32V-NEXT: ret
1510115097
;
1510215098
; RV64V-LABEL: masked_gather_widen_sew_negative_stride:
1510315099
; RV64V: # %bb.0:
15104-
; RV64V-NEXT: addi a0, a0, -128
15100+
; RV64V-NEXT: addi a0, a0, 128
1510515101
; RV64V-NEXT: li a1, -128
1510615102
; RV64V-NEXT: vsetivli zero, 2, e64, m1, ta, ma
1510715103
; RV64V-NEXT: vlse64.v v8, (a0), a1

0 commit comments

Comments
 (0)