paulkirth updated this revision to Diff 464332.
paulkirth added a comment.

rebase and change implementation to include provenance information directly in 
the MD_prof metadata


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D131306/new/

https://reviews.llvm.org/D131306

Files:
  clang/test/CodeGenCXX/attr-likelihood-if-branch-weights.cpp
  clang/test/CodeGenCXX/attr-likelihood-if-vs-builtin-expect.cpp
  llvm/include/llvm/IR/FixedMetadataKinds.def
  llvm/include/llvm/IR/MDBuilder.h
  llvm/include/llvm/IR/ProfDataUtils.h
  llvm/lib/Analysis/BranchProbabilityInfo.cpp
  llvm/lib/CodeGen/CodeGenPrepare.cpp
  llvm/lib/IR/Instruction.cpp
  llvm/lib/IR/Instructions.cpp
  llvm/lib/IR/MDBuilder.cpp
  llvm/lib/IR/ProfDataUtils.cpp
  llvm/lib/IR/Verifier.cpp
  llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
  llvm/lib/Transforms/Scalar/LoopPredication.cpp
  llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
  llvm/lib/Transforms/Utils/Local.cpp
  llvm/lib/Transforms/Utils/MisExpect.cpp
  llvm/lib/Transforms/Utils/SimplifyCFG.cpp
  llvm/test/Transforms/LowerExpectIntrinsic/basic.ll
  llvm/test/Transforms/LowerExpectIntrinsic/expect-with-probability.ll
  llvm/test/Transforms/LowerExpectIntrinsic/expect_nonboolean.ll
  llvm/test/Transforms/LowerExpectIntrinsic/phi_merge.ll
  llvm/test/Transforms/LowerExpectIntrinsic/phi_or.ll
  llvm/test/Transforms/LowerExpectIntrinsic/phi_tern.ll

Index: llvm/test/Transforms/LowerExpectIntrinsic/phi_tern.ll
===================================================================
--- llvm/test/Transforms/LowerExpectIntrinsic/phi_tern.ll
+++ llvm/test/Transforms/LowerExpectIntrinsic/phi_tern.ll
@@ -53,4 +53,4 @@
 
 !0 = !{!"clang version 5.0.0 (trunk 302965)"}
 
-; CHECK: [[WEIGHT]] = !{!"branch_weights", i32 1, i32 2000}
+; CHECK: [[WEIGHT]] = !{!"branch_weights", !"expected", i32 1, i32 2000}
Index: llvm/test/Transforms/LowerExpectIntrinsic/phi_or.ll
===================================================================
--- llvm/test/Transforms/LowerExpectIntrinsic/phi_or.ll
+++ llvm/test/Transforms/LowerExpectIntrinsic/phi_or.ll
@@ -99,5 +99,5 @@
 
 
 !0 = !{!"clang version 5.0.0 (trunk 302965)"}
-; CHECK: [[WEIGHT]] = !{!"branch_weights", i32 2000, i32 1}
-; CHECK: [[WEIGHT2]] = !{!"branch_weights", i32 1, i32 2000}
+; CHECK: [[WEIGHT]] = !{!"branch_weights", !"expected", i32 2000, i32 1}
+; CHECK: [[WEIGHT2]] = !{!"branch_weights", !"expected", i32 1, i32 2000}
Index: llvm/test/Transforms/LowerExpectIntrinsic/phi_merge.ll
===================================================================
--- llvm/test/Transforms/LowerExpectIntrinsic/phi_merge.ll
+++ llvm/test/Transforms/LowerExpectIntrinsic/phi_merge.ll
@@ -352,5 +352,5 @@
 !llvm.ident = !{!0}
 
 !0 = !{!"clang version 5.0.0 (trunk 302965)"}
