[llvm] [InstCombine] Support trunc to i1 in foldSelectICmpAnd (PR #127905)

Andreas Jonson via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 20 12:14:19 PST 2025


https://github.com/andjo403 updated https://github.com/llvm/llvm-project/pull/127905

>From 128c0dae0b3572484b491faded7cded90baa15ef Mon Sep 17 00:00:00 2001
From: Andreas Jonson <andjo403 at hotmail.com>
Date: Sun, 16 Feb 2025 23:25:30 +0100
Subject: [PATCH 1/2] [InstCombine] Support trunc to i1 in foldSelectICmpAnd

---
 .../InstCombine/InstCombineSelect.cpp         | 59 +++++++++++--------
 .../Transforms/InstCombine/select-icmp-and.ll | 26 ++++----
 2 files changed, 47 insertions(+), 38 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index e621a0b7fe596..91cf8c266ce47 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -119,7 +119,7 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
 ///  (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC
 /// With some variations depending if FC is larger than TC, or the shift
 /// isn't needed, or the bit widths don't match.
-static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
+static Value *foldSelectICmpAnd(SelectInst &Sel, Value *CondVal,
                                 InstCombiner::BuilderTy &Builder) {
   const APInt *SelTC, *SelFC;
   if (!match(Sel.getTrueValue(), m_APInt(SelTC)) ||
@@ -128,33 +128,42 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
 
   // If this is a vector select, we need a vector compare.
   Type *SelType = Sel.getType();
-  if (SelType->isVectorTy() != Cmp->getType()->isVectorTy())
+  if (SelType->isVectorTy() != CondVal->getType()->isVectorTy())
     return nullptr;
 
   Value *V;
   APInt AndMask;
   bool CreateAnd = false;
-  ICmpInst::Predicate Pred = Cmp->getPredicate();
-  if (ICmpInst::isEquality(Pred)) {
-    if (!match(Cmp->getOperand(1), m_Zero()))
-      return nullptr;
+  CmpPredicate Pred;
+  Value *CmpLHS, *CmpRHS;
 
-    V = Cmp->getOperand(0);
-    const APInt *AndRHS;
-    if (!match(V, m_And(m_Value(), m_Power2(AndRHS))))
-      return nullptr;
+  if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
+    if (ICmpInst::isEquality(Pred)) {
+      if (!match(CmpRHS, m_Zero()))
+        return nullptr;
 
-    AndMask = *AndRHS;
-  } else if (auto Res = decomposeBitTestICmp(Cmp->getOperand(0),
-                                             Cmp->getOperand(1), Pred)) {
-    assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
-    if (!Res->Mask.isPowerOf2())
-      return nullptr;
+      V = CmpLHS;
+      const APInt *AndRHS;
+      if (!match(V, m_And(m_Value(), m_Power2(AndRHS))))
+        return nullptr;
+
+      AndMask = *AndRHS;
+    } else {
+      auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+      if (!Res || !Res->Mask.isPowerOf2())
+        return nullptr;
+      assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
 
-    V = Res->X;
-    AndMask = Res->Mask;
-    Pred = Res->Pred;
-    CreateAnd = true;
+      V = Res->X;
+      AndMask = Res->Mask;
+      Pred = Res->Pred;
+      CreateAnd = true;
+    }
+  } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
+    V = Trunc->getOperand(0);
+    AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
+    Pred = ICmpInst::ICMP_NE;
+    CreateAnd = !Trunc->hasNoUnsignedWrap();
   } else {
     return nullptr;
   }
@@ -172,7 +181,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
       return nullptr;
     // If we have to create an 'and', then we must kill the cmp to not
     // increase the instruction count.
-    if (CreateAnd && !Cmp->hasOneUse())
+    if (CreateAnd && !CondVal->hasOneUse())
       return nullptr;
 
     // (V & AndMaskC) == 0 ? TC : FC --> TC | (V & AndMaskC)
@@ -213,7 +222,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
   // a 'select' + 'icmp', then this transformation would result in more
   // instructions and potentially interfere with other folding.
   if (CreateAnd + ShouldNotVal + NeedShift + NeedZExtTrunc >
-      1 + Cmp->hasOneUse())
+      1 + CondVal->hasOneUse())
     return nullptr;
 
   // Insert the 'and' instruction on the input to the truncate.
@@ -1955,9 +1964,6 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
           tryToReuseConstantFromSelectInComparison(SI, *ICI, *this))
     return NewSel;
 
-  if (Value *V = foldSelectICmpAnd(SI, ICI, Builder))
-    return replaceInstUsesWith(SI, V);
-
   // NOTE: if we wanted to, this is where to detect integer MIN/MAX
   bool Changed = false;
   Value *TrueVal = SI.getTrueValue();
