Skip to content

Commit 6246b49

Browse files
authored
[RISCV] Select ISD::AVGCEILS/AVGFLOORS as vaadd. (#92839)
I think the behaviors are the same if this describes their behavior. AVGFLOORS sign extends the inputs by 1 bit, adds them, then does an arithmetic shift right by 1 before truncating to the original bit width. This is vaadd with rdn rounding mode. AVGCEILS sign extends the inputs by 1 bit, adds them, then does an arithmetic shift right by 1. If the bit shifted out is 1, it adds 1 to the shifted value. Then truncates to the original bit width. This is vaadd with rnu rounding mode. I think this wasn't implemented previously because there was some confusion about what average means. Some may expect average to round towards zero, but there is no way to do that in RISC-V or with the SelectionDAG nodes. Related issue riscvarchive/riscv-v-spec#935
1 parent 50dbbe5 commit 6246b49

File tree

6 files changed

+42
-30
lines changed

6 files changed

+42
-30
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -844,8 +844,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
844844
VT, Custom);
845845
setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
846846
Custom);
847-
setOperationAction({ISD::AVGFLOORU, ISD::AVGCEILU, ISD::SADDSAT,
848-
ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT},
847+
setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU, ISD::AVGCEILS,
848+
ISD::AVGCEILU, ISD::SADDSAT, ISD::UADDSAT,
849+
ISD::SSUBSAT, ISD::USUBSAT},
849850
VT, Legal);
850851

851852
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
@@ -1237,8 +1238,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
12371238
if (VT.getVectorElementType() != MVT::i64 || Subtarget.hasStdExtV())
12381239
setOperationAction({ISD::MULHS, ISD::MULHU}, VT, Custom);
12391240

1240-
setOperationAction({ISD::AVGFLOORU, ISD::AVGCEILU, ISD::SADDSAT,
1241-
ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT},
1241+
setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU, ISD::AVGCEILS,
1242+
ISD::AVGCEILU, ISD::SADDSAT, ISD::UADDSAT,
1243+
ISD::SSUBSAT, ISD::USUBSAT},
12421244
VT, Custom);
12431245

12441246
setOperationAction(ISD::VSELECT, VT, Custom);
@@ -5841,7 +5843,9 @@ static unsigned getRISCVVLOp(SDValue Op) {
58415843
OP_CASE(UADDSAT)
58425844
OP_CASE(SSUBSAT)
58435845
OP_CASE(USUBSAT)
5846+
OP_CASE(AVGFLOORS)
58445847
OP_CASE(AVGFLOORU)
5848+
OP_CASE(AVGCEILS)
58455849
OP_CASE(AVGCEILU)
58465850
OP_CASE(FADD)
58475851
OP_CASE(FSUB)
@@ -5956,7 +5960,7 @@ static bool hasMergeOp(unsigned Opcode) {
59565960
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
59575961
"not a RISC-V target specific op");
59585962
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
5959-
126 &&
5963+
128 &&
59605964
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
59615965
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
59625966
21 &&
@@ -5982,7 +5986,7 @@ static bool hasMaskOp(unsigned Opcode) {
59825986
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
59835987
"not a RISC-V target specific op");
59845988
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
5985-
126 &&
5989+
128 &&
59865990
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
59875991
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
59885992
21 &&
@@ -6882,7 +6886,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
68826886
!Subtarget.hasVInstructionsF16()))
68836887
return SplitVectorOp(Op, DAG);
68846888
[[fallthrough]];
6889+
case ISD::AVGFLOORS:
68856890
case ISD::AVGFLOORU:
6891+
case ISD::AVGCEILS:
68866892
case ISD::AVGCEILU:
68876893
case ISD::SMIN:
68886894
case ISD::SMAX:
@@ -19958,7 +19964,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
1995819964
NODE_NAME_CASE(UDIV_VL)
1995919965
NODE_NAME_CASE(UREM_VL)
1996019966
NODE_NAME_CASE(XOR_VL)
19967+
NODE_NAME_CASE(AVGFLOORS_VL)
1996119968
NODE_NAME_CASE(AVGFLOORU_VL)
19969+
NODE_NAME_CASE(AVGCEILS_VL)
1996219970
NODE_NAME_CASE(AVGCEILU_VL)
1996319971
NODE_NAME_CASE(SADDSAT_VL)
1996419972
NODE_NAME_CASE(UADDSAT_VL)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,12 @@ enum NodeType : unsigned {
264264
SSUBSAT_VL,
265265
USUBSAT_VL,
266266

267+
// Averaging adds of signed integers.
268+
AVGFLOORS_VL,
267269
// Averaging adds of unsigned integers.
268270
AVGFLOORU_VL,
271+
// Rounding averaging adds of signed integers.
272+
AVGCEILS_VL,
269273
// Rounding averaging adds of unsigned integers.
270274
AVGCEILU_VL,
271275

llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -881,17 +881,17 @@ multiclass VPatMultiplyAddSDNode_VV_VX<SDNode op, string instruction_name> {
881881
}
882882
}
883883

