llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-llvm-ir Author: Paul Kirth (ilovepi) <details> <summary>Changes</summary> As suggested in https://github.com/llvm/llvm-project/pull/86609/files#r1556689262 an API for getting the number of branch weights directly from the MD node would be useful in a variety of checks, and keeps the logic within ProfDataUtils. --- Full diff: https://github.com/llvm/llvm-project/pull/90146.diff 4 Files Affected: - (modified) llvm/include/llvm/IR/ProfDataUtils.h (+2) - (modified) llvm/lib/IR/Instructions.cpp (+1-5) - (modified) llvm/lib/IR/ProfDataUtils.cpp (+5-3) - (modified) llvm/lib/IR/Verifier.cpp (+5-5) ``````````diff diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index 3c761bdc1bf3e9..7008d3240feded 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -66,6 +66,8 @@ bool hasBranchWeightProvenance(const MDNode *ProfileData); /// Return the offset to the first branch weight data unsigned getBranchWeightOffset(const MDNode *ProfileData); +unsigned getNumBranchWeights(const MDNode &ProfileData); + /// Extract branch weights from MD_prof metadata /// /// \param ProfileData A pointer to an MDNode. diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 650d32ac17fc2b..a14d6758cad1d8 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -5165,11 +5165,7 @@ void SwitchInstProfUpdateWrapper::init() { if (!ProfileData) return; - // FIXME: This check belongs in ProfDataUtils. Its almost equivalent to - // getValidBranchWeightMDNode(), but the need to use llvm_unreachable - // makes them slightly different. - if (ProfileData->getNumOperands() != - SI.getNumSuccessors() + getBranchWeightOffset(ProfileData)) { + if (getNumBranchWeights(*ProfileData) != SI.getNumSuccessors()) { llvm_unreachable("number of prof branch_weights metadata operands does " "not correspond to number of succesors"); } diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index cd219c22e3dfe6..9544ea85b93d96 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -123,6 +123,10 @@ unsigned getBranchWeightOffset(const MDNode *ProfileData) { return hasBranchWeightProvenance(ProfileData) ? 2 : 1; } +unsigned getNumBranchWeights(const MDNode &ProfileData) { + return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData); +} + MDNode *getBranchWeightMDNode(const Instruction &I) { auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); if (!isBranchWeightMD(ProfileData)) @@ -132,9 +136,7 @@ MDNode *getBranchWeightMDNode(const Instruction &I) { MDNode *getValidBranchWeightMDNode(const Instruction &I) { auto *ProfileData = getBranchWeightMDNode(I); - auto Offset = getBranchWeightOffset(ProfileData); - if (ProfileData && - ProfileData->getNumOperands() == Offset + I.getNumSuccessors()) + if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors()) return ProfileData; return nullptr; } diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 4a142be71eec41..ecccb1790ff8ff 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -4787,10 +4787,9 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) { // Check consistency of !prof branch_weights metadata. if (ProfName.equals("branch_weights")) { - unsigned int Offset = getBranchWeightOffset(I); + unsigned NumBranchWeights = getNumBranchWeights(*MD); if (isa<InvokeInst>(&I)) { - Check(MD->getNumOperands() == (1 + Offset) || - MD->getNumOperands() == (2 + Offset), + Check(NumBranchWeights == 1 || NumBranchWeights == 2, "Wrong number of InvokeInst branch_weights operands", MD); } else { unsigned ExpectedNumOperands = 0; @@ -4810,10 +4809,11 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) { CheckFailed("!prof branch_weights are not allowed for this instruction", MD); - Check(MD->getNumOperands() == Offset + ExpectedNumOperands, + Check(NumBranchWeights == ExpectedNumOperands, "Wrong number of operands", MD); } - for (unsigned i = Offset; i < MD->getNumOperands(); ++i) { + for (unsigned i = getBranchWeightOffset(MD); i < MD->getNumOperands(); + ++i) { auto &MDO = MD->getOperand(i); Check(MDO, "second operand should not be null", MD); Check(mdconst::dyn_extract<ConstantInt>(MDO), `````````` </details> https://github.com/llvm/llvm-project/pull/90146 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits