[llvm-branch-commits] [llvm] d1c4e85 - [SLP] reduce opcode API dependency in reduction cost calc; NFC

Sanjay Patel via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Jan 18 06:43:50 PST 2021


Author: Sanjay Patel
Date: 2021-01-18T09:32:57-05:00
New Revision: d1c4e859ce42c35c61a0db2f1eb8a4209be4503d

URL: https://github.com/llvm/llvm-project/commit/d1c4e859ce42c35c61a0db2f1eb8a4209be4503d
DIFF: https://github.com/llvm/llvm-project/commit/d1c4e859ce42c35c61a0db2f1eb8a4209be4503d.diff

LOG: [SLP] reduce opcode API dependency in reduction cost calc; NFC

The icmp opcode is now hard-coded in the cost model call.
This will make it easier to eventually remove all opcode
queries for min/max patterns as we transition to intrinsics.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 8dd318a880fc..bf8ef208ccf9 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -7058,12 +7058,10 @@ class HorizontalReduction {
   int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal,
                        unsigned ReduxWidth) {
     Type *ScalarTy = FirstReducedVal->getType();
-    auto *VecTy = FixedVectorType::get(ScalarTy, ReduxWidth);
+    FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth);
 
     RecurKind Kind = RdxTreeInst.getKind();
-    unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
-    int SplittingRdxCost;
-    int ScalarReduxCost;
+    int VectorCost, ScalarCost;
     switch (Kind) {
     case RecurKind::Add:
     case RecurKind::Mul:
@@ -7071,22 +7069,24 @@ class HorizontalReduction {
     case RecurKind::And:
     case RecurKind::Xor:
     case RecurKind::FAdd:
-    case RecurKind::FMul:
-      SplittingRdxCost = TTI->getArithmeticReductionCost(
-          RdxOpcode, VecTy, /*IsPairwiseForm=*/false);
-      ScalarReduxCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy);
+    case RecurKind::FMul: {
+      unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
+      VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
+                                                      /*IsPairwiseForm=*/false);
+      ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy);
       break;
+    }
     case RecurKind::SMax:
     case RecurKind::SMin:
     case RecurKind::UMax:
     case RecurKind::UMin: {
-      auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VecTy));
+      auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VectorTy));
       bool IsUnsigned = Kind == RecurKind::UMax || Kind == RecurKind::UMin;
-      SplittingRdxCost =
-          TTI->getMinMaxReductionCost(VecTy, VecCondTy,
+      VectorCost =
+          TTI->getMinMaxReductionCost(VectorTy, VecCondTy,
                                       /*IsPairwiseForm=*/false, IsUnsigned);
-      ScalarReduxCost =
-          TTI->getCmpSelInstrCost(RdxOpcode, ScalarTy) +
+      ScalarCost =
+          TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy) +
           TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy,
                                   CmpInst::makeCmpResultType(ScalarTy));
       break;
@@ -7095,12 +7095,12 @@ class HorizontalReduction {
       llvm_unreachable("Expected arithmetic or min/max reduction operation");
     }
 
-    ScalarReduxCost *= (ReduxWidth - 1);
-    LLVM_DEBUG(dbgs() << "SLP: Adding cost "
-                      << SplittingRdxCost - ScalarReduxCost
+    // Scalar cost is repeated for N-1 elements.
+    ScalarCost *= (ReduxWidth - 1);
+    LLVM_DEBUG(dbgs() << "SLP: Adding cost " << VectorCost - ScalarCost
                       << " for reduction that starts with " << *FirstReducedVal
                       << " (It is a splitting reduction)\n");
-    return SplittingRdxCost - ScalarReduxCost;
+    return VectorCost - ScalarCost;
   }
 
   /// Emit a horizontal reduction of the vectorized value.


        


More information about the llvm-branch-commits mailing list