[llvm] [InstCombine] Preserve profile branch weights when folding logical booleans (PR #161293)

Alan Zhao via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 6 11:17:53 PDT 2025


https://github.com/alanzhao1 updated https://github.com/llvm/llvm-project/pull/161293

>From b87743ccc3e29991d3a43278140e023181eb724b Mon Sep 17 00:00:00 2001
From: Alan Zhao <ayzhao at google.com>
Date: Mon, 29 Sep 2025 15:49:10 -0700
Subject: [PATCH] [InstCombine] Preserve profile branch weights when folding
 logical booleans

Logical booleans in LLVM are represented by select statements - e.g. the
statement

```
A && B
```

is represented as

```
select i1 %A, i1 %B, i1 false
```

When LLVM folds two of the same logical booleans into a logical boolean
and a bitwise boolean (e.g. `A && B && C` -> `A && (B & C)`), the first
logical boolean is a select statement that retains the original
condition from the first logical boolean of the original statement. This
means that the new select statement has the branch weights as the
original select statement.

Tracking issue: #147390
---
 llvm/include/llvm/IR/IRBuilder.h              | 11 +++++---
 .../InstCombine/InstCombineSelect.cpp         | 10 +++++--
 .../select-safe-bool-transforms.ll            | 27 ++++++++++++-------
 3 files changed, 33 insertions(+), 15 deletions(-)

diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 783f8f6d2478c..041a4ce112275 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -1722,16 +1722,19 @@ class IRBuilderBase {
     return Insert(BinOp, Name);
   }
 
-  Value *CreateLogicalAnd(Value *Cond1, Value *Cond2, const Twine &Name = "") {
+  Value *CreateLogicalAnd(Value *Cond1, Value *Cond2, const Twine &Name = "",
+                          Instruction *MDFrom = nullptr) {
     assert(Cond2->getType()->isIntOrIntVectorTy(1));
     return CreateSelect(Cond1, Cond2,
-                        ConstantInt::getNullValue(Cond2->getType()), Name);
+                        ConstantInt::getNullValue(Cond2->getType()), Name,
+                        MDFrom);
   }
 
-  Value *CreateLogicalOr(Value *Cond1, Value *Cond2, const Twine &Name = "") {
+  Value *CreateLogicalOr(Value *Cond1, Value *Cond2, const Twine &Name = "",
+                         Instruction *MDFrom = nullptr) {
     assert(Cond2->getType()->isIntOrIntVectorTy(1));
     return CreateSelect(Cond1, ConstantInt::getAllOnesValue(Cond2->getType()),
-                        Cond2, Name);
+                        Cond2, Name, MDFrom);
   }
 
   Value *CreateLogicalOp(Instruction::BinaryOps Opc, Value *Cond1, Value *Cond2,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 8f60e506e8a33..8c8fc69e27ee3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3356,7 +3356,10 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
         impliesPoisonOrCond(FalseVal, B, /*Expected=*/false)) {
       // (A || B) || C --> A || (B | C)
       return replaceInstUsesWith(
-          SI, Builder.CreateLogicalOr(A, Builder.CreateOr(B, FalseVal)));
+          SI, Builder.CreateLogicalOr(A, Builder.CreateOr(B, FalseVal), "",
+                                      ProfcheckDisableMetadataFixes
+                                          ? nullptr
+                                          : cast<SelectInst>(CondVal)));
     }
 
     // (A && B) || (C && B) --> (A || C) && B
@@ -3398,7 +3401,10 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
         impliesPoisonOrCond(TrueVal, B, /*Expected=*/true)) {
       // (A && B) && C --> A && (B & C)
       return replaceInstUsesWith(
-          SI, Builder.CreateLogicalAnd(A, Builder.CreateAnd(B, TrueVal)));
+          SI, Builder.CreateLogicalAnd(A, Builder.CreateAnd(B, TrueVal), "",
+                                       ProfcheckDisableMetadataFixes
+                                           ? nullptr
+                                           : cast<SelectInst>(CondVal)));
     }
 
     // (A || B) && (C || B) --> (A && C) || B
diff --git a/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll b/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll
index 9de9150f62ed4..8b0a5ca62918c 100644
--- a/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll
+++ b/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll
@@ -1,4 +1,4 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
 
 ; TODO: All of these should be optimized to less than or equal to a single
@@ -7,13 +7,13 @@
 ; --- (A op B) op' A   /   (B op A) op' A ---
 
 ; (A land B) land A
-define i1 @land_land_left1(i1 %A, i1 %B) {
+define i1 @land_land_left1(i1 %A, i1 %B) !prof !0 {
 ; CHECK-LABEL: @land_land_left1(
-; CHECK-NEXT:    [[C:%.*]] = select i1 [[A:%.*]], i1 [[B:%.*]], i1 false
+; CHECK-NEXT:    [[C:%.*]] = select i1 [[A:%.*]], i1 [[B:%.*]], i1 false, !prof [[PROF1:![0-9]+]]
 ; CHECK-NEXT:    ret i1 [[C]]
 ;
-  %c = select i1 %A, i1 %B, i1 false
-  %res = select i1 %c, i1 %A, i1 false
+  %c = select i1 %A, i1 %B, i1 false, !prof !1
+  %res = select i1 %c, i1 %A, i1 false, !prof !2
   ret i1 %res
 }
 define i1 @land_land_left2(i1 %A, i1 %B) {
@@ -157,13 +157,13 @@ define i1 @lor_band_left2(i1 %A, i1 %B) {
 }
 
 ; (A lor B) lor A
-define i1 @lor_lor_left1(i1 %A, i1 %B) {
+define i1 @lor_lor_left1(i1 %A, i1 %B) !prof !0 {
 ; CHECK-LABEL: @lor_lor_left1(
-; CHECK-NEXT:    [[C:%.*]] = select i1 [[A:%.*]], i1 true, i1 [[B:%.*]]
+; CHECK-NEXT:    [[C:%.*]] = select i1 [[A:%.*]], i1 true, i1 [[B:%.*]], !prof [[PROF1]]
 ; CHECK-NEXT:    ret i1 [[C]]
 ;
-  %c = select i1 %A, i1 true, i1 %B
-  %res = select i1 %c, i1 true, i1 %A
+  %c = select i1 %A, i1 true, i1 %B, !prof !1
+  %res = select i1 %c, i1 true, i1 %A, !prof !2
   ret i1 %res
 }
 define i1 @lor_lor_left2(i1 %A, i1 %B) {
@@ -506,3 +506,12 @@ define <2 x i1> @PR50500_falseval(<2 x i1> %a, <2 x i1> %b) {
   %r = select <2 x i1> %a, <2 x i1> %b, <2 x i1> %s
   ret <2 x i1> %r
 }
+
+!0 = !{!"function_entry_count", i64 1000}
+!1 = !{!"branch_weights", i32 2, i32 3}
+!2 = !{!"branch_weights", i32 5, i32 7}
+
+;.
+; CHECK: [[META0:![0-9]+]] = !{!"function_entry_count", i64 1000}
+; CHECK: [[PROF1]] = !{!"branch_weights", i32 2, i32 3}
+;.



More information about the llvm-commits mailing list