[llvm] [JumpThreading] Convert `s/zext i1` to `select i1` for further unfolding (PR #89345)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 18 22:13:52 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Franklin Zhang (FLZ101)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/89345.diff
3 Files Affected:
- (modified) llvm/include/llvm/Transforms/Scalar/JumpThreading.h (+2)
- (modified) llvm/lib/Transforms/Scalar/JumpThreading.cpp (+81-1)
- (added) llvm/test/Transforms/JumpThreading/szext.ll (+94)
``````````diff
diff --git a/llvm/include/llvm/Transforms/Scalar/JumpThreading.h b/llvm/include/llvm/Transforms/Scalar/JumpThreading.h
index 3364d7eaee4247..c7da9053b7abeb 100644
--- a/llvm/include/llvm/Transforms/Scalar/JumpThreading.h
+++ b/llvm/include/llvm/Transforms/Scalar/JumpThreading.h
@@ -162,6 +162,8 @@ class JumpThreadingPass : public PassInfoMixin<JumpThreadingPass> {
bool tryToUnfoldSelect(SwitchInst *SI, BasicBlock *BB);
bool tryToUnfoldSelectInCurrBB(BasicBlock *BB);
+ bool tryToConvertSZExtToSelect(BasicBlock *BB);
+
bool processGuards(BasicBlock *BB);
bool threadGuard(BasicBlock *BB, IntrinsicInst *Guard, BranchInst *BI);
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index ffcb511e6a8312..c92519610aec18 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -971,6 +971,9 @@ bool JumpThreadingPass::processBlock(BasicBlock *BB) {
if (maybeMergeBasicBlockIntoOnlyPred(BB))
return true;
+ if (tryToConvertSZExtToSelect(BB))
+ return true;
+
if (tryToUnfoldSelectInCurrBB(BB))
return true;
@@ -2750,7 +2753,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred(
// Pred is a predecessor of BB with an unconditional branch to BB. SI is
// a Select instruction in Pred. BB has other predecessors and SI is used in
// a PHI node in BB. SI has no other use.
-// A new basic block, NewBB, is created and SI is converted to compare and
+// A new basic block, NewBB, is created and SI is converted to compare and
// conditional branch. SI is erased from parent.
void JumpThreadingPass::unfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB,
SelectInst *SI, PHINode *SIUse,
@@ -2997,6 +3000,83 @@ bool JumpThreadingPass::tryToUnfoldSelectInCurrBB(BasicBlock *BB) {
return false;
}
+/// Try to convert "sext/zext i1" into "select i1" which could be further
+/// unfolded by tryToUnfoldSelect().
+///
+/// For example,
+///
+/// ; before the transformation
+/// BB1:
+/// %a = icmp ...
+/// %b = zext i1 %a to i32
+/// br label %BB2
+/// BB2:
+/// %c = phi i32 [ %b, %BB1 ], ...
+/// %d = icmp eq i32 %c, 0
+/// br i1 %d, ...
+///
+/// ------
+///
+/// ; after the transformation
+/// BB1:
+/// %a = icmp ...
+/// %b = select i1 %a, i32 1, i32 0
+/// br label %BB2
+/// BB2:
+/// %c = phi i32 [ %b, %BB1 ], ...
+/// %d = icmp eq i32 %c, 0
+/// br i1 %d, ...
+///
+bool JumpThreadingPass::tryToConvertSZExtToSelect(BasicBlock *BB) {
+ // tryToUnfoldSelect requires that Br is unconditional
+ BranchInst *Br = dyn_cast<BranchInst>(BB->getTerminator());
+ if (!Br || Br->isConditional())
+ return false;
+ BasicBlock *BBX = Br->getSuccessor(0);
+
+ SmallVector<Instruction *> ToConvert;
+ for (auto &I : *BB) {
+ using namespace PatternMatch;
+
+ Value *V;
+ if (!match(&I, m_ZExtOrSExt(m_Value(V))) || !V->getType()->isIntegerTy(1))
+ continue;
+
+ // I is only used by Phi
+ Use *U = I.getSingleUndroppableUse();
+ if (!U)
+ continue;
+ PHINode *Phi = dyn_cast<PHINode>(U->getUser());
+ if (!Phi || Phi->getParent() != BBX)
+ continue;
+
+ // tryToUnfoldSelect requires that Phi is used in the following way
+ ICmpInst::Predicate Pred;
+ if (!match(BBX->getTerminator(),
+ m_Br(m_ICmp(Pred, m_Specific(Phi), m_ConstantInt()),
+ m_BasicBlock(), m_BasicBlock())))
+ continue;
+
+ ToConvert.push_back(&I);
+ }
+ if (ToConvert.empty())
+ return false;
+
+ LLVM_DEBUG(dbgs() << "\nconvert-szext-to-select:\n" << *BB << "\n");
+ for (Instruction *I : ToConvert) {
+ auto Ty = I->getType();
+ Value *V1 = isa<SExtInst>(I) ? ConstantInt::getAllOnesValue(Ty)
+ : ConstantInt::get(Ty, 1);
+ Value *V2 = ConstantInt::getNullValue(Ty);
+ SelectInst *SI =
+ SelectInst::Create(I->getOperand(0), V1, V2, I->getName(), I);
+ I->replaceAllUsesWith(SI);
+ I->eraseFromParent();
+ }
+ LLVM_DEBUG(dbgs() << *BB << "\n");
+ return true;
+}
+
/// Try to propagate a guard from the current BB into one of its predecessors
/// in case if another branch of execution implies that the condition of this
/// guard is always true. Currently we only process the simplest case that
diff --git a/llvm/test/Transforms/JumpThreading/szext.ll b/llvm/test/Transforms/JumpThreading/szext.ll
new file mode 100644
index 00000000000000..290fe7ad0ca257
--- /dev/null
+++ b/llvm/test/Transforms/JumpThreading/szext.ll
@@ -0,0 +1,94 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -passes=jump-threading < %s | FileCheck %s
+
+; void fun(int);
+;
+; int compare1(int a, int b, int c, int d)
+; {
+; return a < b ? -1 :
+; a > b ? 1 :
+; c < d ? -1 :
+; c > d ? 1 : 0;
+; }
+;
+; void test1(int a, int b, int c, int d) {
+; int x = compare1(a, b, c, d);
+; if (x < 0)
+; fun(10);
+; else if (x > 0)
+; fun(20);
+; else
+; fun(30);
+; }
+
+declare void @fun(i32 noundef)
+
+define void @test1(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d) {
+; CHECK-LABEL: define void @test1(
+; CHECK-SAME: i32 noundef [[A:%.*]], i32 noundef [[B:%.*]], i32 noundef [[C:%.*]], i32 noundef [[D:%.*]]) {
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP_I:%.*]] = icmp slt i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[CMP_I]], label [[IF_THEN:%.*]], label [[COND_FALSE_I:%.*]]
+; CHECK: cond.false.i:
+; CHECK-NEXT: [[CMP1_I:%.*]] = icmp sgt i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[CMP1_I]], label [[IF_THEN2:%.*]], label [[COND_FALSE3_I:%.*]]
+; CHECK: cond.false3.i:
+; CHECK-NEXT: [[CMP4_I:%.*]] = icmp slt i32 [[C]], [[D]]
+; CHECK-NEXT: br i1 [[CMP4_I]], label [[IF_THEN]], label [[COND_FALSE6_I:%.*]]
+; CHECK: cond.false6.i:
+; CHECK-NEXT: [[CMP7_I:%.*]] = icmp sgt i32 [[C]], [[D]]
+; CHECK-NEXT: br i1 [[CMP7_I]], label [[IF_THEN2]], label [[IF_ELSE3:%.*]]
+; CHECK: if.then:
+; CHECK-NEXT: call void @fun(i32 noundef 10)
+; CHECK-NEXT: br label [[IF_END4:%.*]]
+; CHECK: if.then2:
+; CHECK-NEXT: call void @fun(i32 noundef 20)
+; CHECK-NEXT: br label [[IF_END4]]
+; CHECK: if.else3:
+; CHECK-NEXT: [[COND12_I:%.*]] = phi i32 [ 0, [[COND_FALSE6_I]] ]
+; CHECK-NEXT: call void @fun(i32 noundef 30)
+; CHECK-NEXT: br label [[IF_END4]]
+; CHECK: if.end4:
+; CHECK-NEXT: ret void
+;
+entry:
+ %cmp.i = icmp slt i32 %a, %b
+ br i1 %cmp.i, label %compare1.exit, label %cond.false.i
+
+cond.false.i: ; preds = %entry
+ %cmp1.i = icmp sgt i32 %a, %b
+ br i1 %cmp1.i, label %compare1.exit, label %cond.false3.i
+
+cond.false3.i: ; preds = %cond.false.i
+ %cmp4.i = icmp slt i32 %c, %d
+ br i1 %cmp4.i, label %compare1.exit, label %cond.false6.i
+
+cond.false6.i: ; preds = %cond.false3.i
+ %cmp7.i = icmp sgt i32 %c, %d
+ %cond.i = zext i1 %cmp7.i to i32
+ br label %compare1.exit
+
+compare1.exit: ; preds = %entry, %cond.false.i, %cond.false3.i, %cond.false6.i
+ %cond12.i = phi i32 [ -1, %entry ], [ 1, %cond.false.i ], [ %cond.i, %cond.false6.i ], [ -1, %cond.false3.i ]
+ %cmp = icmp slt i32 %cond12.i, 0
+ br i1 %cmp, label %if.then, label %if.else
+
+if.then: ; preds = %compare1.exit
+ call void @fun(i32 noundef 10)
+ br label %if.end4
+
+if.else: ; preds = %compare1.exit
+ %cmp1.not = icmp eq i32 %cond12.i, 0
+ br i1 %cmp1.not, label %if.else3, label %if.then2
+
+if.then2: ; preds = %if.else
+ call void @fun(i32 noundef 20)
+ br label %if.end4
+
+if.else3: ; preds = %if.else
+ call void @fun(i32 noundef 30)
+ br label %if.end4
+
+if.end4: ; preds = %if.then2, %if.else3, %if.then
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/89345
More information about the llvm-commits
mailing list