[llvm] [CmpInstAnalysis] Decompose icmp eq (and x, C) C2 (PR #136367)

Jeffrey Byrnes via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 24 12:17:08 PDT 2025


https://github.com/jrbyrnes updated https://github.com/llvm/llvm-project/pull/136367

>From 740a37e616a1a4954efcfa2fa3a739e09c592238 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Fri, 18 Apr 2025 10:38:54 -0700
Subject: [PATCH 1/3] [CmpInstAnalysis] Decompose icmp eq (and x, C) C2

Change-Id: I1dd786a4652ccd2e486db6903e16e58ffa1a7959
---
 llvm/include/llvm/Analysis/CmpInstAnalysis.h  | 14 +++++---
 llvm/lib/Analysis/CmpInstAnalysis.cpp         | 34 ++++++++++++++-----
 .../InstCombine/InstCombineAndOrXor.cpp       | 14 +++-----
 .../InstCombine/InstCombineSelect.cpp         | 32 ++++++++---------
 .../Transforms/Scalar/LoopIdiomRecognize.cpp  | 14 +++-----
 5 files changed, 59 insertions(+), 49 deletions(-)

diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index aeda58ac7535d..a796a38feff88 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -95,24 +95,28 @@ namespace llvm {
   /// Represents the operation icmp (X & Mask) pred C, where pred can only be
   /// eq or ne.
   struct DecomposedBitTest {
-    Value *X;
+    Value *X = nullptr;
     CmpInst::Predicate Pred;
     APInt Mask;
     APInt C;
   };
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if possible.
-  /// Unless \p AllowNonZeroC is true, C will always be 0.
+  /// Unless \p AllowNonZeroC is true, C will always be 0. If \p
+  /// DecomposeBitMask is specified, then, for equality predicates, this will
+  /// decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                       bool LookThroughTrunc = true,
-                       bool AllowNonZeroC = false);
+                       bool LookThroughTrunc = true, bool AllowNonZeroC = false,
+                       bool DecomposeBitMask = false);
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if
   /// possible. Unless \p AllowNonZeroC is true, C will always be 0.
+  /// If \p DecomposeBitMask is specified, then, for equality predicates, this
+  /// will decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
-                   bool AllowNonZeroC = false);
+                   bool AllowNonZeroC = false, bool DecomposeBitMask = false);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 5c0d1dd1c74b0..6714088097127 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -75,11 +75,13 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
 
 std::optional<DecomposedBitTest>
 llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                           bool LookThruTrunc, bool AllowNonZeroC) {
+                           bool LookThruTrunc, bool AllowNonZeroC,
+                           bool DecomposeBitMask) {
   using namespace PatternMatch;
 
   const APInt *OrigC;
-  if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC)))
+  if ((ICmpInst::isEquality(Pred) && !DecomposeBitMask) ||
+      !match(RHS, m_APIntAllowPoison(OrigC)))
     return std::nullopt;
 
   bool Inverted = false;
@@ -97,9 +99,10 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   }
 
   DecomposedBitTest Result;
+
   switch (Pred) {
   default:
-    llvm_unreachable("Unexpected predicate");
+    return std::nullopt;
   case ICmpInst::ICMP_SLT: {
     // X < 0 is equivalent to (X & SignMask) != 0.
     if (C.isZero()) {
@@ -128,7 +131,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
     return std::nullopt;
   }
-  case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_ULT: {
     // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
     if (C.isPowerOf2()) {
       Result.Mask = -C;
@@ -147,6 +150,19 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
     return std::nullopt;
   }
+  case ICmpInst::ICMP_EQ:
+  case ICmpInst::ICMP_NE: {
+    assert(DecomposeBitMask);
+    const APInt *AndC;
+    Value *AndVal;
+    if (match(LHS, m_And(m_Value(AndVal), m_APIntAllowPoison(AndC)))) {
+      Result = {AndVal /*X*/, Pred /*Pred*/, *AndC /*Mask*/, *OrigC /*C*/};
+      break;
+    }
+
+    return std::nullopt;
+  }
+  }
 
   if (!AllowNonZeroC && !Result.C.isZero())
     return std::nullopt;
@@ -159,15 +175,17 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
     Result.X = X;
     Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
     Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
-  } else {
+  } else if (!Result.X) {
     Result.X = LHS;
   }
 
   return Result;
 }
 
