[llvm] [InstCombine] move foldAndOrOfICmpsOfAndWithPow2 into foldLogOpOfMaskedICmps (PR #121970)

Andreas Jonson via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 8 08:37:27 PST 2025


https://github.com/andjo403 updated https://github.com/llvm/llvm-project/pull/121970

>From ddefe31e991b6b8b620485f3039e1a7cae87cb06 Mon Sep 17 00:00:00 2001
From: Andreas Jonson <andjo403 at hotmail.com>
Date: Wed, 8 Jan 2025 17:36:18 +0100
Subject: [PATCH] [InstCombine] move foldAndOrOfICmpsOfAndWithPow2 in to
 foldLogOpOfMaskedICmps

---
 .../InstCombine/InstCombineAndOrXor.cpp       | 227 ++++++++----------
 .../InstCombine/InstCombineInternal.h         |   3 -
 .../Transforms/InstCombine/onehot_merge.ll    |   7 +-
 3 files changed, 101 insertions(+), 136 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 184c75a1dd860e..8bfa3d0f6c5ea1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -514,7 +514,8 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric(
 /// into a single (icmp(A & X) ==/!= Y).
 static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
                                      bool IsLogical,
-                                     InstCombiner::BuilderTy &Builder) {
+                                     InstCombiner::BuilderTy &Builder,
+                                     const SimplifyQuery &Q) {
   Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
   ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
   std::optional<std::pair<unsigned, unsigned>> MaskPair =
@@ -587,93 +588,107 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
     return Builder.CreateICmp(NewCC, NewAnd2, A);
   }
 
-  // Remaining cases assume at least that B and D are constant, and depend on
-  // their actual values. This isn't strictly necessary, just a "handle the
-  // easy cases for now" decision.
   const APInt *ConstB, *ConstD;
-  if (!match(B, m_APInt(ConstB)) || !match(D, m_APInt(ConstD)))
-    return nullptr;
-
-  if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) {
-    // (icmp ne (A & B), 0) & (icmp ne (A & D), 0) and
-    // (icmp ne (A & B), B) & (icmp ne (A & D), D)
-    //     -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0)
-    // Only valid if one of the masks is a superset of the other (check "B&D" is
-    // the same as either B or D).
-    APInt NewMask = *ConstB & *ConstD;
-    if (NewMask == *ConstB)
-      return LHS;
-    else if (NewMask == *ConstD)
-      return RHS;
-  }
-
-  if (Mask & AMask_NotAllOnes) {
-    // (icmp ne (A & B), B) & (icmp ne (A & D), D)
-    //     -> (icmp ne (A & B), A) or (icmp ne (A & D), A)
-    // Only valid if one of the masks is a superset of the other (check "B|D" is
-    // the same as either B or D).
-    APInt NewMask = *ConstB | *ConstD;
-    if (NewMask == *ConstB)
-      return LHS;
-    else if (NewMask == *ConstD)
-      return RHS;
-  }
-
-  if (Mask & (BMask_Mixed | BMask_NotMixed)) {
-    // Mixed:
-    // (icmp eq (A & B), C) & (icmp eq (A & D), E)
-    // We already know that B & C == C && D & E == E.
-    // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of
-    // C and E, which are shared by both the mask B and the mask D, don't
-    // contradict, then we can transform to
-    // -> (icmp eq (A & (B|D)), (C|E))
-    // Currently, we only handle the case of B, C, D, and E being constant.
-    // We can't simply use C and E because we might actually handle
-    //   (icmp ne (A & B), B) & (icmp eq (A & D), D)
-    // with B and D, having a single bit set.
-
-    // NotMixed:
-    // (icmp ne (A & B), C) & (icmp ne (A & D), E)
-    // -> (icmp ne (A & (B & D)), (C & E))
-    // Check the intersection (B & D) for inequality.
-    // Assume that (B & D) == B || (B & D) == D, i.e B/D is a subset of D/B
-    // and (B & D) & (C ^ E) == 0, bits of C and E, which are shared by both the
-    // B and the D, don't contradict.
-    // Note that we can assume (~B & C) == 0 && (~D & E) == 0, previous
-    // operation should delete these icmps if it hadn't been met.
-
-    const APInt *OldConstC, *OldConstE;
-    if (!match(C, m_APInt(OldConstC)) || !match(E, m_APInt(OldConstE)))
-      return nullptr;
-
-    auto FoldBMixed = [&](ICmpInst::Predicate CC, bool IsNot) -> Value * {
-      CC = IsNot ? CmpInst::getInversePredicate(CC) : CC;
-      const APInt ConstC = PredL != CC ? *ConstB ^ *OldConstC : *OldConstC;
-      const APInt ConstE = PredR != CC ? *ConstD ^ *OldConstE : *OldConstE;
+  if (match(B, m_APInt(ConstB)) && match(D, m_APInt(ConstD))) {
+    if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) {
+      // (icmp ne (A & B), 0) & (icmp ne (A & D), 0) and
+      // (icmp ne (A & B), B) & (icmp ne (A & D), D)
+      //     -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0)
+      // Only valid if one of the masks is a superset of the other (check "B&D"
+      // is the same as either B or D).
+      APInt NewMask = *ConstB & *ConstD;
+      if (NewMask == *ConstB)
+        return LHS;
+      if (NewMask == *ConstD)
+        return RHS;
+    }
 
-      if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue())
-        return IsNot ? nullptr : ConstantInt::get(LHS->getType(), !IsAnd);
+    if (Mask & AMask_NotAllOnes) {
+      // (icmp ne (A & B), B) & (icmp ne (A & D), D)
+      //     -> (icmp ne (A & B), A) or (icmp ne (A & D), A)
+      // Only valid if one of the masks is a superset of the other (check "B|D"
+      // is the same as either B or D).
+      APInt NewMask = *ConstB | *ConstD;
+      if (NewMask == *ConstB)
+        return LHS;
+      if (NewMask == *ConstD)
+        return RHS;
+    }
 
-      if (IsNot && !ConstB->isSubsetOf(*ConstD) && !ConstD->isSubsetOf(*ConstB))
+    if (Mask & (BMask_Mixed | BMask_NotMixed)) {
+      // Mixed:
+      // (icmp eq (A & B), C) & (icmp eq (A & D), E)
+      // We already know that B & C == C && D & E == E.
+      // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of
+      // C and E, which are shared by both the mask B and the mask D, don't
+      // contradict, then we can transform to
+      // -> (icmp eq (A & (B|D)), (C|E))
+      // Currently, we only handle the case of B, C, D, and E being constant.
+      // We can't simply use C and E because we might actually handle
+      //   (icmp ne (A & B), B) & (icmp eq (A & D), D)
+      // with B and D, having a single bit set.
+
+      // NotMixed:
+      // (icmp ne (A & B), C) & (icmp ne (A & D), E)
+      // -> (icmp ne (A & (B & D)), (C & E))
+      // Check the intersection (B & D) for inequality.
+      // Assume that (B & D) == B || (B & D) == D, i.e B/D is a subset of D/B
+      // and (B & D) & (C ^ E) == 0, bits of C and E, which are shared by both
+      // the B and the D, don't contradict. Note that we can assume (~B & C) ==
+      // 0 && (~D & E) == 0, previous operation should delete these icmps if it
+      // hadn't been met.
+
+      const APInt *OldConstC, *OldConstE;
+      if (!match(C, m_APInt(OldConstC)) || !match(E, m_APInt(OldConstE)))
         return nullptr;
 
-      APInt BD, CE;
-      if (IsNot) {
-        BD = *ConstB & *ConstD;
-        CE = ConstC & ConstE;
-      } else {
-        BD = *ConstB | *ConstD;
-        CE = ConstC | ConstE;
-      }
-      Value *NewAnd = Builder.CreateAnd(A, BD);
-      Value *CEVal = ConstantInt::get(A->getType(), CE);
-      return Builder.CreateICmp(CC, CEVal, NewAnd);
-    };
+      auto FoldBMixed = [&](ICmpInst::Predicate CC, bool IsNot) -> Value * {
+        CC = IsNot ? CmpInst::getInversePredicate(CC) : CC;
+        const APInt ConstC = PredL != CC ? *ConstB ^ *OldConstC : *OldConstC;
+        const APInt ConstE = PredR != CC ? *ConstD ^ *OldConstE : *OldConstE;
+
+        if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue())
+          return IsNot ? nullptr : ConstantInt::get(LHS->getType(), !IsAnd);
+
+        if (IsNot && !ConstB->isSubsetOf(*ConstD) &&
+            !ConstD->isSubsetOf(*ConstB))
+          return nullptr;
+
+        APInt BD, CE;
+        if (IsNot) {
+          BD = *ConstB & *ConstD;
+          CE = ConstC & ConstE;
+        } else {
+          BD = *ConstB | *ConstD;
+          CE = ConstC | ConstE;
+        }
+        Value *NewAnd = Builder.CreateAnd(A, BD);
+        Value *CEVal = ConstantInt::get(A->getType(), CE);
+        return Builder.CreateICmp(CC, CEVal, NewAnd);
+      };
+
+      if (Mask & BMask_Mixed)
+        return FoldBMixed(NewCC, false);
+      if (Mask & BMask_NotMixed) // can be else also
+        return FoldBMixed(NewCC, true);
+    }
+  }
 
