[llvm] [VPlan] Implement VPReductionRecipe::computeCost(). NFC (PR #107790)

Elvis Wang via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 3 18:05:21 PDT 2024


https://github.com/ElvisWang123 updated https://github.com/llvm/llvm-project/pull/107790

>From 4aba2d3cff75caade80e01b4635844ac149335a2 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Wed, 4 Sep 2024 20:52:14 -0700
Subject: [PATCH 1/4] [VPlan] Implment VPReductionRecipe::computeCost(). NFC

Implementation of `computeCost()` function for `VPReductionRecipe`.
---
 llvm/lib/Transforms/Vectorize/VPlan.h         |  4 ++++
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 24 +++++++++++++++++++
 2 files changed, 28 insertions(+)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c4567362eaffc7..56a2d074991684 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2393,6 +2393,10 @@ class VPReductionRecipe : public VPSingleDefRecipe {
   /// Generate the reduction in the loop
   void execute(VPTransformState &State) override;
 
+  /// Return the cost of VPReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   /// Print the recipe.
   void print(raw_ostream &O, const Twine &Indent,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 75908638532950..8137e4f4579fc8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2022,6 +2022,30 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
   State.set(this, NewRed, /*IsScalar*/ true);
 }
 
+InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
+                                               VPCostContext &Ctx) const {
+  RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+  Type *ElementTy = RdxDesc.getRecurrenceType();
+  auto *VectorTy = dyn_cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  if (VectorTy == nullptr)
+    return InstructionCost::getInvalid();
+
+  // Cost = Reduction cost + BinOp cost
+  InstructionCost Cost =
+      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+    Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+    return Cost + Ctx.TTI.getMinMaxReductionCost(
+                      Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  }
+
+  return Cost + Ctx.TTI.getArithmeticReductionCost(
+                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {

>From f2c2b8ff1a89885f48788a98679a660825fcbf4e Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 24 Sep 2024 23:31:24 -0700
Subject: [PATCH 2/4] Address comments.

---
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 8137e4f4579fc8..01e4a4d9494abd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2026,13 +2026,10 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
                                                VPCostContext &Ctx) const {
   RecurKind RdxKind = RdxDesc.getRecurrenceKind();
   Type *ElementTy = RdxDesc.getRecurrenceType();
-  auto *VectorTy = dyn_cast<VectorType>(ToVectorTy(ElementTy, VF));
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  if (VectorTy == nullptr)
-    return InstructionCost::getInvalid();
-
   // Cost = Reduction cost + BinOp cost
   InstructionCost Cost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);

>From 6a388173e8ae1480b08b5c8fa58ef4fa6597fa2d Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 3 Oct 2024 00:56:56 -0700
Subject: [PATCH 3/4] Recoginize the reduction patterns from D93476 in
 computeCost().

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  15 ++
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 151 +++++++++++++++++-
 2 files changed, 159 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 792e0e17dd8719..39d95856c226ef 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7220,6 +7220,9 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
                                                  ChainOps.end());
     // Also include the operands of instructions in the chain, as the cost-model
     // may mark extends as free.
+    // We only handle the reduction cost in the VPlan-based cost model
+    // currently.
+    // TODO: Handle this calculation in VPWidenRecipe and VPWidenCastRecipe.
     for (auto *ChainOp : ChainOps) {
       for (Value *Op : ChainOp->operands()) {
         if (auto *I = dyn_cast<Instruction>(Op))
@@ -7227,6 +7230,18 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
       }
     }
 
+    // Since we implemented the reduction cost for the VPReductionRecipe,
+    // removing the instruction here to prevent VPReductionRecipe::computeCost
+    // be skiped.
+    // TODO: Remove following checks when we can fully support reduction pattern
+    // cost in the VPlan-based cost model.
+    for (auto *I : ChainOpsAndOperands) {
+      if (I->getOpcode() == RdxDesc.getOpcode()) {
+        ChainOpsAndOperands.remove(I);
+        break;
+      }
+    }
+
     // Pre-compute the cost for I, if it has a reduction pattern cost.
     for (Instruction *I : ChainOpsAndOperands) {
       auto ReductionCost = CM.getReductionPatternCost(
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 01e4a4d9494abd..1edf524c7f35ed 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2022,6 +2022,11 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
   State.set(this, NewRed, /*IsScalar*/ true);
 }
 
+static bool isZExtOrSExt(Instruction::CastOps CastOpcode) {
+  return CastOpcode == Instruction::CastOps::ZExt ||
+         CastOpcode == Instruction::CastOps::SExt;
+}
+
 InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
                                                VPCostContext &Ctx) const {
   RecurKind RdxKind = RdxDesc.getRecurrenceKind();
@@ -2030,17 +2035,149 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  // Cost = Reduction cost + BinOp cost
-  InstructionCost Cost =
-      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  InstructionCost BaseCost;
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    return Cost + Ctx.TTI.getMinMaxReductionCost(
-                      Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+    BaseCost = Ctx.TTI.getMinMaxReductionCost(
+        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    BaseCost = Ctx.TTI.getArithmeticReductionCost(
+        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  }
+
+  // For a call to the llvm.fmuladd intrinsic we need to add the cost of a
+  // normal fmul instruction to the cost of the fadd reduction.
+  if (RdxKind == RecurKind::FMulAdd)
+    BaseCost +=
+        Ctx.TTI.getArithmeticInstrCost(Instruction::FMul, VectorTy, CostKind);
+
+  // If we're using ordered reductions then we can just return the base cost
+  // here, since getArithmeticReductionCost calculates the full ordered
+  // reduction cost when FP reassociation is not allowed.
+  if (IsOrdered && Opcode == Instruction::FAdd)
+    return BaseCost;
+
+  // Special case for arm from D93476
+  // The reduction instruction can be substituted in following condition.
+  //
+  //       %sa = sext <16 x i8> A to <16 x i32>
+  //       %sb = sext <16 x i8> B to <16 x i32>
+  //       %m = mul <16 x i32> %sa, %sb
+  //       %r = vecreduce.add(%m)
+  //       ->
+  //       R = VMLADAV A, B
+  //
+  // There are other instructions for performing add reductions of
+  // v4i32/v8i16/v16i8 into i32 (VADDV), for doing the same with v4i32->i64
+  // (VADDLV) and for performing a v4i32/v8i16 MLA into an i64 (VMLALDAV).
+  //
+  // We are looking for a pattern of, and finding the minimal acceptable cost:
+  //       reduce.add(ext(mul(ext(A), ext(B)))) or
+  //       reduce(ext(A)) or
+  //       reduce.add(mul(ext(A), ext(B))) or
+  //       reduce.add(mul(A, B)) or
+  //       reduce(A).
+
+  // Try to match reduce(ext(...))
+  auto *Ext = dyn_cast<VPWidenCastRecipe>(getVecOp());
+  if (Ext && isZExtOrSExt(Ext->getOpcode())) {
+    bool isUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
+
+    // Try to match reduce.add(ext(mul(...)))
+    auto *ExtTy = cast<VectorType>(
+        ToVectorTy(Ext->getOperand(0)->getUnderlyingValue()->getType(), VF));
+    auto *Mul = dyn_cast_if_present<VPWidenRecipe>(
+        Ext->getOperand(0)->getDefiningRecipe());
+    if (Mul && Mul->getOpcode() == Instruction::Mul &&
+        Opcode == Instruction::Add) {
+      auto *MulTy = cast<VectorType>(
+          ToVectorTy(Mul->getUnderlyingValue()->getType(), VF));
+      auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(0));
+      auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(1));
+
+      // Match reduce.add(ext(mul(ext(A), ext(B))))
+      if (InnerExt0 && isZExtOrSExt(InnerExt0->getOpcode()) && InnerExt1 &&
+          isZExtOrSExt(InnerExt1->getOpcode()) &&
+          InnerExt0->getOpcode() == InnerExt1->getOpcode()) {
+        Type *InnerExt0Ty =
+            InnerExt0->getOperand(0)->getUnderlyingValue()->getType();
+        Type *InnerExt1Ty =
+            InnerExt1->getOperand(0)->getUnderlyingValue()->getType();
+        // Get the largest type.
+        auto *MaxExtVecTy = cast<VectorType>(
+            ToVectorTy(InnerExt0Ty->getIntegerBitWidth() >
+                               InnerExt1Ty->getIntegerBitWidth()
+                           ? InnerExt0Ty
+                           : InnerExt1Ty,
+                       VF));
+        InstructionCost RedCost = Ctx.TTI.getMulAccReductionCost(
+            isUnsigned, ElementTy, MaxExtVecTy, CostKind);
+        InstructionCost InnerExtCost =
+            Ctx.TTI.getCastInstrCost(InnerExt0->getOpcode(), MulTy, MaxExtVecTy,
+                                     TTI::CastContextHint::None, CostKind);
+        InstructionCost MulCost =
+            Ctx.TTI.getArithmeticInstrCost(Instruction::Mul, MulTy, CostKind);
+        InstructionCost ExtCost =
+            Ctx.TTI.getCastInstrCost(Ext->getOpcode(), VectorTy, ExtTy,
+                                     TTI::CastContextHint::None, CostKind);
+        if (RedCost.isValid() &&
+            RedCost < InnerExtCost * 2 + MulCost + ExtCost + BaseCost)
+          return RedCost;
+      }
+    }
+
+    // Match reduce(ext(A))
+    InstructionCost RedCost =
+        Ctx.TTI.getExtendedReductionCost(Opcode, isUnsigned, ElementTy, ExtTy,
+                                         RdxDesc.getFastMathFlags(), CostKind);
+    InstructionCost ExtCost =
+        Ctx.TTI.getCastInstrCost(Ext->getOpcode(), VectorTy, ExtTy,
+                                 TTI::CastContextHint::None, CostKind);
+    if (RedCost.isValid() && RedCost < RedCost + ExtCost)
+      return RedCost;
+  }
+
+  // Try to match reduce.add(mul(...))
+  auto *Mul =
+      dyn_cast_if_present<VPWidenRecipe>(getVecOp()->getDefiningRecipe());
+  if (Mul && Mul->getOpcode() == Instruction::Mul &&
+      Opcode == Instruction::Add) {
+    // Match reduce.add(mul(ext(A), ext(B)))
+    auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(0));
+    auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(1));
+    auto *MulTy =
+        cast<VectorType>(ToVectorTy(Mul->getUnderlyingValue()->getType(), VF));
+    InstructionCost MulCost =
+        Ctx.TTI.getArithmeticInstrCost(Instruction::Mul, MulTy, CostKind);
+    if (InnerExt0 && isZExtOrSExt(InnerExt0->getOpcode()) && InnerExt1 &&
+        InnerExt0->getOpcode() == InnerExt1->getOpcode()) {
+      Type *InnerExt0Ty =
+          InnerExt0->getOperand(0)->getUnderlyingValue()->getType();
+      Type *InnerExt1Ty =
+          InnerExt1->getOperand(0)->getUnderlyingValue()->getType();
+      auto *MaxInnerExtVecTy = cast<VectorType>(ToVectorTy(
+          InnerExt0Ty->getIntegerBitWidth() > InnerExt1Ty->getIntegerBitWidth()
+              ? InnerExt0Ty
+              : InnerExt1Ty,
+          VF));
+      bool isUnsigned = InnerExt0->getOpcode() == Instruction::CastOps::ZExt;
+      InstructionCost RedCost = Ctx.TTI.getMulAccReductionCost(
+          isUnsigned, ElementTy, MaxInnerExtVecTy, CostKind);
+      InstructionCost InnerExtCost = Ctx.TTI.getCastInstrCost(
+          InnerExt0->getOpcode(), MulTy, MaxInnerExtVecTy,
+          TTI::CastContextHint::None, CostKind);
+      if (RedCost.isValid() && RedCost < BaseCost + MulCost + 2 * InnerExtCost)
+        return RedCost;
+    }
+    // Match reduce.add(mul)
+    InstructionCost RedCost =
+        Ctx.TTI.getMulAccReductionCost(true, ElementTy, VectorTy, CostKind);
+    if (RedCost.isValid() && RedCost < BaseCost + MulCost)
+      return RedCost;
   }
 
-  return Cost + Ctx.TTI.getArithmeticReductionCost(
-                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  // Normal cost = Reduction cost + BinOp cost
+  return BaseCost + Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)

>From faa86e500720fa73e3fa41800a6a042313069b33 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 3 Oct 2024 17:49:09 -0700
Subject: [PATCH 4/4] Address comments and use inferSalarType

---
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 30 +++++++++++--------
 1 file changed, 17 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 1edf524c7f35ed..d07826be40d45e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2030,11 +2030,19 @@ static bool isZExtOrSExt(Instruction::CastOps CastOpcode) {
 InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
                                                VPCostContext &Ctx) const {
   RecurKind RdxKind = RdxDesc.getRecurrenceKind();
-  Type *ElementTy = RdxDesc.getRecurrenceType();
+  Type *ElementTy = Ctx.Types.inferScalarType(this->getVPSingleValue());
+  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID());
+
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
+  // TODO: Remove the assertion when we support any-of reduction in VPlan-base
+  // cost model.
+  assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(
+             RdxDesc.getRecurrenceKind()) &&
+         "VPlan-base cost model not support any-of reduction.");
+
   InstructionCost BaseCost;
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
@@ -2085,13 +2093,13 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
 
     // Try to match reduce.add(ext(mul(...)))
     auto *ExtTy = cast<VectorType>(
-        ToVectorTy(Ext->getOperand(0)->getUnderlyingValue()->getType(), VF));
+        ToVectorTy(Ctx.Types.inferScalarType(Ext->getOperand(0)), VF));
     auto *Mul = dyn_cast_if_present<VPWidenRecipe>(
         Ext->getOperand(0)->getDefiningRecipe());
     if (Mul && Mul->getOpcode() == Instruction::Mul &&
         Opcode == Instruction::Add) {
       auto *MulTy = cast<VectorType>(
-          ToVectorTy(Mul->getUnderlyingValue()->getType(), VF));
+          ToVectorTy(Ctx.Types.inferScalarType(Mul->getVPSingleValue()), VF));
       auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(0));
       auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(1));
 
