[llvm] [VPlan] Implement VPWidenCastRecipe::computeCost(). (NFCI) (PR #111339)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Oct 6 22:48:15 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Elvis Wang (ElvisWang123)
<details>
<summary>Changes</summary>
This patch implement `VPWidenCastRecipe::computeCost()` and skip cast recipies in the in-loop reduction.
---
Full diff: https://github.com/llvm/llvm-project/pull/111339.diff
4 Files Affected:
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+18-1)
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+4)
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+43)
- (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1-1)
``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 792e0e17dd8719..59527d27d3e7a4 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7218,12 +7218,29 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop);
SetVector<Instruction *> ChainOpsAndOperands(ChainOps.begin(),
ChainOps.end());
+ auto isZExtOrSExt = [](const unsigned Opcode) -> bool {
+ return Opcode == Instruction::ZExt || Opcode == Instruction::SExt;
+ };
// Also include the operands of instructions in the chain, as the cost-model
// may mark extends as free.
+ //
+ // For ARM, some of the instruction can folded into the reducion
+ // instruction. So we need to mark all folded instructions free.
+ // For example: We can fold reduce(mul(ext(A), ext(B))) into one
+ // instruction.
for (auto *ChainOp : ChainOps) {
for (Value *Op : ChainOp->operands()) {
- if (auto *I = dyn_cast<Instruction>(Op))
+ if (auto *I = dyn_cast<Instruction>(Op)) {
ChainOpsAndOperands.insert(I);
+ if (I->getOpcode() == Instruction::Mul) {
+ auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
+ if (Ext0 && isZExtOrSExt(Ext0->getOpcode()))
+ ChainOpsAndOperands.insert(Ext0);
+ auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
+ if (Ext1 && isZExtOrSExt(Ext1->getOpcode()))
+ ChainOpsAndOperands.insert(Ext1);
+ }
+ }
}
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c4567362eaffc7..c8e074785cc877 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1557,6 +1557,10 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags {
/// Produce widened copies of the cast.
void execute(VPTransformState &State) override;
+ /// Return the cost of this VPWidenCastRecipe.
+ InstructionCost computeCost(ElementCount VF,
+ VPCostContext &Ctx) const override;
+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// Print the recipe.
void print(raw_ostream &O, const Twine &Indent,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 75908638532950..764090b7a0ecc2 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1429,6 +1429,49 @@ void VPWidenCastRecipe::execute(VPTransformState &State) {
State.addMetadata(Cast, cast_or_null<Instruction>(getUnderlyingValue()));
}
+InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
+ VPCostContext &Ctx) const {
+ auto *SrcTy = cast<VectorType>(
+ ToVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF));
+ auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+ // Computes the CastContextHint from a VPWidenMemoryRecipe instruction.
+ auto ComputeCCH = [&](VPWidenMemoryRecipe *R) -> TTI::CastContextHint {
+ assert((isa<VPWidenLoadRecipe>(R) || isa<VPWidenStoreRecipe>(R)) &&
+ "Expected a load or a store!");
+
+ if (VF.isScalar())
+ return TTI::CastContextHint::Normal;
+ if (!R->isConsecutive())
+ return TTI::CastContextHint::GatherScatter;
+ if (R->isReverse())
+ return TTI::CastContextHint::Reversed;
+ if (R->isMasked())
+ return TTI::CastContextHint::Masked;
+ return TTI::CastContextHint::Normal;
+ };
+
+ TTI::CastContextHint CCH = TTI::CastContextHint::None;
+ // For Trunc, the context is the only user, which must be a
+ // VPWidenStoreRecipe.
+ if (Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) {
+ if (!cast<VPValue>(this)->hasMoreThanOneUniqueUser())
+ if (VPWidenMemoryRecipe *Store =
+ dyn_cast<VPWidenMemoryRecipe>(*this->user_begin()))
+ CCH = ComputeCCH(Store);
+ }
+ // For Z/Sext, the context is the operand, which must be a VPWidenLoadRecipe.
+ else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt ||
+ Opcode == Instruction::FPExt) {
+ if (VPWidenMemoryRecipe *Load = dyn_cast<VPWidenMemoryRecipe>(
+ this->getOperand(0)->getDefiningRecipe()))
+ CCH = ComputeCCH(Load);
+ }
+ // Arm TTI will use the underlying instruction to determine the cost.
+ return Ctx.TTI.getCastInstrCost(
+ Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
+ dyn_cast_if_present<Instruction>(getUnderlyingValue()));
+}
+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void VPWidenCastRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 4c383244f96f1a..ec4d95048de7a4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -135,7 +135,7 @@ class VPValue {
}
/// Returns true if the value has more than one unique user.
- bool hasMoreThanOneUniqueUser() {
+ bool hasMoreThanOneUniqueUser() const {
if (getNumUsers() == 0)
return false;
``````````
</details>
https://github.com/llvm/llvm-project/pull/111339
More information about the llvm-commits
mailing list