[llvm] Combine more examples to new Checked matcher API (PR #91097)

via llvm-commits llvm-commits at lists.llvm.org
Sat May 4 17:18:45 PDT 2024


https://github.com/AtariDreams updated https://github.com/llvm/llvm-project/pull/91097

>From 1646f2e1c73e0a2582ab1a7684b86d7189780398 Mon Sep 17 00:00:00 2001
From: Rose <gfunni234 at gmail.com>
Date: Sat, 4 May 2024 19:24:12 -0400
Subject: [PATCH] Combine more examples to new Checked matcher API

---
 llvm/lib/Analysis/InstructionSimplify.cpp     | 67 ++++++++++++-------
 llvm/lib/Analysis/ValueTracking.cpp           | 12 ++--
 llvm/lib/Target/X86/X86ISelLowering.cpp       | 11 +--
 .../InstCombine/InstCombineAddSub.cpp         |  9 +--
 .../InstCombine/InstCombineCasts.cpp          |  4 +-
 .../InstCombineSimplifyDemanded.cpp           |  9 ++-
 .../InstCombine/InstCombineVectorOps.cpp      |  6 +-
 7 files changed, 70 insertions(+), 48 deletions(-)

diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 4061dae83c10f3..502c91d33df2c4 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -1028,33 +1028,43 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
     // Make sure that a constant is not the minimum signed value because taking
     // the abs() of that is undefined.
     Type *Ty = X->getType();
-    const APInt *C;
-    if (match(X, m_APInt(C)) && !C->isMinSignedValue()) {
-      // Is the variable divisor magnitude always greater than the constant
-      // dividend magnitude?
-      // |Y| > |C| --> Y < -abs(C) or Y > abs(C)
-      Constant *PosDividendC = ConstantInt::get(Ty, C->abs());
-      Constant *NegDividendC = ConstantInt::get(Ty, -C->abs());
+
+    // Is the variable divisor magnitude always greater than the constant
+    // dividend magnitude?
+    // |Y| > |C| --> Y < -abs(C) or Y > abs(C)
+    auto CheckSignCmp = [Ty, Y, Q, MaxRecurse](const APInt &C) {
+      if (C.isMinSignedValue())
+        return false;
+      Constant *PosDividendC = ConstantInt::get(Ty, C.abs());
+      Constant *NegDividendC = ConstantInt::get(Ty, -C.abs());
       if (isICmpTrue(CmpInst::ICMP_SLT, Y, NegDividendC, Q, MaxRecurse) ||
           isICmpTrue(CmpInst::ICMP_SGT, Y, PosDividendC, Q, MaxRecurse))
         return true;
-    }
-    if (match(Y, m_APInt(C))) {
+      return false;
+    };
+
+    auto CheckSignCmpY = [Ty, X, Y, Q, MaxRecurse](const APInt &C) {
       // Special-case: we can't take the abs() of a minimum signed value. If
       // that's the divisor, then all we have to do is prove that the dividend
       // is also not the minimum signed value.
-      if (C->isMinSignedValue())
+      if (C.isMinSignedValue())
         return isICmpTrue(CmpInst::ICMP_NE, X, Y, Q, MaxRecurse);
 
       // Is the variable dividend magnitude always less than the constant
       // divisor magnitude?
       // |X| < |C| --> X > -abs(C) and X < abs(C)
-      Constant *PosDivisorC = ConstantInt::get(Ty, C->abs());
-      Constant *NegDivisorC = ConstantInt::get(Ty, -C->abs());
-      if (isICmpTrue(CmpInst::ICMP_SGT, X, NegDivisorC, Q, MaxRecurse) &&
-          isICmpTrue(CmpInst::ICMP_SLT, X, PosDivisorC, Q, MaxRecurse))
+      Constant *PosDividendC = ConstantInt::get(Ty, C.abs());
+      Constant *NegDividendC = ConstantInt::get(Ty, -C.abs());
+      if (isICmpTrue(CmpInst::ICMP_SLT, Y, NegDividendC, Q, MaxRecurse) ||
+          isICmpTrue(CmpInst::ICMP_SGT, Y, PosDividendC, Q, MaxRecurse))
         return true;
-    }
+      return false;
+    };
+
+    if (match(X, m_CheckedInt(CheckSignCmp)))
+      return true;
+    if (match(Y, m_CheckedInt(CheckSignCmpY)))
+      return true;
     return false;
   }
 
@@ -1063,9 +1073,11 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
   // Is the unsigned dividend known to be less than a constant divisor?
   // TODO: Convert this (and above) to range analysis
   //      ("computeConstantRangeIncludingKnownBits")?