@@ -2099,10 +2107,8 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
       if (InnerExt0 && isZExtOrSExt(InnerExt0->getOpcode()) && InnerExt1 &&
           isZExtOrSExt(InnerExt1->getOpcode()) &&
           InnerExt0->getOpcode() == InnerExt1->getOpcode()) {
-        Type *InnerExt0Ty =
-            InnerExt0->getOperand(0)->getUnderlyingValue()->getType();
-        Type *InnerExt1Ty =
-            InnerExt1->getOperand(0)->getUnderlyingValue()->getType();
+        Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
+        Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
         // Get the largest type.
         auto *MaxExtVecTy = cast<VectorType>(
             ToVectorTy(InnerExt0Ty->getIntegerBitWidth() >
@@ -2145,16 +2151,14 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
     // Match reduce.add(mul(ext(A), ext(B)))
     auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(0));
     auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(1));
-    auto *MulTy =
-        cast<VectorType>(ToVectorTy(Mul->getUnderlyingValue()->getType(), VF));
+    auto *MulTy = cast<VectorType>(
+        ToVectorTy(Ctx.Types.inferScalarType(Mul->getVPSingleValue()), VF));
     InstructionCost MulCost =
         Ctx.TTI.getArithmeticInstrCost(Instruction::Mul, MulTy, CostKind);
     if (InnerExt0 && isZExtOrSExt(InnerExt0->getOpcode()) && InnerExt1 &&
         InnerExt0->getOpcode() == InnerExt1->getOpcode()) {
-      Type *InnerExt0Ty =
-          InnerExt0->getOperand(0)->getUnderlyingValue()->getType();
-      Type *InnerExt1Ty =
-          InnerExt1->getOperand(0)->getUnderlyingValue()->getType();
+      Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
+      Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
       auto *MaxInnerExtVecTy = cast<VectorType>(ToVectorTy(
           InnerExt0Ty->getIntegerBitWidth() > InnerExt1Ty->getIntegerBitWidth()
               ? InnerExt0Ty



More information about the llvm-commits mailing list