[llvm] [VPlan] Implement VPReductionRecipe::computeCost(). NFC (PR #107790)

Elvis Wang via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 4 03:21:02 PDT 2024


https://github.com/ElvisWang123 updated https://github.com/llvm/llvm-project/pull/107790

>From 4aba2d3cff75caade80e01b4635844ac149335a2 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Wed, 4 Sep 2024 20:52:14 -0700
Subject: [PATCH 1/3] [VPlan] Implment VPReductionRecipe::computeCost(). NFC

Implementation of `computeCost()` function for `VPReductionRecipe`.
---
 llvm/lib/Transforms/Vectorize/VPlan.h         |  4 ++++
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 24 +++++++++++++++++++
 2 files changed, 28 insertions(+)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c4567362eaffc7..56a2d074991684 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2393,6 +2393,10 @@ class VPReductionRecipe : public VPSingleDefRecipe {
   /// Generate the reduction in the loop
   void execute(VPTransformState &State) override;
 
+  /// Return the cost of VPReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   /// Print the recipe.
   void print(raw_ostream &O, const Twine &Indent,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 75908638532950..8137e4f4579fc8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2022,6 +2022,30 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
   State.set(this, NewRed, /*IsScalar*/ true);
 }
 
+InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
+                                               VPCostContext &Ctx) const {
+  RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+  Type *ElementTy = RdxDesc.getRecurrenceType();
+  auto *VectorTy = dyn_cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  if (VectorTy == nullptr)
+    return InstructionCost::getInvalid();
+
+  // Cost = Reduction cost + BinOp cost
+  InstructionCost Cost =
+      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+    Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+    return Cost + Ctx.TTI.getMinMaxReductionCost(
+                      Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  }
+
+  return Cost + Ctx.TTI.getArithmeticReductionCost(
+                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {

>From f2c2b8ff1a89885f48788a98679a660825fcbf4e Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 24 Sep 2024 23:31:24 -0700
Subject: [PATCH 2/3] Address comments.

---
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 8137e4f4579fc8..01e4a4d9494abd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2026,13 +2026,10 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
                                                VPCostContext &Ctx) const {
   RecurKind RdxKind = RdxDesc.getRecurrenceKind();
   Type *ElementTy = RdxDesc.getRecurrenceType();
-  auto *VectorTy = dyn_cast<VectorType>(ToVectorTy(ElementTy, VF));
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  if (VectorTy == nullptr)
-    return InstructionCost::getInvalid();
-
   // Cost = Reduction cost + BinOp cost
   InstructionCost Cost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);

>From dc5b4006a2a097b7248bba3d13eb308a30544a18 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Fri, 4 Oct 2024 03:19:36 -0700
Subject: [PATCH 3/3] Address comments.

---
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 01e4a4d9494abd..867ec118d466ac 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2025,7 +2025,14 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
 InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
                                                VPCostContext &Ctx) const {
   RecurKind RdxKind = RdxDesc.getRecurrenceKind();
-  Type *ElementTy = RdxDesc.getRecurrenceType();
+  // TODO: Support any-of reduction and in-loop reduction
+  assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) &&
+         "Not support any-of reduction in VPlan-based cost model currently.");
+
+  Type *ElementTy = Ctx.Types.inferScalarType(this->getVPSingleValue());
+  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
+         "Infered type and recurrence type mismatch.");
+
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();



More information about the llvm-commits mailing list