-    if (Mask & BMask_Mixed)
-      return FoldBMixed(NewCC, false);
-    if (Mask & BMask_NotMixed) // can be else also
-      return FoldBMixed(NewCC, true);
+  // (icmp eq (A & B), 0) | (icmp eq (A & D), 0)
+  // -> (icmp ne (A & (B|D)), (B|D))
+  // (icmp ne (A & B), 0) & (icmp ne (A & D), 0)
+  // -> (icmp eq (A & (B|D)), (B|D))
+  // iff B and D is known to be a power of two
+  if (Mask & Mask_NotAllZeros &&
+      isKnownToBeAPowerOfTwo(B, /*OrZero=*/false, /*Depth=*/0, Q) &&
+      isKnownToBeAPowerOfTwo(D, /*OrZero=*/false, /*Depth=*/0, Q)) {
+    // If this is a logical and/or, then we must prevent propagation of a
+    // poison value from the RHS by inserting freeze.
+    if (IsLogical)
+      D = Builder.CreateFreeze(D);
+    Value *Mask = Builder.CreateOr(B, D);
+    Value *Masked = Builder.CreateAnd(A, Mask);
+    return Builder.CreateICmp(NewCC, Masked, Mask);
   }
   return nullptr;
 }
@@ -776,46 +791,6 @@ foldAndOrOfICmpsWithPow2AndWithZero(InstCombiner::BuilderTy &Builder,
   return Builder.CreateICmp(Pred, And, Op);
 }
 
-// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2)
-// Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2)
-Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS,
-                                                       ICmpInst *RHS,
-                                                       Instruction *CxtI,
-                                                       bool IsAnd,
-                                                       bool IsLogical) {
-  CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ;
-  if (LHS->getPredicate() != Pred || RHS->getPredicate() != Pred)
-    return nullptr;
-
-  if (!match(LHS->getOperand(1), m_Zero()) ||
-      !match(RHS->getOperand(1), m_Zero()))
-    return nullptr;
-
-  Value *L1, *L2, *R1, *R2;
-  if (match(LHS->getOperand(0), m_And(m_Value(L1), m_Value(L2))) &&
-      match(RHS->getOperand(0), m_And(m_Value(R1), m_Value(R2)))) {
-    if (L1 == R2 || L2 == R2)
-      std::swap(R1, R2);
-    if (L2 == R1)
-      std::swap(L1, L2);
-
-    if (L1 == R1 &&
-        isKnownToBeAPowerOfTwo(L2, false, 0, CxtI) &&
-        isKnownToBeAPowerOfTwo(R2, false, 0, CxtI)) {
-      // If this is a logical and/or, then we must prevent propagation of a
-      // poison value from the RHS by inserting freeze.
-      if (IsLogical)
-        R2 = Builder.CreateFreeze(R2);
-      Value *Mask = Builder.CreateOr(L2, R2);
-      Value *Masked = Builder.CreateAnd(L1, Mask);
-      auto NewPred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
-      return Builder.CreateICmp(NewPred, Masked, Mask);
-    }
-  }
-
-  return nullptr;
-}
-
 /// General pattern:
 ///   X & Y
 ///
