Skip to content

Commit e0e8f96

Browse files
committed
[llvm][profdata][NFC] Support 64-bit weights in ProfDataUtils
Since some places, like SimplifyCFG work with 64-bit weights, we supply an API in ProfDataUtils to extract the weights accordingly. We change the API slightly to disambiguate the 64 bit version from the 32 bit version. Pull Request: llvm#86607
1 parent 350bda4 commit e0e8f96

File tree

4 files changed

+38
-24
lines changed

4 files changed

+38
-24
lines changed

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,14 @@ bool extractBranchWeights(const MDNode *ProfileData,
6565
SmallVectorImpl<uint32_t> &Weights);
6666

6767
/// Faster version of extractBranchWeights() that skips checks and must only
68-
/// be called with "branch_weights" metadata nodes.
69-
void extractFromBranchWeightMD(const MDNode *ProfileData,
70-
SmallVectorImpl<uint32_t> &Weights);
68+
/// be called with "branch_weights" metadata nodes. Supports uint32_t.
69+
void extractFromBranchWeightMD32(const MDNode *ProfileData,
70+
SmallVectorImpl<uint32_t> &Weights);
71+
72+
/// Faster version of extractBranchWeights() that skips checks and must only
73+
/// be called with "branch_weights" metadata nodes. Supports uint64_t.
74+
void extractFromBranchWeightMD64(const MDNode *ProfileData,
75+
SmallVectorImpl<uint64_t> &Weights);
7176

7277
/// Extract branch weights attatched to an Instruction
7378
///

llvm/lib/IR/ProfDataUtils.cpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,26 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
6565
return ProfDataName->getString().equals(Name);
6666
}
6767

68+
template <typename T,
69+
typename = typename std::enable_if<std::is_arithmetic_v<T>>>
70+
static void extractFromBranchWeightMD(const MDNode *ProfileData,
71+
SmallVectorImpl<T> &Weights) {
72+
assert(isBranchWeightMD(ProfileData) && "wrong metadata");
73+
74+
unsigned NOps = ProfileData->getNumOperands();
75+
assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
76+
Weights.resize(NOps - WeightsIdx);
77+
78+
for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
79+
ConstantInt *Weight =
80+
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
81+
assert(Weight && "Malformed branch_weight in MD_prof node");
82+
assert(Weight->getValue().getActiveBits() <= 32 &&
83+
"Too many bits for uint32_t");
84+
Weights[Idx - WeightsIdx] = Weight->getZExtValue();
85+
}
86+
}
87+
6888
} // namespace
6989

7090
namespace llvm {
@@ -100,22 +120,14 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) {
100120
return nullptr;
101121
}
102122

103-
void extractFromBranchWeightMD(const MDNode *ProfileData,
104-
SmallVectorImpl<uint32_t> &Weights) {
105-
assert(isBranchWeightMD(ProfileData) && "wrong metadata");
106-
107-
unsigned NOps = ProfileData->getNumOperands();
108-
assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
109-
Weights.resize(NOps - WeightsIdx);
123+
void extractFromBranchWeightMD32(const MDNode *ProfileData,
124+
SmallVectorImpl<uint32_t> &Weights) {
125+
extractFromBranchWeightMD(ProfileData, Weights);
126+
}
110127

111-
for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
112-
ConstantInt *Weight =
113-
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
114-
assert(Weight && "Malformed branch_weight in MD_prof node");
115-
assert(Weight->getValue().getActiveBits() <= 32 &&
116-
"Too many bits for uint32_t");
117-
Weights[Idx - WeightsIdx] = Weight->getZExtValue();
118-
}
128+
void extractFromBranchWeightMD64(const MDNode *ProfileData,
129+
SmallVectorImpl<uint64_t> &Weights) {
130+
extractFromBranchWeightMD(ProfileData, Weights);
119131
}
120132

121133
bool extractBranchWeights(const MDNode *ProfileData,

llvm/lib/Transforms/Utils/LoopRotationUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
287287
return;
288288

289289
SmallVector<uint32_t, 2> Weights;
290-
extractFromBranchWeightMD(WeightMD, Weights);
290+
extractFromBranchWeightMD32(WeightMD, Weights);
291291
if (Weights.size() != 2)
292292
return;
293293
uint32_t OrigLoopExitWeight = Weights[0];

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,11 +1065,8 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1,
10651065
static void GetBranchWeights(Instruction *TI,
10661066
SmallVectorImpl<uint64_t> &Weights) {
10671067
MDNode *MD = TI->getMetadata(LLVMContext::MD_prof);
1068-
assert(MD);
1069-
for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
1070-
ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(i));
1071-
Weights.push_back(CI->getValue().getZExtValue());
1072-
}
1068+
assert(MD && "Invalid branch-weight metadata");
1069+
extractFromBranchWeightMD64(MD, Weights);
10731070

10741071
// If TI is a conditional eq, the default case is the false case,
10751072
// and the corresponding branch-weight data is at index 2. We swap the

0 commit comments

Comments
 (0)