[llvm] [CmpInstAnalysis] Decompose icmp eq (and x, C) C2 (PR #136367)

Jeffrey Byrnes via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 18 13:36:05 PDT 2025


https://github.com/jrbyrnes created https://github.com/llvm/llvm-project/pull/136367

This type of decomposition is used in multiple places already. Adding it to `CmpInstAnalysis` reduces code duplication.

>From 740a37e616a1a4954efcfa2fa3a739e09c592238 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Fri, 18 Apr 2025 10:38:54 -0700
Subject: [PATCH] [CmpInstAnalysis] Decompose icmp eq (and x, C) C2

Change-Id: I1dd786a4652ccd2e486db6903e16e58ffa1a7959
---
 llvm/include/llvm/Analysis/CmpInstAnalysis.h  | 14 +++++---
 llvm/lib/Analysis/CmpInstAnalysis.cpp         | 34 ++++++++++++++-----
 .../InstCombine/InstCombineAndOrXor.cpp       | 14 +++-----
 .../InstCombine/InstCombineSelect.cpp         | 32 ++++++++---------
 .../Transforms/Scalar/LoopIdiomRecognize.cpp  | 14 +++-----
 5 files changed, 59 insertions(+), 49 deletions(-)

diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index aeda58ac7535d..a796a38feff88 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -95,24 +95,28 @@ namespace llvm {
   /// Represents the operation icmp (X & Mask) pred C, where pred can only be
   /// eq or ne.
   struct DecomposedBitTest {
-    Value *X;
+    Value *X = nullptr;
     CmpInst::Predicate Pred;
     APInt Mask;
     APInt C;
   };
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if possible.
-  /// Unless \p AllowNonZeroC is true, C will always be 0.
+  /// Unless \p AllowNonZeroC is true, C will always be 0. If \p
+  /// DecomposeBitMask is specified, then, for equality predicates, this will
+  /// decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                       bool LookThroughTrunc = true,
-                       bool AllowNonZeroC = false);
+                       bool LookThroughTrunc = true, bool AllowNonZeroC = false,
+                       bool DecomposeBitMask = false);
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if
   /// possible. Unless \p AllowNonZeroC is true, C will always be 0.
+  /// If \p DecomposeBitMask is specified, then, for equality predicates, this
+  /// will decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
-                   bool AllowNonZeroC = false);
+                   bool AllowNonZeroC = false, bool DecomposeBitMask = false);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 5c0d1dd1c74b0..6714088097127 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -75,11 +75,13 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
 
 std::optional<DecomposedBitTest>
 llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                           bool LookThruTrunc, bool AllowNonZeroC) {
+                           bool LookThruTrunc, bool AllowNonZeroC,
+                           bool DecomposeBitMask) {
   using namespace PatternMatch;
 
   const APInt *OrigC;
-  if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC)))
+  if ((ICmpInst::isEquality(Pred) && !DecomposeBitMask) ||
+      !match(RHS, m_APIntAllowPoison(OrigC)))
     return std::nullopt;
 
   bool Inverted = false;
@@ -97,9 +99,10 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   }
 
   DecomposedBitTest Result;
+
   switch (Pred) {
   default:
-    llvm_unreachable("Unexpected predicate");
+    return std::nullopt;
   case ICmpInst::ICMP_SLT: {
     // X < 0 is equivalent to (X & SignMask) != 0.
     if (C.isZero()) {
@@ -128,7 +131,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
     return std::nullopt;
   }
-  case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_ULT: {
     // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
     if (C.isPowerOf2()) {
       Result.Mask = -C;
@@ -147,6 +150,19 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
     return std::nullopt;
   }
+  case ICmpInst::ICMP_EQ:
+  case ICmpInst::ICMP_NE: {
+    assert(DecomposeBitMask);
+    const APInt *AndC;
+    Value *AndVal;
+    if (match(LHS, m_And(m_Value(AndVal), m_APIntAllowPoison(AndC)))) {
+      Result = {AndVal /*X*/, Pred /*Pred*/, *AndC /*Mask*/, *OrigC /*C*/};
+      break;
+    }
+
+    return std::nullopt;
+  }
+  }
 
   if (!AllowNonZeroC && !Result.C.isZero())
     return std::nullopt;
@@ -159,15 +175,17 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
     Result.X = X;
     Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
     Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
-  } else {
+  } else if (!Result.X) {
     Result.X = LHS;
   }
 
   return Result;
 }
 
