Skip to content

Commit 06e295a

Browse files
committed
[WIP][DAG] Add legalization handling for ABDS/ABDU
Still WIP, but I wanted to get some visibility to other teams. Always match ABD patterns pre-legalization, and use TargetLowering::expandABD to expand again during legalization.
1 parent 7c137f7 commit 06e295a

20 files changed

+1198
-1038
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4140,13 +4140,13 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
41404140
}
41414141

41424142
// smax(a,b) - smin(a,b) --> abds(a,b)
4143-
if (hasOperation(ISD::ABDS, VT) &&
4143+
if ((!LegalOperations || hasOperation(ISD::ABDS, VT)) &&
41444144
sd_match(N0, m_SMax(m_Value(A), m_Value(B))) &&
41454145
sd_match(N1, m_SMin(m_Specific(A), m_Specific(B))))
41464146
return DAG.getNode(ISD::ABDS, DL, VT, A, B);
41474147

41484148
// umax(a,b) - umin(a,b) --> abdu(a,b)
4149-
if (hasOperation(ISD::ABDU, VT) &&
4149+
if ((!LegalOperations || hasOperation(ISD::ABDU, VT)) &&
41504150
sd_match(N0, m_UMax(m_Value(A), m_Value(B))) &&
41514151
sd_match(N1, m_UMin(m_Specific(A), m_Specific(B))))
41524152
return DAG.getNode(ISD::ABDU, DL, VT, A, B);
@@ -10914,7 +10914,8 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
1091410914
(Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
1091510915
Opc0 != ISD::SIGN_EXTEND_INREG)) {
1091610916
// fold (abs (sub nsw x, y)) -> abds(x, y)
10917-
if (AbsOp1->getFlags().hasNoSignedWrap() && hasOperation(ISD::ABDS, VT) &&
10917+
if (AbsOp1->getFlags().hasNoSignedWrap() &&
10918+
(!LegalOperations || hasOperation(ISD::ABDS, VT)) &&
1091810919
TLI.preferABDSToABSWithNSW(VT)) {
1091910920
SDValue ABD = DAG.getNode(ISD::ABDS, DL, VT, Op0, Op1);
1092010921
return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
@@ -10936,7 +10937,8 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
1093610937
// fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
1093710938
EVT MaxVT = VT0.bitsGT(VT1) ? VT0 : VT1;
1093810939
if ((VT0 == MaxVT || Op0->hasOneUse()) &&
10939-
(VT1 == MaxVT || Op1->hasOneUse()) && hasOperation(ABDOpcode, MaxVT)) {
10940+
(VT1 == MaxVT || Op1->hasOneUse()) &&
10941+
(!LegalOperations || hasOperation(ABDOpcode, MaxVT))) {
1094010942
SDValue ABD = DAG.getNode(ABDOpcode, DL, MaxVT,
1094110943
DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op0),
1094210944
DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op1));
@@ -10946,7 +10948,7 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
1094610948

1094710949
// fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
1094810950
// fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
10949-
if (hasOperation(ABDOpcode, VT)) {
10951+
if (!LegalOperations || hasOperation(ABDOpcode, VT)) {
1095010952
SDValue ABD = DAG.getNode(ABDOpcode, DL, VT, Op0, Op1);
1095110953
return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
1095210954
}
@@ -12315,7 +12317,7 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
1231512317
N1.getOperand(1) == N2.getOperand(0)) {
1231612318
bool IsSigned = isSignedIntSetCC(CC);
1231712319
unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
12318-
if (hasOperation(ABDOpc, VT)) {
12320+
if (!LegalOperations || hasOperation(ABDOpc, VT)) {
1231912321
switch (CC) {
1232012322
case ISD::SETGT:
1232112323
case ISD::SETGE:

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,15 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
188188
case ISD::VP_SUB:
189189
case ISD::VP_MUL: Res = PromoteIntRes_SimpleIntBinOp(N); break;
190190

191+
case ISD::ABDS:
191192
case ISD::VP_SMIN:
192193
case ISD::VP_SMAX:
193194
case ISD::SDIV:
194195
case ISD::SREM:
195196
case ISD::VP_SDIV:
196197
case ISD::VP_SREM: Res = PromoteIntRes_SExtIntBinOp(N); break;
197198

199+
case ISD::ABDU:
198200
case ISD::VP_UMIN:
199201
case ISD::VP_UMAX:
200202
case ISD::UDIV:
@@ -2660,6 +2662,8 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
26602662
case ISD::PARITY: ExpandIntRes_PARITY(N, Lo, Hi); break;
26612663
case ISD::Constant: ExpandIntRes_Constant(N, Lo, Hi); break;
26622664
case ISD::ABS: ExpandIntRes_ABS(N, Lo, Hi); break;
2665+
case ISD::ABDS:
2666+
case ISD::ABDU: ExpandIntRes_ABD(N, Lo, Hi); break;
26632667
case ISD::CTLZ_ZERO_UNDEF:
26642668
case ISD::CTLZ: ExpandIntRes_CTLZ(N, Lo, Hi); break;
26652669
case ISD::CTPOP: ExpandIntRes_CTPOP(N, Lo, Hi); break;
@@ -3706,6 +3710,11 @@ void DAGTypeLegalizer::ExpandIntRes_CTLZ(SDNode *N,
37063710
Hi = DAG.getConstant(0, dl, NVT);
37073711
}
37083712

3713+
void DAGTypeLegalizer::ExpandIntRes_ABD(SDNode *N, SDValue &Lo, SDValue &Hi) {
3714+
SDValue Result = TLI.expandABD(N, DAG);
3715+
SplitInteger(Result, Lo, Hi);
3716+
}
3717+
37093718
void DAGTypeLegalizer::ExpandIntRes_CTPOP(SDNode *N,
37103719
SDValue &Lo, SDValue &Hi) {
37113720
SDLoc dl(N);

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
424424
void ExpandIntRes_AssertZext (SDNode *N, SDValue &Lo, SDValue &Hi);
425425
void ExpandIntRes_Constant (SDNode *N, SDValue &Lo, SDValue &Hi);
426426
void ExpandIntRes_ABS (SDNode *N, SDValue &Lo, SDValue &Hi);
427+
void ExpandIntRes_ABD (SDNode *N, SDValue &Lo, SDValue &Hi);
427428
void ExpandIntRes_CTLZ (SDNode *N, SDValue &Lo, SDValue &Hi);
428429
void ExpandIntRes_CTPOP (SDNode *N, SDValue &Lo, SDValue &Hi);
429430
void ExpandIntRes_CTTZ (SDNode *N, SDValue &Lo, SDValue &Hi);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
136136
case ISD::FMINIMUM:
137137
case ISD::FMAXIMUM:
138138
case ISD::FLDEXP:
139+
case ISD::ABDS:
140+
case ISD::ABDU:
139141
case ISD::SMIN:
140142
case ISD::SMAX:
141143
case ISD::UMIN:
@@ -1171,6 +1173,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
11711173
case ISD::MUL: case ISD::VP_MUL:
11721174
case ISD::MULHS:
11731175
case ISD::MULHU:
1176+
case ISD::ABDS:
1177+
case ISD::ABDU:
11741178
case ISD::FADD: case ISD::VP_FADD:
11751179
case ISD::FSUB: case ISD::VP_FSUB:
11761180
case ISD::FMUL: case ISD::VP_FMUL:
@@ -4231,6 +4235,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
42314235
case ISD::MUL: case ISD::VP_MUL:
42324236
case ISD::MULHS:
42334237
case ISD::MULHU:
4238+
case ISD::ABDS:
4239+
case ISD::ABDU:
42344240
case ISD::OR: case ISD::VP_OR:
42354241
case ISD::SUB: case ISD::VP_SUB:
42364242
case ISD::XOR: case ISD::VP_XOR:

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6923,6 +6923,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
69236923
assert(VT.isInteger() && "This operator does not apply to FP types!");
69246924
assert(N1.getValueType() == N2.getValueType() &&
69256925
N1.getValueType() == VT && "Binary operator types must match!");
6926+
if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
6927+
return getNode(ISD::XOR, DL, VT, N1, N2);
69266928
break;
69276929
case ISD::SMIN:
69286930
case ISD::UMAX:

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9228,6 +9228,15 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
92289228
DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS),
92299229
DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS));
92309230

9231+
// If the subtract doesn't overflow then just use abs(sub())
9232+
// NOTE: don't use frozen operands for value tracking.
9233+
if (DAG.willNotOverflowSub(IsSigned, N->getOperand(0), N->getOperand(1)))
9234+
return DAG.getNode(ISD::ABS, dl, VT,
9235+
DAG.getNode(ISD::SUB, dl, VT, LHS, RHS));
9236+
if (DAG.willNotOverflowSub(IsSigned, N->getOperand(1), N->getOperand(0)))
9237+
return DAG.getNode(ISD::ABS, dl, VT,
9238+
DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
9239+
92319240
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
92329241
ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT;
92339242
SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC);
@@ -9241,6 +9250,11 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
92419250
return DAG.getNode(ISD::SUB, dl, VT, Cmp, Xor);
92429251
}
92439252

9253+
// FIXME: Should really try to split the vector in case it's legal on a
9254+
// subvector.
9255+
if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
9256+
return DAG.UnrollVectorOp(N);
9257+
92449258
// abds(lhs, rhs) -> select(sgt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
92459259
// abdu(lhs, rhs) -> select(ugt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
92469260
return DAG.getSelect(dl, VT, Cmp, DAG.getNode(ISD::SUB, dl, VT, LHS, RHS),

llvm/test/CodeGen/AArch64/arm64-csel.ll

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ define i32@foo4(i32 %a) nounwind ssp {
6464
define i32@foo5(i32 %a, i32 %b) nounwind ssp {
6565
; CHECK-LABEL: foo5:
6666
; CHECK: // %bb.0: // %entry
67-
; CHECK-NEXT: subs w8, w0, w1
68-
; CHECK-NEXT: cneg w0, w8, mi
67+
; CHECK-NEXT: sub w8, w1, w0
68+
; CHECK-NEXT: subs w9, w0, w1
69+
; CHECK-NEXT: csel w0, w9, w8, gt
6970
; CHECK-NEXT: ret
7071
entry:
7172
%sub = sub nsw i32 %a, %b
@@ -97,12 +98,13 @@ l.else:
9798
define i32 @foo7(i32 %a, i32 %b) nounwind {
9899
; CHECK-LABEL: foo7:
99100
; CHECK: // %bb.0: // %entry
100-
; CHECK-NEXT: subs w8, w0, w1
101-
; CHECK-NEXT: cneg w9, w8, mi
102-
; CHECK-NEXT: cmn w8, #1
103-
; CHECK-NEXT: csel w10, w9, w0, lt
104-
; CHECK-NEXT: cmp w8, #0
105-
; CHECK-NEXT: csel w0, w10, w9, ge
101+
; CHECK-NEXT: sub w8, w1, w0
102+
; CHECK-NEXT: subs w9, w0, w1
103+
; CHECK-NEXT: csel w8, w9, w8, gt
104+
; CHECK-NEXT: cmn w9, #1
105+
; CHECK-NEXT: csel w10, w8, w0, lt
106+
; CHECK-NEXT: cmp w9, #0
107+
; CHECK-NEXT: csel w0, w10, w8, ge
106108
; CHECK-NEXT: ret
107109
entry:
108110
%sub = sub nsw i32 %a, %b

llvm/test/CodeGen/AArch64/arm64-vabs.ll

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,29 +1799,16 @@ define <2 x i128> @uabd_i64(<2 x i64> %a, <2 x i64> %b) {
17991799
; CHECK: // %bb.0:
18001800
; CHECK-NEXT: mov.d x8, v0[1]
18011801
; CHECK-NEXT: mov.d x9, v1[1]
1802+
; CHECK-NEXT: mov x1, xzr
18021803
; CHECK-NEXT: fmov x10, d0
18031804
; CHECK-NEXT: fmov x11, d1
1804-
; CHECK-NEXT: asr x12, x10, #63
1805-
; CHECK-NEXT: asr x13, x11, #63
1805+
; CHECK-NEXT: mov x3, xzr
1806+
; CHECK-NEXT: sub x12, x11, x10
18061807
; CHECK-NEXT: subs x10, x10, x11
1807-
; CHECK-NEXT: asr x11, x8, #63
1808-
; CHECK-NEXT: asr x14, x9, #63
1809-
; CHECK-NEXT: sbc x12, x12, x13
1808+
; CHECK-NEXT: csel x0, x10, x12, gt
1809+
; CHECK-NEXT: sub x10, x9, x8
18101810
; CHECK-NEXT: subs x8, x8, x9
1811-
; CHECK-NEXT: sbc x9, x11, x14
1812-
; CHECK-NEXT: asr x13, x12, #63
1813-
; CHECK-NEXT: asr x11, x9, #63
1814-
; CHECK-NEXT: eor x10, x10, x13
1815-
; CHECK-NEXT: eor x8, x8, x11
1816-
; CHECK-NEXT: eor x9, x9, x11
1817-
; CHECK-NEXT: subs x2, x8, x11
1818-
; CHECK-NEXT: eor x8, x12, x13
1819-
; CHECK-NEXT: sbc x3, x9, x11
1820-
; CHECK-NEXT: subs x9, x10, x13
1821-
; CHECK-NEXT: fmov d0, x9
1822-
; CHECK-NEXT: sbc x1, x8, x13
1823-
; CHECK-NEXT: mov.d v0[1], x1
1824-
; CHECK-NEXT: fmov x0, d0
1811+
; CHECK-NEXT: csel x2, x8, x10, gt
18251812
; CHECK-NEXT: ret
18261813
%aext = sext <2 x i64> %a to <2 x i128>
18271814
%bext = sext <2 x i64> %b to <2 x i128>

0 commit comments

Comments
 (0)