[llvm] 37fb860 - Add support of __builtin_expect_with_probability

Erich Keane via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 22 10:21:43 PDT 2020


Author: Zhi Zhuang
Date: 2020-06-22T10:21:28-07:00
New Revision: 37fb860301272d8138d289da5b606115b3fe5a13

URL: https://github.com/llvm/llvm-project/commit/37fb860301272d8138d289da5b606115b3fe5a13
DIFF: https://github.com/llvm/llvm-project/commit/37fb860301272d8138d289da5b606115b3fe5a13.diff

LOG: Add support of __builtin_expect_with_probability

Add a new builtin-function __builtin_expect_with_probability and
intrinsic llvm.expect.with.probability.
The interface is __builtin_expect_with_probability(long expr, long
expected, double probability).
It is mainly the same as __builtin_expect besides one more argument
indicating the probability of expression equal to expected value. The
probability should be a constant floating-point expression and be in
range [0.0, 1.0] inclusive.
It is similar to builtin-expect-with-probability function in GCC
built-in functions.

Differential Revision: https://reviews.llvm.org/D79830

Added: 
    clang/test/CodeGen/builtin-expect-with-probability.cpp
    clang/test/Sema/builtin-expect-with-probability-avr.cpp
    clang/test/Sema/builtin-expect-with-probability.cpp
    llvm/test/Transforms/LowerExpectIntrinsic/expect-with-probability.ll

Modified: 
    clang/include/clang/Basic/Builtins.def
    clang/include/clang/Basic/DiagnosticSemaKinds.td
    clang/lib/CodeGen/CGBuiltin.cpp
    clang/lib/Sema/SemaChecking.cpp
    llvm/docs/BranchWeightMetadata.rst
    llvm/docs/LangRef.rst
    llvm/include/llvm/IR/Intrinsics.td
    llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def
index ac9af6028867..c060d49ba338 100644
--- a/clang/include/clang/Basic/Builtins.def
+++ b/clang/include/clang/Basic/Builtins.def
@@ -566,6 +566,7 @@ BUILTIN(__builtin___vprintf_chk, "iicC*a", "FP:1:")
 
 BUILTIN(__builtin_unpredictable, "LiLi"   , "nc")
 BUILTIN(__builtin_expect, "LiLiLi"   , "nc")
+BUILTIN(__builtin_expect_with_probability, "LiLiLid", "nc")
 BUILTIN(__builtin_prefetch, "vvC*.", "nc")
 BUILTIN(__builtin_readcyclecounter, "ULLi", "n")
 BUILTIN(__builtin_trap, "v", "nr")

diff  --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 7c72eba8c2c1..d6d0ccaa00be 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10839,4 +10839,12 @@ def err_ext_int_bad_size : Error<"%select{signed|unsigned}0 _ExtInt must "
                                  "have a bit size of at least %select{2|1}0">;
 def err_ext_int_max_size : Error<"%select{signed|unsigned}0 _ExtInt of bit "
                                  "sizes greater than %1 not supported">;
+
+// errors of expect.with.probability
+def err_probability_not_constant_float : Error<
+   "probability argument to __builtin_expect_with_probability must be constant "
+   "floating-point expression">;
+def err_probability_out_of_range : Error<
+   "probability argument to __builtin_expect_with_probability is outside the "
+   "range [0.0, 1.0]">;
 } // end of sema component.

diff  --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7943c3f34504..c18f01b75353 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -2194,6 +2194,33 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
         Builder.CreateCall(FnExpect, {ArgValue, ExpectedValue}, "expval");
     return RValue::get(Result);
   }
+  case Builtin::BI__builtin_expect_with_probability: {
+    Value *ArgValue = EmitScalarExpr(E->getArg(0));
+    llvm::Type *ArgType = ArgValue->getType();
+
+    Value *ExpectedValue = EmitScalarExpr(E->getArg(1));
+    llvm::APFloat Probability(0.0);
+    const Expr *ProbArg = E->getArg(2);
+    bool EvalSucceed = ProbArg->EvaluateAsFloat(Probability, CGM.getContext());
+    assert(EvalSucceed && "probability should be able to evaluate as float");
+    (void)EvalSucceed;
+    bool LoseInfo = false;
+    Probability.convert(llvm::APFloat::IEEEdouble(),
+                        llvm::RoundingMode::Dynamic, &LoseInfo);
+    llvm::Type *Ty = ConvertType(ProbArg->getType());
+    Constant *Confidence = ConstantFP::get(Ty, Probability);
+    // Don't generate llvm.expect.with.probability on -O0 as the backend
+    // won't use it for anything.
+    // Note, we still IRGen ExpectedValue because it could have side-effects.
+    if (CGM.getCodeGenOpts().OptimizationLevel == 0)
+      return RValue::get(ArgValue);
+
+    Function *FnExpect =
+        CGM.getIntrinsic(Intrinsic::expect_with_probability, ArgType);
+    Value *Result = Builder.CreateCall(
+        FnExpect, {ArgValue, ExpectedValue, Confidence}, "expval");
+    return RValue::get(Result);
+  }
   case Builtin::BI__builtin_assume_aligned: {
     const Expr *Ptr = E->getArg(0);
     Value *PtrValue = EmitScalarExpr(Ptr);

diff  --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 883ff17eb279..e6a3b54c3cc1 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -1808,6 +1808,36 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
     TheCall->setType(Context.IntTy);
     break;
   }
