[llvm] [InstCombine] Move foldLogOpOfMaskedICmps to make it possible to handle trunc to i1. (PR #122179)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 8 14:09:00 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Andreas Jonson (andjo403)

<details>
<summary>Changes</summary>

Making the move of foldLogOpOfMaskedICmps before adding the handling of `trunc to i1` in getMaskedTypeForICmpPair as there was some diffs in llvm-opt-benchmark due to the move.

---
Full diff: https://github.com/llvm/llvm-project/pull/122179.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+126-106) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 8bfa3d0f6c5ea1..0aeb025ea44840 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -199,113 +199,132 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre
 /// the right hand side as a pair.
 /// LHS and RHS are the left hand side and the right hand side ICmps and PredL
 /// and PredR are their predicates, respectively.
-static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
-    Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS,
-    ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) {
-  // Don't allow pointers. Splat vectors are fine.
-  if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() ||
-      !RHS->getOperand(0)->getType()->isIntOrIntVectorTy())
-    return std::nullopt;
+static std::optional<std::pair<unsigned, unsigned>>
+getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *&D, Value *&E,
+                         Value *LHS, Value *RHS, ICmpInst::Predicate &PredL,
+                         ICmpInst::Predicate &PredR) {
 
-  // Here comes the tricky part:
-  // LHS might be of the form L11 & L12 == X, X == L21 & L22,
-  // and L11 & L12 == L21 & L22. The same goes for RHS.
-  // Now we must find those components L** and R**, that are equal, so
-  // that we can extract the parameters A, B, C, D, and E for the canonical
-  // above.
-  Value *L1 = LHS->getOperand(0);
-  Value *L2 = LHS->getOperand(1);
-  Value *L11, *L12, *L21, *L22;
-  // Check whether the icmp can be decomposed into a bit test.
-  if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
-    L21 = L22 = L1 = nullptr;
-  } else {
-    // Look for ANDs in the LHS icmp.
-    if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
-      // Any icmp can be viewed as being trivially masked; if it allows us to
-      // remove one, it's worth it.
-      L11 = L1;
-      L12 = Constant::getAllOnesValue(L1->getType());
-    }
+  Value *L1, *L11, *L12, *L2, *L21, *L22;
+  if (auto *LHSCMP = dyn_cast<ICmpInst>(LHS)) {
+
+    // Don't allow pointers. Splat vectors are fine.
+    if (!LHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
+      return std::nullopt;
+
+    PredL = LHSCMP->getPredicate();
+
+    // Here comes the tricky part:
+    // LHS might be of the form L11 & L12 == X, X == L21 & L22,
+    // and L11 & L12 == L21 & L22. The same goes for RHS.
+    // Now we must find those components L** and R**, that are equal, so
+    // that we can extract the parameters A, B, C, D, and E for the canonical
+    // above.
+    L1 = LHSCMP->getOperand(0);
+    L2 = LHSCMP->getOperand(1);
+    // Check whether the icmp can be decomposed into a bit test.
+    if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
+      L21 = L22 = L1 = nullptr;
+    } else {
+      // Look for ANDs in the LHS icmp.
+      if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
+        // Any icmp can be viewed as being trivially masked; if it allows us to
+        // remove one, it's worth it.
+        L11 = L1;
+        L12 = Constant::getAllOnesValue(L1->getType());
+      }
 
-    if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) {
-      L21 = L2;
-      L22 = Constant::getAllOnesValue(L2->getType());
+      if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) {
+        L21 = L2;
+        L22 = Constant::getAllOnesValue(L2->getType());
+      }
     }
-  }
+    // Bail if LHS was a icmp that can't be decomposed into an equality.
+    if (!ICmpInst::isEquality(PredL))
+      return std::nullopt;
 
-  // Bail if LHS was a icmp that can't be decomposed into an equality.
-  if (!ICmpInst::isEquality(PredL))
+  } else {
     return std::nullopt;
+  }
 
-  Value *R1 = RHS->getOperand(0);
-  Value *R2 = RHS->getOperand(1);
   Value *R11, *R12;
-  bool Ok = false;
-  if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
-    if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
-      A = R11;
-      D = R12;
-    } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
-      A = R12;
-      D = R11;
-    } else {
+  if (auto *RHSCMP = dyn_cast<ICmpInst>(RHS)) {
+
+    // Don't allow pointers. Splat vectors are fine.
+    if (!RHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
       return std::nullopt;
-    }
-    E = R2;
-    R1 = nullptr;
-    Ok = true;
-  } else {
-    if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
-      // As before, model no mask as a trivial mask if it'll let us do an
-      // optimization.
-      R11 = R1;
-      R12 = Constant::getAllOnesValue(R1->getType());
-    }
 
-    if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
-      A = R11;
-      D = R12;
-      E = R2;
-      Ok = true;
-    } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
-      A = R12;
-      D = R11;
+    PredR = RHSCMP->getPredicate();
+
+    Value *R1 = RHSCMP->getOperand(0);
+    Value *R2 = RHSCMP->getOperand(1);
+    bool Ok = false;
+    if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
+      if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+        A = R11;
+        D = R12;
+      } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+        A = R12;
+        D = R11;
+      } else {
+        return std::nullopt;
+      }
       E = R2;
+      R1 = nullptr;
       Ok = true;