-; CHECK: [[WEIGHT]] = !{!"branch_weights", i32 2000, i32 1}
-; CHECK: [[WEIGHT2]] = !{!"branch_weights", i32 1, i32 2000}
+; CHECK: [[WEIGHT]] = !{!"branch_weights", !"expected", i32 2000, i32 1}
+; CHECK: [[WEIGHT2]] = !{!"branch_weights", !"expected", i32 1, i32 2000}
Index: llvm/test/Transforms/LowerExpectIntrinsic/expect_nonboolean.ll
===================================================================
--- llvm/test/Transforms/LowerExpectIntrinsic/expect_nonboolean.ll
+++ llvm/test/Transforms/LowerExpectIntrinsic/expect_nonboolean.ll
@@ -99,6 +99,6 @@
 
 !0 = !{i32 1, !"wchar_size", i32 4}
 !1 = !{!"clang version 5.0.0 (trunk 304373)"}
-; CHECK: [[LIKELY]] = !{!"branch_weights", i32 2000, i32 1}
-; CHECK: [[UNLIKELY]] = !{!"branch_weights", i32 1, i32 2000}
+; CHECK: [[LIKELY]] = !{!"branch_weights", !"expected", i32 2000, i32 1}
+; CHECK: [[UNLIKELY]] = !{!"branch_weights", !"expected", i32 1, i32 2000}
 
Index: llvm/test/Transforms/LowerExpectIntrinsic/expect-with-probability.ll
===================================================================
--- llvm/test/Transforms/LowerExpectIntrinsic/expect-with-probability.ll
+++ llvm/test/Transforms/LowerExpectIntrinsic/expect-with-probability.ll
@@ -285,7 +285,7 @@
 
 declare i1 @llvm.expect.with.probability.i1(i1, i1, double) nounwind readnone
 
-; CHECK: !0 = !{!"branch_weights", i32 1717986918, i32 429496731}
-; CHECK: !1 = !{!"branch_weights", i32 429496731, i32 1717986918}
-; CHECK: !2 = !{!"branch_weights", i32 214748366, i32 214748366, i32 1717986918}
-; CHECK: !3 = !{!"branch_weights", i32 1717986918, i32 214748366, i32 214748366}
+; CHECK: !0 = !{!"branch_weights", !"expected", i32 1717986918, i32 429496731}
+; CHECK: !1 = !{!"branch_weights", !"expected", i32 429496731, i32 1717986918}
+; CHECK: !2 = !{!"branch_weights", !"expected", i32 214748366, i32 214748366, i32 1717986918}
+; CHECK: !3 = !{!"branch_weights", !"expected", i32 1717986918, i32 214748366, i32 214748366}
Index: llvm/test/Transforms/LowerExpectIntrinsic/basic.ll
===================================================================
--- llvm/test/Transforms/LowerExpectIntrinsic/basic.ll
+++ llvm/test/Transforms/LowerExpectIntrinsic/basic.ll
@@ -285,7 +285,7 @@
 
 declare i1 @llvm.expect.i1(i1, i1) nounwind readnone
 
-; CHECK: !0 = !{!"branch_weights", i32 2000, i32 1}
-; CHECK: !1 = !{!"branch_weights", i32 1, i32 2000}
-; CHECK: !2 = !{!"branch_weights", i32 1, i32 1, i32 2000}
-; CHECK: !3 = !{!"branch_weights", i32 2000, i32 1, i32 1}
+; CHECK: !0 = !{!"branch_weights", !"expected", i32 2000, i32 1}
+; CHECK: !1 = !{!"branch_weights", !"expected", i32 1, i32 2000}
+; CHECK: !2 = !{!"branch_weights", !"expected", i32 1, i32 1, i32 2000}
+; CHECK: !3 = !{!"branch_weights", !"expected", i32 2000, i32 1, i32 1}
Index: llvm/lib/Transforms/Utils/SimplifyCFG.cpp
===================================================================
--- llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1065,7 +1065,10 @@
                              SmallVectorImpl<uint64_t> &Weights) {
   MDNode *MD = TI->getMetadata(LLVMContext::MD_prof);
   assert(MD);
-  for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
+  // TODO: this should just use extractBranchWeights(), but a lot of this code
+  // assumes uint64_t
+  auto Offset = getBranchWeightOffset(MD);
+  for (unsigned i = Offset, e = MD->getNumOperands(); i < e; ++i) {
     ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(i));
     Weights.push_back(CI->getValue().getZExtValue());
   }
