[llvm] [JumpThreading] Convert `s/zext i1` to `select i1` for further unfolding (PR #89345)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 18 22:13:52 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Franklin Zhang (FLZ101)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/89345.diff


3 Files Affected:

- (modified) llvm/include/llvm/Transforms/Scalar/JumpThreading.h (+2) 
- (modified) llvm/lib/Transforms/Scalar/JumpThreading.cpp (+81-1) 
- (added) llvm/test/Transforms/JumpThreading/szext.ll (+94) 


``````````diff
diff --git a/llvm/include/llvm/Transforms/Scalar/JumpThreading.h b/llvm/include/llvm/Transforms/Scalar/JumpThreading.h
index 3364d7eaee4247..c7da9053b7abeb 100644
--- a/llvm/include/llvm/Transforms/Scalar/JumpThreading.h
+++ b/llvm/include/llvm/Transforms/Scalar/JumpThreading.h
@@ -162,6 +162,8 @@ class JumpThreadingPass : public PassInfoMixin<JumpThreadingPass> {
   bool tryToUnfoldSelect(SwitchInst *SI, BasicBlock *BB);
   bool tryToUnfoldSelectInCurrBB(BasicBlock *BB);
 
+  bool tryToConvertSZExtToSelect(BasicBlock *BB);
+
   bool processGuards(BasicBlock *BB);
   bool threadGuard(BasicBlock *BB, IntrinsicInst *Guard, BranchInst *BI);
 
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index ffcb511e6a8312..c92519610aec18 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -971,6 +971,9 @@ bool JumpThreadingPass::processBlock(BasicBlock *BB) {
   if (maybeMergeBasicBlockIntoOnlyPred(BB))
     return true;
 
+  if (tryToConvertSZExtToSelect(BB))
+    return true;
+
   if (tryToUnfoldSelectInCurrBB(BB))
     return true;
 
@@ -2750,7 +2753,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred(
 // Pred is a predecessor of BB with an unconditional branch to BB. SI is
 // a Select instruction in Pred. BB has other predecessors and SI is used in
 // a PHI node in BB. SI has no other use.
-// A new basic block, NewBB, is created and SI is converted to compare and 
+// A new basic block, NewBB, is created and SI is converted to compare and
 // conditional branch. SI is erased from parent.
 void JumpThreadingPass::unfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB,
                                           SelectInst *SI, PHINode *SIUse,
@@ -2997,6 +3000,83 @@ bool JumpThreadingPass::tryToUnfoldSelectInCurrBB(BasicBlock *BB) {
   return false;
 }
 
+/// Try to convert "sext/zext i1" into "select i1" which could be further
+/// unfolded by tryToUnfoldSelect().
+///
+/// For example,
+///
+/// ; before the transformation
+/// BB1:
+///   %a = icmp ...
+///   %b = zext i1 %a to i32
+///   br label %BB2
+/// BB2:
+///   %c = phi i32 [ %b, %BB1 ], ...
+///   %d = icmp eq i32 %c, 0
+///   br i1 %d, ...
+///
+/// ------
+///
+/// ; after the transformation
+/// BB1:
+///   %a = icmp ...
+///   %b = select i1 %a, i32 1, i32 0
+///   br label %BB2
+/// BB2:
+///   %c = phi i32 [ %b, %BB1 ], ...
+///   %d = icmp eq i32 %c, 0
+///   br i1 %d, ...
+///
+bool JumpThreadingPass::tryToConvertSZExtToSelect(BasicBlock *BB) {
+  // tryToUnfoldSelect requires that Br is unconditional
+  BranchInst *Br = dyn_cast<BranchInst>(BB->getTerminator());
+  if (!Br || Br->isConditional())
+    return false;
+  BasicBlock *BBX = Br->getSuccessor(0);
+
+  SmallVector<Instruction *> ToConvert;
+  for (auto &I : *BB) {
+    using namespace PatternMatch;
+
+    Value *V;
+    if (!match(&I, m_ZExtOrSExt(m_Value(V))) || !V->getType()->isIntegerTy(1))
+      continue;
+
+    // I is only used by Phi
+    Use *U = I.getSingleUndroppableUse();
+    if (!U)
+      continue;
+    PHINode *Phi = dyn_cast<PHINode>(U->getUser());
+    if (!Phi || Phi->getParent() != BBX)
+      continue;
+
+    // tryToUnfoldSelect requires that Phi is used in the following way
+    ICmpInst::Predicate Pred;
+    if (!match(BBX->getTerminator(),
+               m_Br(m_ICmp(Pred, m_Specific(Phi), m_ConstantInt()),
+                    m_BasicBlock(), m_BasicBlock())))
+      continue;
+
+    ToConvert.push_back(&I);
+  }
+  if (ToConvert.empty())
+    return false;
+
+  LLVM_DEBUG(dbgs() << "\nconvert-szext-to-select:\n" << *BB << "\n");
+  for (Instruction *I : ToConvert) {
+    auto Ty = I->getType();
+    Value *V1 = isa<SExtInst>(I) ? ConstantInt::getAllOnesValue(Ty)
+                                 : ConstantInt::get(Ty, 1);
+    Value *V2 = ConstantInt::getNullValue(Ty);
+    SelectInst *SI =
+        SelectInst::Create(I->getOperand(0), V1, V2, I->getName(), I);
+    I->replaceAllUsesWith(SI);
+    I->eraseFromParent();
+  }
+  LLVM_DEBUG(dbgs() << *BB << "\n");
+  return true;
+}
+
 /// Try to propagate a guard from the current BB into one of its predecessors
 /// in case if another branch of execution implies that the condition of this
 /// guard is always true. Currently we only process the simplest case that
diff --git a/llvm/test/Transforms/JumpThreading/szext.ll b/llvm/test/Transforms/JumpThreading/szext.ll
new file mode 100644
index 00000000000000..290fe7ad0ca257
--- /dev/null
+++ b/llvm/test/Transforms/JumpThreading/szext.ll
@@ -0,0 +1,94 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -passes=jump-threading < %s | FileCheck %s
+
+; void fun(int);
+;
+; int compare1(int a, int b, int c, int d)
+; {
+;     return a < b ? -1 :
+;            a > b ?  1 :
+;            c < d ? -1 :
+;            c > d ?  1 : 0;
+; }
+;
+; void test1(int a, int b, int c, int d) {
+;   int x = compare1(a, b, c, d);
+;   if (x < 0)
+;     fun(10);
+;   else if (x > 0)
+;     fun(20);
+;   else
+;     fun(30);
+; }
+
+declare void @fun(i32 noundef)
+
+define void @test1(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d) {
+; CHECK-LABEL: define void @test1(
+; CHECK-SAME: i32 noundef [[A:%.*]], i32 noundef [[B:%.*]], i32 noundef [[C:%.*]], i32 noundef [[D:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP_I:%.*]] = icmp slt i32 [[A]], [[B]]
+; CHECK-NEXT:    br i1 [[CMP_I]], label [[IF_THEN:%.*]], label [[COND_FALSE_I:%.*]]
+; CHECK:       cond.false.i:
+; CHECK-NEXT:    [[CMP1_I:%.*]] = icmp sgt i32 [[A]], [[B]]
+; CHECK-NEXT:    br i1 [[CMP1_I]], label [[IF_THEN2:%.*]], label [[COND_FALSE3_I:%.*]]
+; CHECK:       cond.false3.i:
+; CHECK-NEXT:    [[CMP4_I:%.*]] = icmp slt i32 [[C]], [[D]]
+; CHECK-NEXT:    br i1 [[CMP4_I]], label [[IF_THEN]], label [[COND_FALSE6_I:%.*]]
+; CHECK:       cond.false6.i:
+; CHECK-NEXT:    [[CMP7_I:%.*]] = icmp sgt i32 [[C]], [[D]]
+; CHECK-NEXT:    br i1 [[CMP7_I]], label [[IF_THEN2]], label [[IF_ELSE3:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    call void @fun(i32 noundef 10)
+; CHECK-NEXT:    br label [[IF_END4:%.*]]
+; CHECK:       if.then2:
+; CHECK-NEXT:    call void @fun(i32 noundef 20)
+; CHECK-NEXT:    br label [[IF_END4]]
+; CHECK:       if.else3:
+; CHECK-NEXT:    [[COND12_I:%.*]] = phi i32 [ 0, [[COND_FALSE6_I]] ]
+; CHECK-NEXT:    call void @fun(i32 noundef 30)
+; CHECK-NEXT:    br label [[IF_END4]]
+; CHECK:       if.end4:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %cmp.i = icmp slt i32 %a, %b
+  br i1 %cmp.i, label %compare1.exit, label %cond.false.i
+
+cond.false.i:                                     ; preds = %entry
+  %cmp1.i = icmp sgt i32 %a, %b
+  br i1 %cmp1.i, label %compare1.exit, label %cond.false3.i
+
+cond.false3.i:                                    ; preds = %cond.false.i
+  %cmp4.i = icmp slt i32 %c, %d
+  br i1 %cmp4.i, label %compare1.exit, label %cond.false6.i
+
+cond.false6.i:                                    ; preds = %cond.false3.i
+  %cmp7.i = icmp sgt i32 %c, %d
+  %cond.i = zext i1 %cmp7.i to i32
+  br label %compare1.exit
+
+compare1.exit:                                    ; preds = %entry, %cond.false.i, %cond.false3.i, %cond.false6.i
+  %cond12.i = phi i32 [ -1, %entry ], [ 1, %cond.false.i ], [ %cond.i, %cond.false6.i ], [ -1, %cond.false3.i ]
+  %cmp = icmp slt i32 %cond12.i, 0
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:                                          ; preds = %compare1.exit
+  call void @fun(i32 noundef 10)
+  br label %if.end4
+
+if.else:                                          ; preds = %compare1.exit
+  %cmp1.not = icmp eq i32 %cond12.i, 0
+  br i1 %cmp1.not, label %if.else3, label %if.then2
+
+if.then2:                                         ; preds = %if.else
+  call void @fun(i32 noundef 20)
+  br label %if.end4
+
+if.else3:                                         ; preds = %if.else
+  call void @fun(i32 noundef 30)
+  br label %if.end4
+
+if.end4:                                          ; preds = %if.then2, %if.else3, %if.then
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/89345


More information about the llvm-commits mailing list