[llvm] [VPlan] Implment VPReductionRecipe::computeCost(). NFC (PR #107790)
    via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Sun Sep  8 17:25:03 PDT 2024
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Elvis Wang (ElvisWang123)
<details>
<summary>Changes</summary>
Implementation of `computeCost()` function for `VPReductionRecipe`.
---
Full diff: https://github.com/llvm/llvm-project/pull/107790.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+4) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+24) 
``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index bd71dbffa929e7..68f8ea3dea0db3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2242,6 +2242,10 @@ class VPReductionRecipe : public VPSingleDefRecipe {
   /// Generate the reduction in the loop
   void execute(VPTransformState &State) override;
 
+  /// Return the cost of VPReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   /// Print the recipe.
   void print(raw_ostream &O, const Twine &Indent,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 3d08e3cefbf633..12b6d2c7d06dd1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1897,6 +1897,30 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
   State.set(this, NewRed, 0, /*IsScalar*/ true);
 }
 
+InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
+                                               VPCostContext &Ctx) const {
+  RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+  Type *ElementTy = RdxDesc.getRecurrenceType();
+  auto *VectorTy = dyn_cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  if (VectorTy == nullptr)
+    return InstructionCost::getInvalid();
+
+  // Cost = Reduction cost + BinOp cost
+  InstructionCost Cost =
+      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+    Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+    return Cost + Ctx.TTI.getMinMaxReductionCost(
+                      Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  }
+
+  return Cost + Ctx.TTI.getArithmeticReductionCost(
+                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {
``````````
</details>
https://github.com/llvm/llvm-project/pull/107790
    
    
More information about the llvm-commits
mailing list