@@ -3955,6 +3961,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
     if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
       return Result;
 
+  if (Value *V = foldSelectICmpAnd(SI, CondVal, Builder))
+    return replaceInstUsesWith(SI, V);
+
   if (Value *V = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, Builder))
     return replaceInstUsesWith(SI, V);
 
diff --git a/llvm/test/Transforms/InstCombine/select-icmp-and.ll b/llvm/test/Transforms/InstCombine/select-icmp-and.ll
index 16fb3f34047ee..f9f87e8030512 100644
--- a/llvm/test/Transforms/InstCombine/select-icmp-and.ll
+++ b/llvm/test/Transforms/InstCombine/select-icmp-and.ll
@@ -809,8 +809,8 @@ define i8 @select_bittest_to_xor(i8 %x) {
 
 define i8 @select_trunc_bittest_to_sub(i8 %x) {
 ; CHECK-LABEL: @select_trunc_bittest_to_sub(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[RET:%.*]] = select i1 [[TRUNC]], i8 3, i8 4
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[RET:%.*]] = sub nuw nsw i8 4, [[TMP1]]
 ; CHECK-NEXT:    ret i8 [[RET]]
 ;
   %trunc = trunc i8 %x to i1
@@ -820,8 +820,7 @@ define i8 @select_trunc_bittest_to_sub(i8 %x) {
 
 define i8 @select_trunc_nuw_bittest_to_sub(i8 %x) {
 ; CHECK-LABEL: @select_trunc_nuw_bittest_to_sub(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[RET:%.*]] = select i1 [[TRUNC]], i8 3, i8 4
+; CHECK-NEXT:    [[RET:%.*]] = sub i8 4, [[X:%.*]]
 ; CHECK-NEXT:    ret i8 [[RET]]
 ;
   %trunc = trunc nuw i8 %x to i1
@@ -831,8 +830,8 @@ define i8 @select_trunc_nuw_bittest_to_sub(i8 %x) {
 
 define i8 @select_trunc_nsw_bittest_to_sub(i8 %x) {
 ; CHECK-LABEL: @select_trunc_nsw_bittest_to_sub(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc nsw i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[RET:%.*]] = select i1 [[TRUNC]], i8 3, i8 4
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[RET:%.*]] = sub nuw nsw i8 4, [[TMP1]]
 ; CHECK-NEXT:    ret i8 [[RET]]
 ;
   %trunc = trunc nsw i8 %x to i1
@@ -844,7 +843,7 @@ define i8 @select_trunc_nuw_bittest_to_sub_extra_use(i8 %x) {
 ; CHECK-LABEL: @select_trunc_nuw_bittest_to_sub_extra_use(
 ; CHECK-NEXT:    [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1
 ; CHECK-NEXT:    call void @use1(i1 [[TRUNC]])
-; CHECK-NEXT:    [[RET:%.*]] = select i1 [[TRUNC]], i8 3, i8 4
+; CHECK-NEXT:    [[RET:%.*]] = sub i8 4, [[X]]
 ; CHECK-NEXT:    ret i8 [[RET]]
 ;
   %trunc = trunc nuw i8 %x to i1
@@ -868,8 +867,8 @@ define i8 @neg_select_trunc_bittest_to_sub_extra_use(i8 %x) {
 
 define i8 @select_trunc_nuw_bittest_to_shl_not(i8 %x) {
 ; CHECK-LABEL: @select_trunc_nuw_bittest_to_shl_not(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[RET:%.*]] = select i1 [[TRUNC]], i8 0, i8 4
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i8 [[X:%.*]], 2
+; CHECK-NEXT:    [[RET:%.*]] = xor i8 [[TMP1]], 4
 ; CHECK-NEXT:    ret i8 [[RET]]
 ;
   %trunc = trunc nuw i8 %x to i1
@@ -879,8 +878,8 @@ define i8 @select_trunc_nuw_bittest_to_shl_not(i8 %x) {
 
 define i8 @select_trunc_bittest_to_shl(i8 %x) {
 ; CHECK-LABEL: @select_trunc_bittest_to_shl(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[RET:%.*]] = select i1 [[TRUNC]], i8 4, i8 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i8 [[X:%.*]], 2
+; CHECK-NEXT:    [[RET:%.*]] = and i8 [[TMP1]], 4
 ; CHECK-NEXT:    ret i8 [[RET]]
 ;
   %trunc = trunc i8 %x to i1
@@ -903,8 +902,9 @@ define i8 @neg_select_trunc_bittest_to_shl_extra_use(i8 %x) {
 
 define i16 @select_trunc_nuw_bittest_or(i8 %x) {
 ; CHECK-LABEL: @select_trunc_nuw_bittest_or(
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc nuw i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[TMP1]], i16 20, i16 4
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i8 [[X:%.*]] to i16
+; CHECK-NEXT:    [[SELECT:%.*]] = shl nuw nsw i16 [[TMP1]], 4
+; CHECK-NEXT:    [[RES:%.*]] = or disjoint i16 [[SELECT]], 4
 ; CHECK-NEXT:    ret i16 [[RES]]
 ;
   %trunc = trunc nuw i8 %x to i1

>From c88778acace3fceb840cdb7c32d005adaef1f70a Mon Sep 17 00:00:00 2001
From: Andreas Jonson <andjo403 at hotmail.com>
Date: Thu, 20 Feb 2025 21:14:07 +0100
Subject: [PATCH 2/2] [InstCombine] Reuse common matches between
 foldSelectICmpAndBinOp and foldSelectICmpAnd. (NFC)

---
 .../InstCombine/InstCombineSelect.cpp         | 171 ++++++++----------
 1 file changed, 75 insertions(+), 96 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 91cf8c266ce47..d0994af105c8e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -119,57 +119,15 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
 ///  (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC
 /// With some variations depending if FC is larger than TC, or the shift
 /// isn't needed, or the bit widths don't match.
-static Value *foldSelectICmpAnd(SelectInst &Sel, Value *CondVal,
+static Value *foldSelectICmpAnd(SelectInst &Sel, Value *CondVal, Value *TrueVal,
+                                Value *FalseVal, Value *V, APInt AndMask,
+                                bool CreateAnd,
                                 InstCombiner::BuilderTy &Builder) {
   const APInt *SelTC, *SelFC;
-  if (!match(Sel.getTrueValue(), m_APInt(SelTC)) ||
-      !match(Sel.getFalseValue(), m_APInt(SelFC)))
+  if (!match(TrueVal, m_APInt(SelTC)) || !match(FalseVal, m_APInt(SelFC)))
     return nullptr;
 
-  // If this is a vector select, we need a vector compare.
   Type *SelType = Sel.getType();
-  if (SelType->isVectorTy() != CondVal->getType()->isVectorTy())
-    return nullptr;
-
-  Value *V;
-  APInt AndMask;
-  bool CreateAnd = false;
-  CmpPredicate Pred;
-  Value *CmpLHS, *CmpRHS;
-
-  if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
-    if (ICmpInst::isEquality(Pred)) {
-      if (!match(CmpRHS, m_Zero()))
-        return nullptr;
-
-      V = CmpLHS;
-      const APInt *AndRHS;
-      if (!match(V, m_And(m_Value(), m_Power2(AndRHS))))
-        return nullptr;
-
-      AndMask = *AndRHS;
-    } else {
-      auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
-      if (!Res || !Res->Mask.isPowerOf2())
-        return nullptr;
-      assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
-
-      V = Res->X;
-      AndMask = Res->Mask;
-      Pred = Res->Pred;
-      CreateAnd = true;
-    }
-  } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
-    V = Trunc->getOperand(0);
-    AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
-    Pred = ICmpInst::ICMP_NE;
-    CreateAnd = !Trunc->hasNoUnsignedWrap();
-  } else {
-    return nullptr;
-  }
-  if (Pred == ICmpInst::ICMP_NE)
-    std::swap(SelTC, SelFC);
-
   // In general, when both constants are non-zero, we would need an offset to
   // replace the select. This would require more instructions than we started
   // with. But there's one special-case that we handle here because it can
@@ -756,60 +714,26 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
 /// 2. The select operands are reversed
 /// 3. The magnitude of C2 and C1 are flipped
 static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
-                                     Value *FalseVal,
+                                     Value *FalseVal, Value *V, APInt AndMask,
+                                     bool CreateAnd,
                                      InstCombiner::BuilderTy &Builder) {
-  // Only handle integer compares. Also, if this is a vector select, we need a
-  // vector compare.
-  if (!TrueVal->getType()->isIntOrIntVectorTy() ||
-      TrueVal->getType()->isVectorTy() != CondVal->getType()->isVectorTy())
-    return nullptr;
-
-  unsigned C1Log;
-  bool NeedAnd = false;
-  CmpPredicate Pred;
-  Value *CmpLHS, *CmpRHS;
-
-  if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
-    if (ICmpInst::isEquality(Pred)) {
-      if (!match(CmpRHS, m_Zero()))
-        return nullptr;
-
-      const APInt *C1;
-      if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
-        return nullptr;
-
-      C1Log = C1->logBase2();
-    } else {
-      auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
-      if (!Res || !Res->Mask.isPowerOf2())
-        return nullptr;
-
-      CmpLHS = Res->X;
-      Pred = Res->Pred;
-      C1Log = Res->Mask.logBase2();
-      NeedAnd = true;
-    }
-  } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
-    CmpLHS = Trunc->getOperand(0);
-    C1Log = 0;
-    Pred = ICmpInst::ICMP_NE;
-    NeedAnd = !Trunc->hasNoUnsignedWrap();
-  } else {
+  // Only handle integer compares.
+  if (!TrueVal->getType()->isIntOrIntVectorTy())
     return nullptr;
-  }
 
-  Value *Y, *V = CmpLHS;
+  unsigned C1Log = AndMask.logBase2();
+  Value *Y;
   BinaryOperator *BinOp;
   const APInt *C2;
   bool NeedXor;
   if (match(FalseVal, m_BinOp(m_Specific(TrueVal), m_Power2(C2)))) {
     Y = TrueVal;
     BinOp = cast<BinaryOperator>(FalseVal);
-    NeedXor = Pred == ICmpInst::ICMP_NE;
+    NeedXor = false;
   } else if (match(TrueVal, m_BinOp(m_Specific(FalseVal), m_Power2(C2)))) {
     Y = FalseVal;
     BinOp = cast<BinaryOperator>(TrueVal);
-    NeedXor = Pred == ICmpInst::ICMP_EQ;
+    NeedXor = true;
   } else {
     return nullptr;
   }
@@ -828,14 +752,13 @@ static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
                        V->getType()->getScalarSizeInBits();
 
   // Make sure we don't create more instructions than we save.
-  if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) >
+  if ((NeedShift + NeedXor + NeedZExtTrunc + CreateAnd) >
       (CondVal->hasOneUse() + BinOp->hasOneUse()))
     return nullptr;
 
-  if (NeedAnd) {
+  if (CreateAnd) {
     // Insert the AND instruction on the input to the truncate.
-    APInt C1 = APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), C1Log);
-    V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), C1));
+    V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask));
   }
 
   if (C2Log > C1Log) {
@@ -3789,6 +3712,65 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI,
   return nullptr;
 }
 
+static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
+                                Value *FalseVal,
+                                InstCombiner::BuilderTy &Builder) {
+  // If this is a vector select, we need a vector compare.
+  Type *SelType = Sel.getType();
+  if (SelType->isVectorTy() != CondVal->getType()->isVectorTy())
+    return nullptr;
+
+  Value *V;
+  APInt AndMask;
+  bool CreateAnd = false;
+  CmpPredicate Pred;
+  Value *CmpLHS, *CmpRHS;
+
+  if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
+    if (ICmpInst::isEquality(Pred)) {
+      if (!match(CmpRHS, m_Zero()))
+        return nullptr;
+
+      V = CmpLHS;
+      const APInt *AndRHS;
+      if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
+        return nullptr;
+
+      AndMask = *AndRHS;
+    } else {
+      auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+      if (!Res || !Res->Mask.isPowerOf2())
+        return nullptr;
+      assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
+
+      V = Res->X;
+      AndMask = Res->Mask;
+      Pred = Res->Pred;
+      CreateAnd = true;
+    }
+  } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
+    V = Trunc->getOperand(0);
+    AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
+    Pred = ICmpInst::ICMP_NE;
+    CreateAnd = !Trunc->hasNoUnsignedWrap();
+  } else {
+    return nullptr;
+  }
+
+  if (Pred == ICmpInst::ICMP_NE)
+    std::swap(TrueVal, FalseVal);
+
+  if (Value *X = foldSelectICmpAnd(Sel, CondVal, TrueVal, FalseVal, V, AndMask,
+                                   CreateAnd, Builder))
+    return X;
+
+  if (Value *X = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, V, AndMask,
+                                        CreateAnd, Builder))
+    return X;
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
   Value *CondVal = SI.getCondition();
   Value *TrueVal = SI.getTrueValue();
@@ -3961,10 +3943,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
     if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
       return Result;
 
-  if (Value *V = foldSelectICmpAnd(SI, CondVal, Builder))
-    return replaceInstUsesWith(SI, V);
-
-  if (Value *V = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, Builder))
+  if (Value *V = foldSelectBitTest(SI, CondVal, TrueVal, FalseVal, Builder))
     return replaceInstUsesWith(SI, V);
 
   if (Instruction *Add = foldAddSubSelect(SI, Builder))



More information about the llvm-commits mailing list