[llvm] [InstCombine] handle trunc to i1 in foldSelectICmpAndBinOp (PR #127390)

via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 16 03:50:03 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Andreas Jonson (andjo403)

<details>
<summary>Changes</summary>

for `trunc nuw` saves a instruction and otherwise only other instructions without the select, same behavior as for bit test before.

proof: https://alive2.llvm.org/ce/z/a6QmyV

---
Full diff: https://github.com/llvm/llvm-project/pull/127390.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+35-27) 
- (modified) llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll (+14-15) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 2e14145aef884..2a4bf82f1c070 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -742,39 +742,47 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
 /// 1. The icmp predicate is inverted
 /// 2. The select operands are reversed
 /// 3. The magnitude of C2 and C1 are flipped
-static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal,
-                                  Value *FalseVal,
-                                  InstCombiner::BuilderTy &Builder) {
+static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
+                                     Value *FalseVal,
+                                     InstCombiner::BuilderTy &Builder) {
   // Only handle integer compares. Also, if this is a vector select, we need a
   // vector compare.
   if (!TrueVal->getType()->isIntOrIntVectorTy() ||
-     TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy())
+      TrueVal->getType()->isVectorTy() != CondVal->getType()->isVectorTy())
     return nullptr;
 
-  Value *CmpLHS = IC->getOperand(0);
-  Value *CmpRHS = IC->getOperand(1);
-
   unsigned C1Log;
   bool NeedAnd = false;
-  CmpInst::Predicate Pred = IC->getPredicate();
-  if (IC->isEquality()) {
-    if (!match(CmpRHS, m_Zero()))
-      return nullptr;
+  CmpPredicate Pred;
+  Value *CmpLHS, *CmpRHS;
 
-    const APInt *C1;
-    if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
-      return nullptr;
+  if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
+    if (ICmpInst::isEquality(Pred)) {
+      if (!match(CmpRHS, m_Zero()))
+        return nullptr;
 
-    C1Log = C1->logBase2();
-  } else {
-    auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
-    if (!Res || !Res->Mask.isPowerOf2())
-      return nullptr;
+      const APInt *C1;
+      if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
+        return nullptr;
 
-    CmpLHS = Res->X;
-    Pred = Res->Pred;
-    C1Log = Res->Mask.logBase2();
-    NeedAnd = true;
+      C1Log = C1->logBase2();
+    } else {
+      auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+      if (!Res || !Res->Mask.isPowerOf2())
+        return nullptr;
+
+      CmpLHS = Res->X;
+      Pred = Res->Pred;
+      C1Log = Res->Mask.logBase2();
+      NeedAnd = true;
+    }
+  } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
+    CmpLHS = Trunc->getOperand(0);
+    C1Log = 0;
+    Pred = ICmpInst::ICMP_NE;
+    NeedAnd = !Trunc->hasNoUnsignedWrap();
+  } else {
+    return nullptr;
   }
 
   Value *Y, *V = CmpLHS;
@@ -808,7 +816,7 @@ static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal,
 
   // Make sure we don't create more instructions than we save.
   if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) >
-      (IC->hasOneUse() + BinOp->hasOneUse()))
+      (CondVal->hasOneUse() + BinOp->hasOneUse()))
     return nullptr;
 
   if (NeedAnd) {
@@ -1983,9 +1991,6 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
   if (Instruction *V = foldSelectZeroOrOnes(ICI, TrueVal, FalseVal, Builder))
     return V;
 
-  if (Value *V = foldSelectICmpAndBinOp(ICI, TrueVal, FalseVal, Builder))
-    return replaceInstUsesWith(SI, V);
-
   if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder))
     return replaceInstUsesWith(SI, V);
 
@@ -3943,6 +3948,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
     if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
       return Result;
 
+  if (Value *V = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, Builder))
+    return replaceInstUsesWith(SI, V);
+
   if (Instruction *Add = foldAddSubSelect(SI, Builder))
     return Add;
   if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder))
diff --git a/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll b/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll
index 7c100f579399d..f46c506683c45 100644
--- a/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll
+++ b/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll
@@ -1712,9 +1712,9 @@ define i8 @select_icmp_eq_and_1_0_lshr_tv(i8 %x, i8 %y) {
 
 define i8 @select_trunc_or_2(i8 %x, i8 %y) {
 ; CHECK-LABEL: @select_trunc_or_2(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], 2
-; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = and i8 [[TMP1]], 2
+; CHECK-NEXT:    [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
 ; CHECK-NEXT:    ret i8 [[SELECT]]
 ;
   %trunc = trunc i8 %x to i1
@@ -1725,9 +1725,9 @@ define i8 @select_trunc_or_2(i8 %x, i8 %y) {
 
 define i8 @select_not_trunc_or_2(i8 %x, i8 %y) {
 ; CHECK-LABEL: @select_not_trunc_or_2(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], 2
-; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = and i8 [[TMP1]], 2
+; CHECK-NEXT:    [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
 ; CHECK-NEXT:    ret i8 [[SELECT]]
 ;
   %trunc = trunc i8 %x to i1
@@ -1739,9 +1739,8 @@ define i8 @select_not_trunc_or_2(i8 %x, i8 %y) {
 
 define i8 @select_trunc_nuw_or_2(i8 %x, i8 %y) {
 ; CHECK-LABEL: @select_trunc_nuw_or_2(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], 2
-; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i8 [[SELECT]]
 ;
   %trunc = trunc nuw i8 %x to i1
@@ -1752,9 +1751,9 @@ define i8 @select_trunc_nuw_or_2(i8 %x, i8 %y) {
 
 define i8 @select_trunc_nsw_or_2(i8 %x, i8 %y) {
 ; CHECK-LABEL: @select_trunc_nsw_or_2(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc nsw i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], 2
-; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = and i8 [[TMP1]], 2
+; CHECK-NEXT:    [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
 ; CHECK-NEXT:    ret i8 [[SELECT]]
 ;
   %trunc = trunc nsw i8 %x to i1
@@ -1765,9 +1764,9 @@ define i8 @select_trunc_nsw_or_2(i8 %x, i8 %y) {
 
 define <2 x i8> @select_trunc_or_2_vec(<2 x i8> %x, <2 x i8> %y) {
 ; CHECK-LABEL: @select_trunc_or_2_vec(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc <2 x i8> [[X:%.*]] to <2 x i1>
-; CHECK-NEXT:    [[OR:%.*]] = or <2 x i8> [[Y:%.*]], splat (i8 2)
-; CHECK-NEXT:    [[SELECT:%.*]] = select <2 x i1> [[TRUNC]], <2 x i8> [[OR]], <2 x i8> [[Y]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shl <2 x i8> [[X:%.*]], splat (i8 1)
+; CHECK-NEXT:    [[TMP2:%.*]] = and <2 x i8> [[TMP1]], splat (i8 2)
+; CHECK-NEXT:    [[SELECT:%.*]] = or <2 x i8> [[Y:%.*]], [[TMP2]]
 ; CHECK-NEXT:    ret <2 x i8> [[SELECT]]
 ;
   %trunc = trunc <2 x i8> %x to <2 x i1>

``````````

</details>


https://github.com/llvm/llvm-project/pull/127390


More information about the llvm-commits mailing list