+  case Builtin::BI__builtin_expect_with_probability: {
+    // We first want to ensure we are called with 3 arguments
+    if (checkArgCount(*this, TheCall, 3))
+      return ExprError();
+    // then check probability is constant float in range [0.0, 1.0]
+    const Expr *ProbArg = TheCall->getArg(2);
+    SmallVector<PartialDiagnosticAt, 8> Notes;
+    Expr::EvalResult Eval;
+    Eval.Diag = &Notes;
+    if ((!ProbArg->EvaluateAsConstantExpr(Eval, Expr::EvaluateForCodeGen,
+                                          Context)) ||
+        !Eval.Val.isFloat()) {
+      Diag(ProbArg->getBeginLoc(), diag::err_probability_not_constant_float)
+          << ProbArg->getSourceRange();
+      for (const PartialDiagnosticAt &PDiag : Notes)
+        Diag(PDiag.first, PDiag.second);
+      return ExprError();
+    }
+    llvm::APFloat Probability = Eval.Val.getFloat();
+    bool LoseInfo = false;
+    Probability.convert(llvm::APFloat::IEEEdouble(),
+                        llvm::RoundingMode::Dynamic, &LoseInfo);
+    if (!(Probability >= llvm::APFloat(0.0) &&
+          Probability <= llvm::APFloat(1.0))) {
+      Diag(ProbArg->getBeginLoc(), diag::err_probability_out_of_range)
+          << ProbArg->getSourceRange();
+      return ExprError();
+    }
+    break;
+  }
   case Builtin::BI__builtin_preserve_access_index:
     if (SemaBuiltinPreserveAI(*this, TheCall))
       return ExprError();

diff  --git a/clang/test/CodeGen/builtin-expect-with-probability.cpp b/clang/test/CodeGen/builtin-expect-with-probability.cpp
new file mode 100644
index 000000000000..2d25219e07f5
--- /dev/null
+++ b/clang/test/CodeGen/builtin-expect-with-probability.cpp
@@ -0,0 +1,46 @@
+// RUN: %clang_cc1 -emit-llvm -o - %s -O1 | FileCheck %s
+extern int global;
+
+struct S {
+  static constexpr int prob = 1;
+};
+
+template<typename T>
+int expect_taken(int x) {
+// CHECK: !{{[0-9]+}} = !{!"branch_weights", i32 2147483647, i32 1}
+
+	if (__builtin_expect_with_probability (x == 100, 1, T::prob)) {
+		return 0;
+	}
+	return x;
+}
+
+void f() {
+  expect_taken<S>(global);
+}
+
+int expect_taken2(int x) {
+  // CHECK: !{{[0-9]+}} = !{!"branch_weights", i32 1932735283, i32 214748366}
+
+  if (__builtin_expect_with_probability(x == 100, 1, 0.9)) {
+    return 0;
+  }
+  return x;
+}
+
+int expect_taken3(int x) {
+  // CHECK: !{{[0-9]+}} = !{!"branch_weights", i32 107374184, i32 107374184, i32 1717986918, i32 107374184, i32 107374184}
+  switch (__builtin_expect_with_probability(x, 1, 0.8)) {
+  case 0:
+    x = x + 0;
+  case 1:
+    x = x + 1;
+  case 2:
+    x = x + 2;
+  case 5:
+    x = x + 5;
+  default:
+    x = x + 6;
+  }
+  return x;
+}