-std::optional<DecomposedBitTest>
-llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
+std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
+                                                        bool LookThruTrunc,
+                                                        bool AllowNonZeroC,
+                                                        bool DecomposeBitMask) {
   using namespace PatternMatch;
   if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
     // Don't allow pointers. Splat vectors are fine.
@@ -175,7 +193,7 @@ llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
       return std::nullopt;
     return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
                                 ICmp->getPredicate(), LookThruTrunc,
-                                AllowNonZeroC);
+                                AllowNonZeroC, DecomposeBitMask);
   }
   Value *X;
   if (Cond->getType()->isIntOrIntVectorTy(1) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 19bf81137aab7..60461a6e8b338 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -875,22 +875,16 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
                            APInt &UnsetBitsMask) -> bool {
     CmpPredicate Pred = ICmp->getPredicate();
     // Can it be decomposed into  icmp eq (X & Mask), 0  ?
-    auto Res =
-        llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
-                                   Pred, /*LookThroughTrunc=*/false);
+    auto Res = llvm::decomposeBitTestICmp(
+        ICmp->getOperand(0), ICmp->getOperand(1), Pred,
+        /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
+        /*DecomposeBitMask=*/true);
     if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
       X = Res->X;
       UnsetBitsMask = Res->Mask;
       return true;
     }
 
-    // Is it  icmp eq (X & Mask), 0  already?
-    const APInt *Mask;
-    if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
-        Pred == ICmpInst::ICMP_EQ) {
-      UnsetBitsMask = *Mask;
-      return true;
-    }
     return false;
   };
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4bba2f406b4c1..8cdb00fb44a48 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3730,31 +3730,29 @@ static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
   Value *CmpLHS, *CmpRHS;
 
   if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
-    if (ICmpInst::isEquality(Pred)) {
-      if (!match(CmpRHS, m_Zero()))
-        return nullptr;
+    auto Res = decomposeBitTestICmp(
+        CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
 
-      V = CmpLHS;
-      const APInt *AndRHS;
-      if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
-        return nullptr;
+    if (!Res)
+      return nullptr;
 
-      AndMask = *AndRHS;
-    } else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
-      assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
-      AndMask = Res->Mask;
+    V = CmpLHS;
+    AndMask = Res->Mask;
+
+    if (!ICmpInst::isEquality(Pred)) {
       V = Res->X;
       KnownBits Known =
           computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
       AndMask &= Known.getMaxValue();
-      if (!AndMask.isPowerOf2())
-        return nullptr;
-
-      Pred = Res->Pred;
       CreateAnd = true;
-    } else {
-      return nullptr;
     }
+
+    Pred = Res->Pred;
+
+    if (!AndMask.isPowerOf2())
+      return nullptr;
+
   } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
     V = Trunc->getOperand(0);
     AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index eacc67cd5a475..ced1557ec3060 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -2761,14 +2761,11 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
                              m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)),
                                              CurLoop))));
   };
