[llvm] 67d247a - [InstCombine] Decompose more icmps into masks (#110836)

via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 4 01:17:27 PDT 2024


Author: Nikita Popov
Date: 2024-10-04T10:17:23+02:00
New Revision: 67d247a441572c046ccd82682ffe1102584f24ec

URL: https://github.com/llvm/llvm-project/commit/67d247a441572c046ccd82682ffe1102584f24ec
DIFF: https://github.com/llvm/llvm-project/commit/67d247a441572c046ccd82682ffe1102584f24ec.diff

LOG: [InstCombine] Decompose more icmps into masks (#110836)

Extend decomposeBitTestICmp() to handle cases where the resulting
comparison is of the form `icmp (X & Mask) pred C` with non-zero
`C`. Add a flag to allow code to opt-in to this behavior and use it in
the "log op of icmp" fold infrastructure.

This addresses regressions from #97289.

Proofs: https://alive2.llvm.org/ce/z/hUhdbU

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/CmpInstAnalysis.h
    llvm/lib/Analysis/CmpInstAnalysis.cpp
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/test/Transforms/InstCombine/and-or-icmps.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index 406dacd930605e..c7862a6d39d07e 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -92,18 +92,21 @@ namespace llvm {
   Constant *getPredForFCmpCode(unsigned Code, Type *OpTy,
                                CmpInst::Predicate &Pred);
 
-  /// Represents the operation icmp (X & Mask) pred 0, where pred can only be
+  /// Represents the operation icmp (X & Mask) pred C, where pred can only be
   /// eq or ne.
   struct DecomposedBitTest {
     Value *X;
     CmpInst::Predicate Pred;
     APInt Mask;
+    APInt C;
   };
 
-  /// Decompose an icmp into the form ((X & Mask) pred 0) if possible.
+  /// Decompose an icmp into the form ((X & Mask) pred C) if possible.
+  /// Unless \p AllowNonZeroC is true, C will always be 0.
   std::optional<DecomposedBitTest>
   decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                       bool LookThroughTrunc = true);
+                       bool LookThroughTrunc = true,
+                       bool AllowNonZeroC = false);
 
 } // end namespace llvm
 

diff  --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index ad111559b0d850..2580ea7e972488 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -75,7 +75,7 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
 
 std::optional<DecomposedBitTest>
 llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                           bool LookThruTrunc) {
+                           bool LookThruTrunc, bool AllowNonZeroC) {
   using namespace PatternMatch;
 
   const APInt *OrigC;
@@ -100,22 +100,57 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   switch (Pred) {
   default:
     llvm_unreachable("Unexpected predicate");
-  case ICmpInst::ICMP_SLT:
+  case ICmpInst::ICMP_SLT: {
     // X < 0 is equivalent to (X & SignMask) != 0.
-    if (!C.isZero())
-      return std::nullopt;
-    Result.Mask = APInt::getSignMask(C.getBitWidth());
-    Result.Pred = ICmpInst::ICMP_NE;
-    break;
+    if (C.isZero()) {
+      Result.Mask = APInt::getSignMask(C.getBitWidth());
+      Result.C = APInt::getZero(C.getBitWidth());
+      Result.Pred = ICmpInst::ICMP_NE;
+      break;
+    }
+
+    APInt FlippedSign = C ^ APInt::getSignMask(C.getBitWidth());
+    if (FlippedSign.isPowerOf2()) {
+      // X s< 10000100 is equivalent to (X & 11111100 == 10000000)
+      Result.Mask = -FlippedSign;
+      Result.C = APInt::getSignMask(C.getBitWidth());
+      Result.Pred = ICmpInst::ICMP_EQ;
+      break;
+    }
+
+    if (FlippedSign.isNegatedPowerOf2()) {
+      // X s< 01111100 is equivalent to (X & 11111100 != 01111100)
+      Result.Mask = FlippedSign;
+      Result.C = C;
+      Result.Pred = ICmpInst::ICMP_NE;
+      break;
+    }
+
+    return std::nullopt;
+  }
   case ICmpInst::ICMP_ULT:
     // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
-    if (!C.isPowerOf2())
-      return std::nullopt;
-    Result.Mask = -C;
-    Result.Pred = ICmpInst::ICMP_EQ;
-    break;
+    if (C.isPowerOf2()) {
+      Result.Mask = -C;
+      Result.C = APInt::getZero(C.getBitWidth());
+      Result.Pred = ICmpInst::ICMP_EQ;
+      break;
+    }
+
+    // X u< 11111100 is equivalent to (X & 11111100 != 11111100)
+    if (C.isNegatedPowerOf2()) {
+      Result.Mask = C;
+      Result.C = C;
+      Result.Pred = ICmpInst::ICMP_NE;
+      break;
+    }
+
+    return std::nullopt;
   }
 
+  if (!AllowNonZeroC && !Result.C.isZero())
+    return std::nullopt;
+
   if (Inverted)
     Result.Pred = ICmpInst::getInversePredicate(Result.Pred);
 
@@ -123,6 +158,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) {
     Result.X = X;
     Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
+    Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
   } else {
     Result.X = LHS;
   }

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index e8c0b006616543..688601a8ffa543 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -181,14 +181,15 @@ static unsigned conjugateICmpMask(unsigned Mask) {
 // Adapts the external decomposeBitTestICmp for local use.
 static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
                                  Value *&X, Value *&Y, Value *&Z) {
-  auto Res = llvm::decomposeBitTestICmp(LHS, RHS, Pred);
+  auto Res = llvm::decomposeBitTestICmp(
+      LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/true);
   if (!Res)
     return false;
 
   Pred = Res->Pred;
   X = Res->X;
   Y = ConstantInt::get(X->getType(), Res->Mask);
-  Z = ConstantInt::get(X->getType(), 0);
+  Z = ConstantInt::get(X->getType(), Res->C);
   return true;
 }
 

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index d0aa63ef06ba85..12eec99e7d22bb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5932,31 +5932,15 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) {
     return nullptr;
 
   // This matches patterns corresponding to tests of the signbit as well as:
-  // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?)
-  // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?)
-  if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true)) {
+  // (trunc X) pred C2 --> (X & Mask) == C
+  if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true,
+                                      /*AllowNonZeroC=*/true)) {
     Value *And = Builder.CreateAnd(Res->X, Res->Mask);
-    Constant *Zero = ConstantInt::getNullValue(Res->X->getType());
-    return new ICmpInst(Res->Pred, And, Zero);
+    Constant *C = ConstantInt::get(Res->X->getType(), Res->C);
+    return new ICmpInst(Res->Pred, And, C);
   }
 
   unsigned SrcBits = X->getType()->getScalarSizeInBits();
