[llvm] 72dc5ca - [LV] Make use of PatternMatchers in getReductionPatternCost. NFC

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 21 03:34:41 PDT 2021


Author: David Green
Date: 2021-07-21T11:34:30+01:00
New Revision: 72dc5cab4f8bce67a78dd149f17b8709208da580

URL: https://github.com/llvm/llvm-project/commit/72dc5cab4f8bce67a78dd149f17b8709208da580
DIFF: https://github.com/llvm/llvm-project/commit/72dc5cab4f8bce67a78dd149f17b8709208da580.diff

LOG: [LV] Make use of PatternMatchers in getReductionPatternCost. NFC

Pulled out of D106166, this modifies getReductionPatternCost to use
PatternMatchers, hopefully simplifying the code a little.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 4d87897fc235d..001f1e45e488a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7127,6 +7127,7 @@ LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I,
 
 Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
     Instruction *I, ElementCount VF, Type *Ty, TTI::TargetCostKind CostKind) {
+  using namespace llvm::PatternMatch;
   // Early exit for no inloop reductions
   if (InLoopReductionChains.empty() || VF.isScalar() || !isa<VectorType>(Ty))
     return None;
@@ -7145,13 +7146,12 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
   // it is not we return an invalid cost specifying the orignal cost method
   // should be used.
   Instruction *RetI = I;
-  if ((RetI->getOpcode() == Instruction::SExt ||
-       RetI->getOpcode() == Instruction::ZExt)) {
+  if (match(RetI, m_ZExtOrSExt(m_Value()))) {
     if (!RetI->hasOneUser())
       return None;
     RetI = RetI->user_back();
   }
-  if (RetI->getOpcode() == Instruction::Mul &&
+  if (match(RetI, m_Mul(m_Value(), m_Value())) &&
       RetI->user_back()->getOpcode() == Instruction::Add) {
     if (!RetI->hasOneUser())
       return None;
@@ -7183,8 +7183,10 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
 
   VectorTy = VectorType::get(I->getOperand(0)->getType(), VectorTy);
 
-  if (RedOp && (isa<SExtInst>(RedOp) || isa<ZExtInst>(RedOp)) &&
+  Instruction *Op0, *Op1;
+  if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value())) &&
       !TheLoop->isLoopInvariant(RedOp)) {
+    // Matched reduce(ext(A))
     bool IsUnsigned = isa<ZExtInst>(RedOp);
     auto *ExtType = VectorType::get(RedOp->getOperand(0)->getType(), VectorTy);
     InstructionCost RedCost = TTI.getExtendedAddReductionCost(
@@ -7196,22 +7198,20 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
                              TTI::CastContextHint::None, CostKind, RedOp);
     if (RedCost.isValid() && RedCost < BaseCost + ExtCost)
       return I == RetI ? RedCost : 0;
-  } else if (RedOp && RedOp->getOpcode() == Instruction::Mul) {
-    Instruction *Mul = RedOp;
-    Instruction *Op0 = dyn_cast<Instruction>(Mul->getOperand(0));
-    Instruction *Op1 = dyn_cast<Instruction>(Mul->getOperand(1));
-    if (Op0 && Op1 && (isa<SExtInst>(Op0) || isa<ZExtInst>(Op0)) &&
+  } else if (RedOp &&
+             match(RedOp, m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) {
+    if (match(Op0, m_ZExtOrSExt(m_Value())) &&
         Op0->getOpcode() == Op1->getOpcode() &&
         Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
         !TheLoop->isLoopInvariant(Op0) && !TheLoop->isLoopInvariant(Op1)) {
       bool IsUnsigned = isa<ZExtInst>(Op0);
       auto *ExtType = VectorType::get(Op0->getOperand(0)->getType(), VectorTy);
-      // reduce(mul(ext, ext))
+      // Matched reduce(mul(ext, ext))
       InstructionCost ExtCost =
           TTI.getCastInstrCost(Op0->getOpcode(), VectorTy, ExtType,
                                TTI::CastContextHint::None, CostKind, Op0);
       InstructionCost MulCost =
-          TTI.getArithmeticInstrCost(Mul->getOpcode(), VectorTy, CostKind);
+          TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
 
       InstructionCost RedCost = TTI.getExtendedAddReductionCost(
           /*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType,
@@ -7220,8 +7220,9 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
       if (RedCost.isValid() && RedCost < ExtCost * 2 + MulCost + BaseCost)
         return I == RetI ? RedCost : 0;
     } else {
+      // Matched reduce(mul())
       InstructionCost MulCost =
-          TTI.getArithmeticInstrCost(Mul->getOpcode(), VectorTy, CostKind);
+          TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
 
       InstructionCost RedCost = TTI.getExtendedAddReductionCost(
           /*IsMLA=*/true, true, RdxDesc.getRecurrenceType(), VectorTy,


        


More information about the llvm-commits mailing list