Index: llvm/lib/Transforms/Utils/MisExpect.cpp
===================================================================
--- llvm/lib/Transforms/Utils/MisExpect.cpp
+++ llvm/lib/Transforms/Utils/MisExpect.cpp
@@ -59,9 +59,10 @@
     cl::desc("Use this option to turn on/off "
              "warnings about incorrect usage of llvm.expect intrinsics."));
 
+// Command line option for setting the diagnostic tolerance threshold
 static cl::opt<uint32_t> MisExpectTolerance(
     "misexpect-tolerance", cl::init(0),
-    cl::desc("Prevents emiting diagnostics when profile counts are "
+    cl::desc("Prevents emitting diagnostics when profile counts are "
              "within N% of the threshold.."));
 
 } // namespace llvm
@@ -91,7 +92,6 @@
   // improve diagnostic output, such as the caret. If the same problem exists
   // for branch instructions, then we should remove this function and directly
   // use the instruction
-  //
   else if (auto *S = dyn_cast<SwitchInst>(I)) {
     Ret = dyn_cast<Instruction>(S->getCondition());
   }
@@ -150,15 +150,14 @@
   uint64_t TotalBranchWeight =
       LikelyBranchWeight + (UnlikelyBranchWeight * NumUnlikelyTargets);
 
-  // FIXME: When we've addressed sample profiling, restore the assertion
-  //
-  // We cannot calculate branch probability if either of these invariants aren't
-  // met. However, MisExpect diagnostics should not prevent code from compiling,
-  // so we simply forgo emitting diagnostics here, and return early.
-  // assert((TotalBranchWeight >= LikelyBranchWeight) && (TotalBranchWeight > 0)
-  //              && "TotalBranchWeight is less than the Likely branch weight");
-  if ((TotalBranchWeight == 0) || (TotalBranchWeight <= LikelyBranchWeight))
-    return;
+  // Failing this assert means that we've either got corrupted metadata, that
+  // we're checking stale or manually inserted branch weights, or that branch
+  // weights are being added multiple times, as is the case for SampleProfiling
+  // under ThinLTO. We gate all known entry paths to verifyMisExpect() by first
+  // checking for the presence of the "expected" tag in the metadata, which is
+  // *only* added in the LowerExpectIntrinsic Pass, avoiding a false positive.
+  assert((TotalBranchWeight >= LikelyBranchWeight) && (TotalBranchWeight > 0) &&
+         "TotalBranchWeight is less than the Likely branch weight");
 
   // To determine our threshold value we need to obtain the branch probability
   // for the weights added by llvm.expect and use that proportion to calculate
