[llvm] [InstSimplify] Move Select with bittest folds. NFC. (PR #122944)

Andreas Jonson via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 17 15:11:43 PST 2025


https://github.com/andjo403 updated https://github.com/llvm/llvm-project/pull/122944

>From 39fea1f1f3ed6d959e4dd67eef76fa5e21336baf Mon Sep 17 00:00:00 2001
From: Andreas Jonson <andjo403 at hotmail.com>
Date: Sat, 18 Jan 2025 00:11:20 +0100
Subject: [PATCH] [InstSimplify] Handle trunc to i1 in Select with bit test
 folds.

---
 llvm/lib/Analysis/CmpInstAnalysis.cpp       | 19 +++++++++++++++++++
 llvm/lib/Analysis/InstructionSimplify.cpp   | 19 ++++++++-----------
 llvm/test/Transforms/InstSimplify/select.ll | 12 +++---------
 3 files changed, 30 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 3599428c5ff416..e618164d4c19f6 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -168,6 +168,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
 std::optional<DecomposedBitTest>
 llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
+  using namespace PatternMatch;
   if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
     // Don't allow pointers. Splat vectors are fine.
     if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy())
@@ -176,6 +177,24 @@ llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
                                 ICmp->getPredicate(), LookThruTrunc,
                                 AllowNonZeroC);
   }
+  Value *X;
+  if (Cond->getType()->isIntOrIntVectorTy(1) &&
+      (match(Cond, m_Trunc(m_Value(X))) ||
+       match(Cond, m_Not(m_Trunc(m_Value(X)))))) {
+    DecomposedBitTest Result;
+    Result.X = X;
+    unsigned BitWidth = X->getType()->getScalarSizeInBits();
+    Result.Mask = APInt(BitWidth, 1);
+    Result.C = APInt::getZero(BitWidth);
+    Result.Pred = isa<TruncInst>(Cond) ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ;
+
+    if (LookThruTrunc && match(Result.X, m_Trunc(m_Value(X)))) {
+      Result.X = X;
+      Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
+      Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
+    }
+    return Result;
+  }
 
   return std::nullopt;
 }
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index d69747e30f884d..1facf56937f244 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4612,12 +4612,11 @@ static Value *simplifyCmpSelOfMaxMin(Value *CmpLHS, Value *CmpRHS,
   return nullptr;
 }
 
-/// An alternative way to test if a bit is set or not uses sgt/slt instead of
-/// eq/ne.
-static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,
-                                           CmpPredicate Pred, Value *TrueVal,
-                                           Value *FalseVal) {
-  if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred))
+/// An alternative way to test if a bit is set or not.
+/// uses e.g. sgt/slt or trunc instead of eq/ne.
+static Value *simplifySelectWithBitTest(Value *CondVal, Value *TrueVal,
+                                        Value *FalseVal) {
+  if (auto Res = decomposeBitTest(CondVal))
     return simplifySelectBitTest(TrueVal, FalseVal, Res->X, &Res->Mask,
                                  Res->Pred == ICmpInst::ICMP_EQ);
 
@@ -4728,11 +4727,6 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
       return FalseVal;
   }
 
-  // Check for other compares that behave like bit test.
-  if (Value *V =
-          simplifySelectWithFakeICmpEq(CmpLHS, CmpRHS, Pred, TrueVal, FalseVal))
-    return V;
-
   // If we have a scalar equality comparison, then we know the value in one of
   // the arms of the select. See if substituting this value into the arm and
   // simplifying the result yields the same value as the other arm.
@@ -4984,6 +4978,9 @@ static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
           simplifySelectWithICmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse))
     return V;
 
+  if (Value *V = simplifySelectWithBitTest(Cond, TrueVal, FalseVal))
+    return V;
+
   if (Value *V = simplifySelectWithFCmp(Cond, TrueVal, FalseVal, Q, MaxRecurse))
     return V;
 
diff --git a/llvm/test/Transforms/InstSimplify/select.ll b/llvm/test/Transforms/InstSimplify/select.ll
index 40539b8ade388b..1b5703a46cf683 100644
--- a/llvm/test/Transforms/InstSimplify/select.ll
+++ b/llvm/test/Transforms/InstSimplify/select.ll
@@ -1752,10 +1752,7 @@ define <4 x i32> @select_vector_cmp_with_bitcasts(<2 x i64> %x, <4 x i32> %y) {
 
 define i8 @bittest_trunc_or(i8 %x) {
 ; CHECK-LABEL: @bittest_trunc_or(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i8 [[X1:%.*]] to i1
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[X1]], 1
-; CHECK-NEXT:    [[X:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[X1]]
-; CHECK-NEXT:    ret i8 [[X]]
+; CHECK-NEXT:    ret i8 [[X:%.*]]
 ;
   %trunc = trunc i8 %x to i1
   %or = or i8 %x, 1
@@ -1765,11 +1762,8 @@ define i8 @bittest_trunc_or(i8 %x) {
 
 define i8 @bittest_trunc_not_or(i8 %x) {
 ; CHECK-LABEL: @bittest_trunc_not_or(
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1
-; CHECK-NEXT:    [[NOT:%.*]] = xor i1 [[TRUNC]], true
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[X]], 1
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[NOT]], i8 [[OR]], i8 [[X]]
-; CHECK-NEXT:    ret i8 [[COND]]
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[X:%.*]], 1
+; CHECK-NEXT:    ret i8 [[OR]]
 ;
   %trunc = trunc i8 %x to i1
   %not = xor i1 %trunc, true



More information about the llvm-commits mailing list