884-
multiclass VPatAVGADD_VV_VX_RM<SDNode vop, int vxrm> {
884+
multiclass VPatAVGADD_VV_VX_RM<SDNode vop, int vxrm, string suffix = ""> {
885885
foreach vti = AllIntegerVectors in {
886886
let Predicates = GetVTypePredicates<vti>.Predicates in {
887887
def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
888888
(vti.Vector vti.RegClass:$rs2)),
889-
(!cast<Instruction>("PseudoVAADDU_VV_"#vti.LMul.MX)
889+
(!cast<Instruction>("PseudoVAADD"#suffix#"_VV_"#vti.LMul.MX)
890890
(vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.RegClass:$rs2,
891891
vxrm, vti.AVL, vti.Log2SEW, TA_MA)>;
892892
def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
893893
(vti.Vector (SplatPat (XLenVT GPR:$rs2)))),
894-
(!cast<Instruction>("PseudoVAADDU_VX_"#vti.LMul.MX)
894+
(!cast<Instruction>("PseudoVAADD"#suffix#"_VX_"#vti.LMul.MX)
895895
(vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, GPR:$rs2,
896896
vxrm, vti.AVL, vti.Log2SEW, TA_MA)>;
897897
}
@@ -1163,8 +1163,10 @@ defm : VPatBinarySDNode_VV_VX<ssubsat, "PseudoVSSUB">;
11631163
defm : VPatBinarySDNode_VV_VX<usubsat, "PseudoVSSUBU">;
11641164

11651165
// 12.2. Vector Single-Width Averaging Add and Subtract
1166-
defm : VPatAVGADD_VV_VX_RM<avgflooru, 0b10>;
1167-
defm : VPatAVGADD_VV_VX_RM<avgceilu, 0b00>;
1166+
defm : VPatAVGADD_VV_VX_RM<avgfloors, 0b10>;
1167+
defm : VPatAVGADD_VV_VX_RM<avgflooru, 0b10, suffix = "U">;
1168+
defm : VPatAVGADD_VV_VX_RM<avgceils, 0b00>;
1169+
defm : VPatAVGADD_VV_VX_RM<avgceilu, 0b00, suffix = "U">;
11681170

11691171
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
11701172
multiclass VPatTruncSatClipSDNode<VTypeInfo vti, VTypeInfo wti> {

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def riscv_ctlz_vl : SDNode<"RISCVISD::CTLZ_VL", SDT_RISCVIntUnOp_VL>
111111
def riscv_cttz_vl : SDNode<"RISCVISD::CTTZ_VL", SDT_RISCVIntUnOp_VL>;
112112
def riscv_ctpop_vl : SDNode<"RISCVISD::CTPOP_VL", SDT_RISCVIntUnOp_VL>;
113113

114+
def riscv_avgfloors_vl : SDNode<"RISCVISD::AVGFLOORS_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
114115
def riscv_avgflooru_vl : SDNode<"RISCVISD::AVGFLOORU_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
116+
def riscv_avgceils_vl : SDNode<"RISCVISD::AVGCEILS_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
115117
def riscv_avgceilu_vl : SDNode<"RISCVISD::AVGCEILU_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
116118
def riscv_saddsat_vl : SDNode<"RISCVISD::SADDSAT_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
117119
def riscv_uaddsat_vl : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
@@ -2073,19 +2075,19 @@ multiclass VPatSlide1VL_VF<SDNode vop, string instruction_name> {
20732075
}
20742076
}
20752077

2076-
multiclass VPatAVGADDVL_VV_VX_RM<SDNode vop, int vxrm> {
2078+
multiclass VPatAVGADDVL_VV_VX_RM<SDNode vop, int vxrm, string suffix = ""> {
20772079
foreach vti = AllIntegerVectors in {
20782080
let Predicates = GetVTypePredicates<vti>.Predicates in {
20792081
def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
20802082
(vti.Vector vti.RegClass:$rs2),
20812083
vti.RegClass:$merge, (vti.Mask V0), VLOpFrag),
2082-
(!cast<Instruction>("PseudoVAADDU_VV_"#vti.LMul.MX#"_MASK")
2084+
(!cast<Instruction>("PseudoVAADD"#suffix#"_VV_"#vti.LMul.MX#"_MASK")
20832085
vti.RegClass:$merge, vti.RegClass:$rs1, vti.RegClass:$rs2,
20842086
(vti.Mask V0), vxrm, GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
20852087
def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
20862088
(vti.Vector (SplatPat (XLenVT GPR:$rs2))),
20872089
vti.RegClass:$merge, (vti.Mask V0), VLOpFrag),
2088-
(!cast<Instruction>("PseudoVAADDU_VX_"#vti.LMul.MX#"_MASK")
2090+
(!cast<Instruction>("PseudoVAADD"#suffix#"_VX_"#vti.LMul.MX#"_MASK")
20892091
vti.RegClass:$merge, vti.RegClass:$rs1, GPR:$rs2,
20902092
(vti.Mask V0), vxrm, GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
20912093
}
@@ -2369,8 +2371,10 @@ defm : VPatBinaryVL_VV_VX<riscv_ssubsat_vl, "PseudoVSSUB">;
23692371
defm : VPatBinaryVL_VV_VX<riscv_usubsat_vl, "PseudoVSSUBU">;
23702372

23712373
// 12.2. Vector Single-Width Averaging Add and Subtract
2372-
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgflooru_vl, 0b10>;
2373-
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00>;
2374+
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgfloors_vl, 0b10>;
2375+
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgflooru_vl, 0b10, suffix="U">;
2376+
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceils_vl, 0b00>;
2377+
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
23742378

23752379
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
23762380
multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ define <8 x i8> @vaaddu_vx_v8i8_floor(<8 x i8> %x, i8 %y) {
3838
define <8 x i8> @vaaddu_vv_v8i8_floor_sexti16(<8 x i8> %x, <8 x i8> %y) {
3939
; CHECK-LABEL: vaaddu_vv_v8i8_floor_sexti16:
4040
; CHECK: # %bb.0:
41+
; CHECK-NEXT: csrwi vxrm, 2
4142
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
42-
; CHECK-NEXT: vwadd.vv v10, v8, v9
43-
; CHECK-NEXT: vnsrl.wi v8, v10, 1
43+
; CHECK-NEXT: vaadd.vv v8, v8, v9
4444
; CHECK-NEXT: ret
4545
%xzv = sext <8 x i8> %x to <8 x i16>
4646
%yzv = sext <8 x i8> %y to <8 x i16>
@@ -248,12 +248,9 @@ define <8 x i8> @vaaddu_vx_v8i8_ceil(<8 x i8> %x, i8 %y) {
248248
define <8 x i8> @vaaddu_vv_v8i8_ceil_sexti16(<8 x i8> %x, <8 x i8> %y) {
249249
; CHECK-LABEL: vaaddu_vv_v8i8_ceil_sexti16:
250250
; CHECK: # %bb.0:
251+
; CHECK-NEXT: csrwi vxrm, 0
251252
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
252-
; CHECK-NEXT: vwadd.vv v10, v8, v9
253-
; CHECK-NEXT: vsetvli zero, zero, e16, m1, ta, ma
254-
; CHECK-NEXT: vadd.vi v8, v10, 1
255-
; CHECK-NEXT: vsetvli zero, zero, e8, mf2, ta, ma
256-
; CHECK-NEXT: vnsrl.wi v8, v8, 1
253+
; CHECK-NEXT: vaadd.vv v8, v8, v9
257254
; CHECK-NEXT: ret
258255
%xzv = sext <8 x i8> %x to <8 x i16>
259256
%yzv = sext <8 x i8> %y to <8 x i16>

llvm/test/CodeGen/RISCV/rvv/vaaddu-sdnode.ll

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ define <vscale x 8 x i8> @vaaddu_vx_nxv8i8_floor(<vscale x 8 x i8> %x, i8 %y) {
3737
define <vscale x 8 x i8> @vaaddu_vv_nxv8i8_floor_sexti16(<vscale x 8 x i8> %x, <vscale x 8 x i8> %y) {
3838
; CHECK-LABEL: vaaddu_vv_nxv8i8_floor_sexti16:
3939
; CHECK: # %bb.0:
40+
; CHECK-NEXT: csrwi vxrm, 2
4041
; CHECK-NEXT: vsetvli a0, zero, e8, m1, ta, ma
41-
; CHECK-NEXT: vwadd.vv v10, v8, v9
42-
; CHECK-NEXT: vnsrl.wi v8, v10, 1
42+
; CHECK-NEXT: vaadd.vv v8, v8, v9
4343
; CHECK-NEXT: ret
4444
%xzv = sext <vscale x 8 x i8> %x to <vscale x 8 x i16>
4545
%yzv = sext <vscale x 8 x i8> %y to <vscale x 8 x i16>
@@ -226,12 +226,9 @@ define <vscale x 8 x i8> @vaaddu_vx_nxv8i8_ceil(<vscale x 8 x i8> %x, i8 %y) {
226226
define <vscale x 8 x i8> @vaaddu_vv_nxv8i8_ceil_sexti16(<vscale x 8 x i8> %x, <vscale x 8 x i8> %y) {
227227
; CHECK-LABEL: vaaddu_vv_nxv8i8_ceil_sexti16:
228228
; CHECK: # %bb.0:
229+
; CHECK-NEXT: csrwi vxrm, 0
229230
; CHECK-NEXT: vsetvli a0, zero, e8, m1, ta, ma
230-
; CHECK-NEXT: vwadd.vv v10, v8, v9
231-
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
232-
; CHECK-NEXT: vadd.vi v10, v10, 1
233-
; CHECK-NEXT: vsetvli zero, zero, e8, m1, ta, ma
234-
; CHECK-NEXT: vnsrl.wi v8, v10, 1
231+
; CHECK-NEXT: vaadd.vv v8, v8, v9
235232
; CHECK-NEXT: ret
236233
%xzv = sext <vscale x 8 x i8> %x to <vscale x 8 x i16>
237234
%yzv = sext <vscale x 8 x i8> %y to <vscale x 8 x i16>

0 commit comments

Comments
 (0)