@@ -185,6 +184,8 @@
 
 void checkBackendInstrumentation(Instruction &I,
                                  const ArrayRef<uint32_t> RealWeights) {
+  if (!hasExpectedProvenance(I))
+    return;
   SmallVector<uint32_t> ExpectedWeights;
   if (!extractBranchWeights(I, ExpectedWeights))
     return;
Index: llvm/lib/Transforms/Utils/Local.cpp
===================================================================
--- llvm/lib/Transforms/Utils/Local.cpp
+++ llvm/lib/Transforms/Utils/Local.cpp
@@ -63,6 +63,7 @@
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -212,25 +213,21 @@
       if (i->getCaseSuccessor() == DefaultDest) {
         MDNode *MD = SI->getMetadata(LLVMContext::MD_prof);
         unsigned NCases = SI->getNumCases();
+        // Collect branch weights into a vector.
+        SmallVector<uint32_t> Weights;
         // Fold the case metadata into the default if there will be any branches
         // left, unless the metadata doesn't match the switch.
-        if (NCases > 1 && MD && MD->getNumOperands() == 2 + NCases) {
-          // Collect branch weights into a vector.
-          SmallVector<uint32_t, 8> Weights;
-          for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e;
-               ++MD_i) {
-            auto *CI = mdconst::extract<ConstantInt>(MD->getOperand(MD_i));
-            Weights.push_back(CI->getValue().getZExtValue());
-          }
+        if (NCases > 1 && extractBranchWeights(MD, Weights)) {
           // Merge weight of this case to the default weight.
           unsigned idx = i->getCaseIndex();
           Weights[0] += Weights[idx+1];
           // Remove weight for this case.
           std::swap(Weights[idx+1], Weights.back());
           Weights.pop_back();
-          SI->setMetadata(LLVMContext::MD_prof,
-                          MDBuilder(BB->getContext()).
-                          createBranchWeights(Weights));
+          SI->setMetadata(
+              LLVMContext::MD_prof,
+              MDBuilder(BB->getContext())
+                  .createBranchWeights(Weights, hasExpectedProvenance(MD)));
         }
         // Remove this entry.
         BasicBlock *ParentBB = SI->getParent();
@@ -306,7 +303,7 @@
                                                FirstCase.getCaseSuccessor(),
                                                SI->getDefaultDest());
       MDNode *MD = SI->getMetadata(LLVMContext::MD_prof);
-      if (MD && MD->getNumOperands() == 3) {
+      if (MD && MD->getNumOperands() == (2 + getBranchWeightOffset(MD))) {
         ConstantInt *SICase =
             mdconst::dyn_extract<ConstantInt>(MD->getOperand(2));
         ConstantInt *SIDef =
@@ -2696,6 +2693,7 @@
     case LLVMContext::MD_dbg:
     case LLVMContext::MD_tbaa:
     case LLVMContext::MD_prof:
+    case LLVMContext::MD_expected:
     case LLVMContext::MD_fpmath:
     case LLVMContext::MD_tbaa_struct:
     case LLVMContext::MD_invariant_load:
Index: llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
===================================================================
--- llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -107,7 +107,8 @@
   SI.setCondition(ArgValue);
 
   SI.setMetadata(LLVMContext::MD_prof,
-                 MDBuilder(CI->getContext()).createBranchWeights(Weights));
+                 MDBuilder(CI->getContext())
+                     .createBranchWeights(Weights, /*IsExpected*/ true));
 
   return true;
 }
@@ -252,11 +253,13 @@
     if (IsOpndComingFromSuccessor(BI->getSuccessor(1)))
       BI->setMetadata(LLVMContext::MD_prof,
                       MDB.createBranchWeights(LikelyBranchWeightVal,
-                                              UnlikelyBranchWeightVal));
+                                              UnlikelyBranchWeightVal,
+                                              /*IsExpected=*/true));
     else if (IsOpndComingFromSuccessor(BI->getSuccessor(0)))
       BI->setMetadata(LLVMContext::MD_prof,
                       MDB.createBranchWeights(UnlikelyBranchWeightVal,
-                                              LikelyBranchWeightVal));
+                                              LikelyBranchWeightVal,
+                                              /*IsExpected=*/true));
   }
 }
 
@@ -321,12 +324,12 @@
   SmallVector<uint32_t, 4> ExpectedWeights;
   if ((ExpectedValue->getZExtValue() == ValueComparedTo) ==
       (Predicate == CmpInst::ICMP_EQ)) {
-    Node =
-        MDB.createBranchWeights(LikelyBranchWeightVal, UnlikelyBranchWeightVal);
+    Node = MDB.createBranchWeights(
+        LikelyBranchWeightVal, UnlikelyBranchWeightVal, /*IsExpected=*/true);
     ExpectedWeights = {LikelyBranchWeightVal, UnlikelyBranchWeightVal};
   } else {
-    Node =
-        MDB.createBranchWeights(UnlikelyBranchWeightVal, LikelyBranchWeightVal);
+    Node = MDB.createBranchWeights(UnlikelyBranchWeightVal,
+                                   LikelyBranchWeightVal, /*IsExpected=*/true);
     ExpectedWeights = {UnlikelyBranchWeightVal, LikelyBranchWeightVal};
   }
 
Index: llvm/lib/Transforms/Scalar/LoopPredication.cpp
===================================================================
--- llvm/lib/Transforms/Scalar/LoopPredication.cpp
+++ llvm/lib/Transforms/Scalar/LoopPredication.cpp
@@ -191,6 +191,7 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
@@ -960,7 +961,8 @@
     if (MDString *MDS = dyn_cast<MDString>(ProfileData->getOperand(0)))
       if (!MDS->getString().equals("branch_weights"))
         return false;
