[llvm] 06499f3 - [InstCombine] Prepare foldLogOpOfMaskedICmps to handle trunc to i1. (NFC) (#122179)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 15 09:08:57 PST 2025


Author: Andreas Jonson
Date: 2025-01-15T18:08:53+01:00
New Revision: 06499f3672afc371b653bf54422c2e80e1e27c90

URL: https://github.com/llvm/llvm-project/commit/06499f3672afc371b653bf54422c2e80e1e27c90
DIFF: https://github.com/llvm/llvm-project/commit/06499f3672afc371b653bf54422c2e80e1e27c90.diff

LOG: [InstCombine] Prepare foldLogOpOfMaskedICmps to handle trunc to i1. (NFC) (#122179)

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/CmpInstAnalysis.h
    llvm/lib/Analysis/CmpInstAnalysis.cpp
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index c7862a6d39d07e..aeda58ac7535d5 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -108,6 +108,12 @@ namespace llvm {
                        bool LookThroughTrunc = true,
                        bool AllowNonZeroC = false);
 
+  /// Decompose an icmp into the form ((X & Mask) pred C) if
+  /// possible. Unless \p AllowNonZeroC is true, C will always be 0.
+  std::optional<DecomposedBitTest>
+  decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
+                   bool AllowNonZeroC = false);
+
 } // end namespace llvm
 
 #endif

diff  --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 2580ea7e972488..3599428c5ff416 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -165,3 +165,17 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
   return Result;
 }
+
+std::optional<DecomposedBitTest>
+llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
+  if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
+    // Don't allow pointers. Splat vectors are fine.
+    if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy())
+      return std::nullopt;
+    return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
+                                ICmp->getPredicate(), LookThruTrunc,
+                                AllowNonZeroC);
+  }
+
+  return std::nullopt;
+}

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index f82a557e5760c8..f7d17b1aa3865c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -179,10 +179,10 @@ static unsigned conjugateICmpMask(unsigned Mask) {
 }
 
 // Adapts the external decomposeBitTestICmp for local use.
-static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
+static bool decomposeBitTestICmp(Value *Cond, CmpInst::Predicate &Pred,
                                  Value *&X, Value *&Y, Value *&Z) {
-  auto Res = llvm::decomposeBitTestICmp(
-      LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/true);
+  auto Res = llvm::decomposeBitTest(Cond, /*LookThroughTrunc=*/true,
+                                    /*AllowNonZeroC=*/true);
   if (!Res)
     return false;
 
@@ -198,13 +198,10 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre
 /// the right hand side as a pair.
 /// LHS and RHS are the left hand side and the right hand side ICmps and PredL
 /// and PredR are their predicates, respectively.
-static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
-    Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS,
-    ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) {
-  // Don't allow pointers. Splat vectors are fine.
-  if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() ||
-      !RHS->getOperand(0)->getType()->isIntOrIntVectorTy())
-    return std::nullopt;
+static std::optional<std::pair<unsigned, unsigned>>
+getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *&D, Value *&E,
+                         Value *LHS, Value *RHS, ICmpInst::Predicate &PredL,
+                         ICmpInst::Predicate &PredR) {
 
   // Here comes the tricky part:
   // LHS might be of the form L11 & L12 == X, X == L21 & L22,
@@ -212,13 +209,23 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
   // Now we must find those components L** and R**, that are equal, so
   // that we can extract the parameters A, B, C, D, and E for the canonical
   // above.
-  Value *L1 = LHS->getOperand(0);
-  Value *L2 = LHS->getOperand(1);
-  Value *L11, *L12, *L21, *L22;
+
   // Check whether the icmp can be decomposed into a bit test.
-  if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
+  Value *L1, *L11, *L12, *L2, *L21, *L22;
+  if (decomposeBitTestICmp(LHS, PredL, L11, L12, L2)) {
     L21 = L22 = L1 = nullptr;
   } else {
+    auto *LHSCMP = dyn_cast<ICmpInst>(LHS);
+    if (!LHSCMP)
+      return std::nullopt;
+
+    // Don't allow pointers. Splat vectors are fine.
+    if (!LHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
+      return std::nullopt;
+
+    PredL = LHSCMP->getPredicate();
+    L1 = LHSCMP->getOperand(0);
+    L2 = LHSCMP->getOperand(1);
     // Look for ANDs in the LHS icmp.
     if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
       // Any icmp can be viewed as being trivially masked; if it allows us to
@@ -237,11 +244,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
   if (!ICmpInst::isEquality(PredL))
     return std::nullopt;
 
-  Value *R1 = RHS->getOperand(0);
-  Value *R2 = RHS->getOperand(1);
-  Value *R11, *R12;
-  bool Ok = false;
-  if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
+  Value *R11, *R12, *R2;
+  if (decomposeBitTestICmp(RHS, PredR, R11, R12, R2)) {
     if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
       A = R11;
       D = R12;
@@ -252,9 +256,19 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
       return std::nullopt;
     }
     E = R2;
-    R1 = nullptr;
-    Ok = true;
   } else {
+    auto *RHSCMP = dyn_cast<ICmpInst>(RHS);
+    if (!RHSCMP)
+      return std::nullopt;
+    // Don't allow pointers. Splat vectors are fine.
+    if (!RHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
+      return std::nullopt;
+
+    PredR = RHSCMP->getPredicate();
+
+    Value *R1 = RHSCMP->getOperand(0);
+    R2 = RHSCMP->getOperand(1);
+    bool Ok = false;
     if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
       // As before, model no mask as a trivial mask if it'll let us do an
       // optimization.
@@ -277,36 +291,32 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
     // Avoid matching against the -1 value we created for unmasked operand.
     if (Ok && match(A, m_AllOnes()))
       Ok = false;
+
+    // Look for ANDs on the right side of the RHS icmp.
+    if (!Ok) {
+      if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
+        R11 = R2;
+        R12 = Constant::getAllOnesValue(R2->getType());
+      }
+
+      if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+        A = R11;
+        D = R12;
+        E = R1;
+      } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+        A = R12;
+        D = R11;
+        E = R1;
+      } else {
+        return std::nullopt;
+      }
+    }
   }
 
   // Bail if RHS was a icmp that can't be decomposed into an equality.
   if (!ICmpInst::isEquality(PredR))
     return std::nullopt;
 