-    }
-
-    // Avoid matching against the -1 value we created for unmasked operand.
-    if (Ok && match(A, m_AllOnes()))
-      Ok = false;
-  }
+    } else {
+      if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
+        // As before, model no mask as a trivial mask if it'll let us do an
+        // optimization.
+        R11 = R1;
+        R12 = Constant::getAllOnesValue(R1->getType());
+      }
 
-  // Bail if RHS was a icmp that can't be decomposed into an equality.
-  if (!ICmpInst::isEquality(PredR))
-    return std::nullopt;
+      if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+        A = R11;
+        D = R12;
+        E = R2;
+        Ok = true;
+      } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+        A = R12;
+        D = R11;
+        E = R2;
+        Ok = true;
+      }
 
-  // Look for ANDs on the right side of the RHS icmp.
-  if (!Ok) {
-    if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
-      R11 = R2;
-      R12 = Constant::getAllOnesValue(R2->getType());
+      // Avoid matching against the -1 value we created for unmasked operand.
+      if (Ok && match(A, m_AllOnes()))
+        Ok = false;
     }
 
-    if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
-      A = R11;
-      D = R12;
-      E = R1;
-      Ok = true;
-    } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
-      A = R12;
-      D = R11;
-      E = R1;
-      Ok = true;
-    } else {
+    // Bail if RHS was a icmp that can't be decomposed into an equality.
+    if (!ICmpInst::isEquality(PredR))
       return std::nullopt;
-    }
 
-    assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
+    // Look for ANDs on the right side of the RHS icmp.
+    if (!Ok) {
+      if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
+        R11 = R2;
+        R12 = Constant::getAllOnesValue(R2->getType());
+      }
+
+      if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+        A = R11;
+        D = R12;
+        E = R1;
+        Ok = true;
+      } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+        A = R12;
+        D = R11;
+        E = R1;
+        Ok = true;
+      } else {
+        return std::nullopt;
+      }
+
+      assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
+    }
+  } else {
+    return std::nullopt;
   }
 
   if (L11 == A) {
@@ -334,8 +353,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
 /// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8).
 /// Also used for logical and/or, must be poison safe.
 static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
-    ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *D,
-    Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+    Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *D, Value *E,
+    ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
     InstCombiner::BuilderTy &Builder) {
   // We are given the canonical form:
   //   (icmp ne (A & B), 0) & (icmp eq (A & D), E).
@@ -458,7 +477,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
   // (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
   if (IsSuperSetOrEqual(BCst, DCst)) {
     // We can't guarantee that samesign hold after this fold.
-    RHS->setSameSign(false);
+    if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
+      ICmp->setSameSign(false);
     return RHS;
   }
   // Otherwise, B is a subset of D. If B and E have a common bit set,
@@ -467,7 +487,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
   assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code");
   if ((*BCst & ECst) != 0) {
     // We can't guarantee that samesign hold after this fold.
-    RHS->setSameSign(false);
+    if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
+      ICmp->setSameSign(false);
     return RHS;
   }
   // Otherwise, LHS and RHS contradict and the whole expression becomes false
@@ -482,8 +503,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
 /// aren't of the common mask pattern type.
 /// Also used for logical and/or, must be poison safe.
 static Value *foldLogOpOfMaskedICmpsAsymmetric(
-    ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C,
-    Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+    Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D,
+    Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
     unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) {
   assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
          "Expected equality predicates for masked type of icmps.");
@@ -512,12 +533,12 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric(
 
 /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
 /// into a single (icmp(A & X) ==/!= Y).
-static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
+static Value *foldLogOpOfMaskedICmps(Value *LHS, Value *RHS, bool IsAnd,
                                      bool IsLogical,
                                      InstCombiner::BuilderTy &Builder,
                                      const SimplifyQuery &Q) {
   Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
-  ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
+  ICmpInst::Predicate PredL, PredR;
   std::optional<std::pair<unsigned, unsigned>> MaskPair =
       getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR);
   if (!MaskPair)
@@ -1067,8 +1088,7 @@ static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
   if (!JoinedByAnd)
     return nullptr;
   Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
-  ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(),
-                      CmpPred1 = Cmp1->getPredicate();
+  ICmpInst::Predicate CmpPred0, CmpPred1;
   // Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
   // 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
   // SignMask) == 0).
@@ -3325,12 +3345,6 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
     }
   }
 
-  // handle (roughly):
-  // (icmp ne (A & B), C) | (icmp ne (A & D), E)
-  // (icmp eq (A & B), C) & (icmp eq (A & D), E)
-  if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder, Q))
-    return V;
-
   if (Value *V =
           foldAndOrOfICmpEqConstantAndICmp(LHS, RHS, IsAnd, IsLogical, Builder))
     return V;
@@ -3510,6 +3524,12 @@ Value *InstCombinerImpl::foldBooleanAndOr(Value *LHS, Value *RHS,
       if (Value *Res = foldAndOrOfICmps(LHSCmp, RHSCmp, I, IsAnd, IsLogical))
         return Res;
 
+  /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
+  /// into a single (icmp(A & X) ==/!= Y).
+  if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder,
+                                        SQ.getWithInstruction(&I)))
+    return V;
+
   if (auto *LHSCmp = dyn_cast<FCmpInst>(LHS))
     if (auto *RHSCmp = dyn_cast<FCmpInst>(RHS))
       if (Value *Res = foldLogicOfFCmps(LHSCmp, RHSCmp, IsAnd, IsLogical))

``````````

</details>


https://github.com/llvm/llvm-project/pull/122179


More information about the llvm-commits mailing list