[llvm] Simplify Patterns (PR #102221)

Rose Silicon via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 6 13:49:11 PDT 2024


https://github.com/RSilicon created https://github.com/llvm/llvm-project/pull/102221

We can simplify patterns where we don't need to hold onto another variable.

>From b8d9564d09608f09aa1721b623fbcf2b2cd81bbb Mon Sep 17 00:00:00 2001
From: Rose <gfunni234 at gmail.com>
Date: Tue, 6 Aug 2024 16:43:57 -0400
Subject: [PATCH] Simplify Patterns

We can simplify patterns where we don't need to hold onto another variable.
---
 llvm/lib/Analysis/ValueTracking.cpp           | 41 ++++++++-----------
 llvm/lib/CodeGen/CodeGenPrepare.cpp           |  4 +-
 llvm/lib/IR/Constants.cpp                     |  4 +-
 llvm/lib/Target/X86/X86ISelLowering.cpp       |  9 ++--
 .../InstCombine/InstCombineAddSub.cpp         |  4 +-
 .../InstCombine/InstCombineCompares.cpp       |  5 +--
 .../InstCombine/InstCombineMulDivRem.cpp      |  5 ++-
 .../InstCombine/InstCombineNegator.cpp        |  3 +-
 .../InstCombine/InstCombineSelect.cpp         | 10 ++---
 .../InstCombine/InstCombineShifts.cpp         |  5 ++-
 10 files changed, 42 insertions(+), 48 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 202eaad57d1e3..e364f40fe5c79 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -3803,12 +3803,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
     case Instruction::SDiv: {
       const APInt *Denominator;
       // sdiv X, C -> adds log(C) sign bits.
-      if (match(U->getOperand(1), m_APInt(Denominator))) {
-
-        // Ignore non-positive denominator.
-        if (!Denominator->isStrictlyPositive())
-          break;
-
+      if (match(U->getOperand(1), m_StrictlyPositive(Denominator))) {
         // Calculate the incoming numerator bits.
         unsigned NumBits =
             ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
@@ -3826,26 +3821,24 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
       // srem X, C -> we know that the result is within [-C+1,C) when C is a
       // positive constant.  This let us put a lower bound on the number of sign
       // bits.
-      if (match(U->getOperand(1), m_APInt(Denominator))) {
+      if (match(U->getOperand(1), m_StrictlyPositive(Denominator))) {
 
         // Ignore non-positive denominator.
-        if (Denominator->isStrictlyPositive()) {
-          // Calculate the leading sign bit constraints by examining the
-          // denominator.  Given that the denominator is positive, there are two
-          // cases:
-          //
-          //  1. The numerator is positive. The result range is [0,C) and
-          //     [0,C) u< (1 << ceilLogBase2(C)).
-          //
-          //  2. The numerator is negative. Then the result range is (-C,0] and
-          //     integers in (-C,0] are either 0 or >u (-1 << ceilLogBase2(C)).
-          //
-          // Thus a lower bound on the number of sign bits is `TyBits -
-          // ceilLogBase2(C)`.
-
-          unsigned ResBits = TyBits - Denominator->ceilLogBase2();
-          Tmp = std::max(Tmp, ResBits);
-        }
+        // Calculate the leading sign bit constraints by examining the
+        // denominator.  Given that the denominator is positive, there are two
+        // cases:
+        //
+        //  1. The numerator is positive. The result range is [0,C) and
+        //     [0,C) u< (1 << ceilLogBase2(C)).
+        //
+        //  2. The numerator is negative. Then the result range is (-C,0] and
+        //     integers in (-C,0] are either 0 or >u (-1 << ceilLogBase2(C)).
+        //
+        // Thus a lower bound on the number of sign bits is `TyBits -
+        // ceilLogBase2(C)`.
+
+        unsigned ResBits = TyBits - Denominator->ceilLogBase2();
+        Tmp = std::max(Tmp, ResBits);
       }
       return Tmp;
     }
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 22d0708f54786..62b3ea23d478e 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -1733,9 +1733,9 @@ bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
     }
 
     // A + (-C), A u< C (canonicalized form of (sub A, C))
-    const APInt *CmpC, *AddC;
+    const APInt *AddC;
     if (match(U, m_Add(m_Specific(A), m_APInt(AddC))) &&
-        match(B, m_APInt(CmpC)) && *AddC == -(*CmpC)) {
+        match(B, m_SpecificInt(-*AddC))) {
       Sub = cast<BinaryOperator>(U);
       break;
     }
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index a1c9e925a024f..aca26d9f53130 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2636,7 +2636,7 @@ Constant *ConstantExpr::getXor(Constant *C1, Constant *C2) {
 Constant *ConstantExpr::getExactLogBase2(Constant *C) {
   Type *Ty = C->getType();
   const APInt *IVal;
-  if (match(C, m_APInt(IVal)) && IVal->isPowerOf2())
+  if (match(C, m_Power2(IVal)))
     return ConstantInt::get(Ty, IVal->logBase2());
 
   // FIXME: We can extract pow of 2 of splat constant for scalable vectors.
@@ -2654,7 +2654,7 @@ Constant *ConstantExpr::getExactLogBase2(Constant *C) {
       Elts.push_back(Constant::getNullValue(Ty->getScalarType()));
       continue;
     }
-    if (!match(Elt, m_APInt(IVal)) || !IVal->isPowerOf2())
+    if (!match(Elt, m_Power2(IVal)))
       return nullptr;
     Elts.push_back(ConstantInt::get(Ty->getScalarType(), IVal->logBase2()));
   }
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 2891e21be1b26..f305e205a2f18 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -30579,11 +30579,12 @@ static std::pair<Value *, BitTestKind> FindSingleBitChange(Value *V) {
       Value *BitV = I->getOperand(1);
 
       Value *AndOp;
-      const APInt *AndC;
-      if (match(BitV, m_c_And(m_Value(AndOp), m_APInt(AndC)))) {
+      if (match(BitV,
+                m_c_And(m_Value(AndOp),
+                        m_SpecificInt(I->getType()->getPrimitiveSizeInBits() -
+                                      1)))) {
         // Read past a shiftmask instruction to find count
-        if (*AndC == (I->getType()->getPrimitiveSizeInBits() - 1))
-          BitV = AndOp;
+        BitV = AndOp;
       }
       return {BitV, BTK};
     }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 3bd086230cbec..5796b844cb448 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -773,10 +773,10 @@ static Value *checkForNegativeOperand(BinaryOperator &I,
     if (match(X, m_Xor(m_Value(Y), m_APInt(C1)))) {
       // X = XOR(Y, C1), Y = OR(Z, C2), C2 = NOT(C1) ==> X == NOT(AND(Z, C1))
       // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, AND(Z, C1))
-      if (match(Y, m_Or(m_Value(Z), m_APInt(C2))) && (*C2 == ~(*C1))) {
+      if (match(Y, m_Or(m_Value(Z), m_SpecificInt(~(*C1))))) {
         Value *NewAnd = Builder.CreateAnd(Z, *C1);
         return Builder.CreateSub(RHS, NewAnd, "sub");
-      } else if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && (*C1 == *C2)) {
+      } else if (match(Y, m_And(m_Value(Z), m_SpecificInt(*C1)))) {
         // X = XOR(Y, C1), Y = AND(Z, C2), C2 == C1 ==> X == NOT(OR(Z, ~C1))
         // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, OR(Z, ~C1))
         Value *NewOr = Builder.CreateOr(Z, ~(*C1));
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 10a89b47e0753..9e1d9d5997271 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -7245,11 +7245,10 @@ Instruction *InstCombinerImpl::foldICmpCommutative(ICmpInst::Predicate Pred,
       return Res;
 
   {
-    Value *X;
     const APInt *C;
     // icmp X+Cst, X
-    if (match(Op0, m_Add(m_Value(X), m_APInt(C))) && Op1 == X)
-      return foldICmpAddOpConst(X, *C, Pred);
+    if (match(Op0, m_Add(m_Specific(Op1), m_APInt(C))))
+      return foldICmpAddOpConst(Op1, *C, Pred);
   }
 
   // abs(X) >=  X --> true
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f4f3644acfe5e..8f18be6a4db86 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1366,7 +1366,7 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
       auto OB1HasNSW = cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap();
       auto OB1HasNUW =
           cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap();
-      const APInt *C1, *C2;
+      const APInt *C1;
       if (IsSigned && OB0HasNSW) {
         if (OB1HasNSW && match(B, m_APInt(C1)) && !C1->isAllOnes())
           return BinaryOperator::CreateSDiv(A, B);
@@ -1374,7 +1374,8 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
       if (!IsSigned && OB0HasNUW) {
         if (OB1HasNUW)
           return BinaryOperator::CreateUDiv(A, B);
-        if (match(A, m_APInt(C1)) && match(B, m_APInt(C2)) && C2->ule(*C1))
+        if (match(A, m_APInt(C1)) &&
+            match(B, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, *C1)))
           return BinaryOperator::CreateUDiv(A, B);
       }
       return nullptr;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
index e4895b59f4b4a..225a1c6c2dd8d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
@@ -181,8 +181,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
   case Instruction::AShr:
   case Instruction::LShr: {
     // Right-shift sign bit smear is negatible.
-    const APInt *Op1Val;
-    if (match(I->getOperand(1), m_APInt(Op1Val)) && *Op1Val == BitWidth - 1) {
+    if (match(I->getOperand(1), m_SpecificInt(BitWidth - 1))) {
       Value *BO = I->getOpcode() == Instruction::AShr
                       ? Builder.CreateLShr(I->getOperand(0), I->getOperand(1))
                       : Builder.CreateAShr(I->getOperand(0), I->getOperand(1));
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 6025e73f07cf3..973c1b02d6b73 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1882,7 +1882,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
         DL.getTypeSizeInBits(TrueVal->getType()->getScalarType());
     APInt MinSignedValue = APInt::getSignedMinValue(BitWidth);
     Value *X;
-    const APInt *Y, *C;
+    const APInt *Y;
     bool TrueWhenUnset;
     bool IsBitTest = false;
     if (ICmpInst::isEquality(Pred) &&
@@ -1905,19 +1905,19 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
       Value *V = nullptr;
       // (X & Y) == 0 ? X : X ^ Y  --> X & ~Y
       if (TrueWhenUnset && TrueVal == X &&
-          match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
+          match(FalseVal, m_Xor(m_Specific(X), m_SpecificInt(*Y))))
         V = Builder.CreateAnd(X, ~(*Y));
       // (X & Y) != 0 ? X ^ Y : X  --> X & ~Y
       else if (!TrueWhenUnset && FalseVal == X &&
-               match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
+               match(TrueVal, m_Xor(m_Specific(X), m_SpecificInt(*Y))))
         V = Builder.CreateAnd(X, ~(*Y));
       // (X & Y) == 0 ? X ^ Y : X  --> X | Y
       else if (TrueWhenUnset && FalseVal == X &&
-               match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
+               match(TrueVal, m_Xor(m_Specific(X), m_SpecificInt(*Y))))
         V = Builder.CreateOr(X, *Y);
       // (X & Y) != 0 ? X : X ^ Y  --> X | Y
       else if (!TrueWhenUnset && TrueVal == X &&
-               match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
+               match(FalseVal, m_Xor(m_Specific(X), m_SpecificInt(*Y))))
         V = Builder.CreateOr(X, *Y);
 
       if (V)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 38f8a41214b68..dd919e9ffad19 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -460,8 +460,9 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
   // C << (X - AddC) --> (C >> AddC) << X
   // and
   // C >> (X - AddC) --> (C << AddC) >> X
-  if (match(Op0, m_APInt(AC)) && match(Op1, m_Add(m_Value(A), m_APInt(AddC))) &&
-      AddC->isNegative() && (-*AddC).ult(BitWidth)) {
+  if (match(Op0, m_APInt(AC)) &&
+      match(Op1, m_Add(m_Value(A), m_Negative(AddC))) &&
+      (-*AddC).ult(BitWidth)) {
     assert(!AC->isZero() && "Expected simplify of shifted zero");
     unsigned PosOffset = (-*AddC).getZExtValue();
 



More information about the llvm-commits mailing list