Skip to content

Commit 7538df9

Browse files
authored
[llvm][profdata][NFC] Support 64-bit weights in ProfDataUtils (#86607)
Since some places, like SimplifyCFG, work with 64-bit weights, we supply an API in ProfDataUtils to extract the weights accordingly. We change the API slightly to disambiguate the 64-bit version from the 32-bit version.
1 parent 05d04f0 commit 7538df9

File tree

4 files changed

+38
-24
lines changed

4 files changed

+38
-24
lines changed

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,14 @@ bool extractBranchWeights(const MDNode *ProfileData,
6565
SmallVectorImpl<uint32_t> &Weights);
6666

6767
/// Faster version of extractBranchWeights() that skips checks and must only
68-
/// be called with "branch_weights" metadata nodes.
69-
void extractFromBranchWeightMD(const MDNode *ProfileData,
70-
SmallVectorImpl<uint32_t> &Weights);
68+
/// be called with "branch_weights" metadata nodes. Supports uint32_t.
69+
void extractFromBranchWeightMD32(const MDNode *ProfileData,
70+
SmallVectorImpl<uint32_t> &Weights);
71+
72+
/// Faster version of extractBranchWeights() that skips checks and must only
73+
/// be called with "branch_weights" metadata nodes. Supports uint64_t.
74+
void extractFromBranchWeightMD64(const MDNode *ProfileData,
75+
SmallVectorImpl<uint64_t> &Weights);
7176

7277
/// Extract branch weights attatched to an Instruction
7378
///

llvm/lib/IR/ProfDataUtils.cpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,26 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
6565
return ProfDataName->getString() == Name;
6666
}
6767

68+
template <typename T,
69+
typename = typename std::enable_if<std::is_arithmetic_v<T>>>
70+
static void extractFromBranchWeightMD(const MDNode *ProfileData,
71+
SmallVectorImpl<T> &Weights) {
72+
assert(isBranchWeightMD(ProfileData) && "wrong metadata");
73+
74+
unsigned NOps = ProfileData->getNumOperands();
75+
assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
76+
Weights.resize(NOps - WeightsIdx);
77+
78+
for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
79+
ConstantInt *Weight =
80+
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
81+
assert(Weight && "Malformed branch_weight in MD_prof node");
82+
assert(Weight->getValue().getActiveBits() <= 32 &&
83+
"Too many bits for uint32_t");
84+
Weights[Idx - WeightsIdx] = Weight->getZExtValue();
85+
}
86+
}
87+
6888
} // namespace
6989

7090
namespace llvm {
@@ -100,22 +120,14 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) {
100120
return nullptr;
101121
}
102122

103-
void extractFromBranchWeightMD(const MDNode *ProfileData,
104-
SmallVectorImpl<uint32_t> &Weights) {
105-
assert(isBranchWeightMD(ProfileData) && "wrong metadata");
106-
107-
unsigned NOps = ProfileData->getNumOperands();
108-
assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
109-
Weights.resize(NOps - WeightsIdx);
123+
void extractFromBranchWeightMD32(const MDNode *ProfileData,
124+
SmallVectorImpl<uint32_t> &Weights) {
125+
extractFromBranchWeightMD(ProfileData, Weights);
126+
}
110127

111-
for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
112-
ConstantInt *Weight =
113-
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
114-
assert(Weight && "Malformed branch_weight in MD_prof node");
115-
assert(Weight->getValue().getActiveBits() <= 32 &&
116-
"Too many bits for uint32_t");
117-
Weights[Idx - WeightsIdx] = Weight->getZExtValue();
118-
}
128+
void extractFromBranchWeightMD64(const MDNode *ProfileData,
129+
SmallVectorImpl<uint64_t> &Weights) {
130+
extractFromBranchWeightMD(ProfileData, Weights);
119131
}
120132

121133
bool extractBranchWeights(const MDNode *ProfileData,

llvm/lib/Transforms/Utils/LoopRotationUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
287287
return;
288288

289289
SmallVector<uint32_t, 2> Weights;
290-
extractFromBranchWeightMD(WeightMD, Weights);
290+
extractFromBranchWeightMD32(WeightMD, Weights);
291291
if (Weights.size() != 2)
292292
return;
293293
uint32_t OrigLoopExitWeight = Weights[0];

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,11 +1066,8 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1,
10661066
static void GetBranchWeights(Instruction *TI,
10671067
SmallVectorImpl<uint64_t> &Weights) {
10681068
MDNode *MD = TI->getMetadata(LLVMContext::MD_prof);
1069-
assert(MD);
1070-
for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
1071-
ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(i));
1072-
Weights.push_back(CI->getValue().getZExtValue());
1073-
}
1069+
assert(MD && "Invalid branch-weight metadata");
1070+
extractFromBranchWeightMD64(MD, Weights);
10741071

10751072
// If TI is a conditional eq, the default case is the false case,
10761073
// and the corresponding branch-weight data is at index 2. We swap the

0 commit comments

Comments
 (0)