[llvm] [InstCombine] Set !prof metadata on Selects identified by add.ll test (PR #158743)

Alan Zhao via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 29 09:55:07 PDT 2025


https://github.com/alanzhao1 updated https://github.com/llvm/llvm-project/pull/158743

>From 2f5c4f8f945309219c06ed312b0cbfdfa10af532 Mon Sep 17 00:00:00 2001
From: Alan Zhao <ayzhao at google.com>
Date: Mon, 15 Sep 2025 14:41:42 -0700
Subject: [PATCH 1/7] [InstCombine] Set !prof metadata on Selects identified by
 add.ll test

These select instructions are created from non-branching instructions,
so their branch weights are unknown.

Tracking issue: #147390
---
 .../InstCombine/InstCombineAddSub.cpp         | 10 ++++-
 .../InstCombine/InstCombineShifts.cpp         |  5 ++-
 .../InstCombine/InstructionCombining.cpp      |  5 ++-
 .../InstCombine/preserve-profile.ll           | 39 +++++++++++++++++++
 llvm/utils/profcheck-xfail.txt                |  2 -
 5 files changed, 55 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 00951fde0cf8a..a480d96cd4cb9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -24,6 +24,7 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Support/AlignOf.h"
@@ -878,13 +879,18 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
     return BinaryOperator::CreateAdd(Builder.CreateNot(Y), X);
 
   // zext(bool) + C -> bool ? C + 1 : C
+  SelectInst *SI = nullptr;
   if (match(Op0, m_ZExt(m_Value(X))) &&
       X->getType()->getScalarSizeInBits() == 1)
-    return SelectInst::Create(X, InstCombiner::AddOne(Op1C), Op1);
+    SI = SelectInst::Create(X, InstCombiner::AddOne(Op1C), Op1);
   // sext(bool) + C -> bool ? C - 1 : C
   if (match(Op0, m_SExt(m_Value(X))) &&
       X->getType()->getScalarSizeInBits() == 1)
-    return SelectInst::Create(X, InstCombiner::SubOne(Op1C), Op1);
+    SI = SelectInst::Create(X, InstCombiner::SubOne(Op1C), Op1);
+  if (SI) {
+    setExplicitlyUnknownBranchWeights(*SI, DEBUG_TYPE);
+    return SI;
+  }
 
   // ~X + C --> (C-1) - X
   if (match(Op0, m_Not(m_Value(X)))) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 550f095b26ba4..247f1483e14f5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -14,6 +14,7 @@
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 using namespace llvm;
 using namespace PatternMatch;
@@ -1253,7 +1254,9 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
     // shl (zext i1 X), C1 --> select (X, 1 << C1, 0)
     if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
       auto *NewC = Builder.CreateShl(ConstantInt::get(Ty, 1), C1);
-      return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty));
+      auto *SI = SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty));
+      setExplicitlyUnknownBranchWeights(*SI, DEBUG_TYPE);
+      return SI;
     }
   }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index f0ddd5ca94c5a..957f8bd588857 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -81,6 +81,7 @@
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -1735,7 +1736,9 @@ Instruction *InstCombinerImpl::foldBinopOfSextBoolToSelect(BinaryOperator &BO) {
   Constant *Zero = ConstantInt::getNullValue(BO.getType());
   Value *TVal = Builder.CreateBinOp(BO.getOpcode(), Ones, C);
   Value *FVal = Builder.CreateBinOp(BO.getOpcode(), Zero, C);
-  return SelectInst::Create(X, TVal, FVal);
+  SelectInst *SI = SelectInst::Create(X, TVal, FVal);
+  setExplicitlyUnknownBranchWeights(*SI, DEBUG_TYPE);
+  return SI;
 }
 
 static Value *simplifyOperationIntoSelectOperand(Instruction &I, SelectInst *SI,
diff --git a/llvm/test/Transforms/InstCombine/preserve-profile.ll b/llvm/test/Transforms/InstCombine/preserve-profile.ll
index dd83805ed3397..0b750fd87d641 100644
--- a/llvm/test/Transforms/InstCombine/preserve-profile.ll
+++ b/llvm/test/Transforms/InstCombine/preserve-profile.ll
@@ -46,9 +46,48 @@ define i32 @NegBin(i1 %C) !prof !0 {
   ret i32 %V
 }
 
+define i32 @select_C_minus_1_or_C_from_bool(i1 %x) {
+; CHECK-LABEL: define i32 @select_C_minus_1_or_C_from_bool(
+; CHECK-SAME: i1 [[X:%.*]]) {
+; CHECK-NEXT:    [[ADD:%.*]] = select i1 [[X]], i32 41, i32 42, !prof [[PROF2:![0-9]+]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+  %ext = sext i1 %x to i32
+  %add = add i32 %ext, 42
+  ret i32 %add
+}
+
+define i5 @and_add(i1 %x, i1 %y) {
+; CHECK-LABEL: define i5 @and_add(
+; CHECK-SAME: i1 [[X:%.*]], i1 [[Y:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i1 [[X]], true
+; CHECK-NEXT:    [[TMP2:%.*]] = and i1 [[Y]], [[TMP1]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[TMP2]], i5 -2, i5 0, !prof [[PROF2]]
+; CHECK-NEXT:    ret i5 [[R]]
+;
+  %xz = zext i1 %x to i5
+  %ys = sext i1 %y to i5
+  %sub = add i5 %xz, %ys
+  %r = and i5 %sub, 30
+  ret i5 %r
+}
+
+define i32 @add_zext_zext_i1(i1 %a) {
+; CHECK-LABEL: define i32 @add_zext_zext_i1(
+; CHECK-SAME: i1 [[A:%.*]]) {
+; CHECK-NEXT:    [[ADD:%.*]] = select i1 [[A]], i32 2, i32 0, !prof [[PROF2]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+  %zext = zext i1 %a to i32
+  %add = add i32 %zext, %zext
+  ret i32 %add
+}
+
+
 !0 = !{!"function_entry_count", i64 1000}
 !1 = !{!"branch_weights", i32 2, i32 3}
 ;.
 ; CHECK: [[PROF0]] = !{!"function_entry_count", i64 1000}
 ; CHECK: [[PROF1]] = !{!"branch_weights", i32 2, i32 3}
+; CHECK: [[PROF2]] = !{!"unknown", !"instcombine"}
 ;.
diff --git a/llvm/utils/profcheck-xfail.txt b/llvm/utils/profcheck-xfail.txt
index 482848842aa05..582bb42315ac7 100644
--- a/llvm/utils/profcheck-xfail.txt
+++ b/llvm/utils/profcheck-xfail.txt
@@ -836,8 +836,6 @@ Transforms/InstCombine/2011-02-14-InfLoop.ll
 Transforms/InstCombine/AArch64/sve-intrinsic-sel.ll
 Transforms/InstCombine/AArch64/sve-intrinsic-simplify-binop.ll
 Transforms/InstCombine/AArch64/sve-intrinsic-simplify-shift.ll
-Transforms/InstCombine/add2.ll
-Transforms/InstCombine/add.ll
 Transforms/InstCombine/add-mask.ll
 Transforms/InstCombine/add-shl-mul-umax.ll
 Transforms/InstCombine/add-shl-sdiv-to-srem.ll

>From b42e0c5b625794a328ed1f3d31602002995369a6 Mon Sep 17 00:00:00 2001
From: Alan Zhao <ayzhao at google.com>
Date: Mon, 15 Sep 2025 18:42:29 -0700
Subject: [PATCH 2/7] add check for !SI

---
 llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index a480d96cd4cb9..d97ff07a9312a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -884,7 +884,7 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
       X->getType()->getScalarSizeInBits() == 1)
     SI = SelectInst::Create(X, InstCombiner::AddOne(Op1C), Op1);
   // sext(bool) + C -> bool ? C - 1 : C
-  if (match(Op0, m_SExt(m_Value(X))) &&
+  if (!SI && match(Op0, m_SExt(m_Value(X))) &&
       X->getType()->getScalarSizeInBits() == 1)
     SI = SelectInst::Create(X, InstCombiner::SubOne(Op1C), Op1);
   if (SI) {

>From cc4e8b4b18ce587db56c827495d6f1929d5f3c6d Mon Sep 17 00:00:00 2001
From: Alan Zhao <ayzhao at google.com>
Date: Wed, 17 Sep 2025 15:47:41 -0700
Subject: [PATCH 3/7] add profile unknown only if the enclosing function has
 profile counts

---
 llvm/include/llvm/IR/ProfDataUtils.h          |  8 +++++++
 llvm/lib/IR/ProfDataUtils.cpp                 | 10 ++++++++
 .../InstCombine/InstCombineAddSub.cpp         |  2 +-
 .../InstCombine/InstCombineShifts.cpp         |  2 +-
 .../InstCombine/InstructionCombining.cpp      |  2 +-
 .../InstCombine/preserve-profile.ll           | 23 ++++++++++++++-----
 6 files changed, 38 insertions(+), 9 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index de9675f48c79b..c39ebe5b682bc 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -185,6 +185,14 @@ inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
 LLVM_ABI void setExplicitlyUnknownBranchWeights(Instruction &I,
                                                 StringRef PassName);
 
+/// Like setExplicitlyUnknownBranchWeights(...), but only sets unknown branch
+/// weights in the new instruction if the parent function of the original
+/// instruction has function counts. This is to not confuse users by injecting
+/// profile data into non-profiled functions.
+LLVM_ABI void setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &New,
+                                                          Instruction &Original,
+                                                          StringRef PassName);
+
 /// Analogous to setExplicitlyUnknownBranchWeights, but for functions and their
 /// entry counts.
 LLVM_ABI void setExplicitlyUnknownFunctionEntryCount(Function &F,
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 5827292cee39b..a3e8f03b252e9 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -252,6 +252,16 @@ void setExplicitlyUnknownBranchWeights(Instruction &I, StringRef PassName) {
                    MDB.createString(PassName)}));
 }
 
+void setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &New,
+                                                 Instruction &Original,
+                                                 StringRef PassName) {
+  Function *F = Original.getFunction();
+  assert(F && "instruction does not belong to a function!");
+  std::optional<Function::ProfileCount> EC = F->getEntryCount();
+  if (EC && EC->getCount() > 0)
+    setExplicitlyUnknownBranchWeights(New, PassName);
+}
+
 void setExplicitlyUnknownFunctionEntryCount(Function &F, StringRef PassName) {
   MDBuilder MDB(F.getContext());
   F.setMetadata(
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index d97ff07a9312a..855fd00e9bf83 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -888,7 +888,7 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
       X->getType()->getScalarSizeInBits() == 1)
     SI = SelectInst::Create(X, InstCombiner::SubOne(Op1C), Op1);
   if (SI) {
-    setExplicitlyUnknownBranchWeights(*SI, DEBUG_TYPE);
+    setExplicitlyUnknownBranchWeightsIfProfiled(*SI, Add, DEBUG_TYPE);
     return SI;
   }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 247f1483e14f5..0a111d5befd79 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1255,7 +1255,7 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
     if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
       auto *NewC = Builder.CreateShl(ConstantInt::get(Ty, 1), C1);
       auto *SI = SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty));
-      setExplicitlyUnknownBranchWeights(*SI, DEBUG_TYPE);
+      setExplicitlyUnknownBranchWeightsIfProfiled(*SI, I, DEBUG_TYPE);
       return SI;
     }
   }
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 957f8bd588857..b39aa44379514 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1737,7 +1737,7 @@ Instruction *InstCombinerImpl::foldBinopOfSextBoolToSelect(BinaryOperator &BO) {
   Value *TVal = Builder.CreateBinOp(BO.getOpcode(), Ones, C);
   Value *FVal = Builder.CreateBinOp(BO.getOpcode(), Zero, C);
   SelectInst *SI = SelectInst::Create(X, TVal, FVal);
-  setExplicitlyUnknownBranchWeights(*SI, DEBUG_TYPE);
+  setExplicitlyUnknownBranchWeightsIfProfiled(*SI, BO, DEBUG_TYPE);
   return SI;
 }
 
diff --git a/llvm/test/Transforms/InstCombine/preserve-profile.ll b/llvm/test/Transforms/InstCombine/preserve-profile.ll
index 0b750fd87d641..8cb3e685ae302 100644
--- a/llvm/test/Transforms/InstCombine/preserve-profile.ll
+++ b/llvm/test/Transforms/InstCombine/preserve-profile.ll
@@ -46,9 +46,9 @@ define i32 @NegBin(i1 %C) !prof !0 {
   ret i32 %V
 }
 
-define i32 @select_C_minus_1_or_C_from_bool(i1 %x) {
+define i32 @select_C_minus_1_or_C_from_bool(i1 %x) !prof !0 {
 ; CHECK-LABEL: define i32 @select_C_minus_1_or_C_from_bool(
-; CHECK-SAME: i1 [[X:%.*]]) {
+; CHECK-SAME: i1 [[X:%.*]]) !prof [[PROF0]] {
 ; CHECK-NEXT:    [[ADD:%.*]] = select i1 [[X]], i32 41, i32 42, !prof [[PROF2:![0-9]+]]
 ; CHECK-NEXT:    ret i32 [[ADD]]
 ;
@@ -57,9 +57,9 @@ define i32 @select_C_minus_1_or_C_from_bool(i1 %x) {
   ret i32 %add
 }
 
-define i5 @and_add(i1 %x, i1 %y) {
+define i5 @and_add(i1 %x, i1 %y) !prof !0 {
 ; CHECK-LABEL: define i5 @and_add(
-; CHECK-SAME: i1 [[X:%.*]], i1 [[Y:%.*]]) {
+; CHECK-SAME: i1 [[X:%.*]], i1 [[Y:%.*]]) !prof [[PROF0]] {
 ; CHECK-NEXT:    [[TMP1:%.*]] = xor i1 [[X]], true
 ; CHECK-NEXT:    [[TMP2:%.*]] = and i1 [[Y]], [[TMP1]]
 ; CHECK-NEXT:    [[R:%.*]] = select i1 [[TMP2]], i5 -2, i5 0, !prof [[PROF2]]
@@ -72,9 +72,9 @@ define i5 @and_add(i1 %x, i1 %y) {
   ret i5 %r
 }
 
-define i32 @add_zext_zext_i1(i1 %a) {
+define i32 @add_zext_zext_i1(i1 %a) !prof !0 {
 ; CHECK-LABEL: define i32 @add_zext_zext_i1(
-; CHECK-SAME: i1 [[A:%.*]]) {
+; CHECK-SAME: i1 [[A:%.*]]) !prof [[PROF0]] {
 ; CHECK-NEXT:    [[ADD:%.*]] = select i1 [[A]], i32 2, i32 0, !prof [[PROF2]]
 ; CHECK-NEXT:    ret i32 [[ADD]]
 ;
@@ -83,6 +83,17 @@ define i32 @add_zext_zext_i1(i1 %a) {
   ret i32 %add
 }
 
+define i32 @no_count_no_branch_weights(i1 %a) {
+; CHECK-LABEL: define i32 @no_count_no_branch_weights(
+; CHECK-SAME: i1 [[A:%.*]]) {
+; CHECK-NEXT:    [[ADD:%.*]] = select i1 [[A]], i32 2, i32 0
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+  %zext = zext i1 %a to i32
+  %add = add i32 %zext, %zext
+  ret i32 %add
+}
+
 
 !0 = !{!"function_entry_count", i64 1000}
 !1 = !{!"branch_weights", i32 2, i32 3}

>From 7c396ade073945da2c157d4232d9a3bc19919185 Mon Sep 17 00:00:00 2001
From: Alan Zhao <ayzhao at google.com>
Date: Thu, 25 Sep 2025 13:48:40 -0700
Subject: [PATCH 4/7] create InstCombine-specific helper and pass function
 directly

---
 llvm/include/llvm/IR/ProfDataUtils.h                |  4 ++--
 llvm/lib/IR/ProfDataUtils.cpp                       | 11 ++++-------
 .../Transforms/InstCombine/InstCombineAddSub.cpp    |  7 ++++---
 .../Transforms/InstCombine/InstCombineInternal.h    | 13 +++++++++++++
 .../Transforms/InstCombine/InstCombineShifts.cpp    |  5 ++---
 .../Transforms/InstCombine/InstructionCombining.cpp |  5 ++---
 6 files changed, 27 insertions(+), 18 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index c39ebe5b682bc..d401fa7740c8b 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -189,8 +189,8 @@ LLVM_ABI void setExplicitlyUnknownBranchWeights(Instruction &I,
 /// weights in the new instruction if the parent function of the original
 /// instruction has function counts. This is to not confuse users by injecting
 /// profile data into non-profiled functions.
-LLVM_ABI void setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &New,
-                                                          Instruction &Original,
+LLVM_ABI void setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &I,
+                                                          Function &F,
                                                           StringRef PassName);
 
 /// Analogous to setExplicitlyUnknownBranchWeights, but for functions and their
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index a3e8f03b252e9..99029c1719507 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -252,14 +252,11 @@ void setExplicitlyUnknownBranchWeights(Instruction &I, StringRef PassName) {
                    MDB.createString(PassName)}));
 }
 
-void setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &New,
-                                                 Instruction &Original,
+void setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &I, Function &F,
                                                  StringRef PassName) {
-  Function *F = Original.getFunction();
-  assert(F && "instruction does not belong to a function!");
-  std::optional<Function::ProfileCount> EC = F->getEntryCount();
-  if (EC && EC->getCount() > 0)
-    setExplicitlyUnknownBranchWeights(New, PassName);
+  if (std::optional<Function::ProfileCount> EC = F.getEntryCount();
+      EC && EC->getCount() > 0)
+    setExplicitlyUnknownBranchWeights(I, PassName);
 }
 
 void setExplicitlyUnknownFunctionEntryCount(Function &F, StringRef PassName) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 855fd00e9bf83..66c52662de7e0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -882,13 +882,14 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
   SelectInst *SI = nullptr;
   if (match(Op0, m_ZExt(m_Value(X))) &&
       X->getType()->getScalarSizeInBits() == 1)
-    SI = SelectInst::Create(X, InstCombiner::AddOne(Op1C), Op1);
+    SI = createSelectInstMaybeWithUnknownBranchWeights(
+        X, InstCombiner::AddOne(Op1C), Op1, Add.getFunction());
   // sext(bool) + C -> bool ? C - 1 : C
   if (!SI && match(Op0, m_SExt(m_Value(X))) &&
       X->getType()->getScalarSizeInBits() == 1)
-    SI = SelectInst::Create(X, InstCombiner::SubOne(Op1C), Op1);
+    SI = createSelectInstMaybeWithUnknownBranchWeights(
+        X, InstCombiner::SubOne(Op1C), Op1, Add.getFunction());
   if (SI) {
-    setExplicitlyUnknownBranchWeightsIfProfiled(*SI, Add, DEBUG_TYPE);
     return SI;
   }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 7a979c16da501..8137eb2fb1f34 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -23,6 +23,7 @@
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstVisitor.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/KnownBits.h"
@@ -469,6 +470,18 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Value *simplifyNonNullOperand(Value *V, bool HasDereferenceable,
                                 unsigned Depth = 0);
 
+  static SelectInst *createSelectInstMaybeWithUnknownBranchWeights(
+      Value *C, Value *S1, Value *S2, Function *F, const Twine &NameStr = "",
+      InsertPosition InsertBefore = nullptr, Instruction *MDFrom = nullptr) {
+    SelectInst *SI =
+        SelectInst::Create(C, S1, S2, NameStr, InsertBefore, MDFrom);
+    if (!SI) {
+      assert(F && "provided parent function is nullptr!");
+      setExplicitlyUnknownBranchWeightsIfProfiled(*SI, *F, DEBUG_TYPE);
+    }
+    return SI;
+  }
+
 public:
   /// Create and insert the idiom we use to indicate a block is unreachable
   /// without having to rewrite the CFG from within InstCombine.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 0a111d5befd79..168a5e8ecb8f9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1254,9 +1254,8 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
     // shl (zext i1 X), C1 --> select (X, 1 << C1, 0)
     if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
       auto *NewC = Builder.CreateShl(ConstantInt::get(Ty, 1), C1);
-      auto *SI = SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty));
-      setExplicitlyUnknownBranchWeightsIfProfiled(*SI, I, DEBUG_TYPE);
-      return SI;
+      return createSelectInstMaybeWithUnknownBranchWeights(
+          X, NewC, ConstantInt::getNullValue(Ty), I.getFunction());
     }
   }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index b39aa44379514..553c9e5abd207 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1736,9 +1736,8 @@ Instruction *InstCombinerImpl::foldBinopOfSextBoolToSelect(BinaryOperator &BO) {
   Constant *Zero = ConstantInt::getNullValue(BO.getType());
   Value *TVal = Builder.CreateBinOp(BO.getOpcode(), Ones, C);
   Value *FVal = Builder.CreateBinOp(BO.getOpcode(), Zero, C);
-  SelectInst *SI = SelectInst::Create(X, TVal, FVal);
-  setExplicitlyUnknownBranchWeightsIfProfiled(*SI, BO, DEBUG_TYPE);
-  return SI;
+  return createSelectInstMaybeWithUnknownBranchWeights(X, TVal, FVal,
+                                                       BO.getFunction());
 }
 
 static Value *simplifyOperationIntoSelectOperand(Instruction &I, SelectInst *SI,

>From 9764b21c8867669694d8b1a091dbe7de488315d0 Mon Sep 17 00:00:00 2001
From: Alan Zhao <ayzhao at google.com>
Date: Thu, 25 Sep 2025 15:07:26 -0700
Subject: [PATCH 5/7] accidentally negated a conditional, fixed

---
 llvm/lib/Transforms/InstCombine/InstCombineInternal.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 8137eb2fb1f34..496141e90f2c7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -475,7 +475,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
       InsertPosition InsertBefore = nullptr, Instruction *MDFrom = nullptr) {
     SelectInst *SI =
         SelectInst::Create(C, S1, S2, NameStr, InsertBefore, MDFrom);
-    if (!SI) {
+    if (SI && !MDFrom) {
       assert(F && "provided parent function is nullptr!");
       setExplicitlyUnknownBranchWeightsIfProfiled(*SI, *F, DEBUG_TYPE);
     }

>From 2555fce763c4b8c79382808bb9f56af685440d95 Mon Sep 17 00:00:00 2001
From: Alan Zhao <ayzhao at google.com>
Date: Fri, 26 Sep 2025 17:03:54 -0700
Subject: [PATCH 6/7] Add function reference as a member of the InstCombiner
 class and simplify createSelectInst(...)

---
 .../Transforms/InstCombine/InstCombiner.h     | 20 ++++++++++---------
 .../InstCombine/InstCombineAddSub.cpp         |  6 ++----
 .../InstCombine/InstCombineInternal.h         | 16 +++++++--------
 .../InstCombine/InstCombineShifts.cpp         |  3 +--
 .../InstCombine/InstructionCombining.cpp      |  8 +++-----
 5 files changed, 25 insertions(+), 28 deletions(-)

diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index fa313f5290773..d6c2d7fc48bda 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -64,6 +64,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   /// A worklist of the instructions that need to be simplified.
   InstructionWorklist &Worklist;
 
+  Function &F;
+
   // Mode in which we are running the combiner.
   const bool MinimizeSize;
 
@@ -98,17 +100,17 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   bool ComputedBackEdges = false;
 
 public:
-  InstCombiner(InstructionWorklist &Worklist, BuilderTy &Builder,
-               bool MinimizeSize, AAResults *AA, AssumptionCache &AC,
-               TargetLibraryInfo &TLI, TargetTransformInfo &TTI,
-               DominatorTree &DT, OptimizationRemarkEmitter &ORE,
-               BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI,
-               ProfileSummaryInfo *PSI, const DataLayout &DL,
+  InstCombiner(InstructionWorklist &Worklist, BuilderTy &Builder, Function &F,
+               AAResults *AA, AssumptionCache &AC, TargetLibraryInfo &TLI,
+               TargetTransformInfo &TTI, DominatorTree &DT,
+               OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI,
+               BranchProbabilityInfo *BPI, ProfileSummaryInfo *PSI,
+               const DataLayout &DL,
                ReversePostOrderTraversal<BasicBlock *> &RPOT)
       : TTIForTargetIntrinsicsOnly(TTI), Builder(Builder), Worklist(Worklist),
-        MinimizeSize(MinimizeSize), AA(AA), AC(AC), TLI(TLI), DT(DT), DL(DL),
-        SQ(DL, &TLI, &DT, &AC, nullptr, /*UseInstrInfo*/ true,
-           /*CanUseUndef*/ true, &DC),
+        F(F), MinimizeSize(F.hasMinSize()), AA(AA), AC(AC), TLI(TLI), DT(DT),
+        DL(DL), SQ(DL, &TLI, &DT, &AC, nullptr, /*UseInstrInfo*/ true,
+                   /*CanUseUndef*/ true, &DC),
         ORE(ORE), BFI(BFI), BPI(BPI), PSI(PSI), RPOT(RPOT) {}
 
   virtual ~InstCombiner() = default;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 66c52662de7e0..9ca7988627578 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -882,13 +882,11 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
   SelectInst *SI = nullptr;
   if (match(Op0, m_ZExt(m_Value(X))) &&
       X->getType()->getScalarSizeInBits() == 1)
-    SI = createSelectInstMaybeWithUnknownBranchWeights(
-        X, InstCombiner::AddOne(Op1C), Op1, Add.getFunction());
+    SI = createSelectInst(X, InstCombiner::AddOne(Op1C), Op1);
   // sext(bool) + C -> bool ? C - 1 : C
   if (!SI && match(Op0, m_SExt(m_Value(X))) &&
       X->getType()->getScalarSizeInBits() == 1)
-    SI = createSelectInstMaybeWithUnknownBranchWeights(
-        X, InstCombiner::SubOne(Op1C), Op1, Add.getFunction());
+    SI = createSelectInst(X, InstCombiner::SubOne(Op1C), Op1);
   if (SI) {
     return SI;
   }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 496141e90f2c7..fb92a14a5f1d9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -63,14 +63,14 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
       public InstVisitor<InstCombinerImpl, Instruction *> {
 public:
   InstCombinerImpl(InstructionWorklist &Worklist, BuilderTy &Builder,
-                   bool MinimizeSize, AAResults *AA, AssumptionCache &AC,
+                   Function &F, AAResults *AA, AssumptionCache &AC,
                    TargetLibraryInfo &TLI, TargetTransformInfo &TTI,
                    DominatorTree &DT, OptimizationRemarkEmitter &ORE,
                    BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI,
                    ProfileSummaryInfo *PSI, const DataLayout &DL,
                    ReversePostOrderTraversal<BasicBlock *> &RPOT)
-      : InstCombiner(Worklist, Builder, MinimizeSize, AA, AC, TLI, TTI, DT, ORE,
-                     BFI, BPI, PSI, DL, RPOT) {}
+      : InstCombiner(Worklist, Builder, F, AA, AC, TLI, TTI, DT, ORE, BFI, BPI,
+                     PSI, DL, RPOT) {}
 
   virtual ~InstCombinerImpl() = default;
 
@@ -470,14 +470,14 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Value *simplifyNonNullOperand(Value *V, bool HasDereferenceable,
                                 unsigned Depth = 0);
 
-  static SelectInst *createSelectInstMaybeWithUnknownBranchWeights(
-      Value *C, Value *S1, Value *S2, Function *F, const Twine &NameStr = "",
-      InsertPosition InsertBefore = nullptr, Instruction *MDFrom = nullptr) {
+  SelectInst *createSelectInst(Value *C, Value *S1, Value *S2,
+                               const Twine &NameStr = "",
+                               InsertPosition InsertBefore = nullptr,
+                               Instruction *MDFrom = nullptr) {
     SelectInst *SI =
         SelectInst::Create(C, S1, S2, NameStr, InsertBefore, MDFrom);
     if (SI && !MDFrom) {
-      assert(F && "provided parent function is nullptr!");
-      setExplicitlyUnknownBranchWeightsIfProfiled(*SI, *F, DEBUG_TYPE);
+      setExplicitlyUnknownBranchWeightsIfProfiled(*SI, F, DEBUG_TYPE);
     }
     return SI;
   }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 168a5e8ecb8f9..18b4d96571158 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1254,8 +1254,7 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
     // shl (zext i1 X), C1 --> select (X, 1 << C1, 0)
     if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
       auto *NewC = Builder.CreateShl(ConstantInt::get(Ty, 1), C1);
-      return createSelectInstMaybeWithUnknownBranchWeights(
-          X, NewC, ConstantInt::getNullValue(Ty), I.getFunction());
+      return createSelectInst(X, NewC, ConstantInt::getNullValue(Ty));
     }
   }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 553c9e5abd207..8fbaf68dfcc43 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -81,7 +81,6 @@
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/PatternMatch.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -1736,8 +1735,7 @@ Instruction *InstCombinerImpl::foldBinopOfSextBoolToSelect(BinaryOperator &BO) {
   Constant *Zero = ConstantInt::getNullValue(BO.getType());
   Value *TVal = Builder.CreateBinOp(BO.getOpcode(), Ones, C);
   Value *FVal = Builder.CreateBinOp(BO.getOpcode(), Zero, C);
-  return createSelectInstMaybeWithUnknownBranchWeights(X, TVal, FVal,
-                                                       BO.getFunction());
+  return createSelectInst(X, TVal, FVal);
 }
 
 static Value *simplifyOperationIntoSelectOperand(Instruction &I, SelectInst *SI,
@@ -5936,8 +5934,8 @@ static bool combineInstructionsOverFunction(
     LLVM_DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on "
                       << F.getName() << "\n");
 
-    InstCombinerImpl IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, TTI, DT,
-                        ORE, BFI, BPI, PSI, DL, RPOT);
+    InstCombinerImpl IC(Worklist, Builder, F, AA, AC, TLI, TTI, DT, ORE, BFI,
+                        BPI, PSI, DL, RPOT);
     IC.MaxArraySizeForCombine = MaxArraySize;
     bool MadeChangeInThisIteration = IC.prepareWorklist(F);
     MadeChangeInThisIteration |= IC.run();

>From 2db6abd5fdbd4e8adc182f5bf99317a0e768fada Mon Sep 17 00:00:00 2001
From: Alan Zhao <ayzhao at google.com>
Date: Mon, 29 Sep 2025 09:54:48 -0700
Subject: [PATCH 7/7] remove unnecessary includes and SelectInst variables

---
 llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp | 11 +++--------
 llvm/lib/Transforms/InstCombine/InstCombineInternal.h |  2 +-
 llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp |  1 -
 3 files changed, 4 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 9ca7988627578..5faae5dae75bb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -24,7 +24,6 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Support/AlignOf.h"
@@ -879,17 +878,13 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
     return BinaryOperator::CreateAdd(Builder.CreateNot(Y), X);
 
   // zext(bool) + C -> bool ? C + 1 : C
-  SelectInst *SI = nullptr;
   if (match(Op0, m_ZExt(m_Value(X))) &&
       X->getType()->getScalarSizeInBits() == 1)
-    SI = createSelectInst(X, InstCombiner::AddOne(Op1C), Op1);
+    return createSelectInst(X, InstCombiner::AddOne(Op1C), Op1);
   // sext(bool) + C -> bool ? C - 1 : C
-  if (!SI && match(Op0, m_SExt(m_Value(X))) &&
+  if (match(Op0, m_SExt(m_Value(X))) &&
       X->getType()->getScalarSizeInBits() == 1)
-    SI = createSelectInst(X, InstCombiner::SubOne(Op1C), Op1);
-  if (SI) {
-    return SI;
-  }
+    return createSelectInst(X, InstCombiner::SubOne(Op1C), Op1);
 
   // ~X + C --> (C-1) - X
   if (match(Op0, m_Not(m_Value(X)))) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index fb92a14a5f1d9..6f9c206db1336 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -476,7 +476,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                                Instruction *MDFrom = nullptr) {
     SelectInst *SI =
         SelectInst::Create(C, S1, S2, NameStr, InsertBefore, MDFrom);
-    if (SI && !MDFrom) {
+    if (!MDFrom) {
       setExplicitlyUnknownBranchWeightsIfProfiled(*SI, F, DEBUG_TYPE);
     }
     return SI;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 18b4d96571158..d457e0c7dd1c4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -14,7 +14,6 @@
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PatternMatch.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 using namespace llvm;
 using namespace PatternMatch;



More information about the llvm-commits mailing list