-    if (ProfileData->getNumOperands() != 1 + Term->getNumSuccessors())
+    if (ProfileData->getNumOperands() !=
+        getBranchWeightOffset(ProfileData) + Term->getNumSuccessors())
       return false;
     return true;
   };
@@ -977,14 +979,13 @@
     MDNode *ProfileData = Term->getMetadata(LLVMContext::MD_prof);
     unsigned NumSucc = Term->getNumSuccessors();
     if (IsValidProfileData(ProfileData, Term)) {
-      uint64_t Numerator = 0, Denominator = 0, ProfVal = 0;
+      uint64_t Numerator = 0, Denominator = 0;
+      SmallVector<uint32_t> Weights;
+      extractBranchWeights(ProfileData, Weights);
       for (unsigned i = 0; i < NumSucc; i++) {
-        ConstantInt *CI =
-            mdconst::extract<ConstantInt>(ProfileData->getOperand(i + 1));
-        ProfVal = CI->getValue().getZExtValue();
         if (Term->getSuccessor(i) == ExitBlock)
-          Numerator += ProfVal;
-        Denominator += ProfVal;
+          Numerator += Weights[i];
+        Denominator += Weights[i];
       }
       return BranchProbability::getBranchProbability(Numerator, Denominator);
     } else {
Index: llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
===================================================================
--- llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -509,6 +509,7 @@
     case LLVMContext::MD_dbg:
     case LLVMContext::MD_tbaa:
     case LLVMContext::MD_prof:
+    case LLVMContext::MD_expected:
     case LLVMContext::MD_fpmath:
     case LLVMContext::MD_tbaa_struct:
     case LLVMContext::MD_alias_scope:
Index: llvm/lib/IR/Verifier.cpp
===================================================================
--- llvm/lib/IR/Verifier.cpp
+++ llvm/lib/IR/Verifier.cpp
@@ -91,6 +91,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/ModuleSlotTracker.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Statepoint.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
@@ -4512,11 +4513,13 @@
         "expected string with name of the !prof annotation", MD);
   MDString *MDS = cast<MDString>(MD->getOperand(0));
   StringRef ProfName = MDS->getString();