diff  --git a/clang/test/Sema/builtin-expect-with-probability-avr.cpp b/clang/test/Sema/builtin-expect-with-probability-avr.cpp
new file mode 100644
index 000000000000..1dbb1d52396a
--- /dev/null
+++ b/clang/test/Sema/builtin-expect-with-probability-avr.cpp
@@ -0,0 +1,15 @@
+// RUN: %clang_cc1 -triple avr -fsyntax-only -verify %s
+
+void test(int x, double p) { // expected-note {{declared here}}
+  bool dummy = false;
+  dummy = __builtin_expect_with_probability(x > 0, 1, 0.9);
+  dummy = __builtin_expect_with_probability(x > 0, 1, 1.1); // expected-error {{probability argument to __builtin_expect_with_probability is outside the range [0.0, 1.0]}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, -1); // expected-error {{probability argument to __builtin_expect_with_probability is outside the range [0.0, 1.0]}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, p); // expected-error {{probability argument to __builtin_expect_with_probability must be constant floating-point expression}} expected-note {{read of non-constexpr variable 'p' is not allowed in a constant expression}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, "aa"); // expected-error {{cannot initialize a parameter of type 'double' with an lvalue of type 'const char [3]'}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, __builtin_nan("")); // expected-error {{probability argument to __builtin_expect_with_probability is outside the range [0.0, 1.0]}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, __builtin_inf()); // expected-error {{probability argument to __builtin_expect_with_probability is outside the range [0.0, 1.0]}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, -0.0);
+  dummy = __builtin_expect_with_probability(x > 0, 1, 1.0 + __DBL_EPSILON__); // expected-error {{probability argument to __builtin_expect_with_probability is outside the range [0.0, 1.0]}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, -__DBL_DENORM_MIN__); // expected-error {{probability argument to __builtin_expect_with_probability is outside the range [0.0, 1.0]}}
+}

diff  --git a/clang/test/Sema/builtin-expect-with-probability.cpp b/clang/test/Sema/builtin-expect-with-probability.cpp
new file mode 100644
index 000000000000..4a5b62a55bbc
--- /dev/null
+++ b/clang/test/Sema/builtin-expect-with-probability.cpp
@@ -0,0 +1,32 @@
+// RUN: %clang_cc1 -fsyntax-only -verify %s
+extern int global;
+
+struct S {
+  static constexpr float prob = 0.7;
+};
+
+template<typename T>
+void expect_taken(int x) {
+  if (__builtin_expect_with_probability(x > 0, 1, T::prob)) {
+    global++;
+  }
+}
+
+void test(int x, double p) { // expected-note {{declared here}}
+  bool dummy;
+  dummy = __builtin_expect_with_probability(x > 0, 1, 0.9);
+  dummy = __builtin_expect_with_probability(x > 0, 1, 1.1); // expected-error {{probability argument to __builtin_expect_with_probability is outside the range [0.0, 1.0]}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, -1); // expected-error {{probability argument to __builtin_expect_with_probability is outside the range [0.0, 1.0]}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, p); // expected-error {{probability argument to __builtin_expect_with_probability must be constant floating-point expression}} expected-note {{read of non-constexpr variable 'p' is not allowed in a constant expression}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, "aa"); // expected-error {{cannot initialize a parameter of type 'double' with an lvalue of type 'const char [3]'}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, __builtin_nan("")); // expected-error {{probability argument to __builtin_expect_with_probability is outside the range [0.0, 1.0]}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, __builtin_inf()); // expected-error {{probability argument to __builtin_expect_with_probability is outside the range [0.0, 1.0]}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, -0.0);
+  dummy = __builtin_expect_with_probability(x > 0, 1, 1.0 + __DBL_EPSILON__); // expected-error {{probability argument to __builtin_expect_with_probability is outside the range [0.0, 1.0]}}
+  dummy = __builtin_expect_with_probability(x > 0, 1, -__DBL_DENORM_MIN__); // expected-error {{probability argument to __builtin_expect_with_probability is outside the range [0.0, 1.0]}}
+  constexpr double pd = 0.7;
+  dummy = __builtin_expect_with_probability(x > 0, 1, pd);
+  constexpr int pi = 1;
+  dummy = __builtin_expect_with_probability(x > 0, 1, pi);
+  expect_taken<S>(x);
+}

diff  --git a/llvm/docs/BranchWeightMetadata.rst b/llvm/docs/BranchWeightMetadata.rst
index 2a1083215987..902ba87316ed 100644
--- a/llvm/docs/BranchWeightMetadata.rst
+++ b/llvm/docs/BranchWeightMetadata.rst
@@ -15,7 +15,7 @@ The first operator is always a ``MDString`` node with the string
 "branch_weights".  Number of operators depends on the terminator type.
 
 Branch weights might be fetch from the profiling file, or generated based on
-`__builtin_expect`_ instruction.
+`__builtin_expect`_ and `__builtin_expect_with_probability`_ instruction.
 
 All weights are represented as an unsigned 32-bit values, where higher value
 indicates greater chance to be taken.
@@ -144,6 +144,47 @@ case is assumed to be likely taken.
   case 5:  // This case is likely to be taken.
   }
 
