Skip to content

Commit a3a44bf

Browse files
authored
[llvm][ProfDataUtils] Provide getNumBranchWeights API (#90146)
As suggested in https://github.com/llvm/llvm-project/pull/86609/files#r1556689262 an API for getting the number of branch weights directly from the MD node would be useful in a variety of checks, and keeps the logic within ProfDataUtils.
1 parent 75ac887 commit a3a44bf

File tree

4 files changed

+14
-14
lines changed

4 files changed

+14
-14
lines changed

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ bool hasBranchWeightOrigin(const MDNode *ProfileData);
6666
/// Return the offset to the first branch weight data
6767
unsigned getBranchWeightOffset(const MDNode *ProfileData);
6868

69+
unsigned getNumBranchWeights(const MDNode &ProfileData);
70+
6971
/// Extract branch weights from MD_prof metadata
7072
///
7173
/// \param ProfileData A pointer to an MDNode.

llvm/lib/IR/Instructions.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4002,11 +4002,7 @@ void SwitchInstProfUpdateWrapper::init() {
40024002
if (!ProfileData)
40034003
return;
40044004

4005-
// FIXME: This check belongs in ProfDataUtils. Its almost equivalent to
4006-
// getValidBranchWeightMDNode(), but the need to use llvm_unreachable
4007-
// makes them slightly different.
4008-
if (ProfileData->getNumOperands() !=
4009-
SI.getNumSuccessors() + getBranchWeightOffset(ProfileData)) {
4005+
if (getNumBranchWeights(*ProfileData) != SI.getNumSuccessors()) {
40104006
llvm_unreachable("number of prof branch_weights metadata operands does "
40114007
"not correspond to number of succesors");
40124008
}

llvm/lib/IR/ProfDataUtils.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ unsigned getBranchWeightOffset(const MDNode *ProfileData) {
142142
return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
143143
}
144144

145+
unsigned getNumBranchWeights(const MDNode &ProfileData) {
146+
return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData);
147+
}
148+
145149
MDNode *getBranchWeightMDNode(const Instruction &I) {
146150
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
147151
if (!isBranchWeightMD(ProfileData))
@@ -151,9 +155,7 @@ MDNode *getBranchWeightMDNode(const Instruction &I) {
151155

152156
MDNode *getValidBranchWeightMDNode(const Instruction &I) {
153157
auto *ProfileData = getBranchWeightMDNode(I);
154-
auto Offset = getBranchWeightOffset(ProfileData);
155-
if (ProfileData &&
156-
ProfileData->getNumOperands() == Offset + I.getNumSuccessors())
158+
if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
157159
return ProfileData;
158160
return nullptr;
159161
}

llvm/lib/IR/Verifier.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4818,10 +4818,9 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
48184818

48194819
// Check consistency of !prof branch_weights metadata.
48204820
if (ProfName == "branch_weights") {
4821-
unsigned int Offset = getBranchWeightOffset(MD);
4821+
unsigned NumBranchWeights = getNumBranchWeights(*MD);
48224822
if (isa<InvokeInst>(&I)) {
4823-
Check(MD->getNumOperands() == (1 + Offset) ||
4824-
MD->getNumOperands() == (2 + Offset),
4823+
Check(NumBranchWeights == 1 || NumBranchWeights == 2,
48254824
"Wrong number of InvokeInst branch_weights operands", MD);
48264825
} else {
48274826
unsigned ExpectedNumOperands = 0;
@@ -4841,10 +4840,11 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
48414840
CheckFailed("!prof branch_weights are not allowed for this instruction",
48424841
MD);
48434842

4844-
Check(MD->getNumOperands() == Offset + ExpectedNumOperands,
4845-
"Wrong number of operands", MD);
4843+
Check(NumBranchWeights == ExpectedNumOperands, "Wrong number of operands",
4844+
MD);
48464845
}
4847-
for (unsigned i = Offset; i < MD->getNumOperands(); ++i) {
4846+
for (unsigned i = getBranchWeightOffset(MD); i < MD->getNumOperands();
4847+
++i) {
48484848
auto &MDO = MD->getOperand(i);
48494849
Check(MDO, "second operand should not be null", MD);
48504850
Check(mdconst::dyn_extract<ConstantInt>(MDO),

0 commit comments

Comments
 (0)