[llvm] [InstCombine] handle trunc to i1 in foldSelectICmpAndBinOp (PR #127390)

Andreas Jonson via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 19 09:16:33 PST 2025


https://github.com/andjo403 updated https://github.com/llvm/llvm-project/pull/127390

>From 8d61bc5669e094208a486a285a42daa5084681dd Mon Sep 17 00:00:00 2001
From: Andreas Jonson <andjo403 at hotmail.com>
Date: Sun, 16 Feb 2025 12:09:20 +0100
Subject: [PATCH] [InstCombine] handle trunc to i1 in foldSelectICmpAndBinOp

---
 .../InstCombine/InstCombineSelect.cpp         | 62 +++++++++++--------
 .../InstCombine/select-with-bitwise-ops.ll    | 29 +++++----
 2 files changed, 49 insertions(+), 42 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index cf38fc5f058f2..0dfdd9209b40e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -742,39 +742,47 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
 /// 1. The icmp predicate is inverted
 /// 2. The select operands are reversed
 /// 3. The magnitude of C2 and C1 are flipped
-static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal,
-                                  Value *FalseVal,
-                                  InstCombiner::BuilderTy &Builder) {
+static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
+                                     Value *FalseVal,
+                                     InstCombiner::BuilderTy &Builder) {
   // Only handle integer compares. Also, if this is a vector select, we need a
   // vector compare.
   if (!TrueVal->getType()->isIntOrIntVectorTy() ||
-     TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy())
+      TrueVal->getType()->isVectorTy() != CondVal->getType()->isVectorTy())
     return nullptr;
 
-  Value *CmpLHS = IC->getOperand(0);
-  Value *CmpRHS = IC->getOperand(1);
-
   unsigned C1Log;
   bool NeedAnd = false;
-  CmpInst::Predicate Pred = IC->getPredicate();
-  if (IC->isEquality()) {
-    if (!match(CmpRHS, m_Zero()))
-      return nullptr;
+  CmpPredicate Pred;
+  Value *CmpLHS, *CmpRHS;
 
-    const APInt *C1;
-    if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
-      return nullptr;
+  if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
+    if (ICmpInst::isEquality(Pred)) {
+      if (!match(CmpRHS, m_Zero()))
+        return nullptr;
 
-    C1Log = C1->logBase2();
-  } else {
-    auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
-    if (!Res || !Res->Mask.isPowerOf2())
-      return nullptr;
+      const APInt *C1;
+      if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
+        return nullptr;
 
-    CmpLHS = Res->X;
-    Pred = Res->Pred;
-    C1Log = Res->Mask.logBase2();
-    NeedAnd = true;
+      C1Log = C1->logBase2();
+    } else {
+      auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+      if (!Res || !Res->Mask.isPowerOf2())
+        return nullptr;
+
+      CmpLHS = Res->X;
+      Pred = Res->Pred;
+      C1Log = Res->Mask.logBase2();
+      NeedAnd = true;
+    }
+  } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
+    CmpLHS = Trunc->getOperand(0);
+    C1Log = 0;
+    Pred = ICmpInst::ICMP_NE;
+    NeedAnd = !Trunc->hasNoUnsignedWrap();
+  } else {
+    return nullptr;
   }
 
   Value *Y, *V = CmpLHS;
@@ -808,7 +816,7 @@ static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal,
 
   // Make sure we don't create more instructions than we save.
   if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) >
-      (IC->hasOneUse() + BinOp->hasOneUse()))
+      (CondVal->hasOneUse() + BinOp->hasOneUse()))
     return nullptr;
 
   if (NeedAnd) {
@@ -1986,9 +1994,6 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
   if (Instruction *V = foldSelectZeroOrOnes(ICI, TrueVal, FalseVal, Builder))
     return V;
 
-  if (Value *V = foldSelectICmpAndBinOp(ICI, TrueVal, FalseVal, Builder))
-    return replaceInstUsesWith(SI, V);
-
   if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder))
     return replaceInstUsesWith(SI, V);
 
@@ -3946,6 +3951,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
     if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
       return Result;
 
+  if (Value *V = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, Builder))
+    return replaceInstUsesWith(SI, V);
+
   if (Instruction *Add = foldAddSubSelect(SI, Builder))
     return Add;
   if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder))
diff --git a/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll b/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll
index 67dec9178eeca..ca2e23c1d082e 100644
--- a/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll
+++ b/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll
@@ -1754,9 +1754,9 @@ define i8 @select_icmp_eq_and_1_0_lshr_tv(i8 %x, i8 %y) {
 
 define i8 @select_trunc_or_2(i8 %x, i8 %y) {
 ; CHECK-LABEL: @select_trunc_or_2(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], 2
-; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = and i8 [[TMP1]], 2
+; CHECK-NEXT:    [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
 ; CHECK-NEXT:    ret i8 [[SELECT]]
 ;
   %trunc = trunc i8 %x to i1
@@ -1767,9 +1767,9 @@ define i8 @select_trunc_or_2(i8 %x, i8 %y) {
 
 define i8 @select_not_trunc_or_2(i8 %x, i8 %y) {
 ; CHECK-LABEL: @select_not_trunc_or_2(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], 2
-; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = and i8 [[TMP1]], 2
+; CHECK-NEXT:    [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
 ; CHECK-NEXT:    ret i8 [[SELECT]]
 ;
   %trunc = trunc i8 %x to i1
@@ -1781,9 +1781,8 @@ define i8 @select_not_trunc_or_2(i8 %x, i8 %y) {
 
 define i8 @select_trunc_nuw_or_2(i8 %x, i8 %y) {
 ; CHECK-LABEL: @select_trunc_nuw_or_2(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], 2
-; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i8 [[SELECT]]
 ;
   %trunc = trunc nuw i8 %x to i1
@@ -1794,9 +1793,9 @@ define i8 @select_trunc_nuw_or_2(i8 %x, i8 %y) {
 
 define i8 @select_trunc_nsw_or_2(i8 %x, i8 %y) {
 ; CHECK-LABEL: @select_trunc_nsw_or_2(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc nsw i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], 2
-; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = and i8 [[TMP1]], 2
+; CHECK-NEXT:    [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
 ; CHECK-NEXT:    ret i8 [[SELECT]]
 ;
   %trunc = trunc nsw i8 %x to i1
@@ -1807,9 +1806,9 @@ define i8 @select_trunc_nsw_or_2(i8 %x, i8 %y) {
 
 define <2 x i8> @select_trunc_or_2_vec(<2 x i8> %x, <2 x i8> %y) {
 ; CHECK-LABEL: @select_trunc_or_2_vec(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc <2 x i8> [[X:%.*]] to <2 x i1>
-; CHECK-NEXT:    [[OR:%.*]] = or <2 x i8> [[Y:%.*]], splat (i8 2)
-; CHECK-NEXT:    [[SELECT:%.*]] = select <2 x i1> [[TRUNC]], <2 x i8> [[OR]], <2 x i8> [[Y]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shl <2 x i8> [[X:%.*]], splat (i8 1)
+; CHECK-NEXT:    [[TMP2:%.*]] = and <2 x i8> [[TMP1]], splat (i8 2)
+; CHECK-NEXT:    [[SELECT:%.*]] = or <2 x i8> [[Y:%.*]], [[TMP2]]
 ; CHECK-NEXT:    ret <2 x i8> [[SELECT]]
 ;
   %trunc = trunc <2 x i8> %x to <2 x i1>



More information about the llvm-commits mailing list