+.. _\__builtin_expect_with_probability:
+
+Built-in ``expect.with.probability`` Instruction
+================================================
+
+``__builtin_expect_with_probability(long exp, long c, double probability)`` has
+the same semantics as ``__builtin_expect``, but the caller provides the
+probability that ``exp == c``. The last argument ``probability`` must be
+constant floating-point expression and be in the range [0.0, 1.0] inclusive.
+The usage is also similar as ``__builtin_expect``, for example:
+
+``if`` statement
+^^^^^^^^^^^^^^^^
+
+If the expect comparison value ``c`` is equal to 1(true), and probability
+value ``probability`` is set to 0.8, that means the probability of condition
+to be true is 80% while that of false is 20%.
+
+.. code-block:: c++
+
+  if (__builtin_expect_with_probability(x > 0, 1, 0.8)) {
+    // This block is likely to be taken with probability 80%.
+  }
+
+``switch`` statement
+^^^^^^^^^^^^^^^^^^^^
+
+This is basically the same as ``switch`` statement in ``__builtin_expect``.
+The probability that ``exp`` is equal to the expect value is given in
+the third argument ``probability``, while the probability of other value is
+the average of remaining probability(``1.0 - probability``). For example:
+
+.. code-block:: c++
+
+  switch (__builtin_expect_with_probability(x, 5, 0.7)) {
+  default: break;  // Take this case with probability 10%
+  case 0:  break;  // Take this case with probability 10%
+  case 3:  break;  // Take this case with probability 10%
+  case 5:  break;  // This case is likely to be taken with probability 70%
+  }
+
 CFG Modifications
 =================
 

diff  --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index e8e481a0a411..4adf958d86a6 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -4619,7 +4619,7 @@ to the ``add`` instruction using the ``!dbg`` identifier:
     %indvar.next = add i64 %indvar, 1, !dbg !21
 
 Metadata can also be attached to a function or a global variable. Here metadata
-``!22`` is attached to the ``f1`` and ``f2 functions, and the globals ``g1``
+``!22`` is attached to the ``f1`` and ``f2`` functions, and the globals ``g1``
 and ``g2`` using the ``!dbg`` identifier:
 
 .. code-block:: llvm
@@ -19063,6 +19063,40 @@ Semantics:
 
 This intrinsic is lowered to the ``val``.
 