+  unsigned Offset = getBranchWeightOffset(I);
 
   // Check consistency of !prof branch_weights metadata.
   if (ProfName.equals("branch_weights")) {
     if (isa<InvokeInst>(&I)) {
-      Check(MD->getNumOperands() == 2 || MD->getNumOperands() == 3,
+      Check(MD->getNumOperands() == (1 + Offset) ||
+                MD->getNumOperands() == (2 + Offset),
             "Wrong number of InvokeInst branch_weights operands", MD);
     } else {
       unsigned ExpectedNumOperands = 0;
@@ -4536,10 +4539,10 @@
         CheckFailed("!prof branch_weights are not allowed for this instruction",
                     MD);
 
-      Check(MD->getNumOperands() == 1 + ExpectedNumOperands,
+      Check(MD->getNumOperands() == Offset + ExpectedNumOperands,
             "Wrong number of operands", MD);
     }
-    for (unsigned i = 1; i < MD->getNumOperands(); ++i) {
+    for (unsigned i = Offset; i < MD->getNumOperands(); ++i) {
       auto &MDO = MD->getOperand(i);
       Check(MDO, "second operand should not be null", MD);
       Check(mdconst::dyn_extract<ConstantInt>(MDO),
Index: llvm/lib/IR/ProfDataUtils.cpp
===================================================================
--- llvm/lib/IR/ProfDataUtils.cpp
+++ llvm/lib/IR/ProfDataUtils.cpp
@@ -39,9 +39,6 @@
 // We maintain some constants here to ensure that we access the branch weights
 // correctly, and can change the behavior in the future if the layout changes
 
-// The index at which the weights vector starts
-constexpr unsigned WeightsIdx = 1;
-
 // the minimum number of operands for MD_prof nodes with branch weights
 constexpr unsigned MinBWOps = 3;
 
@@ -51,6 +48,8 @@
   assert(ProfileData && "ProfileData was nullptr in extractWeights");
   unsigned NOps = ProfileData->getNumOperands();
 
+  // The index at which the weights vector starts
+  unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
   assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
   Weights.resize(NOps - WeightsIdx);
 
@@ -101,6 +100,30 @@
   return isBranchWeightMD(ProfileData);
 }
 
+bool hasExpectedProvenance(const Instruction &I) {
+  auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+  return hasExpectedProvenance(ProfileData);
+}
+
+bool hasExpectedProvenance(const MDNode *ProfileData) {
+  if (!isBranchWeightMD(ProfileData))
+    return false;
+
+  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
+  if (!ProfDataName)
+    return false;
+  return ProfDataName->getString().equals("expected");
+}
+
+unsigned getBranchWeightOffset(const Instruction &I) {
+  auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+  return getBranchWeightOffset(ProfileData);
+}
+
+unsigned getBranchWeightOffset(const MDNode *ProfileData) {
+  return hasExpectedProvenance(ProfileData) ? 2 : 1;
+}
+
 bool extractBranchWeights(const MDNode *ProfileData,
                           SmallVectorImpl<uint32_t> &Weights) {
   if (!isBranchWeightMD(ProfileData))
@@ -143,7 +166,8 @@
     return false;
 
   if (ProfDataName->getString().equals("branch_weights")) {
-    for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
+    int Offset = getBranchWeightOffset(ProfileData);
+    for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); Idx++) {
       auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
       assert(V && "Malformed branch_weight in MD_prof node");
       TotalVal += V->getValue().getZExtValue();
Index: llvm/lib/IR/MDBuilder.cpp
===================================================================
--- llvm/lib/IR/MDBuilder.cpp
+++ llvm/lib/IR/MDBuilder.cpp
@@ -35,19 +35,23 @@
 }
 
 MDNode *MDBuilder::createBranchWeights(uint32_t TrueWeight,
-                                       uint32_t FalseWeight) {
-  return createBranchWeights({TrueWeight, FalseWeight});
+                                       uint32_t FalseWeight, bool IsExpected) {
+  return createBranchWeights({TrueWeight, FalseWeight}, IsExpected);
 }
 
-MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights) {
+MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights,
+                                       bool IsExpected) {
   assert(Weights.size() >= 1 && "Need at least one branch weights!");
 
-  SmallVector<Metadata *, 4> Vals(Weights.size() + 1);
+  auto Offset = IsExpected ? 2 : 1;
+  SmallVector<Metadata *, 4> Vals(Weights.size() + Offset);
   Vals[0] = createString("branch_weights");
+  if (IsExpected)
+    Vals[1] = createString("expected");
 
   Type *Int32Ty = Type::getInt32Ty(Context);
   for (unsigned i = 0, e = Weights.size(); i != e; ++i)
-    Vals[i + 1] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));
+    Vals[i + Offset] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));
 
   return MDNode::get(Context, Vals);
 }