-  auto MatchConstantBitMask = [&]() {
-    return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
-           match(CmpLHS, m_And(m_Value(CurrX),
-                               m_CombineAnd(m_Value(BitMask), m_Power2()))) &&
-           (BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
-  };
+
   auto MatchDecomposableConstantBitMask = [&]() {
-    auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+    auto Res = llvm::decomposeBitTestICmp(
+        CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
     if (Res && Res->Mask.isPowerOf2()) {
       assert(ICmpInst::isEquality(Res->Pred));
       Pred = Res->Pred;
@@ -2780,8 +2777,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
     return false;
   };
 
-  if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
-      !MatchDecomposableConstantBitMask()) {
+  if (!MatchVariableBitMask() && !MatchDecomposableConstantBitMask()) {
     LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
     return false;
   }

>From 5a2b928e096f3d1bc8a92e9149e10c3b2d6974f2 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Wed, 23 Apr 2025 16:32:01 -0700
Subject: [PATCH 2/3] Review comments

Change-Id: I5355b103c34cb5d49868f44e3f8c10b95764a531
---
 llvm/include/llvm/Analysis/CmpInstAnalysis.h  | 10 +++---
 llvm/lib/Analysis/CmpInstAnalysis.cpp         | 20 ++++++------
 .../InstCombine/InstCombineAndOrXor.cpp       |  2 +-
 .../InstCombine/InstCombineSelect.cpp         | 32 ++++++++++---------
 .../Transforms/Scalar/LoopIdiomRecognize.cpp  |  2 +-
 5 files changed, 35 insertions(+), 31 deletions(-)

diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index a796a38feff88..77ffc47bca942 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -95,7 +95,7 @@ namespace llvm {
   /// Represents the operation icmp (X & Mask) pred C, where pred can only be
   /// eq or ne.
   struct DecomposedBitTest {
-    Value *X = nullptr;
+    Value *X;
     CmpInst::Predicate Pred;
     APInt Mask;
     APInt C;
@@ -103,20 +103,20 @@ namespace llvm {
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if possible.
   /// Unless \p AllowNonZeroC is true, C will always be 0. If \p
-  /// DecomposeBitMask is specified, then, for equality predicates, this will
+  /// DecomposeAnd is specified, then, for equality predicates, this will
   /// decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
                        bool LookThroughTrunc = true, bool AllowNonZeroC = false,
-                       bool DecomposeBitMask = false);
+                       bool DecomposeAnd = false);
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if
   /// possible. Unless \p AllowNonZeroC is true, C will always be 0.
-  /// If \p DecomposeBitMask is specified, then, for equality predicates, this
+  /// If \p DecomposeAnd is specified, then, for equality predicates, this
   /// will decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
-                   bool AllowNonZeroC = false, bool DecomposeBitMask = false);
+                   bool AllowNonZeroC = false, bool DecomposeAnd = false);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 6714088097127..8663928e98e9f 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -76,11 +76,11 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
 std::optional<DecomposedBitTest>
 llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
                            bool LookThruTrunc, bool AllowNonZeroC,
-                           bool DecomposeBitMask) {
+                           bool DecomposeAnd) {
   using namespace PatternMatch;
 
   const APInt *OrigC;
-  if ((ICmpInst::isEquality(Pred) && !DecomposeBitMask) ||
+  if ((ICmpInst::isEquality(Pred) && !DecomposeAnd) ||
       !match(RHS, m_APIntAllowPoison(OrigC)))
     return std::nullopt;
 
@@ -102,7 +102,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
   switch (Pred) {
   default:
-    return std::nullopt;
+    llvm_unreachable("Unexpected predicate");
   case ICmpInst::ICMP_SLT: {
     // X < 0 is equivalent to (X & SignMask) != 0.
     if (C.isZero()) {
@@ -152,11 +152,14 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   }
   case ICmpInst::ICMP_EQ:
   case ICmpInst::ICMP_NE: {
-    assert(DecomposeBitMask);
+    assert(DecomposeAnd);
     const APInt *AndC;
     Value *AndVal;
     if (match(LHS, m_And(m_Value(AndVal), m_APIntAllowPoison(AndC)))) {
-      Result = {AndVal /*X*/, Pred /*Pred*/, *AndC /*Mask*/, *OrigC /*C*/};
+      LHS = AndVal;
+      Result.Mask = *AndC;
+      Result.C = C;
+      Result.Pred = Pred;
       break;
     }
 
@@ -175,9 +178,8 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
     Result.X = X;
     Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
     Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
-  } else if (!Result.X) {
+  } else
     Result.X = LHS;
-  }
 
   return Result;
 }
@@ -185,7 +187,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
                                                         bool LookThruTrunc,
                                                         bool AllowNonZeroC,
-                                                        bool DecomposeBitMask) {
+                                                        bool DecomposeAnd) {
   using namespace PatternMatch;
   if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
     // Don't allow pointers. Splat vectors are fine.
@@ -193,7 +195,7 @@ std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
       return std::nullopt;
     return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
                                 ICmp->getPredicate(), LookThruTrunc,
-                                AllowNonZeroC, DecomposeBitMask);
+                                AllowNonZeroC, DecomposeAnd);
   }
   Value *X;
   if (Cond->getType()->isIntOrIntVectorTy(1) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 60461a6e8b338..5442cad8892b3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -878,7 +878,7 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
     auto Res = llvm::decomposeBitTestICmp(
         ICmp->getOperand(0), ICmp->getOperand(1), Pred,
         /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
-        /*DecomposeBitMask=*/true);
+        /*DecomposeAnd=*/true);
     if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
       X = Res->X;
       UnsetBitsMask = Res->Mask;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 8cdb00fb44a48..4bba2f406b4c1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3730,29 +3730,31 @@ static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
   Value *CmpLHS, *CmpRHS;
 
   if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
