Skip to content

Commit bc0fea0

Browse files
committed
[SDAG] Allow scalable vectors in ComputeKnownBits
his is the SelectionDAG equivalent of D136470, and is thus an alternate patch to D128159. The basic idea here is that we track a single lane for scalable vectors which corresponds to an unknown number of lanes at runtime. This is enough for us to perform lane wise reasoning on many arithmetic operations. This patch also includes an implementation for SPLAT_VECTOR as without it, the lane wise reasoning has no base case. The original patch which inspired this (D128159), also included STEP_VECTOR. I plan to do that as a separate patch. Differential Revision: https://reviews.llvm.org/D137140
1 parent 2656fb3 commit bc0fea0

File tree

4 files changed

+64
-66
lines changed

4 files changed

+64
-66
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2910,14 +2910,10 @@ const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
29102910
KnownBits SelectionDAG::computeKnownBits(SDValue Op, unsigned Depth) const {
29112911
EVT VT = Op.getValueType();
29122912

2913-
// TOOD: Until we have a plan for how to represent demanded elements for
2914-
// scalable vectors, we can just bail out for now.
2915-
if (Op.getValueType().isScalableVector()) {
2916-
unsigned BitWidth = Op.getScalarValueSizeInBits();
2917-
return KnownBits(BitWidth);
2918-
}
2919-
2920-
APInt DemandedElts = VT.isVector()
2913+
// Since the number of lanes in a scalable vector is unknown at compile time,
2914+
// we track one bit which is implicitly broadcast to all lanes. This means
2915+
// that all lanes in a scalable vector are considered demanded.
2916+
APInt DemandedElts = VT.isFixedLengthVector()
29212917
? APInt::getAllOnes(VT.getVectorNumElements())
29222918
: APInt(1, 1);
29232919
return computeKnownBits(Op, DemandedElts, Depth);
@@ -2932,11 +2928,6 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
29322928

29332929
KnownBits Known(BitWidth); // Don't know anything.
29342930

2935-
// TOOD: Until we have a plan for how to represent demanded elements for
2936-
// scalable vectors, we can just bail out for now.
2937-
if (Op.getValueType().isScalableVector())
2938-
return Known;
2939-
29402931
if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
29412932
// We know all of the bits for a constant!
29422933
return KnownBits::makeConstant(C->getAPIntValue());
@@ -2951,7 +2942,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
29512942

29522943
KnownBits Known2;
29532944
unsigned NumElts = DemandedElts.getBitWidth();
2954-
assert((!Op.getValueType().isVector() ||
2945+
assert((!Op.getValueType().isFixedLengthVector() ||
29552946
NumElts == Op.getValueType().getVectorNumElements()) &&
29562947
"Unexpected vector size");
29572948

@@ -2963,7 +2954,18 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
29632954
case ISD::MERGE_VALUES:
29642955
return computeKnownBits(Op.getOperand(Op.getResNo()), DemandedElts,
29652956
Depth + 1);
2957+
case ISD::SPLAT_VECTOR: {
2958+
SDValue SrcOp = Op.getOperand(0);
2959+
Known = computeKnownBits(SrcOp, Depth + 1);
2960+
if (SrcOp.getValueSizeInBits() != BitWidth) {
2961+
assert(SrcOp.getValueSizeInBits() > BitWidth &&
2962+
"Expected SPLAT_VECTOR implicit truncation");
2963+
Known = Known.trunc(BitWidth);
2964+
}
2965+
break;
2966+
}
29662967
case ISD::BUILD_VECTOR:
2968+
assert(!Op.getValueType().isScalableVector());
29672969
// Collect the known bits that are shared by every demanded vector element.
29682970
Known.Zero.setAllBits(); Known.One.setAllBits();
29692971
for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
@@ -2989,6 +2991,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
29892991
}
29902992
break;
29912993
case ISD::VECTOR_SHUFFLE: {
2994+
assert(!Op.getValueType().isScalableVector());
29922995
// Collect the known bits that are shared by every vector element referenced
29932996
// by the shuffle.
29942997
APInt DemandedLHS, DemandedRHS;
@@ -3016,6 +3019,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
30163019
break;
30173020
}
30183021
case ISD::CONCAT_VECTORS: {
3022+
if (Op.getValueType().isScalableVector())
3023+
break;
30193024
// Split DemandedElts and test each of the demanded subvectors.
30203025
Known.Zero.setAllBits(); Known.One.setAllBits();
30213026
EVT SubVectorVT = Op.getOperand(0).getValueType();
@@ -3036,6 +3041,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
30363041
break;
30373042
}
30383043
case ISD::INSERT_SUBVECTOR: {
3044+
if (Op.getValueType().isScalableVector())
3045+
break;
30393046
// Demand any elements from the subvector and the remainder from the src its
30403047
// inserted into.
30413048
SDValue Src = Op.getOperand(0);
@@ -3063,7 +3070,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
30633070
// Offset the demanded elts by the subvector index.
30643071
SDValue Src = Op.getOperand(0);
30653072
// Bail until we can represent demanded elements for scalable vectors.
3066-
if (Src.getValueType().isScalableVector())
3073+
if (Op.getValueType().isScalableVector() || Src.getValueType().isScalableVector())
30673074
break;
30683075
uint64_t Idx = Op.getConstantOperandVal(1);
30693076
unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
@@ -3072,6 +3079,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
30723079
break;
30733080
}
30743081
case ISD::SCALAR_TO_VECTOR: {
3082+
if (Op.getValueType().isScalableVector())
3083+
break;
30753084
// We know about scalar_to_vector as much as we know about it source,
30763085
// which becomes the first element of otherwise unknown vector.
30773086
if (DemandedElts != 1)
@@ -3085,6 +3094,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
30853094
break;
30863095
}
30873096
case ISD::BITCAST: {
3097+
if (Op.getValueType().isScalableVector())
3098+
break;
3099+
30883100
SDValue N0 = Op.getOperand(0);
30893101
EVT SubVT = N0.getValueType();
30903102
unsigned SubBitWidth = SubVT.getScalarSizeInBits();
@@ -3406,7 +3418,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
34063418
if (ISD::isNON_EXTLoad(LD) && Cst) {
34073419
// Determine any common known bits from the loaded constant pool value.
34083420
Type *CstTy = Cst->getType();
3409-
if ((NumElts * BitWidth) == CstTy->getPrimitiveSizeInBits()) {
3421+
if ((NumElts * BitWidth) == CstTy->getPrimitiveSizeInBits() &&
3422+
!Op.getValueType().isScalableVector()) {
34103423
// If its a vector splat, then we can (quickly) reuse the scalar path.
34113424
// NOTE: We assume all elements match and none are UNDEF.
34123425
if (CstTy->isVectorTy()) {
@@ -3480,6 +3493,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
34803493
break;
34813494
}
34823495
case ISD::ZERO_EXTEND_VECTOR_INREG: {
3496+
if (Op.getValueType().isScalableVector())
3497+
break;
34833498
EVT InVT = Op.getOperand(0).getValueType();
34843499
APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
34853500
Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
@@ -3492,6 +3507,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
34923507
break;
34933508
}
34943509
case ISD::SIGN_EXTEND_VECTOR_INREG: {
3510+
if (Op.getValueType().isScalableVector())
3511+
break;
34953512
EVT InVT = Op.getOperand(0).getValueType();
34963513
APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
34973514
Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
@@ -3508,6 +3525,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
35083525
break;
35093526
}
35103527
case ISD::ANY_EXTEND_VECTOR_INREG: {
3528+
if (Op.getValueType().isScalableVector())
3529+
break;
35113530
EVT InVT = Op.getOperand(0).getValueType();
35123531
APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
35133532
Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
@@ -3673,6 +3692,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
36733692
break;
36743693
}
36753694
case ISD::INSERT_VECTOR_ELT: {
3695+
if (Op.getValueType().isScalableVector())
3696+
break;
3697+
36763698
// If we know the element index, split the demand between the
36773699
// source vector and the inserted element, otherwise assume we need
36783700
// the original demanded vector elements and the value.
@@ -3839,6 +3861,11 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
38393861
case ISD::INTRINSIC_WO_CHAIN:
38403862
case ISD::INTRINSIC_W_CHAIN:
38413863
case ISD::INTRINSIC_VOID:
3864+
// TODO: Probably okay to remove after audit; here to reduce change size
3865+
// in initial enablement patch for scalable vectors
3866+
if (Op.getValueType().isScalableVector())
3867+
break;
3868+
38423869
// Allow the target to implement this method for its nodes.
38433870
TLI->computeKnownBitsForTargetNode(Op, Known, DemandedElts, *this, Depth);
38443871
break;

llvm/test/CodeGen/AArch64/sve-intrinsics-index.ll

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ define <vscale x 2 x i64> @index_ii_range() {
5555
define <vscale x 8 x i16> @index_ii_range_combine(i16 %a) {
5656
; CHECK-LABEL: index_ii_range_combine:
5757
; CHECK: // %bb.0:
58-
; CHECK-NEXT: index z0.h, #2, #8
58+
; CHECK-NEXT: index z0.h, #0, #8
59+
; CHECK-NEXT: orr z0.h, z0.h, #0x2
5960
; CHECK-NEXT: ret
6061
%val = insertelement <vscale x 8 x i16> poison, i16 2, i32 0
6162
%val1 = shufflevector <vscale x 8 x i16> %val, <vscale x 8 x i16> poison, <vscale x 8 x i32> zeroinitializer

llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ define <vscale x 2 x i64> @dupq_i64_range(<vscale x 2 x i64> %a) {
574574
; CHECK: // %bb.0:
575575
; CHECK-NEXT: index z1.d, #0, #1
576576
; CHECK-NEXT: and z1.d, z1.d, #0x1
577-
; CHECK-NEXT: add z1.d, z1.d, #8 // =0x8
577+
; CHECK-NEXT: orr z1.d, z1.d, #0x8
578578
; CHECK-NEXT: tbl z0.d, { z0.d }, z1.d
579579
; CHECK-NEXT: ret
580580
%out = call <vscale x 2 x i64> @llvm.aarch64.sve.dupq.lane.nxv2i64(<vscale x 2 x i64> %a, i64 4)

llvm/test/CodeGen/AArch64/sve-umulo-sdnode.ll

Lines changed: 18 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,10 @@ define <vscale x 2 x i8> @umulo_nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i8> %
99
; CHECK-NEXT: ptrue p0.d
1010
; CHECK-NEXT: and z1.d, z1.d, #0xff
1111
; CHECK-NEXT: and z0.d, z0.d, #0xff
12-
; CHECK-NEXT: movprfx z2, z0
13-
; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d
14-
; CHECK-NEXT: umulh z0.d, p0/m, z0.d, z1.d
15-
; CHECK-NEXT: lsr z1.d, z2.d, #8
16-
; CHECK-NEXT: cmpne p1.d, p0/z, z0.d, #0
12+
; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d
13+
; CHECK-NEXT: lsr z1.d, z0.d, #8
1714
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
18-
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
19-
; CHECK-NEXT: mov z2.d, p0/m, #0 // =0x0
20-
; CHECK-NEXT: mov z0.d, z2.d
15+
; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0
2116
; CHECK-NEXT: ret
2217
%a = call { <vscale x 2 x i8>, <vscale x 2 x i1> } @llvm.umul.with.overflow.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y)
2318
%b = extractvalue { <vscale x 2 x i8>, <vscale x 2 x i1> } %a, 0
@@ -34,15 +29,10 @@ define <vscale x 4 x i8> @umulo_nxv4i8(<vscale x 4 x i8> %x, <vscale x 4 x i8> %
3429
; CHECK-NEXT: ptrue p0.s
3530
; CHECK-NEXT: and z1.s, z1.s, #0xff
3631
; CHECK-NEXT: and z0.s, z0.s, #0xff
37-
; CHECK-NEXT: movprfx z2, z0
38-
; CHECK-NEXT: mul z2.s, p0/m, z2.s, z1.s
39-
; CHECK-NEXT: umulh z0.s, p0/m, z0.s, z1.s
40-
; CHECK-NEXT: lsr z1.s, z2.s, #8
41-
; CHECK-NEXT: cmpne p1.s, p0/z, z0.s, #0
32+
; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s
33+
; CHECK-NEXT: lsr z1.s, z0.s, #8
4234
; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, #0
43-
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
44-
; CHECK-NEXT: mov z2.s, p0/m, #0 // =0x0
45-
; CHECK-NEXT: mov z0.d, z2.d
35+
; CHECK-NEXT: mov z0.s, p0/m, #0 // =0x0
4636
; CHECK-NEXT: ret
4737
%a = call { <vscale x 4 x i8>, <vscale x 4 x i1> } @llvm.umul.with.overflow.nxv4i8(<vscale x 4 x i8> %x, <vscale x 4 x i8> %y)
4838
%b = extractvalue { <vscale x 4 x i8>, <vscale x 4 x i1> } %a, 0
@@ -59,15 +49,10 @@ define <vscale x 8 x i8> @umulo_nxv8i8(<vscale x 8 x i8> %x, <vscale x 8 x i8> %
5949
; CHECK-NEXT: ptrue p0.h
6050
; CHECK-NEXT: and z1.h, z1.h, #0xff
6151
; CHECK-NEXT: and z0.h, z0.h, #0xff
62-
; CHECK-NEXT: movprfx z2, z0
63-
; CHECK-NEXT: mul z2.h, p0/m, z2.h, z1.h
64-
; CHECK-NEXT: umulh z0.h, p0/m, z0.h, z1.h
65-
; CHECK-NEXT: lsr z1.h, z2.h, #8
66-
; CHECK-NEXT: cmpne p1.h, p0/z, z0.h, #0
52+
; CHECK-NEXT: mul z0.h, p0/m, z0.h, z1.h
53+
; CHECK-NEXT: lsr z1.h, z0.h, #8
6754
; CHECK-NEXT: cmpne p0.h, p0/z, z1.h, #0
68-
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
69-
; CHECK-NEXT: mov z2.h, p0/m, #0 // =0x0
70-
; CHECK-NEXT: mov z0.d, z2.d
55+
; CHECK-NEXT: mov z0.h, p0/m, #0 // =0x0
7156
; CHECK-NEXT: ret
7257
%a = call { <vscale x 8 x i8>, <vscale x 8 x i1> } @llvm.umul.with.overflow.nxv8i8(<vscale x 8 x i8> %x, <vscale x 8 x i8> %y)
7358
%b = extractvalue { <vscale x 8 x i8>, <vscale x 8 x i1> } %a, 0
@@ -164,15 +149,10 @@ define <vscale x 2 x i16> @umulo_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1
164149
; CHECK-NEXT: ptrue p0.d
165150
; CHECK-NEXT: and z1.d, z1.d, #0xffff
166151
; CHECK-NEXT: and z0.d, z0.d, #0xffff
167-
; CHECK-NEXT: movprfx z2, z0
168-
; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d
169-
; CHECK-NEXT: umulh z0.d, p0/m, z0.d, z1.d
170-
; CHECK-NEXT: lsr z1.d, z2.d, #16
171-
; CHECK-NEXT: cmpne p1.d, p0/z, z0.d, #0
152+
; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d
153+
; CHECK-NEXT: lsr z1.d, z0.d, #16
172154
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
173-
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
174-
; CHECK-NEXT: mov z2.d, p0/m, #0 // =0x0
175-
; CHECK-NEXT: mov z0.d, z2.d
155+
; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0
176156
; CHECK-NEXT: ret
177157
%a = call { <vscale x 2 x i16>, <vscale x 2 x i1> } @llvm.umul.with.overflow.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y)
178158
%b = extractvalue { <vscale x 2 x i16>, <vscale x 2 x i1> } %a, 0
@@ -189,15 +169,10 @@ define <vscale x 4 x i16> @umulo_nxv4i16(<vscale x 4 x i16> %x, <vscale x 4 x i1
189169
; CHECK-NEXT: ptrue p0.s
190170
; CHECK-NEXT: and z1.s, z1.s, #0xffff
191171
; CHECK-NEXT: and z0.s, z0.s, #0xffff
192-
; CHECK-NEXT: movprfx z2, z0
193-
; CHECK-NEXT: mul z2.s, p0/m, z2.s, z1.s
194-
; CHECK-NEXT: umulh z0.s, p0/m, z0.s, z1.s
195-
; CHECK-NEXT: lsr z1.s, z2.s, #16
196-
; CHECK-NEXT: cmpne p1.s, p0/z, z0.s, #0
172+
; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s
173+
; CHECK-NEXT: lsr z1.s, z0.s, #16
197174
; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, #0
198-
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
199-
; CHECK-NEXT: mov z2.s, p0/m, #0 // =0x0
200-
; CHECK-NEXT: mov z0.d, z2.d
175+
; CHECK-NEXT: mov z0.s, p0/m, #0 // =0x0
201176
; CHECK-NEXT: ret
202177
%a = call { <vscale x 4 x i16>, <vscale x 4 x i1> } @llvm.umul.with.overflow.nxv4i16(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y)
203178
%b = extractvalue { <vscale x 4 x i16>, <vscale x 4 x i1> } %a, 0
@@ -294,15 +269,10 @@ define <vscale x 2 x i32> @umulo_nxv2i32(<vscale x 2 x i32> %x, <vscale x 2 x i3
294269
; CHECK-NEXT: ptrue p0.d
295270
; CHECK-NEXT: and z1.d, z1.d, #0xffffffff
296271
; CHECK-NEXT: and z0.d, z0.d, #0xffffffff
297-
; CHECK-NEXT: movprfx z2, z0
298-
; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d
299-
; CHECK-NEXT: umulh z0.d, p0/m, z0.d, z1.d
300-
; CHECK-NEXT: lsr z1.d, z2.d, #32
301-
; CHECK-NEXT: cmpne p1.d, p0/z, z0.d, #0
272+
; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d
273+
; CHECK-NEXT: lsr z1.d, z0.d, #32
302274
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
303-
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
304-
; CHECK-NEXT: mov z2.d, p0/m, #0 // =0x0
305-
; CHECK-NEXT: mov z0.d, z2.d
275+
; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0
306276
; CHECK-NEXT: ret
307277
%a = call { <vscale x 2 x i32>, <vscale x 2 x i1> } @llvm.umul.with.overflow.nxv2i32(<vscale x 2 x i32> %x, <vscale x 2 x i32> %y)
308278
%b = extractvalue { <vscale x 2 x i32>, <vscale x 2 x i1> } %a, 0

0 commit comments

Comments
 (0)