Index: llvm/lib/IR/Instructions.cpp
===================================================================
--- llvm/lib/IR/Instructions.cpp
+++ llvm/lib/IR/Instructions.cpp
@@ -32,6 +32,7 @@
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Support/AtomicOrdering.h"
@@ -630,10 +631,12 @@
   APInt APS(128, S), APT(128, T);
   if (ProfDataName->getString().equals("branch_weights") &&
       ProfileData->getNumOperands() > 0) {
+    auto Offset = getBranchWeightOffset(ProfileData);
     // Using APInt::div may be expensive, but most cases should fit 64 bits.
-    APInt Val(128, mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1))
-                       ->getValue()
-                       .getZExtValue());
+    APInt Val(128,
+              mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Offset))
+                  ->getValue()
+                  .getZExtValue());
     Val *= APS;
     Vals.push_back(MDB.createConstant(
         ConstantInt::get(Type::getInt32Ty(getContext()),
@@ -4500,21 +4503,15 @@
 }
 
 void SwitchInstProfUpdateWrapper::init() {
-  MDNode *ProfileData = getProfBranchWeightsMD(SI);
-  if (!ProfileData)
+  SmallVector<uint32_t, 8> Weights;
+  if (!extractBranchWeights(SI, Weights))
     return;
 
-  if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) {
+  if (Weights.size() != SI.getNumSuccessors()) {
     llvm_unreachable("number of prof branch_weights metadata operands does "
                      "not correspond to number of succesors");
   }
 
-  SmallVector<uint32_t, 8> Weights;
-  for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) {
-    ConstantInt *C = mdconst::extract<ConstantInt>(ProfileData->getOperand(CI));
-    uint32_t CW = C->getValue().getZExtValue();
-    Weights.push_back(CW);
-  }
   this->Weights = std::move(Weights);
 }
 
Index: llvm/lib/IR/Instruction.cpp
===================================================================
--- llvm/lib/IR/Instruction.cpp
+++ llvm/lib/IR/Instruction.cpp
@@ -17,6 +17,7 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 using namespace llvm;
 
@@ -849,19 +850,25 @@
 
 void Instruction::swapProfMetadata() {
   MDNode *ProfileData = getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData || ProfileData->getNumOperands() != 3 ||
-      !isa<MDString>(ProfileData->getOperand(0)))
+  if (!isBranchWeightMD(ProfileData))
     return;
 
-  MDString *MDName = cast<MDString>(ProfileData->getOperand(0));
-  if (MDName->getString() != "branch_weights")
-    return;
-
-  // The first operand is the name. Fetch them backwards and build a new one.
-  Metadata *Ops[] = {ProfileData->getOperand(0), ProfileData->getOperand(2),
-                     ProfileData->getOperand(1)};
-  setMetadata(LLVMContext::MD_prof,
-              MDNode::get(ProfileData->getContext(), Ops));
+  // FIXME Use helper to init Ops and clean up diverging control flow
+  if (hasExpectedProvenance(ProfileData)) {
+    // The first operand is the name. Second is "expected" provenance string.
+    // Fetch the weights backwards and build a new one.
+    Metadata *Ops[] = {ProfileData->getOperand(0), ProfileData->getOperand(1),
+                       ProfileData->getOperand(3), ProfileData->getOperand(2)};
+    setMetadata(LLVMContext::MD_prof,
+                MDNode::get(ProfileData->getContext(), Ops));
+  } else {
+    // The first operand is the name.
+    // Fetch the weights backwards and build a new one.
+    Metadata *Ops[] = {ProfileData->getOperand(0), ProfileData->getOperand(2),
+                       ProfileData->getOperand(1)};
+    setMetadata(LLVMContext::MD_prof,
+                MDNode::get(ProfileData->getContext(), Ops));
+  }
 }
 
 void Instruction::copyMetadata(const Instruction &SrcInst,
Index: llvm/lib/CodeGen/CodeGenPrepare.cpp
===================================================================
--- llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -8517,7 +8517,8 @@
         scaleWeights(NewTrueWeight, NewFalseWeight);
         Br1->setMetadata(LLVMContext::MD_prof,
                          MDBuilder(Br1->getContext())
-                             .createBranchWeights(TrueWeight, FalseWeight));
+                             .createBranchWeights(TrueWeight, FalseWeight,
+                                                  hasExpectedProvenance(*Br1)));
 
         NewTrueWeight = TrueWeight;
         NewFalseWeight = 2 * FalseWeight;
Index: llvm/lib/Analysis/BranchProbabilityInfo.cpp
===================================================================
--- llvm/lib/Analysis/BranchProbabilityInfo.cpp
+++ llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -392,7 +392,8 @@
 
   // Ensure there are weights for all of the successors. Note that the first
   // operand to the metadata node is a name, not a weight.
-  if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
+  if (WeightsNode->getNumOperands() !=
+      TI->getNumSuccessors() + getBranchWeightOffset(WeightsNode))
     return false;
 
   // Build up the final weights that will be used in a temporary buffer.