-  const APInt *C;
-  if (match(Y, m_APInt(C)) &&
-      computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(*C))
+
+  auto CheckULT1 = [X, Q](const APInt &C) {
+    return computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(C);
+  };
+  if (match(Y, m_CheckedInt(CheckULT1)))
     return true;
 
   // Try again for any divisor:
@@ -2362,15 +2374,16 @@ static Value *simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
   // (-1 << X) | (-1 >> (C - X)) --> -1
   // (-1 >> X) | (-1 << (C - X)) --> -1
   // ...with C <= bitwidth (and commuted variants).
-  Value *X, *Y;
+  Value *X = nullptr, *Y = nullptr;
+  auto CheckULE = [X](const APInt &C) {
+    return C.ule(X->getType()->getScalarSizeInBits());
+  };
   if ((match(Op0, m_Shl(m_AllOnes(), m_Value(X))) &&
        match(Op1, m_LShr(m_AllOnes(), m_Value(Y)))) ||
       (match(Op1, m_Shl(m_AllOnes(), m_Value(X))) &&
        match(Op0, m_LShr(m_AllOnes(), m_Value(Y))))) {
-    const APInt *C;
-    if ((match(X, m_Sub(m_APInt(C), m_Specific(Y))) ||
-         match(Y, m_Sub(m_APInt(C), m_Specific(X)))) &&
-        C->ule(X->getType()->getScalarSizeInBits())) {
+    if (match(X, m_Sub(m_CheckedInt(CheckULE), m_Specific(Y))) ||
+        match(Y, m_Sub(m_CheckedInt(CheckULE), m_Specific(X)))) {
       return ConstantInt::getAllOnesValue(X->getType());
     }
   }
@@ -3158,9 +3171,10 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
   // x udiv C >=u x --> false for C != 1.
   // x udiv C ==  x --> false for C != 1.
   // TODO: allow non-constant shift amount/divisor
-  const APInt *C;
-  if ((match(LBO, m_LShr(m_Specific(RHS), m_APInt(C))) && *C != 0) ||
-      (match(LBO, m_UDiv(m_Specific(RHS), m_APInt(C))) && *C != 1)) {
+  auto IsNotZero = [](const APInt &C) { return C != 0; };
+  auto IsNotOne = [](const APInt &C) { return C != 1; };
+  if (match(LBO, m_LShr(m_Specific(RHS), m_CheckedInt(IsNotZero))) ||
+      match(LBO, m_UDiv(m_Specific(RHS), m_CheckedInt(IsNotOne)))) {
     if (isKnownNonZero(RHS, Q)) {
       switch (Pred) {
       default:
@@ -3203,6 +3217,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
 
   // (sub C, X) == X, C is odd  --> false
   // (sub C, X) != X, C is odd  --> true
+  const APInt *C;
   if (match(LBO, m_Sub(m_APIntAllowPoison(C), m_Specific(RHS))) &&
       (*C & 1) == 1 && ICmpInst::isEquality(Pred))
     return (Pred == ICmpInst::ICMP_EQ) ? getFalse(ITy) : getTrue(ITy);
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 0dbb39d7c8ec46..9a4ae6cdcf8258 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -3275,11 +3275,11 @@ static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
 /// the multiplication is nuw or nsw.
 static bool isNonEqualMul(const Value *V1, const Value *V2, unsigned Depth,
                           const SimplifyQuery &Q) {
+  auto NotZeroOrOne = [](const APInt &C) { return !C.isZero() && !C.isOne(); };
   if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
-    const APInt *C;
-    return match(OBO, m_Mul(m_Specific(V1), m_APInt(C))) &&
+    return match(OBO, m_Mul(m_Specific(V1), m_CheckedInt(NotZeroOrOne))) &&
            (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
-           !C->isZero() && !C->isOne() && isKnownNonZero(V1, Q, Depth + 1);
+           isKnownNonZero(V1, Q, Depth + 1);
   }
   return false;
 }
@@ -3288,11 +3288,11 @@ static bool isNonEqualMul(const Value *V1, const Value *V2, unsigned Depth,
 /// the shift is nuw or nsw.
 static bool isNonEqualShl(const Value *V1, const Value *V2, unsigned Depth,
                           const SimplifyQuery &Q) {
+  auto NotZeroOrOne = [](const APInt &C) { return !C.isZero() && !C.isOne(); };
   if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
-    const APInt *C;
-    return match(OBO, m_Shl(m_Specific(V1), m_APInt(C))) &&
+    return match(OBO, m_Shl(m_Specific(V1), m_CheckedInt(NotZeroOrOne))) &&
            (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
-           !C->isZero() && isKnownNonZero(V1, Q, Depth + 1);
+           isKnownNonZero(V1, Q, Depth + 1);
   }
   return false;
 }
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index cf4a64ffded2e8..b5728b0ca3a04e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -30447,11 +30447,12 @@ static std::pair<Value *, BitTestKind> FindSingleBitChange(Value *V) {
       Value *BitV = I->getOperand(1);
 
       Value *AndOp;
-      const APInt *AndC;
-      if (match(BitV, m_c_And(m_Value(AndOp), m_APInt(AndC)))) {
-        // Read past a shiftmask instruction to find count
-        if (*AndC == (I->getType()->getPrimitiveSizeInBits() - 1))
-          BitV = AndOp;
+      // Read past a shiftmask instruction to find count
+      auto IsMask = [&I](const APInt &AndC) {
+        return AndC == I->getType()->getPrimitiveSizeInBits() - 1;
+      };
+      if (match(BitV, m_c_And(m_Value(AndOp), m_CheckedInt(IsMask)))) {
+        BitV = AndOp;
       }
       return {BitV, BTK};
     }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 51ac77348ed9e3..10964eeb8ba814 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1761,7 +1761,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   // zext(ctpop(A) >u/!= 1) + (ctlz(A, true) ^ (BW - 1))
   // -->
   // BW - ctlz(A - 1, false)
-  const APInt *XorC;
+  auto CheckBW = [A](const APInt &XorC) {
+    return XorC == A->getType()->getScalarSizeInBits() - 1;
+  };
   if (match(&I,
             m_c_Add(
                 m_ZExt(m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(m_Value(A)),
@@ -1769,9 +1771,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
                 m_OneUse(m_ZExtOrSelf(m_OneUse(m_Xor(
                     m_OneUse(m_TruncOrSelf(m_OneUse(
                         m_Intrinsic<Intrinsic::ctlz>(m_Deferred(A), m_One())))),
-                    m_APInt(XorC))))))) &&
-      (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_NE) &&
-      *XorC == A->getType()->getScalarSizeInBits() - 1) {
+                    m_CheckedInt(CheckBW))))))) &&
+      (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_NE)) {
     Value *Sub = Builder.CreateAdd(A, Constant::getAllOnesValue(A->getType()));
     Value *Ctlz = Builder.CreateIntrinsic(Intrinsic::ctlz, {A->getType()},
                                           {Sub, Builder.getFalse()});
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 11e31877de38c2..bc6c9fd7deeaf5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -764,8 +764,8 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
     }
 
     {
-      const APInt *C;
-      if (match(Src, m_Shl(m_APInt(C), m_Value(X))) && (*C)[0] == 1) {
+      auto CheckOdd = [](const APInt &C) { return (C)[0] == 1; };
+      if (match(Src, m_Shl(m_CheckedInt(CheckOdd), m_Value(X)))) {
         // trunc (C << X) to i1 --> X == 0, where C is odd
         return new ICmpInst(ICmpInst::Predicate::ICMP_EQ, X, Zero);
       }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 6739b8745d74e4..2fbcf29c20d53c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -336,9 +336,12 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     // If the RHS is a constant, see if we can change it. Don't alter a -1
     // constant because that's a canonical 'not' op, and that is better for
     // combining, SCEV, and codegen.
-    const APInt *C;
-    if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnes()) {
-      if ((*C | ~DemandedMask).isAllOnes()) {
+    auto IsNotAllOnes = [](const APInt &C) { return C.isAllOnes(); };
+    auto IsNotAllOnesAndDemandedMask = [&DemandedMask](const APInt &C) {
+      return (C | ~DemandedMask).isAllOnes();
+    };
+    if (match(I->getOperand(1), m_CheckedInt(IsNotAllOnes))) {
+      if (match(I->getOperand(1), m_CheckedInt(IsNotAllOnesAndDemandedMask))) {
         // Force bits to 1 to create a 'not' op.
         I->setOperand(1, ConstantInt::getAllOnesValue(VTy));
         return I;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 99f1f8eb34bb5a..7c23a22d717b9d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -2071,8 +2071,10 @@ static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) {
   }
   case Instruction::Or: {
     // or X, C --> add X, C (when X and C have no common bits set)
-    const APInt *C;
-    if (match(BO1, m_APInt(C)) && MaskedValueIsZero(BO0, *C, DL))
+    auto CheckMaskedValIsZero = [BO0, DL](const APInt &C) {
+      return MaskedValueIsZero(BO0, C, DL);
+    };
+    if (match(BO1, m_CheckedInt(CheckMaskedValIsZero)))
       return {Instruction::Add, BO0, BO1};
     break;
   }



More information about the llvm-commits mailing list