-  // Look for ANDs on the right side of the RHS icmp.
-  if (!Ok) {
-    if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
-      R11 = R2;
-      R12 = Constant::getAllOnesValue(R2->getType());
-    }
-
-    if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
-      A = R11;
-      D = R12;
-      E = R1;
-      Ok = true;
-    } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
-      A = R12;
-      D = R11;
-      E = R1;
-      Ok = true;
-    } else {
-      return std::nullopt;
-    }
-
-    assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
-  }
-
   if (L11 == A) {
     B = L12;
     C = L2;
@@ -333,8 +343,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
 /// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8).
 /// Also used for logical and/or, must be poison safe.
 static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
-    ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *D,
-    Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+    Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *D, Value *E,
+    ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
     InstCombiner::BuilderTy &Builder) {
   // We are given the canonical form:
   //   (icmp ne (A & B), 0) & (icmp eq (A & D), E).
@@ -457,7 +467,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
   // (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
   if (IsSuperSetOrEqual(BCst, DCst)) {
     // We can't guarantee that samesign hold after this fold.
-    RHS->setSameSign(false);
+    if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
+      ICmp->setSameSign(false);
     return RHS;
   }
   // Otherwise, B is a subset of D. If B and E have a common bit set,
@@ -466,7 +477,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
   assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code");
   if ((*BCst & ECst) != 0) {
     // We can't guarantee that samesign hold after this fold.
-    RHS->setSameSign(false);
+    if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
+      ICmp->setSameSign(false);
     return RHS;
   }
   // Otherwise, LHS and RHS contradict and the whole expression becomes false
@@ -481,8 +493,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
 /// aren't of the common mask pattern type.
 /// Also used for logical and/or, must be poison safe.
 static Value *foldLogOpOfMaskedICmpsAsymmetric(
-    ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C,
-    Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+    Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D,
+    Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
     unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) {
   assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
          "Expected equality predicates for masked type of icmps.");
@@ -511,12 +523,12 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric(
 
 /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
 /// into a single (icmp(A & X) ==/!= Y).
-static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
+static Value *foldLogOpOfMaskedICmps(Value *LHS, Value *RHS, bool IsAnd,
                                      bool IsLogical,
                                      InstCombiner::BuilderTy &Builder,
                                      const SimplifyQuery &Q) {
   Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
-  ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
+  ICmpInst::Predicate PredL, PredR;
   std::optional<std::pair<unsigned, unsigned>> MaskPair =
       getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR);
   if (!MaskPair)
@@ -1066,8 +1078,7 @@ static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
   if (!JoinedByAnd)
     return nullptr;
   Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
-  ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(),
-                      CmpPred1 = Cmp1->getPredicate();
+  ICmpInst::Predicate CmpPred0, CmpPred1;
   // Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
   // 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
   // SignMask) == 0).


        


More information about the llvm-commits mailing list