Index: llvm/include/llvm/IR/ProfDataUtils.h
===================================================================
--- llvm/include/llvm/IR/ProfDataUtils.h
+++ llvm/include/llvm/IR/ProfDataUtils.h
@@ -34,6 +34,18 @@
 /// otherwise.
 bool hasBranchWeightMD(const Instruction &I);
 
+/// Check if Branch Weight Metadata has an "expected" field
+bool hasExpectedProvenance(const Instruction &I);
+
+/// Check if Branch Weight Metadata has an "expected" field
+bool hasExpectedProvenance(const MDNode *ProfileData);
+
+/// return the offset to the first branch weight data
+unsigned getBranchWeightOffset(const Instruction &I);
+
+/// return the offset to the first branch weight data
+unsigned getBranchWeightOffset(const MDNode *ProfileData);
+
 /// Extract branch weights from MD_prof metadata
 ///
 /// \param ProfileData A pointer to an MDNode.
Index: llvm/include/llvm/IR/MDBuilder.h
===================================================================
--- llvm/include/llvm/IR/MDBuilder.h
+++ llvm/include/llvm/IR/MDBuilder.h
@@ -59,10 +59,12 @@
   //===------------------------------------------------------------------===//
 
   /// Return metadata containing two branch weights.
-  MDNode *createBranchWeights(uint32_t TrueWeight, uint32_t FalseWeight);
+  MDNode *createBranchWeights(uint32_t TrueWeight, uint32_t FalseWeight,
+                              bool IsExpected = false);
 
   /// Return metadata containing a number of branch weights.
-  MDNode *createBranchWeights(ArrayRef<uint32_t> Weights);
+  MDNode *createBranchWeights(ArrayRef<uint32_t> Weights,
+                              bool IsExpected = false);
 
   /// Return metadata specifying that a branch or switch is unpredictable.
   MDNode *createUnpredictable();
Index: llvm/include/llvm/IR/FixedMetadataKinds.def
===================================================================
--- llvm/include/llvm/IR/FixedMetadataKinds.def
+++ llvm/include/llvm/IR/FixedMetadataKinds.def
@@ -49,3 +49,4 @@
 LLVM_FIXED_MD_KIND(MD_callsite, "callsite", 35)
 LLVM_FIXED_MD_KIND(MD_kcfi_type, "kcfi_type", 36)
 LLVM_FIXED_MD_KIND(MD_pcsections, "pcsections", 37)
+LLVM_FIXED_MD_KIND(MD_expected, "expected", 38)
Index: clang/test/CodeGenCXX/attr-likelihood-if-vs-builtin-expect.cpp
===================================================================
--- clang/test/CodeGenCXX/attr-likelihood-if-vs-builtin-expect.cpp
+++ clang/test/CodeGenCXX/attr-likelihood-if-vs-builtin-expect.cpp
@@ -221,5 +221,5 @@
   }
 }
 
-// CHECK: [[BW_LIKELY]] = !{!"branch_weights", i32 2000, i32 1}
-// CHECK: [[BW_UNLIKELY]] = !{!"branch_weights", i32 1, i32 2000}
+// CHECK: [[BW_LIKELY]] = !{!"branch_weights", !"expected", i32 2000, i32 1}
+// CHECK: [[BW_UNLIKELY]] = !{!"branch_weights", !"expected", i32 1, i32 2000}
Index: clang/test/CodeGenCXX/attr-likelihood-if-branch-weights.cpp
===================================================================
--- clang/test/CodeGenCXX/attr-likelihood-if-branch-weights.cpp
+++ clang/test/CodeGenCXX/attr-likelihood-if-branch-weights.cpp
@@ -144,5 +144,5 @@
   }
 }
 
-// CHECK: !7 = !{!"branch_weights", i32 [[UNLIKELY]], i32 [[LIKELY]]}
-// CHECK: !8 = !{!"branch_weights", i32 [[LIKELY]], i32 [[UNLIKELY]]}
+// CHECK: !7 = !{!"branch_weights", !"expected", i32 [[UNLIKELY]], i32 [[LIKELY]]}
+// CHECK: !8 = !{!"branch_weights", !"expected", i32 [[LIKELY]], i32 [[UNLIKELY]]}
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to