[llvm] b3edc76 - [VPlan] Implement VPWidenCastRecipe::computeCost(). (NFCI) (#111339)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 21 21:23:53 PDT 2024


Author: Elvis Wang
Date: 2024-10-22T12:23:49+08:00
New Revision: b3edc764f70f4e56807af60abdcfbef4dbdc5d95

URL: https://github.com/llvm/llvm-project/commit/b3edc764f70f4e56807af60abdcfbef4dbdc5d95
DIFF: https://github.com/llvm/llvm-project/commit/b3edc764f70f4e56807af60abdcfbef4dbdc5d95.diff

LOG: [VPlan] Implement VPWidenCastRecipe::computeCost(). (NFCI) (#111339)

This patch implement `VPWidenCastRecipe::computeCost()` and skip cast
recipies in the in-loop reduction.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/lib/Transforms/Vectorize/VPlan.h
    llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
    llvm/lib/Transforms/Vectorize/VPlanValue.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e8653498d32a12..a88989aa2829b1 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7307,12 +7307,30 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
     const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop);
     SetVector<Instruction *> ChainOpsAndOperands(ChainOps.begin(),
                                                  ChainOps.end());
+    auto IsZExtOrSExt = [](const unsigned Opcode) -> bool {
+      return Opcode == Instruction::ZExt || Opcode == Instruction::SExt;
+    };
     // Also include the operands of instructions in the chain, as the cost-model
     // may mark extends as free.
+    //
+    // For ARM, some of the instruction can folded into the reducion
+    // instruction. So we need to mark all folded instructions free.
+    // For example: We can fold reduce(mul(ext(A), ext(B))) into one
+    // instruction.
     for (auto *ChainOp : ChainOps) {
       for (Value *Op : ChainOp->operands()) {
-        if (auto *I = dyn_cast<Instruction>(Op))
+        if (auto *I = dyn_cast<Instruction>(Op)) {
           ChainOpsAndOperands.insert(I);
+          if (I->getOpcode() == Instruction::Mul) {
+            auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
+            auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
+            if (Ext0 && IsZExtOrSExt(Ext0->getOpcode()) && Ext1 &&
+                Ext0->getOpcode() == Ext1->getOpcode()) {
+              ChainOpsAndOperands.insert(Ext0);
+              ChainOpsAndOperands.insert(Ext1);
+            }
+          }
+        }
       }
     }
 

diff  --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 8dff800a0b2224..0ea73229f1eceb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1603,6 +1603,10 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags {
   /// Produce widened copies of the cast.
   void execute(VPTransformState &State) override;
 
+  /// Return the cost of this VPWidenCastRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   /// Print the recipe.
   void print(raw_ostream &O, const Twine &Indent,

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 3038f982fc814b..0eb4f7c7c88cee 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1522,6 +1522,55 @@ void VPWidenCastRecipe::execute(VPTransformState &State) {
   State.addMetadata(Cast, cast_or_null<Instruction>(getUnderlyingValue()));
 }
 
+InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
+                                               VPCostContext &Ctx) const {
+  // Computes the CastContextHint from a recipes that may access memory.
+  auto ComputeCCH = [&](const VPRecipeBase *R) -> TTI::CastContextHint {
+    if (VF.isScalar())
+      return TTI::CastContextHint::Normal;
+    if (isa<VPInterleaveRecipe>(R))
+      return TTI::CastContextHint::Interleave;
+    if (const auto *ReplicateRecipe = dyn_cast<VPReplicateRecipe>(R))
+      return ReplicateRecipe->isPredicated() ? TTI::CastContextHint::Masked
+                                             : TTI::CastContextHint::Normal;
+    const auto *WidenMemoryRecipe = dyn_cast<VPWidenMemoryRecipe>(R);
+    if (WidenMemoryRecipe == nullptr)
+      return TTI::CastContextHint::None;
+    if (!WidenMemoryRecipe->isConsecutive())
+      return TTI::CastContextHint::GatherScatter;
+    if (WidenMemoryRecipe->isReverse())
+      return TTI::CastContextHint::Reversed;
+    if (WidenMemoryRecipe->isMasked())
+      return TTI::CastContextHint::Masked;
+    return TTI::CastContextHint::Normal;
+  };
+
+  VPValue *Operand = getOperand(0);
+  TTI::CastContextHint CCH = TTI::CastContextHint::None;
+  // For Trunc/FPTrunc, get the context from the only user.
+  if ((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) &&
+      !hasMoreThanOneUniqueUser() && getNumUsers() > 0) {
+    if (auto *StoreRecipe = dyn_cast<VPRecipeBase>(*user_begin()))
+      CCH = ComputeCCH(StoreRecipe);
+  }
+  // For Z/Sext, get the context from the operand.
+  else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt ||
+           Opcode == Instruction::FPExt) {
+    if (Operand->isLiveIn())
+      CCH = TTI::CastContextHint::Normal;
+    else if (Operand->getDefiningRecipe())
+      CCH = ComputeCCH(Operand->getDefiningRecipe());
+  }
+
+  auto *SrcTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(Operand), VF));
+  auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+  // Arm TTI will use the underlying instruction to determine the cost.
+  return Ctx.TTI.getCastInstrCost(
+      Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
+      dyn_cast_if_present<Instruction>(getUnderlyingValue()));
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPWidenCastRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index f2978b0a758b6a..1900182f76e071 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -135,7 +135,7 @@ class VPValue {
   }
 
   /// Returns true if the value has more than one unique user.
-  bool hasMoreThanOneUniqueUser() {
+  bool hasMoreThanOneUniqueUser() const {
     if (getNumUsers() == 0)
       return false;
 


        


More information about the llvm-commits mailing list