[llvm] [VPlan] Implement VPWidenCastRecipe::computeCost(). (NFCI) (PR #111339)

via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 6 22:48:15 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Elvis Wang (ElvisWang123)

<details>
<summary>Changes</summary>

This patch implement `VPWidenCastRecipe::computeCost()` and skip cast recipies in the in-loop reduction.

---
Full diff: https://github.com/llvm/llvm-project/pull/111339.diff


4 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+18-1) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+4) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+43) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1-1) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 792e0e17dd8719..59527d27d3e7a4 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7218,12 +7218,29 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
     const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop);
     SetVector<Instruction *> ChainOpsAndOperands(ChainOps.begin(),
                                                  ChainOps.end());
+    auto isZExtOrSExt = [](const unsigned Opcode) -> bool {
+      return Opcode == Instruction::ZExt || Opcode == Instruction::SExt;
+    };
     // Also include the operands of instructions in the chain, as the cost-model
     // may mark extends as free.
+    //
+    // For ARM, some of the instruction can folded into the reducion
+    // instruction. So we need to mark all folded instructions free.
+    // For example: We can fold reduce(mul(ext(A), ext(B))) into one
+    // instruction.
     for (auto *ChainOp : ChainOps) {
       for (Value *Op : ChainOp->operands()) {
-        if (auto *I = dyn_cast<Instruction>(Op))
+        if (auto *I = dyn_cast<Instruction>(Op)) {
           ChainOpsAndOperands.insert(I);
+          if (I->getOpcode() == Instruction::Mul) {
+            auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
+            if (Ext0 && isZExtOrSExt(Ext0->getOpcode()))
+              ChainOpsAndOperands.insert(Ext0);
+            auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
+            if (Ext1 && isZExtOrSExt(Ext1->getOpcode()))
+              ChainOpsAndOperands.insert(Ext1);
+          }
+        }
       }
     }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c4567362eaffc7..c8e074785cc877 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1557,6 +1557,10 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags {
   /// Produce widened copies of the cast.
   void execute(VPTransformState &State) override;
 
+  /// Return the cost of this VPWidenCastRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   /// Print the recipe.
   void print(raw_ostream &O, const Twine &Indent,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 75908638532950..764090b7a0ecc2 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1429,6 +1429,49 @@ void VPWidenCastRecipe::execute(VPTransformState &State) {
   State.addMetadata(Cast, cast_or_null<Instruction>(getUnderlyingValue()));
 }
 
+InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
+                                               VPCostContext &Ctx) const {
+  auto *SrcTy = cast<VectorType>(
+      ToVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF));
+  auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+  // Computes the CastContextHint from a VPWidenMemoryRecipe instruction.
+  auto ComputeCCH = [&](VPWidenMemoryRecipe *R) -> TTI::CastContextHint {
+    assert((isa<VPWidenLoadRecipe>(R) || isa<VPWidenStoreRecipe>(R)) &&
+           "Expected a load or a store!");
+
+    if (VF.isScalar())
+      return TTI::CastContextHint::Normal;
+    if (!R->isConsecutive())
+      return TTI::CastContextHint::GatherScatter;
+    if (R->isReverse())
+      return TTI::CastContextHint::Reversed;
+    if (R->isMasked())
+      return TTI::CastContextHint::Masked;
+    return TTI::CastContextHint::Normal;
+  };
+
+  TTI::CastContextHint CCH = TTI::CastContextHint::None;
+  // For Trunc, the context is the only user, which must be a
+  // VPWidenStoreRecipe.
+  if (Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) {
+    if (!cast<VPValue>(this)->hasMoreThanOneUniqueUser())
+      if (VPWidenMemoryRecipe *Store =
+              dyn_cast<VPWidenMemoryRecipe>(*this->user_begin()))
+        CCH = ComputeCCH(Store);
+  }
+  // For Z/Sext, the context is the operand, which must be a VPWidenLoadRecipe.
+  else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt ||
+           Opcode == Instruction::FPExt) {
+    if (VPWidenMemoryRecipe *Load = dyn_cast<VPWidenMemoryRecipe>(
+            this->getOperand(0)->getDefiningRecipe()))
+      CCH = ComputeCCH(Load);
+  }
+  // Arm TTI will use the underlying instruction to determine the cost.
+  return Ctx.TTI.getCastInstrCost(
+      Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
+      dyn_cast_if_present<Instruction>(getUnderlyingValue()));
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPWidenCastRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 4c383244f96f1a..ec4d95048de7a4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -135,7 +135,7 @@ class VPValue {
   }
 
   /// Returns true if the value has more than one unique user.
-  bool hasMoreThanOneUniqueUser() {
+  bool hasMoreThanOneUniqueUser() const {
     if (getNumUsers() == 0)
       return false;
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/111339


More information about the llvm-commits mailing list