-    auto Res = decomposeBitTestICmp(
-        CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
-        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
-
-    if (!Res)
-      return nullptr;
+    if (ICmpInst::isEquality(Pred)) {
+      if (!match(CmpRHS, m_Zero()))
+        return nullptr;
 
-    V = CmpLHS;
-    AndMask = Res->Mask;
+      V = CmpLHS;
+      const APInt *AndRHS;
+      if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
+        return nullptr;
 
-    if (!ICmpInst::isEquality(Pred)) {
+      AndMask = *AndRHS;
+    } else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
+      assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
+      AndMask = Res->Mask;
       V = Res->X;
       KnownBits Known =
           computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
       AndMask &= Known.getMaxValue();
-      CreateAnd = true;
-    }
-
-    Pred = Res->Pred;
+      if (!AndMask.isPowerOf2())
+        return nullptr;
 
-    if (!AndMask.isPowerOf2())
+      Pred = Res->Pred;
+      CreateAnd = true;
+    } else {
       return nullptr;
-
+    }
   } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
     V = Trunc->getOperand(0);
     AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index ced1557ec3060..8ddfe46e7aae0 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -2765,7 +2765,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
   auto MatchDecomposableConstantBitMask = [&]() {
     auto Res = llvm::decomposeBitTestICmp(
         CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
-        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
+        /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
     if (Res && Res->Mask.isPowerOf2()) {
       assert(ICmpInst::isEquality(Res->Pred));
       Pred = Res->Pred;

>From b12d3e93d115c66d3e22fbf880ce14b9e79552b8 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Thu, 24 Apr 2025 12:14:15 -0700
Subject: [PATCH 3/3] Review comments

Change-Id: Id1749b08b8e4ae360aa885458e9b45acbd017a39
---
 llvm/include/llvm/Analysis/CmpInstAnalysis.h            | 2 +-
 llvm/lib/Analysis/CmpInstAnalysis.cpp                   | 4 ++--
 llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 2 +-
 3 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index 77ffc47bca942..f863f41ce9d99 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -104,7 +104,7 @@ namespace llvm {
   /// Decompose an icmp into the form ((X & Mask) pred C) if possible.
   /// Unless \p AllowNonZeroC is true, C will always be 0. If \p
   /// DecomposeAnd is specified, then, for equality predicates, this will
-  /// decompose bitmasking (e.g. implemented via `and`).
+  /// decompose bitmasking via `and`.
   std::optional<DecomposedBitTest>
   decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
                        bool LookThroughTrunc = true, bool AllowNonZeroC = false,
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 8663928e98e9f..a1a79e5685f80 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -99,7 +99,6 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   }
 
   DecomposedBitTest Result;
-
   switch (Pred) {
   default:
     llvm_unreachable("Unexpected predicate");
@@ -178,8 +177,9 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
     Result.X = X;
     Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
     Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
-  } else
+  } else {
     Result.X = LHS;
+  }
 
   return Result;
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 5442cad8892b3..57d4459228d03 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -877,7 +877,7 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
     // Can it be decomposed into  icmp eq (X & Mask), 0  ?
     auto Res = llvm::decomposeBitTestICmp(
         ICmp->getOperand(0), ICmp->getOperand(1), Pred,
-        /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
+        /*LookThroughTrunc=*/false, /*AllowNonZeroC=*/false,
         /*DecomposeAnd=*/true);
     if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
       X = Res->X;



More information about the llvm-commits mailing list