-std::optional<DecomposedBitTest>
-llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
+std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
+                                                        bool LookThruTrunc,
+                                                        bool AllowNonZeroC,
+                                                        bool DecomposeBitMask) {
   using namespace PatternMatch;
   if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
     // Don't allow pointers. Splat vectors are fine.
@@ -175,7 +193,7 @@ llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
       return std::nullopt;
     return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
                                 ICmp->getPredicate(), LookThruTrunc,
-                                AllowNonZeroC);
+                                AllowNonZeroC, DecomposeBitMask);
   }
   Value *X;
   if (Cond->getType()->isIntOrIntVectorTy(1) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 19bf81137aab7..60461a6e8b338 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -875,22 +875,16 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
                            APInt &UnsetBitsMask) -> bool {
     CmpPredicate Pred = ICmp->getPredicate();
     // Can it be decomposed into  icmp eq (X & Mask), 0  ?
-    auto Res =
-        llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
-                                   Pred, /*LookThroughTrunc=*/false);
+    auto Res = llvm::decomposeBitTestICmp(
+        ICmp->getOperand(0), ICmp->getOperand(1), Pred,
+        /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
+        /*DecomposeBitMask=*/true);
     if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
       X = Res->X;
       UnsetBitsMask = Res->Mask;
       return true;
     }
 
-    // Is it  icmp eq (X & Mask), 0  already?
-    const APInt *Mask;
-    if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
-        Pred == ICmpInst::ICMP_EQ) {
-      UnsetBitsMask = *Mask;
-      return true;
-    }
     return false;
   };
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4bba2f406b4c1..8cdb00fb44a48 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3730,31 +3730,29 @@ static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
   Value *CmpLHS, *CmpRHS;
 
   if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
-    if (ICmpInst::isEquality(Pred)) {
-      if (!match(CmpRHS, m_Zero()))
-        return nullptr;
+    auto Res = decomposeBitTestICmp(
+        CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
 
-      V = CmpLHS;
-      const APInt *AndRHS;
-      if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
-        return nullptr;
+    if (!Res)
+      return nullptr;
 
-      AndMask = *AndRHS;
-    } else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
-      assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
-      AndMask = Res->Mask;
+    V = CmpLHS;
+    AndMask = Res->Mask;
+
+    if (!ICmpInst::isEquality(Pred)) {
       V = Res->X;
       KnownBits Known =
           computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
       AndMask &= Known.getMaxValue();
-      if (!AndMask.isPowerOf2())
-        return nullptr;
-
-      Pred = Res->Pred;
       CreateAnd = true;
-    } else {
-      return nullptr;
     }
+
+    Pred = Res->Pred;
+
+    if (!AndMask.isPowerOf2())
+      return nullptr;
+
   } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
     V = Trunc->getOperand(0);
     AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index eacc67cd5a475..ced1557ec3060 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -2761,14 +2761,11 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
                              m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)),
                                              CurLoop))));
   };
-  auto MatchConstantBitMask = [&]() {
-    return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
-           match(CmpLHS, m_And(m_Value(CurrX),
-                               m_CombineAnd(m_Value(BitMask), m_Power2()))) &&
-           (BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
-  };
+
   auto MatchDecomposableConstantBitMask = [&]() {
-    auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+    auto Res = llvm::decomposeBitTestICmp(
+        CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
     if (Res && Res->Mask.isPowerOf2()) {
       assert(ICmpInst::isEquality(Res->Pred));
       Pred = Res->Pred;
@@ -2780,8 +2777,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
     return false;
   };
 
-  if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
-      !MatchDecomposableConstantBitMask()) {
+  if (!MatchVariableBitMask() && !MatchDecomposableConstantBitMask()) {
     LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
     return false;
   }



More information about the llvm-commits mailing list