+'``llvm.expect.with.probability``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+This intrinsic is similar to ``llvm.expect``. This is an overloaded intrinsic.
+You can use ``llvm.expect.with.probability`` on any integer bit width.
+
+::
+
+      declare i1 @llvm.expect.with.probability.i1(i1 <val>, i1 <expected_val>, double <prob>)
+      declare i32 @llvm.expect.with.probability.i32(i32 <val>, i32 <expected_val>, double <prob>)
+      declare i64 @llvm.expect.with.probability.i64(i64 <val>, i64 <expected_val>, double <prob>)
+
+Overview:
+"""""""""
+
+The ``llvm.expect.with.probability`` intrinsic provides information about
+expected value of ``val`` with probability(or confidence) ``prob``, which can
+be used by optimizers.
+
+Arguments:
+""""""""""
+
+The ``llvm.expect.with.probability`` intrinsic takes three arguments. The first
+argument is a value. The second argument is an expected value. The third
+argument is a probability.
+
+Semantics:
+""""""""""
+
+This intrinsic is lowered to the ``val``.
+
 .. _int_assume:
 
 '``llvm.assume``' Intrinsic

diff  --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index a4fb6c6d9cd5..1af5588f5f56 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -837,6 +837,10 @@ let IntrProperties = [IntrInaccessibleMemOnly, IntrWillReturn] in {
 def int_expect : Intrinsic<[llvm_anyint_ty],
   [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem, IntrWillReturn]>;
 
+def int_expect_with_probability : Intrinsic<[llvm_anyint_ty],
+  [LLVMMatchType<0>, LLVMMatchType<0>, llvm_double_ty],
+  [IntrNoMem, IntrWillReturn]>;
+
 //===-------------------- Bit Manipulation Intrinsics ---------------------===//
 //
 

diff  --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
index 53671c7bc3d1..4b98be28b862 100644
--- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -55,13 +55,34 @@ static cl::opt<uint32_t> UnlikelyBranchWeight(
     "unlikely-branch-weight", cl::Hidden, cl::init(1),
     cl::desc("Weight of the branch unlikely to be taken (default = 1)"));
 
+std::tuple<uint32_t, uint32_t> getBranchWeight(Intrinsic::ID IntrinsicID,
+                                               CallInst *CI, int BranchCount) {
+  if (IntrinsicID == Intrinsic::expect) {
+    // __builtin_expect
+    return {LikelyBranchWeight, UnlikelyBranchWeight};
+  } else {
+    // __builtin_expect_with_probability
+    assert(CI->getNumOperands() >= 3 &&
+           "expect with probability must have 3 arguments");
+    ConstantFP *Confidence = dyn_cast<ConstantFP>(CI->getArgOperand(2));
+    double TrueProb = Confidence->getValueAPF().convertToDouble();
+    assert((TrueProb >= 0.0 && TrueProb <= 1.0) &&
+           "probability value must be in the range [0.0, 1.0]");
+    double FalseProb = (1.0 - TrueProb) / (BranchCount - 1);
+    uint32_t LikelyBW = ceil((TrueProb * (double)(INT32_MAX - 1)) + 1.0);
+    uint32_t UnlikelyBW = ceil((FalseProb * (double)(INT32_MAX - 1)) + 1.0);
+    return {LikelyBW, UnlikelyBW};
+  }
+}
+
 static bool handleSwitchExpect(SwitchInst &SI) {
   CallInst *CI = dyn_cast<CallInst>(SI.getCondition());
   if (!CI)
     return false;
 
   Function *Fn = CI->getCalledFunction();
-  if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect)
+  if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect &&
+              Fn->getIntrinsicID() != Intrinsic::expect_with_probability))
     return false;
 
   Value *ArgValue = CI->getArgOperand(0);
@@ -71,15 +92,19 @@ static bool handleSwitchExpect(SwitchInst &SI) {
 
   SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue);
   unsigned n = SI.getNumCases(); // +1 for default case.
-  SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeight);
+  uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
+  std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
+      getBranchWeight(Fn->getIntrinsicID(), CI, n + 1);
+
+  SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeightVal);
 
   uint64_t Index = (Case == *SI.case_default()) ? 0 : Case.getCaseIndex() + 1;
-  Weights[Index] = LikelyBranchWeight;
+  Weights[Index] = LikelyBranchWeightVal;
 
-  SI.setMetadata(
-      LLVMContext::MD_misexpect,
-      MDBuilder(CI->getContext())
-          .createMisExpect(Index, LikelyBranchWeight, UnlikelyBranchWeight));
+  SI.setMetadata(LLVMContext::MD_misexpect,
+                 MDBuilder(CI->getContext())
+                     .createMisExpect(Index, LikelyBranchWeightVal,
+                                      UnlikelyBranchWeightVal));
 
   SI.setCondition(ArgValue);
   misexpect::checkFrontendInstrumentation(SI);
@@ -223,15 +248,18 @@ static void handlePhiDef(CallInst *Expect) {
         return true;
       return false;
     };
+    uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
+    std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = getBranchWeight(
+        Expect->getCalledFunction()->getIntrinsicID(), Expect, 2);
 
     if (IsOpndComingFromSuccessor(BI->getSuccessor(1)))
-      BI->setMetadata(
-          LLVMContext::MD_prof,
-          MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight));
+      BI->setMetadata(LLVMContext::MD_prof,
+                      MDB.createBranchWeights(LikelyBranchWeightVal,
+                                              UnlikelyBranchWeightVal));
     else if (IsOpndComingFromSuccessor(BI->getSuccessor(0)))
-      BI->setMetadata(
-          LLVMContext::MD_prof,
-          MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight));
+      BI->setMetadata(LLVMContext::MD_prof,
+                      MDB.createBranchWeights(UnlikelyBranchWeightVal,
+                                              LikelyBranchWeightVal));
   }
 }
 
@@ -277,7 +305,8 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) {
   }
 
   Function *Fn = CI->getCalledFunction();
-  if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect)
+  if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect &&
+              Fn->getIntrinsicID() != Intrinsic::expect_with_probability))
     return false;
 
   Value *ArgValue = CI->getArgOperand(0);
@@ -289,13 +318,21 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) {
   MDNode *Node;
   MDNode *ExpNode;
 
+  uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
+  std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
+      getBranchWeight(Fn->getIntrinsicID(), CI, 2);
+
   if ((ExpectedValue->getZExtValue() == ValueComparedTo) ==
       (Predicate == CmpInst::ICMP_EQ)) {
-    Node = MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight);
-    ExpNode = MDB.createMisExpect(0, LikelyBranchWeight, UnlikelyBranchWeight);
+    Node =
+        MDB.createBranchWeights(LikelyBranchWeightVal, UnlikelyBranchWeightVal);
+    ExpNode =
+        MDB.createMisExpect(0, LikelyBranchWeightVal, UnlikelyBranchWeightVal);
   } else {
-    Node = MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight);
-    ExpNode = MDB.createMisExpect(1, LikelyBranchWeight, UnlikelyBranchWeight);
+    Node =
+        MDB.createBranchWeights(UnlikelyBranchWeightVal, LikelyBranchWeightVal);
+    ExpNode =
+        MDB.createMisExpect(1, LikelyBranchWeightVal, UnlikelyBranchWeightVal);
   }
 
   BSI.setMetadata(LLVMContext::MD_misexpect, ExpNode);
@@ -347,7 +384,8 @@ static bool lowerExpectIntrinsic(Function &F) {
       }
 
       Function *Fn = CI->getCalledFunction();
-      if (Fn && Fn->getIntrinsicID() == Intrinsic::expect) {
+      if (Fn && (Fn->getIntrinsicID() == Intrinsic::expect ||
+                 Fn->getIntrinsicID() == Intrinsic::expect_with_probability)) {
         // Before erasing the llvm.expect, walk backward to find
         // phi that define llvm.expect's first arg, and
         // infer branch probability:

diff  --git a/llvm/test/Transforms/LowerExpectIntrinsic/expect-with-probability.ll b/llvm/test/Transforms/LowerExpectIntrinsic/expect-with-probability.ll
new file mode 100644
index 000000000000..8972ee01c12d
--- /dev/null
+++ b/llvm/test/Transforms/LowerExpectIntrinsic/expect-with-probability.ll
@@ -0,0 +1,295 @@
+; RUN: opt -lower-expect -strip-dead-prototypes -S -o - < %s | FileCheck %s
+; RUN: opt -S -passes='function(lower-expect),strip-dead-prototypes' < %s | FileCheck %s
+
+; CHECK-LABEL: @test1(
+define i32 @test1(i32 %x) nounwind uwtable ssp {
+entry:
+  %retval = alloca i32, align 4
+  %x.addr = alloca i32, align 4
+  store i32 %x, i32* %x.addr, align 4
+  %tmp = load i32, i32* %x.addr, align 4
+  %cmp = icmp sgt i32 %tmp, 1
+  %conv = zext i1 %cmp to i32
+  %conv1 = sext i32 %conv to i64
+  %expval = call i64 @llvm.expect.with.probability.i64(i64 %conv1, i64 1, double 8.000000e-01)
+  %tobool = icmp ne i64 %expval, 0
+; CHECK: !prof !0, !misexpect !1
+; CHECK-NOT: @llvm.expect.with.probability
+  br i1 %tobool, label %if.then, label %if.end
+
+if.then:                                          ; preds = %entry
+  %call = call i32 (...) @f()
+  store i32 %call, i32* %retval
+  br label %return
+
+if.end:                                           ; preds = %entry
+  store i32 1, i32* %retval
+  br label %return
+
+return:                                           ; preds = %if.end, %if.then
+  %0 = load i32, i32* %retval
+  ret i32 %0
+}
+
+declare i64 @llvm.expect.with.probability.i64(i64, i64, double) nounwind readnone
+
+declare i32 @f(...)
+
+; CHECK-LABEL: @test2(
+define i32 @test2(i32 %x) nounwind uwtable ssp {
+entry:
+  %retval = alloca i32, align 4
+  %x.addr = alloca i32, align 4
+  store i32 %x, i32* %x.addr, align 4
+  %tmp = load i32, i32* %x.addr, align 4
+  %conv = sext i32 %tmp to i64
+  %expval = call i64 @llvm.expect.with.probability.i64(i64 %conv, i64 1, double 8.000000e-01)
+  %tobool = icmp ne i64 %expval, 0
+; CHECK: !prof !0, !misexpect !1
+; CHECK-NOT: @llvm.expect.with.probability
+  br i1 %tobool, label %if.then, label %if.end
+
+if.then:                                          ; preds = %entry
+  %call = call i32 (...) @f()
+  store i32 %call, i32* %retval
+  br label %return
+
+if.end:                                           ; preds = %entry
+  store i32 1, i32* %retval
+  br label %return
+
+return:                                           ; preds = %if.end, %if.then
+  %0 = load i32, i32* %retval
+  ret i32 %0
+}
+
+; CHECK-LABEL: @test3(
+define i32 @test3(i32 %x) nounwind uwtable ssp {
+entry:
+  %retval = alloca i32, align 4
+  %x.addr = alloca i32, align 4
+  store i32 %x, i32* %x.addr, align 4
+  %tmp = load i32, i32* %x.addr, align 4
+  %tobool = icmp ne i32 %tmp, 0
+  %lnot = xor i1 %tobool, true
+  %lnot.ext = zext i1 %lnot to i32
+  %conv = sext i32 %lnot.ext to i64
+  %expval = call i64 @llvm.expect.with.probability.i64(i64 %conv, i64 1, double 8.000000e-01)
+  %tobool1 = icmp ne i64 %expval, 0
+; CHECK: !prof !0, !misexpect !1
+; CHECK-NOT: @llvm.expect.with.probability
+  br i1 %tobool1, label %if.then, label %if.end
+
+if.then:                                          ; preds = %entry
+  %call = call i32 (...) @f()
+  store i32 %call, i32* %retval
+  br label %return
+
+if.end:                                           ; preds = %entry
+  store i32 1, i32* %retval
+  br label %return
+
+return:                                           ; preds = %if.end, %if.then
+  %0 = load i32, i32* %retval
+  ret i32 %0
+}
+
+; CHECK-LABEL: @test4(
+define i32 @test4(i32 %x) nounwind uwtable ssp {
+entry:
+  %retval = alloca i32, align 4
+  %x.addr = alloca i32, align 4
+  store i32 %x, i32* %x.addr, align 4
+  %tmp = load i32, i32* %x.addr, align 4
+  %tobool = icmp ne i32 %tmp, 0
+  %lnot = xor i1 %tobool, true
+  %lnot1 = xor i1 %lnot, true
+  %lnot.ext = zext i1 %lnot1 to i32
+  %conv = sext i32 %lnot.ext to i64
+  %expval = call i64 @llvm.expect.with.probability.i64(i64 %conv, i64 1, double 8.000000e-01)
+  %tobool2 = icmp ne i64 %expval, 0
+; CHECK: !prof !0, !misexpect !1
+; CHECK-NOT: @llvm.expect.with.probability
+  br i1 %tobool2, label %if.then, label %if.end
+
+if.then:                                          ; preds = %entry
+  %call = call i32 (...) @f()
+  store i32 %call, i32* %retval
+  br label %return
+
+if.end:                                           ; preds = %entry
+  store i32 1, i32* %retval
+  br label %return
+
+return:                                           ; preds = %if.end, %if.then
+  %0 = load i32, i32* %retval
+  ret i32 %0
+}
+
+; CHECK-LABEL: @test5(
+define i32 @test5(i32 %x) nounwind uwtable ssp {
+entry:
+  %retval = alloca i32, align 4
+  %x.addr = alloca i32, align 4
+  store i32 %x, i32* %x.addr, align 4
+  %tmp = load i32, i32* %x.addr, align 4
+  %cmp = icmp slt i32 %tmp, 0
+  %conv = zext i1 %cmp to i32
+  %conv1 = sext i32 %conv to i64
+  %expval = call i64 @llvm.expect.with.probability.i64(i64 %conv1, i64 0, double 8.000000e-01)
+  %tobool = icmp ne i64 %expval, 0
+; CHECK: !prof !2, !misexpect !3
+; CHECK-NOT: @llvm.expect.with.probability
+  br i1 %tobool, label %if.then, label %if.end
+
+if.then:                                          ; preds = %entry
+  %call = call i32 (...) @f()
+  store i32 %call, i32* %retval
+  br label %return
+
+if.end:                                           ; preds = %entry
+  store i32 1, i32* %retval
+  br label %return
+
+return:                                           ; preds = %if.end, %if.then
+  %0 = load i32, i32* %retval
+  ret i32 %0
+}
+
+; CHECK-LABEL: @test6(
+define i32 @test6(i32 %x) nounwind uwtable ssp {
+entry:
+  %retval = alloca i32, align 4
+  %x.addr = alloca i32, align 4
+  store i32 %x, i32* %x.addr, align 4
+  %tmp = load i32, i32* %x.addr, align 4
+  %conv = sext i32 %tmp to i64
+  %expval = call i64 @llvm.expect.with.probability.i64(i64 %conv, i64 2, double 8.000000e-01)
+; CHECK: !prof !4, !misexpect !5
+; CHECK-NOT: @llvm.expect.with.probability
+  switch i64 %expval, label %sw.epilog [
+    i64 1, label %sw.bb
+    i64 2, label %sw.bb
+  ]
+
+sw.bb:                                            ; preds = %entry, %entry
+  store i32 0, i32* %retval
+  br label %return
+
+sw.epilog:                                        ; preds = %entry
+  store i32 1, i32* %retval
+  br label %return
+
+return:                                           ; preds = %sw.epilog, %sw.bb
+  %0 = load i32, i32* %retval
+  ret i32 %0
+}
+
+; CHECK-LABEL: @test7(
+define i32 @test7(i32 %x) nounwind uwtable ssp {
+entry:
+  %retval = alloca i32, align 4
+  %x.addr = alloca i32, align 4
+  store i32 %x, i32* %x.addr, align 4
+  %tmp = load i32, i32* %x.addr, align 4
+  %conv = sext i32 %tmp to i64
+  %expval = call i64 @llvm.expect.with.probability.i64(i64 %conv, i64 1, double 8.000000e-01)
+; CHECK: !prof !6, !misexpect !7
+; CHECK-NOT: @llvm.expect.with.probability
+  switch i64 %expval, label %sw.epilog [
+    i64 2, label %sw.bb
+    i64 3, label %sw.bb
+  ]
+
+sw.bb:                                            ; preds = %entry, %entry
+  %tmp1 = load i32, i32* %x.addr, align 4
+  store i32 %tmp1, i32* %retval
+  br label %return
+
+sw.epilog:                                        ; preds = %entry
+  store i32 0, i32* %retval
+  br label %return
+
+return:                                           ; preds = %sw.epilog, %sw.bb
+  %0 = load i32, i32* %retval
+  ret i32 %0
+}
+
+; CHECK-LABEL: @test8(
+define i32 @test8(i32 %x) nounwind uwtable ssp {
+entry:
+  %retval = alloca i32, align 4
+  %x.addr = alloca i32, align 4
+  store i32 %x, i32* %x.addr, align 4
+  %tmp = load i32, i32* %x.addr, align 4
+  %cmp = icmp sgt i32 %tmp, 1
+  %conv = zext i1 %cmp to i32
+  %expval = call i32 @llvm.expect.with.probability.i32(i32 %conv, i32 1, double 8.000000e-01)
+  %tobool = icmp ne i32 %expval, 0
+; CHECK: !prof !0, !misexpect !1
+; CHECK-NOT: @llvm.expect.with.probability
+  br i1 %tobool, label %if.then, label %if.end
+
+if.then:                                          ; preds = %entry
+  %call = call i32 (...) @f()
+  store i32 %call, i32* %retval
+  br label %return
+
+if.end:                                           ; preds = %entry
+  store i32 1, i32* %retval
+  br label %return
+
+return:                                           ; preds = %if.end, %if.then
+  %0 = load i32, i32* %retval
+  ret i32 %0
+}
+
+declare i32 @llvm.expect.with.probability.i32(i32, i32, double) nounwind readnone
+
+; CHECK-LABEL: @test9(
+define i32 @test9(i32 %x) nounwind uwtable ssp {
+entry:
+  %retval = alloca i32, align 4
+  %x.addr = alloca i32, align 4
+  store i32 %x, i32* %x.addr, align 4
+  %tmp = load i32, i32* %x.addr, align 4
+  %cmp = icmp sgt i32 %tmp, 1
+  %expval = call i1 @llvm.expect.with.probability.i1(i1 %cmp, i1 1, double 8.000000e-01)
+; CHECK: !prof !0, !misexpect !1
+; CHECK-NOT: @llvm.expect.with.probability
+  br i1 %expval, label %if.then, label %if.end
+
+if.then:                                          ; preds = %entry
+  %call = call i32 (...) @f()
+  store i32 %call, i32* %retval
+  br label %return
+
+if.end:                                           ; preds = %entry
+  store i32 1, i32* %retval
+  br label %return
+
+return:                                           ; preds = %if.end, %if.then
+  %0 = load i32, i32* %retval
+  ret i32 %0
+}
+
+; CHECK-LABEL: @test10(
+define i32 @test10(i64 %t6) {
+  %t7 = call i64 @llvm.expect.with.probability.i64(i64 %t6, i64 0, double 8.000000e-01)
+  %t8 = icmp ne i64 %t7, 0
+  %t9 = select i1 %t8, i32 1, i32 2
+; CHECK: select{{.*}}, !prof !2, !misexpect !3
+  ret i32 %t9
+}
+
+
+declare i1 @llvm.expect.with.probability.i1(i1, i1, double) nounwind readnone
+
+; CHECK: !0 = !{!"branch_weights", i32 1717986918, i32 429496731}
+; CHECK: !1 = !{!"misexpect", i64 0, i64 1717986918, i64 429496731}
+; CHECK: !2 = !{!"branch_weights", i32 429496731, i32 1717986918}
+; CHECK: !3 = !{!"misexpect", i64 1, i64 1717986918, i64 429496731}
+; CHECK: !4 = !{!"branch_weights", i32 214748366, i32 214748366, i32 1717986918}
+; CHECK: !5 = !{!"misexpect", i64 2, i64 1717986918, i64 214748366}
+; CHECK: !6 = !{!"branch_weights", i32 1717986918, i32 214748366, i32 214748366}
+; CHECK: !7 = !{!"misexpect", i64 0, i64 1717986918, i64 214748366}


        


More information about the llvm-commits mailing list