-  if (Pred == ICmpInst::ICMP_ULT && C->isNegatedPowerOf2()) {
-    // If C is a negative power-of-2 (high-bit mask):
-    // (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?)
-    Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits));
-    Value *And = Builder.CreateAnd(X, MaskC);
-    return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC);
-  }
-
-  if (Pred == ICmpInst::ICMP_UGT && (~*C).isPowerOf2()) {
-    // If C is not-of-power-of-2 (one clear bit):
-    // (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?)
-    Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits));
-    Value *And = Builder.CreateAnd(X, MaskC);
-    return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC);
-  }
-
   if (auto *II = dyn_cast<IntrinsicInst>(X)) {
     if (II->getIntrinsicID() == Intrinsic::cttz ||
         II->getIntrinsicID() == Intrinsic::ctlz) {

diff  --git a/llvm/test/Transforms/InstCombine/and-or-icmps.ll b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
index 26f708dc787c7d..9ddc628aa47691 100644
--- a/llvm/test/Transforms/InstCombine/and-or-icmps.ll
+++ b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
@@ -3335,10 +3335,7 @@ define i1 @icmp_eq_or_z_or_pow2orz_fail_bad_pred2(i8 %x, i8 %y) {
 
 define i1 @and_slt_to_mask(i8 %x) {
 ; CHECK-LABEL: @and_slt_to_mask(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[X:%.*]], -124
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[X]], 2
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT:    [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
+; CHECK-NEXT:    [[AND2:%.*]] = icmp slt i8 [[X:%.*]], -126
 ; CHECK-NEXT:    ret i1 [[AND2]]
 ;
   %cmp = icmp slt i8 %x, -124
@@ -3365,10 +3362,8 @@ define i1 @and_slt_to_mask_off_by_one(i8 %x) {
 
 define i1 @and_sgt_to_mask(i8 %x) {
 ; CHECK-LABEL: @and_sgt_to_mask(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], 123
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[X]], 2
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT:    [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], -2
+; CHECK-NEXT:    [[AND2:%.*]] = icmp eq i8 [[TMP1]], 124
 ; CHECK-NEXT:    ret i1 [[AND2]]
 ;
   %cmp = icmp sgt i8 %x, 123
@@ -3395,10 +3390,8 @@ define i1 @and_sgt_to_mask_off_by_one(i8 %x) {
 
 define i1 @and_ugt_to_mask(i8 %x) {
 ; CHECK-LABEL: @and_ugt_to_mask(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i8 [[X:%.*]], -5
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[X]], 2
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT:    [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], -2
+; CHECK-NEXT:    [[AND2:%.*]] = icmp eq i8 [[TMP1]], -4
 ; CHECK-NEXT:    ret i1 [[AND2]]
 ;
   %cmp = icmp ugt i8 %x, -5


        


More information about the llvm-commits mailing list