[PATCH] D15548: Remove the restriction that known and unknown probabilities cannot coexist when being normalized.

Cong Hou via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 17 14:13:13 PST 2015


congh updated this revision to Diff 43182.
congh added a comment.

Update the patch on David' comment.


http://reviews.llvm.org/D15548

Files:
  include/llvm/Support/BranchProbability.h
  unittests/Support/BranchProbabilityTest.cpp

Index: unittests/Support/BranchProbabilityTest.cpp
===================================================================
--- unittests/Support/BranchProbabilityTest.cpp
+++ unittests/Support/BranchProbabilityTest.cpp
@@ -288,6 +288,7 @@
 }
 
 TEST(BranchProbabilityTest, NormalizeProbabilities) {
+  const auto UnknownProb = BranchProbability::getUnknown();
   {
     SmallVector<BranchProbability, 2> Probs{{0, 1}, {0, 1}};
     BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end());
@@ -322,6 +323,36 @@
     EXPECT_EQ(BranchProbability::getDenominator() / 3 + 1,
               Probs[2].getNumerator());
   }
+  {
+    SmallVector<BranchProbability, 2> Probs{{0, 1}, UnknownProb};
+    BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end());
+    EXPECT_EQ(0, Probs[0].getNumerator());
+    EXPECT_EQ(BranchProbability::getDenominator(), Probs[1].getNumerator());
+  }
+  {
+    SmallVector<BranchProbability, 2> Probs{{1, 1}, UnknownProb};
+    BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end());
+    EXPECT_EQ(BranchProbability::getDenominator(), Probs[0].getNumerator());
+    EXPECT_EQ(0, Probs[1].getNumerator());
+  }
+  {
+    SmallVector<BranchProbability, 2> Probs{{1, 2}, UnknownProb};
+    BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end());
+    EXPECT_EQ(BranchProbability::getDenominator() / 2, Probs[0].getNumerator());
+    EXPECT_EQ(BranchProbability::getDenominator() / 2, Probs[1].getNumerator());
+  }
+  {
+    SmallVector<BranchProbability, 4> Probs{
+        {1, 2}, {1, 2}, {1, 2}, UnknownProb};
+    BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end());
+    EXPECT_EQ(BranchProbability::getDenominator() / 3 + 1,
+              Probs[0].getNumerator());
+    EXPECT_EQ(BranchProbability::getDenominator() / 3 + 1,
+              Probs[1].getNumerator());
+    EXPECT_EQ(BranchProbability::getDenominator() / 3 + 1,
+              Probs[2].getNumerator());
+    EXPECT_EQ(0, Probs[3].getNumerator());
+  }
 }
 
 }
Index: include/llvm/Support/BranchProbability.h
===================================================================
--- include/llvm/Support/BranchProbability.h
+++ include/llvm/Support/BranchProbability.h
@@ -183,17 +183,32 @@
   if (Begin == End)
     return;
 
-  auto UnknownProbCount =
-      std::count(Begin, End, BranchProbability::getUnknown());
-  assert((UnknownProbCount == 0 ||
-          UnknownProbCount == std::distance(Begin, End)) &&
-         "Cannot normalize probabilities with known and unknown ones.");
-  (void)UnknownProbCount;
-
-  uint64_t Sum = std::accumulate(
-      Begin, End, uint64_t(0),
-      [](uint64_t S, const BranchProbability &BP) { return S + BP.N; });
-
+  unsigned UnknownProbCount = 0;
+  uint64_t Sum = std::accumulate(Begin, End, uint64_t(0),
+                                 [&](uint64_t S, const BranchProbability &BP) {
+                                   if (!BP.isUnknown())
+                                     return S + BP.N;
+                                   UnknownProbCount++;
+                                   return S;
+                                 });
+
+  if (UnknownProbCount > 0) {
+    BranchProbability ProbForUnknown = BranchProbability::getZero();
+    // If the sum of all known probabilities is less than one, evenly distribute
+    // the complement of sum to unknown probabilities. Otherwise, set unknown
+    // probabilities to zeros and continue to normalize known probabilities.
+    if (Sum < BranchProbability::getDenominator())
+      ProbForUnknown = BranchProbability::getRaw(
+          (BranchProbability::getDenominator() - Sum) / UnknownProbCount);
+
+    std::replace_if(Begin, End,
+                    [](const BranchProbability &BP) { return BP.isUnknown(); },
+                    ProbForUnknown);
+
+    if (Sum <= BranchProbability::getDenominator())
+      return;
+  }
+ 
   if (Sum == 0) {
     BranchProbability BP(1, std::distance(Begin, End));
     std::fill(Begin, End, BP);


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D15548.43182.patch
Type: text/x-patch
Size: 4012 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20151217/d31d6736/attachment.bin>


More information about the llvm-commits mailing list