@@ -3327,12 +3302,6 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
                                           bool IsLogical) {
   const SimplifyQuery Q = SQ.getWithInstruction(&I);
 
-  // Fold (iszero(A & K1) | iszero(A & K2)) ->  (A & (K1 | K2)) != (K1 | K2)
-  // Fold (!iszero(A & K1) & !iszero(A & K2)) ->  (A & (K1 | K2)) == (K1 | K2)
-  // if K1 and K2 are a one-bit mask.
-  if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, &I, IsAnd, IsLogical))
-    return V;
-
   ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
   Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
   Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1);
@@ -3359,7 +3328,7 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
   // handle (roughly):
   // (icmp ne (A & B), C) | (icmp ne (A & D), E)
   // (icmp eq (A & B), C) & (icmp eq (A & D), E)
-  if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder))
+  if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder, Q))
     return V;
 
   if (Value *V =
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index b31ae374540bbd..f6992119280c16 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -435,9 +435,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *
   canonicalizeConditionalNegationViaMathToSelect(BinaryOperator &i);
 
-  Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS,
-                                       Instruction *CxtI, bool IsAnd,
-                                       bool IsLogical = false);
   Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D,
                               bool InvertFalseVal = false);
   Value *getSelectCondition(Value *A, Value *B, bool ABIsTheSame);
diff --git a/llvm/test/Transforms/InstCombine/onehot_merge.ll b/llvm/test/Transforms/InstCombine/onehot_merge.ll
index d68de1f1f01904..3b7314d36eaaa7 100644
--- a/llvm/test/Transforms/InstCombine/onehot_merge.ll
+++ b/llvm/test/Transforms/InstCombine/onehot_merge.ll
@@ -1147,10 +1147,9 @@ define i1 @foo1_and_signbit_lshr_without_shifting_signbit_not_pwr2_logical(i32 %
 define i1 @two_types_of_bittest(i8 %x, i8 %c) {
 ; CHECK-LABEL: @two_types_of_bittest(
 ; CHECK-NEXT:    [[T0:%.*]] = shl nuw i8 1, [[C:%.*]]
-; CHECK-NEXT:    [[ICMP1:%.*]] = icmp slt i8 [[X:%.*]], 0
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[X]], [[T0]]
-; CHECK-NEXT:    [[ICMP2:%.*]] = icmp ne i8 [[AND]], 0
-; CHECK-NEXT:    [[RET:%.*]] = and i1 [[ICMP1]], [[ICMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = or i8 [[T0]], -128
+; CHECK-NEXT:    [[TMP2:%.*]] = and i8 [[X:%.*]], [[TMP1]]
+; CHECK-NEXT:    [[RET:%.*]] = icmp eq i8 [[TMP2]], [[TMP1]]
 ; CHECK-NEXT:    ret i1 [[RET]]
 ;
   %t0 = shl i8 1, %c



More information about the llvm-commits mailing list