[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)
Elvis Wang via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 15 01:08:32 PST 2024
https://github.com/ElvisWang123 updated https://github.com/llvm/llvm-project/pull/113903
>From 0de427c5830f641460cb684b430fc8dc18621712 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 28 Oct 2024 05:39:35 -0700
Subject: [PATCH 01/14] [VPlan] Impl VPlan-based pattern match for ExtendedRed
and MulAccRed. NFCI
This patch implement the VPlan-based pattern match for extendedReduction
and MulAccReduction. In above reduction patterns, extened instructions
and mul instruction can fold into reduction instruction and the cost is
free.
We add `FoldedRecipes` in the `VPCostContext` to put recipes that can be
folded into other recipes.
ExtendedReductionPatterns:
reduce(ext(...))
MulAccReductionPatterns:
reduce.add(mul(...))
reduce.add(mul(ext(...), ext(...)))
reduce.add(ext(mul(...)))
reduce.add(ext(mul(ext(...), ext(...))))
Ref: Original instruction based implementation:
https://reviews.llvm.org/D93476
---
.../Transforms/Vectorize/LoopVectorize.cpp | 45 ------
llvm/lib/Transforms/Vectorize/VPlan.h | 2 +
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 139 ++++++++++++++++--
3 files changed, 129 insertions(+), 57 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1ebc62f9843905..32fbde97bb8283 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7306,51 +7306,6 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
Cost += ReductionCost;
continue;
}
-
- const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop);
- SetVector<Instruction *> ChainOpsAndOperands(ChainOps.begin(),
- ChainOps.end());
- auto IsZExtOrSExt = [](const unsigned Opcode) -> bool {
- return Opcode == Instruction::ZExt || Opcode == Instruction::SExt;
- };
- // Also include the operands of instructions in the chain, as the cost-model
- // may mark extends as free.
- //
- // For ARM, some of the instruction can folded into the reducion
- // instruction. So we need to mark all folded instructions free.
- // For example: We can fold reduce(mul(ext(A), ext(B))) into one
- // instruction.
- for (auto *ChainOp : ChainOps) {
- for (Value *Op : ChainOp->operands()) {
- if (auto *I = dyn_cast<Instruction>(Op)) {
- ChainOpsAndOperands.insert(I);
- if (I->getOpcode() == Instruction::Mul) {
- auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
- auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
- if (Ext0 && IsZExtOrSExt(Ext0->getOpcode()) && Ext1 &&
- Ext0->getOpcode() == Ext1->getOpcode()) {
- ChainOpsAndOperands.insert(Ext0);
- ChainOpsAndOperands.insert(Ext1);
- }
- }
- }
- }
- }
-
- // Pre-compute the cost for I, if it has a reduction pattern cost.
- for (Instruction *I : ChainOpsAndOperands) {
- auto ReductionCost = CM.getReductionPatternCost(
- I, VF, ToVectorTy(I->getType(), VF), TTI::TCK_RecipThroughput);
- if (!ReductionCost)
- continue;
-
- assert(!CostCtx.SkipCostComputation.contains(I) &&
- "reduction op visited multiple times");
- CostCtx.SkipCostComputation.insert(I);
- LLVM_DEBUG(dbgs() << "Cost of " << ReductionCost << " for VF " << VF
- << ":\n in-loop reduction " << *I << "\n");
- Cost += *ReductionCost;
- }
}
// Pre-compute the costs for branches except for the backedge, as the number
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index abfe97b4ab55b6..dd763de1e7a1ef 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -682,6 +682,8 @@ struct VPCostContext {
LLVMContext &LLVMCtx;
LoopVectorizationCostModel &CM;
SmallPtrSet<Instruction *, 8> SkipCostComputation;
+ /// Contains recipes that are folded into other recipes.
+ SmallDenseMap<ElementCount, SmallPtrSet<VPRecipeBase *, 4>, 4> FoldedRecipes;
VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,
Type *CanIVTy, LoopVectorizationCostModel &CM)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ef2ca9af7268d1..36898220c8e48a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,7 +270,9 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
UI = &WidenMem->getIngredient();
InstructionCost RecipeCost;
- if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
+ if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
+ (Ctx.FoldedRecipes.contains(VF) &&
+ Ctx.FoldedRecipes.at(VF).contains(this))) {
RecipeCost = 0;
} else {
RecipeCost = computeCost(VF, Ctx);
@@ -2187,30 +2189,143 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
unsigned Opcode = RdxDesc.getOpcode();
- // TODO: Support any-of and in-loop reductions.
+ // TODO: Support any-of reductions.
assert(
(!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
"Any-of reduction not implemented in VPlan-based cost model currently.");
- assert(
- (!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
- ForceTargetInstructionCost.getNumOccurrences() > 0) &&
- "In-loop reduction not implemented in VPlan-based cost model currently.");
assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
"Inferred type and recurrence type mismatch.");
- // Cost = Reduction cost + BinOp cost
- InstructionCost Cost =
+ // BaseCost = Reduction cost + BinOp cost
+ InstructionCost BaseCost =
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
- return Cost + Ctx.TTI.getMinMaxReductionCost(
- Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ BaseCost += Ctx.TTI.getMinMaxReductionCost(
+ Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ } else {
+ BaseCost += Ctx.TTI.getArithmeticReductionCost(
+ Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
}
- return Cost + Ctx.TTI.getArithmeticReductionCost(
- Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ using namespace llvm::VPlanPatternMatch;
+ auto GetMulAccReductionCost =
+ [&](const VPReductionRecipe *Red) -> InstructionCost {
+ VPValue *A, *B;
+ InstructionCost InnerExt0Cost = 0;
+ InstructionCost InnerExt1Cost = 0;
+ InstructionCost ExtCost = 0;
+ InstructionCost MulCost = 0;
+
+ VectorType *SrcVecTy = VectorTy;
+ Type *InnerExt0Ty;
+ Type *InnerExt1Ty;
+ Type *MaxInnerExtTy;
+ bool IsUnsigned = true;
+ bool HasOuterExt = false;
+
+ auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
+ Red->getVecOp()->getDefiningRecipe());
+ VPRecipeBase *Mul;
+ // Try to match outer extend reduce.add(ext(...))
+ if (Ext && match(Ext, m_ZExtOrSExt(m_VPValue())) &&
+ cast<VPWidenCastRecipe>(Ext)->getNumUsers() == 1) {
+ IsUnsigned =
+ Ext->getOpcode() == Instruction::CastOps::ZExt ? true : false;
+ ExtCost = Ext->computeCost(VF, Ctx);
+ Mul = Ext->getOperand(0)->getDefiningRecipe();
+ HasOuterExt = true;
+ } else {
+ Mul = Red->getVecOp()->getDefiningRecipe();
+ }
+
+ // Match reduce.add(mul())
+ if (Mul && match(Mul, m_Mul(m_VPValue(A), m_VPValue(B))) &&
+ cast<VPWidenRecipe>(Mul)->getNumUsers() == 1) {
+ MulCost = cast<VPWidenRecipe>(Mul)->computeCost(VF, Ctx);
+ auto *InnerExt0 =
+ dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+ auto *InnerExt1 =
+ dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+ bool HasInnerExt = false;
+ // Try to match inner extends.
+ if (InnerExt0 && InnerExt1 &&
+ match(InnerExt0, m_ZExtOrSExt(m_VPValue())) &&
+ match(InnerExt1, m_ZExtOrSExt(m_VPValue())) &&
+ InnerExt0->getOpcode() == InnerExt1->getOpcode() &&
+ (InnerExt0->getNumUsers() > 0 &&
+ !InnerExt0->hasMoreThanOneUniqueUser()) &&
+ (InnerExt1->getNumUsers() > 0 &&
+ !InnerExt1->hasMoreThanOneUniqueUser())) {
+ InnerExt0Cost = InnerExt0->computeCost(VF, Ctx);
+ InnerExt1Cost = InnerExt1->computeCost(VF, Ctx);
+ Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
+ Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
+ Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth() >
+ InnerExt1Ty->getIntegerBitWidth()
+ ? InnerExt0Ty
+ : InnerExt1Ty;
+ SrcVecTy = cast<VectorType>(ToVectorTy(MaxInnerExtTy, VF));
+ IsUnsigned = true;
+ HasInnerExt = true;
+ }
+ InstructionCost MulAccRedCost = Ctx.TTI.getMulAccReductionCost(
+ IsUnsigned, ElementTy, SrcVecTy, CostKind);
+ // Check if folding ext/mul into MulAccReduction is profitable.
+ if (MulAccRedCost.isValid() &&
+ MulAccRedCost <
+ ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
+ if (HasInnerExt) {
+ Ctx.FoldedRecipes[VF].insert(InnerExt0);
+ Ctx.FoldedRecipes[VF].insert(InnerExt1);
+ }
+ Ctx.FoldedRecipes[VF].insert(Mul);
+ if (HasOuterExt)
+ Ctx.FoldedRecipes[VF].insert(Ext);
+ return MulAccRedCost;
+ }
+ }
+ return InstructionCost::getInvalid();
+ };
+
+ // Match reduce(ext(...))
+ auto GetExtendedReductionCost =
+ [&](const VPReductionRecipe *Red) -> InstructionCost {
+ VPValue *VecOp = Red->getVecOp();
+ VPValue *A;
+ if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && VecOp->getNumUsers() == 1) {
+ VPWidenCastRecipe *Ext =
+ cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+ bool IsUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
+ InstructionCost ExtCost = Ext->computeCost(VF, Ctx);
+ auto *ExtVecTy =
+ cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(A), VF));
+ InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
+ Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags(),
+ CostKind);
+ // Check if folding ext into ExtendedReduction is profitable.
+ if (ExtendedRedCost.isValid() && ExtendedRedCost < ExtCost + BaseCost) {
+ Ctx.FoldedRecipes[VF].insert(Ext);
+ return ExtendedRedCost;
+ }
+ }
+ return InstructionCost::getInvalid();
+ };
+
+ // Match MulAccReduction patterns.
+ InstructionCost MulAccCost = GetMulAccReductionCost(this);
+ if (MulAccCost.isValid())
+ return MulAccCost;
+
+ // Match ExtendedReduction patterns.
+ InstructionCost ExtendedCost = GetExtendedReductionCost(this);
+ if (ExtendedCost.isValid())
+ return ExtendedCost;
+
+ // Default cost.
+ return BaseCost;
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
>From b4272ccafffb2cbbe9852df7fcc96621b3399c8f Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 3 Nov 2024 18:55:55 -0800
Subject: [PATCH 02/14] Partially support Extended-reduction.
---
.../Transforms/Vectorize/LoopVectorize.cpp | 24 +-
llvm/lib/Transforms/Vectorize/VPlan.h | 217 ++++++++++++++++++
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 114 +++++++--
.../Transforms/Vectorize/VPlanTransforms.cpp | 24 ++
.../Transforms/Vectorize/VPlanTransforms.h | 3 +
llvm/lib/Transforms/Vectorize/VPlanValue.h | 2 +
6 files changed, 359 insertions(+), 25 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 32fbde97bb8283..29df6a52fa98e7 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7662,6 +7662,10 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
ILV.getOrCreateVectorTripCount(nullptr),
CanonicalIVStartValue, State);
+ // TODO: Rebase to fhahn's implementation.
+ VPlanTransforms::prepareExecute(BestVPlan);
+ dbgs() << "\n\n print plan\n";
+ BestVPlan.print(dbgs());
BestVPlan.execute(&State);
// 2.5 Collect reduction resume values.
@@ -9256,6 +9260,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
// Adjust AnyOf reductions; replace the reduction phi for the selected value
// with a boolean reduction phi node to check if the condition is true in any
// iteration. The final value is selected by the final ComputeReductionResult.
+// TODO: Implement VPMulAccHere.
void LoopVectorizationPlanner::adjustRecipesForReductions(
VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) {
using namespace VPlanPatternMatch;
@@ -9374,9 +9379,22 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
if (CM.blockNeedsPredicationForAnyReason(BB))
CondOp = RecipeBuilder.getBlockInMask(BB);
- VPReductionRecipe *RedRecipe =
- new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
- CondOp, CM.useOrderedReductions(RdxDesc));
+ // VPWidenCastRecipes can folded into VPReductionRecipe
+ VPValue *A;
+ VPSingleDefRecipe *RedRecipe;
+ if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+ !VecOp->hasMoreThanOneUniqueUser()) {
+ RedRecipe = new VPExtendedReductionRecipe(
+ RdxDesc, CurrentLinkI,
+ cast<CastInst>(
+ cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
+ PreviousLink, A, CondOp, CM.useOrderedReductions(RdxDesc),
+ cast<VPWidenCastRecipe>(VecOp)->getResultType());
+ } else {
+ RedRecipe =
+ new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
+ CondOp, CM.useOrderedReductions(RdxDesc));
+ }
// Append the recipe to the end of the VPBasicBlock because we need to
// ensure that it comes after all of it's inputs, including CondOp.
// Note that this transformation may leave over dead recipes (including
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index dd763de1e7a1ef..7a05145514d0e7 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -859,6 +859,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
case VPRecipeBase::VPInstructionSC:
case VPRecipeBase::VPReductionEVLSC:
case VPRecipeBase::VPReductionSC:
+ case VPRecipeBase::VPMulAccSC:
+ case VPRecipeBase::VPExtendedReductionSC:
case VPRecipeBase::VPReplicateSC:
case VPRecipeBase::VPScalarIVStepsSC:
case VPRecipeBase::VPVectorPointerSC:
@@ -2655,6 +2657,221 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
}
};
+/// A recipe to represent inloop extended reduction operations, performing a
+/// reduction on a vector operand into a scalar value, and adding the result to
+/// a chain. This recipe is high level abstract which will generate
+/// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
+/// {ChainOp, VecOp, [Condition]}.
+class VPExtendedReductionRecipe : public VPSingleDefRecipe {
+ /// The recurrence decriptor for the reduction in question.
+ const RecurrenceDescriptor &RdxDesc;
+ bool IsOrdered;
+ /// Whether the reduction is conditional.
+ bool IsConditional = false;
+ /// Type after extend.
+ Type *ResultTy;
+ Instruction::CastOps ExtOp;
+ CastInst *CastInstr;
+ bool IsZExt;
+
+protected:
+ VPExtendedReductionRecipe(const unsigned char SC,
+ const RecurrenceDescriptor &R, Instruction *RedI,
+ Instruction::CastOps ExtOp, CastInst *CastI,
+ ArrayRef<VPValue *> Operands, VPValue *CondOp,
+ bool IsOrdered, Type *ResultTy)
+ : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+ ResultTy(ResultTy), ExtOp(ExtOp), CastInstr(CastI) {
+ if (CondOp) {
+ IsConditional = true;
+ addOperand(CondOp);
+ }
+ IsZExt = ExtOp == Instruction::CastOps::ZExt;
+ }
+
+public:
+ VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+ CastInst *CastI, VPValue *ChainOp, VPValue *VecOp,
+ VPValue *CondOp, bool IsOrdered, Type *ResultTy)
+ : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
+ CastI->getOpcode(), CastI,
+ ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
+ IsOrdered, ResultTy) {}
+
+ VPExtendedReductionRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *Ext)
+ : VPExtendedReductionRecipe(
+ VPDef::VPExtendedReductionSC, Red->getRecurrenceDescriptor(),
+ Red->getUnderlyingInstr(), Ext->getOpcode(),
+ cast<CastInst>(Ext->getUnderlyingInstr()),
+ ArrayRef<VPValue *>({Red->getChainOp(), Ext->getOperand(0)}),
+ Red->getCondOp(), Red->isOrdered(), Ext->getResultType()) {}
+
+ ~VPExtendedReductionRecipe() override = default;
+
+ VPExtendedReductionRecipe *clone() override {
+ llvm_unreachable("Not implement yet");
+ }
+
+ static inline bool classof(const VPRecipeBase *R) {
+ return R->getVPDefID() == VPDef::VPExtendedReductionSC;
+ }
+
+ static inline bool classof(const VPUser *U) {
+ auto *R = dyn_cast<VPRecipeBase>(U);
+ return R && classof(R);
+ }
+
+ /// Generate the reduction in the loop
+ void execute(VPTransformState &State) override {
+ llvm_unreachable("VPExtendedReductionRecipe should be transform to "
+ "VPExtendedRecipe + VPReductionRecipe before execution.");
+ };
+
+ /// Return the cost of VPExtendedReductionRecipe.
+ InstructionCost computeCost(ElementCount VF,
+ VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ /// Print the recipe.
+ void print(raw_ostream &O, const Twine &Indent,
+ VPSlotTracker &SlotTracker) const override;
+#endif
+
+ /// Return the recurrence decriptor for the in-loop reduction.
+ const RecurrenceDescriptor &getRecurrenceDescriptor() const {
+ return RdxDesc;
+ }
+ /// Return true if the in-loop reduction is ordered.
+ bool isOrdered() const { return IsOrdered; };
+ /// Return true if the in-loop reduction is conditional.
+ bool isConditional() const { return IsConditional; };
+ /// The VPValue of the scalar Chain being accumulated.
+ VPValue *getChainOp() const { return getOperand(0); }
+ /// The VPValue of the vector value to be extended and reduced.
+ VPValue *getVecOp() const { return getOperand(1); }
+ /// The VPValue of the condition for the block.
+ VPValue *getCondOp() const {
+ return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
+ }
+ Type *getResultType() const { return ResultTy; };
+ Instruction::CastOps getExtOpcode() const { return ExtOp; };
+ CastInst *getExtInstr() const { return CastInstr; };
+};
+
+/// A recipe to represent inloop MulAccreduction operations, performing a
+/// reduction on a vector operand into a scalar value, and adding the result to
+/// a chain. This recipe is high level abstract which will generate
+/// VPReductionRecipe VPWidenRecipe(mul)and VPWidenCastRecipe before execution.
+/// The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+class VPMulAccRecipe : public VPSingleDefRecipe {
+ /// The recurrence decriptor for the reduction in question.
+ const RecurrenceDescriptor &RdxDesc;
+ bool IsOrdered;
+ /// Whether the reduction is conditional.
+ bool IsConditional = false;
+ /// Type after extend.
+ Type *ResultTy;
+ /// Type for mul.
+ Type *MulTy;
+ /// reduce.add(OuterExt(mul(InnerExt(), InnerExt())))
+ Instruction::CastOps OuterExtOp;
+ Instruction::CastOps InnerExtOp;
+
+ Instruction *MulI;
+ Instruction *OuterExtI;
+ Instruction *InnerExt0I;
+ Instruction *InnerExt1I;
+
+protected:
+ VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+ Instruction *RedI, Instruction::CastOps OuterExtOp,
+ Instruction *OuterExtI, Instruction *MulI,
+ Instruction::CastOps InnerExtOp, Instruction *InnerExt0I,
+ Instruction *InnerExt1I, ArrayRef<VPValue *> Operands,
+ VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+ : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+ ResultTy(ResultTy), MulTy(MulTy), OuterExtOp(OuterExtOp),
+ InnerExtOp(InnerExtOp), MulI(MulI), OuterExtI(OuterExtI),
+ InnerExt0I(InnerExt0I), InnerExt1I(InnerExt1I) {
+ if (CondOp) {
+ IsConditional = true;
+ addOperand(CondOp);
+ }
+ }
+
+public:
+ VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+ Instruction *OuterExt, Instruction *Mul,
+ Instruction *InnerExt0, Instruction *InnerExt1,
+ VPValue *ChainOp, VPValue *InnerExt0Op, VPValue *InnerExt1Op,
+ VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+ : VPMulAccRecipe(
+ VPDef::VPMulAccSC, R, RedI, cast<CastInst>(OuterExt)->getOpcode(),
+ OuterExt, Mul, cast<CastInst>(InnerExt0)->getOpcode(), InnerExt0,
+ InnerExt1, ArrayRef<VPValue *>({ChainOp, InnerExt0Op, InnerExt1Op}),
+ CondOp, IsOrdered, ResultTy, MulTy) {}
+
+ VPMulAccRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *OuterExt,
+ VPWidenRecipe *Mul, VPWidenCastRecipe *InnerExt0,
+ VPWidenCastRecipe *InnerExt1)
+ : VPMulAccRecipe(
+ VPDef::VPMulAccSC, Red->getRecurrenceDescriptor(),
+ Red->getUnderlyingInstr(), OuterExt->getOpcode(),
+ OuterExt->getUnderlyingInstr(), Mul->getUnderlyingInstr(),
+ InnerExt0->getOpcode(), InnerExt0->getUnderlyingInstr(),
+ InnerExt1->getUnderlyingInstr(),
+ ArrayRef<VPValue *>({Red->getChainOp(), InnerExt0->getOperand(0),
+ InnerExt1->getOperand(0)}),
+ Red->getCondOp(), Red->isOrdered(), OuterExt->getResultType(),
+ InnerExt0->getResultType()) {}
+
+ ~VPMulAccRecipe() override = default;
+
+ VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
+
+ static inline bool classof(const VPRecipeBase *R) {
+ return R->getVPDefID() == VPRecipeBase::VPMulAccSC;
+ }
+
+ static inline bool classof(const VPUser *U) {
+ auto *R = dyn_cast<VPRecipeBase>(U);
+ return R && classof(R);
+ }
+
+ /// Generate the reduction in the loop
+ void execute(VPTransformState &State) override;
+
+ /// Return the cost of VPExtendedReductionRecipe.
+ InstructionCost computeCost(ElementCount VF,
+ VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ /// Print the recipe.
+ void print(raw_ostream &O, const Twine &Indent,
+ VPSlotTracker &SlotTracker) const override;
+#endif
+
+ /// Return the recurrence decriptor for the in-loop reduction.
+ const RecurrenceDescriptor &getRecurrenceDescriptor() const {
+ return RdxDesc;
+ }
+ /// Return true if the in-loop reduction is ordered.
+ bool isOrdered() const { return IsOrdered; };
+ /// Return true if the in-loop reduction is conditional.
+ bool isConditional() const { return IsConditional; };
+ /// The VPValue of the scalar Chain being accumulated.
+ VPValue *getChainOp() const { return getOperand(0); }
+ /// The VPValue of the vector value to be extended and reduced.
+ VPValue *getVecOp() const { return getOperand(1); }
+ /// The VPValue of the condition for the block.
+ VPValue *getCondOp() const {
+ return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
+ }
+ Type *getResultTy() const { return ResultTy; };
+ Instruction::CastOps getOuterExtOpcode() const { return OuterExtOp; };
+ Instruction::CastOps getInnerExtOpcode() const { return InnerExtOp; };
+};
+
/// VPReplicateRecipe replicates a given instruction producing multiple scalar
/// copies of the original scalar type, one per lane, instead of producing a
/// single copy of widened type for all lanes. If the instruction is known to be
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 36898220c8e48a..a7858a729c2c92 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1490,6 +1490,27 @@ void VPWidenCastRecipe::execute(VPTransformState &State) {
setFlags(CastOp);
}
+// Computes the CastContextHint from a recipes that may access memory.
+static TTI::CastContextHint computeCCH(const VPRecipeBase *R, ElementCount VF) {
+ if (VF.isScalar())
+ return TTI::CastContextHint::Normal;
+ if (isa<VPInterleaveRecipe>(R))
+ return TTI::CastContextHint::Interleave;
+ if (const auto *ReplicateRecipe = dyn_cast<VPReplicateRecipe>(R))
+ return ReplicateRecipe->isPredicated() ? TTI::CastContextHint::Masked
+ : TTI::CastContextHint::Normal;
+ const auto *WidenMemoryRecipe = dyn_cast<VPWidenMemoryRecipe>(R);
+ if (WidenMemoryRecipe == nullptr)
+ return TTI::CastContextHint::None;
+ if (!WidenMemoryRecipe->isConsecutive())
+ return TTI::CastContextHint::GatherScatter;
+ if (WidenMemoryRecipe->isReverse())
+ return TTI::CastContextHint::Reversed;
+ if (WidenMemoryRecipe->isMasked())
+ return TTI::CastContextHint::Masked;
+ return TTI::CastContextHint::Normal;
+}
+
InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
// TODO: In some cases, VPWidenCastRecipes are created but not considered in
@@ -1497,26 +1518,6 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
// reduction in a smaller type.
if (!getUnderlyingValue())
return 0;
- // Computes the CastContextHint from a recipes that may access memory.
- auto ComputeCCH = [&](const VPRecipeBase *R) -> TTI::CastContextHint {
- if (VF.isScalar())
- return TTI::CastContextHint::Normal;
- if (isa<VPInterleaveRecipe>(R))
- return TTI::CastContextHint::Interleave;
- if (const auto *ReplicateRecipe = dyn_cast<VPReplicateRecipe>(R))
- return ReplicateRecipe->isPredicated() ? TTI::CastContextHint::Masked
- : TTI::CastContextHint::Normal;
- const auto *WidenMemoryRecipe = dyn_cast<VPWidenMemoryRecipe>(R);
- if (WidenMemoryRecipe == nullptr)
- return TTI::CastContextHint::None;
- if (!WidenMemoryRecipe->isConsecutive())
- return TTI::CastContextHint::GatherScatter;
- if (WidenMemoryRecipe->isReverse())
- return TTI::CastContextHint::Reversed;
- if (WidenMemoryRecipe->isMasked())
- return TTI::CastContextHint::Masked;
- return TTI::CastContextHint::Normal;
- };
VPValue *Operand = getOperand(0);
TTI::CastContextHint CCH = TTI::CastContextHint::None;
@@ -1524,7 +1525,7 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
if ((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) &&
!hasMoreThanOneUniqueUser() && getNumUsers() > 0) {
if (auto *StoreRecipe = dyn_cast<VPRecipeBase>(*user_begin()))
- CCH = ComputeCCH(StoreRecipe);
+ CCH = computeCCH(StoreRecipe, VF);
}
// For Z/Sext, get the context from the operand.
else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt ||
@@ -1532,7 +1533,7 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
if (Operand->isLiveIn())
CCH = TTI::CastContextHint::Normal;
else if (Operand->getDefiningRecipe())
- CCH = ComputeCCH(Operand->getDefiningRecipe());
+ CCH = computeCCH(Operand->getDefiningRecipe(), VF);
}
auto *SrcTy =
@@ -2210,6 +2211,7 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
}
+ /*
using namespace llvm::VPlanPatternMatch;
auto GetMulAccReductionCost =
[&](const VPReductionRecipe *Red) -> InstructionCost {
@@ -2323,11 +2325,57 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
InstructionCost ExtendedCost = GetExtendedReductionCost(this);
if (ExtendedCost.isValid())
return ExtendedCost;
+ */
// Default cost.
return BaseCost;
}
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+ VPCostContext &Ctx) const {
+ RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+ Type *ElementTy = getResultType();
+ auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ unsigned Opcode = RdxDesc.getOpcode();
+
+ assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
+ "Inferred type and recurrence type mismatch.");
+
+ // BaseCost = Reduction cost + BinOp cost
+ InstructionCost ReductionCost =
+ Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+ if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+ Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+ ReductionCost += Ctx.TTI.getMinMaxReductionCost(
+ Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ } else {
+ ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+ Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ }
+
+ // Extended cost
+ auto *SrcTy =
+ cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
+ auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+ TTI::CastContextHint CCH = computeCCH(getVecOp()->getDefiningRecipe(), VF);
+ // Arm TTI will use the underlying instruction to determine the cost.
+ InstructionCost ExtendedCost = Ctx.TTI.getCastInstrCost(
+ Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
+ dyn_cast_if_present<Instruction>(getUnderlyingValue()));
+
+ // ExtendedReduction Cost
+ InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
+ Opcode, IsZExt, ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
+ // Check if folding ext into ExtendedReduction is profitable.
+ if (ExtendedRedCost.isValid() &&
+ ExtendedRedCost < ExtendedCost + ReductionCost) {
+ return ExtendedRedCost;
+ }
+ return ExtendedCost + ReductionCost;
+}
+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
@@ -2373,6 +2421,28 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
O << " (with final reduction value stored in invariant address sank "
"outside of loop)";
}
+
+void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+ VPSlotTracker &SlotTracker) const {
+ O << Indent << "EXTENDED-REDUCE ";
+ printAsOperand(O, SlotTracker);
+ O << " = ";
+ getChainOp()->printAsOperand(O, SlotTracker);
+ O << " +";
+ if (isa<FPMathOperator>(getUnderlyingInstr()))
+ O << getUnderlyingInstr()->getFastMathFlags();
+ O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+ getVecOp()->printAsOperand(O, SlotTracker);
+ O << " extended to " << *getResultType();
+ if (isConditional()) {
+ O << ", ";
+ getCondOp()->printAsOperand(O, SlotTracker);
+ }
+ O << ")";
+ if (RdxDesc.IntermediateStore)
+ O << " (with final reduction value stored in invariant address sank "
+ "outside of loop)";
+}
#endif
bool VPReplicateRecipe::shouldPack() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index b9ab8a8fe60107..22ac98751bbd86 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -519,6 +519,30 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
}
}
+void VPlanTransforms::prepareExecute(VPlan &Plan) {
+ errs() << "\n\n\n!!Prepare to execute\n";
+ ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
+ Plan.getVectorLoopRegion());
+ for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
+ vp_depth_first_deep(Plan.getEntry()))) {
+ for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
+ if (!isa<VPExtendedReductionRecipe>(&R))
+ continue;
+ auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+ auto *Ext = new VPWidenCastRecipe(
+ ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+ *ExtRed->getExtInstr());
+ auto *Red = new VPReductionRecipe(
+ ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
+ ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
+ Ext->insertBefore(ExtRed);
+ Red->insertBefore(ExtRed);
+ ExtRed->replaceAllUsesWith(Red);
+ ExtRed->eraseFromParent();
+ }
+ }
+}
+
static VPScalarIVStepsRecipe *
createScalarIVSteps(VPlan &Plan, InductionDescriptor::InductionKind Kind,
Instruction::BinaryOps InductionOpcode,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 11e094db6294f6..6310c23b605da3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -123,6 +123,9 @@ struct VPlanTransforms {
/// Remove dead recipes from \p Plan.
static void removeDeadRecipes(VPlan &Plan);
+
+ /// TODO: Rebase to fhahn's implementation.
+ static void prepareExecute(VPlan &Plan);
};
} // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 691b0d40823cfb..09defa6406c078 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -329,6 +329,8 @@ class VPDef {
VPInterleaveSC,
VPReductionEVLSC,
VPReductionSC,
+ VPMulAccSC,
+ VPExtendedReductionSC,
VPReplicateSC,
VPScalarCastSC,
VPScalarIVStepsSC,
>From 8102cd4f2069cd1354c463b3be1f6393aaad2a3c Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 4 Nov 2024 22:02:22 -0800
Subject: [PATCH 03/14] Support MulAccRecipe
---
.../Transforms/Vectorize/LoopVectorize.cpp | 33 ++++-
llvm/lib/Transforms/Vectorize/VPlan.h | 103 +++++++++-------
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 114 +++++++++++++++++-
.../Transforms/Vectorize/VPlanTransforms.cpp | 56 ++++++---
4 files changed, 237 insertions(+), 69 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 29df6a52fa98e7..7147c35c807491 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7664,8 +7664,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
// TODO: Rebase to fhahn's implementation.
VPlanTransforms::prepareExecute(BestVPlan);
- dbgs() << "\n\n print plan\n";
- BestVPlan.print(dbgs());
BestVPlan.execute(&State);
// 2.5 Collect reduction resume values.
@@ -9379,11 +9377,34 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
if (CM.blockNeedsPredicationForAnyReason(BB))
CondOp = RecipeBuilder.getBlockInMask(BB);
- // VPWidenCastRecipes can folded into VPReductionRecipe
- VPValue *A;
+ VPValue *A, *B;
VPSingleDefRecipe *RedRecipe;
- if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
- !VecOp->hasMoreThanOneUniqueUser()) {
+ // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+ if (RdxDesc.getOpcode() == Instruction::Add &&
+ match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+ VPRecipeBase *RecipeA = A->getDefiningRecipe();
+ VPRecipeBase *RecipeB = B->getDefiningRecipe();
+ if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+ match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+ cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
+ cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
+ !A->hasMoreThanOneUniqueUser() && !B->hasMoreThanOneUniqueUser()) {
+ RedRecipe = new VPMulAccRecipe(
+ RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+ CM.useOrderedReductions(RdxDesc),
+ cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
+ cast<VPWidenCastRecipe>(RecipeA),
+ cast<VPWidenCastRecipe>(RecipeB));
+ } else {
+ RedRecipe = new VPMulAccRecipe(
+ RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+ CM.useOrderedReductions(RdxDesc),
+ cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+ }
+ }
+ // VPWidenCastRecipes can folded into VPReductionRecipe
+ else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+ !VecOp->hasMoreThanOneUniqueUser()) {
RedRecipe = new VPExtendedReductionRecipe(
RdxDesc, CurrentLinkI,
cast<CastInst>(
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 7a05145514d0e7..3a49962e8b465c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2770,60 +2770,64 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
/// Whether the reduction is conditional.
bool IsConditional = false;
/// Type after extend.
- Type *ResultTy;
- /// Type for mul.
- Type *MulTy;
- /// reduce.add(OuterExt(mul(InnerExt(), InnerExt())))
- Instruction::CastOps OuterExtOp;
- Instruction::CastOps InnerExtOp;
+ Type *ResultType;
+ /// reduce.add(mul(Ext(), Ext()))
+ Instruction::CastOps ExtOp;
+
+ Instruction *MulInstr;
+ CastInst *Ext0Instr;
+ CastInst *Ext1Instr;
- Instruction *MulI;
- Instruction *OuterExtI;
- Instruction *InnerExt0I;
- Instruction *InnerExt1I;
+ bool IsExtended;
protected:
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
- Instruction *RedI, Instruction::CastOps OuterExtOp,
- Instruction *OuterExtI, Instruction *MulI,
- Instruction::CastOps InnerExtOp, Instruction *InnerExt0I,
- Instruction *InnerExt1I, ArrayRef<VPValue *> Operands,
- VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+ Instruction *RedI, Instruction *MulInstr,
+ Instruction::CastOps ExtOp, Instruction *Ext0Instr,
+ Instruction *Ext1Instr, ArrayRef<VPValue *> Operands,
+ VPValue *CondOp, bool IsOrdered, Type *ResultType)
+ : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+ ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+ Ext0Instr(cast<CastInst>(Ext0Instr)),
+ Ext1Instr(cast<CastInst>(Ext1Instr)) {
+ if (CondOp) {
+ IsConditional = true;
+ addOperand(CondOp);
+ }
+ IsExtended = true;
+ }
+
+ VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+ Instruction *RedI, Instruction *MulInstr,
+ ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
: VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
- ResultTy(ResultTy), MulTy(MulTy), OuterExtOp(OuterExtOp),
- InnerExtOp(InnerExtOp), MulI(MulI), OuterExtI(OuterExtI),
- InnerExt0I(InnerExt0I), InnerExt1I(InnerExt1I) {
+ MulInstr(MulInstr) {
if (CondOp) {
IsConditional = true;
addOperand(CondOp);
}
+ IsExtended = false;
}
public:
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
- Instruction *OuterExt, Instruction *Mul,
- Instruction *InnerExt0, Instruction *InnerExt1,
- VPValue *ChainOp, VPValue *InnerExt0Op, VPValue *InnerExt1Op,
- VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
- : VPMulAccRecipe(
- VPDef::VPMulAccSC, R, RedI, cast<CastInst>(OuterExt)->getOpcode(),
- OuterExt, Mul, cast<CastInst>(InnerExt0)->getOpcode(), InnerExt0,
- InnerExt1, ArrayRef<VPValue *>({ChainOp, InnerExt0Op, InnerExt1Op}),
- CondOp, IsOrdered, ResultTy, MulTy) {}
-
- VPMulAccRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *OuterExt,
- VPWidenRecipe *Mul, VPWidenCastRecipe *InnerExt0,
- VPWidenCastRecipe *InnerExt1)
- : VPMulAccRecipe(
- VPDef::VPMulAccSC, Red->getRecurrenceDescriptor(),
- Red->getUnderlyingInstr(), OuterExt->getOpcode(),
- OuterExt->getUnderlyingInstr(), Mul->getUnderlyingInstr(),
- InnerExt0->getOpcode(), InnerExt0->getUnderlyingInstr(),
- InnerExt1->getUnderlyingInstr(),
- ArrayRef<VPValue *>({Red->getChainOp(), InnerExt0->getOperand(0),
- InnerExt1->getOperand(0)}),
- Red->getCondOp(), Red->isOrdered(), OuterExt->getResultType(),
- InnerExt0->getResultType()) {}
+ VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+ VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+ VPWidenCastRecipe *Ext1)
+ : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+ Ext0->getOpcode(), Ext0->getUnderlyingInstr(),
+ Ext1->getUnderlyingInstr(),
+ ArrayRef<VPValue *>(
+ {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+ CondOp, IsOrdered, Ext0->getResultType()) {}
+
+ VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+ VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+ VPWidenRecipe *Mul)
+ : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+ ArrayRef<VPValue *>(
+ {ChainOp, Mul->getOperand(0), Mul->getOperand(0)}),
+ CondOp, IsOrdered) {}
~VPMulAccRecipe() override = default;
@@ -2839,7 +2843,10 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
}
/// Generate the reduction in the loop
- void execute(VPTransformState &State) override;
+ void execute(VPTransformState &State) override {
+ llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
+ "VPWidenRecipe + VPReductionRecipe before execution");
+ }
/// Return the cost of VPExtendedReductionRecipe.
InstructionCost computeCost(ElementCount VF,
@@ -2862,14 +2869,18 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
/// The VPValue of the scalar Chain being accumulated.
VPValue *getChainOp() const { return getOperand(0); }
/// The VPValue of the vector value to be extended and reduced.
- VPValue *getVecOp() const { return getOperand(1); }
+ VPValue *getVecOp0() const { return getOperand(1); }
+ VPValue *getVecOp1() const { return getOperand(2); }
/// The VPValue of the condition for the block.
VPValue *getCondOp() const {
return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
}
- Type *getResultTy() const { return ResultTy; };
- Instruction::CastOps getOuterExtOpcode() const { return OuterExtOp; };
- Instruction::CastOps getInnerExtOpcode() const { return InnerExtOp; };
+ Type *getResultType() const { return ResultType; };
+ Instruction::CastOps getExtOpcode() const { return ExtOp; };
+ Instruction *getMulInstr() const { return MulInstr; };
+ CastInst *getExt0Instr() const { return Ext0Instr; };
+ CastInst *getExt1Instr() const { return Ext1Instr; };
+ bool isExtended() const { return IsExtended; };
};
/// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index a7858a729c2c92..195a4d676f5fa0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,9 +270,7 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
UI = &WidenMem->getIngredient();
InstructionCost RecipeCost;
- if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
- (Ctx.FoldedRecipes.contains(VF) &&
- Ctx.FoldedRecipes.at(VF).contains(this))) {
+ if ((UI && Ctx.skipCostComputation(UI, VF.isVector()))) {
RecipeCost = 0;
} else {
RecipeCost = computeCost(VF, Ctx);
@@ -2376,6 +2374,85 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
return ExtendedCost + ReductionCost;
}
+InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
+ VPCostContext &Ctx) const {
+ Type *ElementTy =
+ IsExtended ? getResultType() : Ctx.Types.inferScalarType(getVecOp0());
+ auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ unsigned Opcode = RdxDesc.getOpcode();
+
+ assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
+ "Inferred type and recurrence type mismatch.");
+
+ // BaseCost = Reduction cost + BinOp cost
+ InstructionCost ReductionCost =
+ Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+ ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+ Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+
+ // Extended cost
+ InstructionCost ExtendedCost = 0;
+ if (IsExtended) {
+ auto *SrcTy = cast<VectorType>(
+ ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+ auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+ TTI::CastContextHint CCH0 =
+ computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+ // Arm TTI will use the underlying instruction to determine the cost.
+ ExtendedCost = Ctx.TTI.getCastInstrCost(
+ ExtOp, DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+ dyn_cast_if_present<Instruction>(getExt0Instr()));
+ TTI::CastContextHint CCH1 =
+ computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+ ExtendedCost += Ctx.TTI.getCastInstrCost(
+ ExtOp, DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+ dyn_cast_if_present<Instruction>(getExt1Instr()));
+ }
+
+ // Mul cost
+ InstructionCost MulCost;
+ SmallVector<const Value *, 4> Operands;
+ Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
+ if (IsExtended)
+ MulCost = Ctx.TTI.getArithmeticInstrCost(
+ Instruction::Mul, VectorTy, CostKind,
+ {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+ {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+ Operands, MulInstr, &Ctx.TLI);
+ else {
+ VPValue *RHS = getVecOp1();
+ // Certain instructions can be cheaper to vectorize if they have a constant
+ // second vector operand. One example of this are shifts on x86.
+ TargetTransformInfo::OperandValueInfo RHSInfo = {
+ TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None};
+ if (RHS->isLiveIn())
+ RHSInfo = Ctx.TTI.getOperandInfo(RHS->getLiveInIRValue());
+
+ if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
+ RHS->isDefinedOutsideLoopRegions())
+ RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
+ MulCost = Ctx.TTI.getArithmeticInstrCost(
+ Instruction::Mul, VectorTy, CostKind,
+ {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+ RHSInfo, Operands, MulInstr, &Ctx.TLI);
+ }
+
+ // ExtendedReduction Cost
+ VectorType *SrcVecTy =
+ cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+ InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+ getExtOpcode() == Instruction::CastOps::ZExt, ElementTy, SrcVecTy,
+ CostKind);
+
+ // Check if folding ext into ExtendedReduction is profitable.
+ if (MulAccCost.isValid() &&
+ MulAccCost < ExtendedCost + ReductionCost + MulCost) {
+ return MulAccCost;
+ }
+ return ExtendedCost + ReductionCost + MulCost;
+}
+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
@@ -2443,6 +2520,37 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
O << " (with final reduction value stored in invariant address sank "
"outside of loop)";
}
+
+void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
+ VPSlotTracker &SlotTracker) const {
+ O << Indent << "MULACC-REDUCE ";
+ printAsOperand(O, SlotTracker);
+ O << " = ";
+ getChainOp()->printAsOperand(O, SlotTracker);
+ O << " +";
+ if (isa<FPMathOperator>(getUnderlyingInstr()))
+ O << getUnderlyingInstr()->getFastMathFlags();
+ O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+ O << " mul ";
+ if (IsExtended)
+ O << "(";
+ getVecOp0()->printAsOperand(O, SlotTracker);
+ if (IsExtended)
+ O << " extended to " << *getResultType() << ")";
+ if (IsExtended)
+ O << "(";
+ getVecOp1()->printAsOperand(O, SlotTracker);
+ if (IsExtended)
+ O << " extended to " << *getResultType() << ")";
+ if (isConditional()) {
+ O << ", ";
+ getCondOp()->printAsOperand(O, SlotTracker);
+ }
+ O << ")";
+ if (RdxDesc.IntermediateStore)
+ O << " (with final reduction value stored in invariant address sank "
+ "outside of loop)";
+}
#endif
bool VPReplicateRecipe::shouldPack() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 22ac98751bbd86..6c9c157d7a9071 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -520,25 +520,53 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
}
void VPlanTransforms::prepareExecute(VPlan &Plan) {
- errs() << "\n\n\n!!Prepare to execute\n";
ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
Plan.getVectorLoopRegion());
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
vp_depth_first_deep(Plan.getEntry()))) {
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
- if (!isa<VPExtendedReductionRecipe>(&R))
- continue;
- auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
- auto *Ext = new VPWidenCastRecipe(
- ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
- *ExtRed->getExtInstr());
- auto *Red = new VPReductionRecipe(
- ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
- ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
- Ext->insertBefore(ExtRed);
- Red->insertBefore(ExtRed);
- ExtRed->replaceAllUsesWith(Red);
- ExtRed->eraseFromParent();
+ if (isa<VPExtendedReductionRecipe>(&R)) {
+ auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+ auto *Ext = new VPWidenCastRecipe(
+ ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+ *ExtRed->getExtInstr());
+ auto *Red = new VPReductionRecipe(
+ ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
+ ExtRed->getChainOp(), Ext, ExtRed->getCondOp(),
+ ExtRed->isOrdered());
+ Ext->insertBefore(ExtRed);
+ Red->insertBefore(ExtRed);
+ ExtRed->replaceAllUsesWith(Red);
+ ExtRed->eraseFromParent();
+ } else if (isa<VPMulAccRecipe>(&R)) {
+ auto *MulAcc = cast<VPMulAccRecipe>(&R);
+ VPValue *Op0, *Op1;
+ if (MulAcc->isExtended()) {
+ Op0 = new VPWidenCastRecipe(
+ MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+ MulAcc->getResultType(), *MulAcc->getExt0Instr());
+ Op1 = new VPWidenCastRecipe(
+ MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+ MulAcc->getResultType(), *MulAcc->getExt1Instr());
+ Op0->getDefiningRecipe()->insertBefore(MulAcc);
+ Op1->getDefiningRecipe()->insertBefore(MulAcc);
+ } else {
+ Op0 = MulAcc->getVecOp0();
+ Op1 = MulAcc->getVecOp1();
+ }
+ Instruction *MulInstr = MulAcc->getMulInstr();
+ SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
+ auto *Mul = new VPWidenRecipe(*MulInstr,
+ make_range(MulOps.begin(), MulOps.end()));
+ auto *Red = new VPReductionRecipe(
+ MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
+ MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+ MulAcc->isOrdered());
+ Mul->insertBefore(MulAcc);
+ Red->insertBefore(MulAcc);
+ MulAcc->replaceAllUsesWith(Red);
+ MulAcc->eraseFromParent();
+ }
}
}
}
>From b92c42d1337a9776123c29b8394f09f77ae98ac2 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 5 Nov 2024 16:52:31 -0800
Subject: [PATCH 04/14] Fix servel errors and update tests.
We need to update tests since the generated vector IR will be reordered.
---
.../Transforms/Vectorize/LoopVectorize.cpp | 45 ++++++++++++++
llvm/lib/Transforms/Vectorize/VPlan.h | 34 ++++++++---
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 6 +-
.../Transforms/Vectorize/VPlanTransforms.cpp | 14 ++++-
.../LoopVectorize/ARM/mve-reduction-types.ll | 4 +-
.../LoopVectorize/ARM/mve-reductions.ll | 61 ++++++++++---------
.../LoopVectorize/RISCV/inloop-reduction.ll | 32 ++++++----
.../LoopVectorize/reduction-inloop-pred.ll | 34 +++++------
.../LoopVectorize/reduction-inloop.ll | 12 ++--
9 files changed, 163 insertions(+), 79 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 7147c35c807491..ea2e6b96f7711c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7394,6 +7394,19 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
}
if (Instruction *UI = GetInstructionForCost(&R))
SeenInstrs.insert(UI);
+ // VPExtendedReductionRecipe contains a folded extend instruction.
+ if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
+ SeenInstrs.insert(ExtendedRed->getExtInstr());
+ // VPMulAccRecupe constians a mul and otional extend instructions.
+ else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+ SeenInstrs.insert(MulAcc->getMulInstr());
+ if (MulAcc->isExtended()) {
+ SeenInstrs.insert(MulAcc->getExt0Instr());
+ SeenInstrs.insert(MulAcc->getExt1Instr());
+ if (auto *Ext = MulAcc->getExtInstr())
+ SeenInstrs.insert(Ext);
+ }
+ }
}
}
@@ -9401,6 +9414,38 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
CM.useOrderedReductions(RdxDesc),
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
}
+ } else if (RdxDesc.getOpcode() == Instruction::Add &&
+ match(VecOp,
+ m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue(A)),
+ m_ZExtOrSExt(m_VPValue(B)))))) {
+ VPWidenCastRecipe *Ext =
+ dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+ VPWidenRecipe *Mul =
+ dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+ if (Mul && match(Mul, m_Mul(m_ZExtOrSExt(m_VPValue()),
+ m_ZExtOrSExt(m_VPValue())))) {
+ VPWidenRecipe *Mul =
+ cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+ VPWidenCastRecipe *Ext0 =
+ cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
+ VPWidenCastRecipe *Ext1 =
+ cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
+ if (Ext->getOpcode() == Ext0->getOpcode() &&
+ Ext0->getOpcode() == Ext1->getOpcode()) {
+ RedRecipe = new VPMulAccRecipe(
+ RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+ CM.useOrderedReductions(RdxDesc),
+ cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
+ cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
+ } else
+ RedRecipe = new VPExtendedReductionRecipe(
+ RdxDesc, CurrentLinkI,
+ cast<CastInst>(
+ cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
+ PreviousLink, cast<VPWidenCastRecipe>(VecOp)->getOperand(0),
+ CondOp, CM.useOrderedReductions(RdxDesc),
+ cast<VPWidenCastRecipe>(VecOp)->getResultType());
+ }
}
// VPWidenCastRecipes can folded into VPReductionRecipe
else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 3a49962e8b465c..0103686be422d6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2771,23 +2771,27 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
bool IsConditional = false;
/// Type after extend.
Type *ResultType;
- /// reduce.add(mul(Ext(), Ext()))
+ /// reduce.add(ext((mul(Ext(), Ext())))
Instruction::CastOps ExtOp;
Instruction *MulInstr;
+ CastInst *ExtInstr = nullptr;
CastInst *Ext0Instr;
CastInst *Ext1Instr;
bool IsExtended;
+ bool IsOuterExtended = false;
protected:
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
- Instruction *RedI, Instruction *MulInstr,
- Instruction::CastOps ExtOp, Instruction *Ext0Instr,
- Instruction *Ext1Instr, ArrayRef<VPValue *> Operands,
- VPValue *CondOp, bool IsOrdered, Type *ResultType)
+ Instruction *RedI, Instruction *ExtInstr,
+ Instruction *MulInstr, Instruction::CastOps ExtOp,
+ Instruction *Ext0Instr, Instruction *Ext1Instr,
+ ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered,
+ Type *ResultType)
: VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+ ExtInstr(cast_if_present<CastInst>(ExtInstr)),
Ext0Instr(cast<CastInst>(Ext0Instr)),
Ext1Instr(cast<CastInst>(Ext1Instr)) {
if (CondOp) {
@@ -2814,9 +2818,9 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
VPWidenCastRecipe *Ext1)
- : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
- Ext0->getOpcode(), Ext0->getUnderlyingInstr(),
- Ext1->getUnderlyingInstr(),
+ : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
+ Mul->getUnderlyingInstr(), Ext0->getOpcode(),
+ Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
ArrayRef<VPValue *>(
{ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
CondOp, IsOrdered, Ext0->getResultType()) {}
@@ -2826,9 +2830,20 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
VPWidenRecipe *Mul)
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
ArrayRef<VPValue *>(
- {ChainOp, Mul->getOperand(0), Mul->getOperand(0)}),
+ {ChainOp, Mul->getOperand(0), Mul->getOperand(1)}),
CondOp, IsOrdered) {}
+ VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+ VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+ VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
+ VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
+ : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
+ Mul->getUnderlyingInstr(), Ext0->getOpcode(),
+ Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
+ ArrayRef<VPValue *>(
+ {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+ CondOp, IsOrdered, Ext0->getResultType()) {}
+
~VPMulAccRecipe() override = default;
VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
@@ -2878,6 +2893,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
Type *getResultType() const { return ResultType; };
Instruction::CastOps getExtOpcode() const { return ExtOp; };
Instruction *getMulInstr() const { return MulInstr; };
+ CastInst *getExtInstr() const { return ExtInstr; };
CastInst *getExt0Instr() const { return Ext0Instr; };
CastInst *getExt1Instr() const { return Ext1Instr; };
bool isExtended() const { return IsExtended; };
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 195a4d676f5fa0..d28b2d93ee5630 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2376,8 +2376,8 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
- Type *ElementTy =
- IsExtended ? getResultType() : Ctx.Types.inferScalarType(getVecOp0());
+ Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
+ : Ctx.Types.inferScalarType(getVecOp0());
auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
unsigned Opcode = RdxDesc.getOpcode();
@@ -2438,7 +2438,7 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
RHSInfo, Operands, MulInstr, &Ctx.TLI);
}
- // ExtendedReduction Cost
+ // MulAccReduction Cost
VectorType *SrcVecTy =
cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 6c9c157d7a9071..b7cc945747df13 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -554,15 +554,27 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
Op0 = MulAcc->getVecOp0();
Op1 = MulAcc->getVecOp1();
}
+ VPSingleDefRecipe *VecOp;
Instruction *MulInstr = MulAcc->getMulInstr();
SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
auto *Mul = new VPWidenRecipe(*MulInstr,
make_range(MulOps.begin(), MulOps.end()));
+ if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
+ // dbgs() <<"\n!!!"<< *OuterExtInstr << " " << MulAcc->getExtOpcode()
+ // << "\n";
+ VecOp = new VPWidenCastRecipe(
+ MulAcc->getExtOpcode(), Mul,
+ MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
+ *OuterExtInstr);
+ } else
+ VecOp = Mul;
auto *Red = new VPReductionRecipe(
MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
- MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+ MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
MulAcc->isOrdered());
Mul->insertBefore(MulAcc);
+ if (VecOp != Mul)
+ VecOp->insertBefore(MulAcc);
Red->insertBefore(MulAcc);
MulAcc->replaceAllUsesWith(Red);
MulAcc->eraseFromParent();
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
index 832d4db53036fb..9d4b83edd55ccb 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
@@ -24,11 +24,11 @@ define i32 @mla_i32(ptr noalias nocapture readonly %A, ptr noalias nocapture rea
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[TMP0]]
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[TMP1]], i32 0
; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[TMP0]]
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0
; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP5]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
; CHECK-NEXT: [[TMP6:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <16 x i32> [[TMP6]], [[TMP3]]
; CHECK-NEXT: [[TMP8:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP7]], <16 x i32> zeroinitializer
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
@@ -107,11 +107,11 @@ define i32 @mla_i8(ptr noalias nocapture readonly %A, ptr noalias nocapture read
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[TMP0]]
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[TMP1]], i32 0
; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[TMP0]]
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0
; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP5]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
; CHECK-NEXT: [[TMP6:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <16 x i32> [[TMP6]], [[TMP3]]
; CHECK-NEXT: [[TMP8:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP7]], <16 x i32> zeroinitializer
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index c115c91cff896c..93a030c2bf1556 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -646,11 +646,11 @@ define i64 @mla_i16_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
-; CHECK-NEXT: [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <8 x i16>, ptr [[TMP2]], align 2
-; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT: [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
+; CHECK-NEXT: [[TMP11:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
+; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT: [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP3]]
; CHECK-NEXT: [[TMP5:%.*]] = sext <8 x i32> [[TMP4]] to <8 x i64>
; CHECK-NEXT: [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
; CHECK-NEXT: [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
@@ -726,11 +726,11 @@ define i64 @mla_i8_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i32
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i8>, ptr [[TMP0]], align 1
-; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <8 x i8>, ptr [[TMP2]], align 1
-; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT: [[TMP4:%.*]] = mul nuw nsw <8 x i32> [[TMP3]], [[TMP1]]
+; CHECK-NEXT: [[TMP11:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
+; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT: [[TMP4:%.*]] = mul nuw nsw <8 x i32> [[TMP11]], [[TMP3]]
; CHECK-NEXT: [[TMP5:%.*]] = zext nneg <8 x i32> [[TMP4]] to <8 x i64>
; CHECK-NEXT: [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
; CHECK-NEXT: [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
@@ -802,11 +802,11 @@ define i32 @mla_i32_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP0]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[Y:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP1]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
-; CHECK-NEXT: [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i32, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT: [[WIDE_MASKED_LOAD2:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP7]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
+; CHECK-NEXT: [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
; CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP2]], <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP3]])
; CHECK-NEXT: [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
@@ -855,10 +855,10 @@ define i32 @mla_i16_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP0]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
-; CHECK-NEXT: [[TMP1:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD]] to <8 x i32>
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP2]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32>
+; CHECK-NEXT: [[TMP1:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD]] to <8 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
; CHECK-NEXT: [[TMP5:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP4]], <8 x i32> zeroinitializer
; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP5]])
@@ -910,10 +910,10 @@ define i32 @mla_i8_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i32
; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT: [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT: [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul nuw nsw <16 x i32> [[TMP3]], [[TMP1]]
; CHECK-NEXT: [[TMP5:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP4]], <16 x i32> zeroinitializer
; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
@@ -963,11 +963,11 @@ define signext i16 @mla_i16_i16(ptr nocapture readonly %x, ptr nocapture readonl
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i16 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP0]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP1]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
-; CHECK-NEXT: [[TMP2:%.*]] = mul <8 x i16> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT: [[WIDE_MASKED_LOAD2:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP7]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
+; CHECK-NEXT: [[TMP2:%.*]] = mul <8 x i16> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
; CHECK-NEXT: [[TMP3:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> [[TMP2]], <8 x i16> zeroinitializer
; CHECK-NEXT: [[TMP4:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP3]])
; CHECK-NEXT: [[TMP5]] = add i16 [[TMP4]], [[VEC_PHI]]
@@ -1016,10 +1016,10 @@ define signext i16 @mla_i8_i16(ptr nocapture readonly %x, ptr nocapture readonly
; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT: [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i16>
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i16>
+; CHECK-NEXT: [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i16>
; CHECK-NEXT: [[TMP4:%.*]] = mul nuw <16 x i16> [[TMP3]], [[TMP1]]
; CHECK-NEXT: [[TMP5:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i16> [[TMP4]], <16 x i16> zeroinitializer
; CHECK-NEXT: [[TMP6:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[TMP5]])
@@ -1069,11 +1069,11 @@ define zeroext i8 @mla_i8_i8(ptr nocapture readonly %x, ptr nocapture readonly %
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i8 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP1]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT: [[TMP2:%.*]] = mul <16 x i8> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT: [[WIDE_MASKED_LOAD2:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP7]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
+; CHECK-NEXT: [[TMP2:%.*]] = mul <16 x i8> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
; CHECK-NEXT: [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> [[TMP2]], <16 x i8> zeroinitializer
; CHECK-NEXT: [[TMP4:%.*]] = call i8 @llvm.vector.reduce.add.v16i8(<16 x i8> [[TMP3]])
; CHECK-NEXT: [[TMP5]] = add i8 [[TMP4]], [[VEC_PHI]]
@@ -1122,10 +1122,10 @@ define i32 @red_mla_ext_s8_s16_s32(ptr noalias nocapture readonly %A, ptr noalia
; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <8 x i8> @llvm.masked.load.v8i8.p0(ptr [[TMP0]], i32 1, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i8> poison)
-; CHECK-NEXT: [[TMP1:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32>
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[B:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP2]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32>
+; CHECK-NEXT: [[TMP1:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
; CHECK-NEXT: [[TMP5:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP4]], <8 x i32> zeroinitializer
; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP5]])
@@ -1183,11 +1183,11 @@ define i64 @red_mla_ext_s16_u16_s64(ptr noalias nocapture readonly %A, ptr noali
; CHECK: vector.body:
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, ptr [[TMP0]], align 1
-; CHECK-NEXT: [[TMP1:%.*]] = sext <4 x i16> [[WIDE_LOAD]] to <4 x i32>
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[B:%.*]], i32 [[INDEX]]
-; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <4 x i16>, ptr [[TMP2]], align 2
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, ptr [[TMP2]], align 1
+; CHECK-NEXT: [[TMP1:%.*]] = sext <4 x i16> [[WIDE_LOAD]] to <4 x i32>
+; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds i16, ptr [[B1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <4 x i16>, ptr [[TMP11]], align 2
; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i16> [[WIDE_LOAD1]] to <4 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul nsw <4 x i32> [[TMP3]], [[TMP1]]
; CHECK-NEXT: [[TMP5:%.*]] = zext <4 x i32> [[TMP4]] to <4 x i64>
@@ -1206,10 +1206,10 @@ define i64 @red_mla_ext_s16_u16_s64(ptr noalias nocapture readonly %A, ptr noali
; CHECK: for.body:
; CHECK-NEXT: [[I_011:%.*]] = phi i32 [ [[INC:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
; CHECK-NEXT: [[S_010:%.*]] = phi i64 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
-; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i16, ptr [[A]], i32 [[I_011]]
+; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i16, ptr [[B]], i32 [[I_011]]
; CHECK-NEXT: [[TMP9:%.*]] = load i16, ptr [[ARRAYIDX]], align 1
; CHECK-NEXT: [[CONV:%.*]] = sext i16 [[TMP9]] to i32
-; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds i16, ptr [[B]], i32 [[I_011]]
+; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds i16, ptr [[B1]], i32 [[I_011]]
; CHECK-NEXT: [[TMP10:%.*]] = load i16, ptr [[ARRAYIDX1]], align 2
; CHECK-NEXT: [[CONV2:%.*]] = zext i16 [[TMP10]] to i32
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[CONV2]], [[CONV]]
@@ -1268,12 +1268,12 @@ define i32 @red_mla_u8_s8_u32(ptr noalias nocapture readonly %A, ptr noalias noc
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP0]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
-; CHECK-NEXT: [[TMP1:%.*]] = zext <4 x i8> [[WIDE_MASKED_LOAD]] to <4 x i32>
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP2]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
-; CHECK-NEXT: [[TMP3:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD1]] to <4 x i32>
+; CHECK-NEXT: [[TMP1:%.*]] = zext <4 x i8> [[WIDE_MASKED_LOAD1]] to <4 x i32>
+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i8, ptr [[B1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT: [[WIDE_MASKED_LOAD2:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP9]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
+; CHECK-NEXT: [[TMP3:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD2]] to <4 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul nsw <4 x i32> [[TMP3]], [[TMP1]]
; CHECK-NEXT: [[TMP5:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP4]], <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP5]])
@@ -1413,7 +1413,8 @@ define i32 @mla_i8_i32_multiuse(ptr nocapture readonly %x, ptr nocapture readonl
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
; CHECK-NEXT: [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
-; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP1]]
+; CHECK-NEXT: [[TMP7:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
+; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP7]]
; CHECK-NEXT: [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP2]], <16 x i32> zeroinitializer
; CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP3]])
; CHECK-NEXT: [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
index 8ca2bd1f286ae3..9f1a61ebb5efef 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
@@ -187,23 +187,33 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
; IF-EVL-INLOOP-NEXT: [[N_RND_UP:%.*]] = add i32 [[N]], [[TMP2]]
; IF-EVL-INLOOP-NEXT: [[N_MOD_VF:%.*]] = urem i32 [[N_RND_UP]], [[TMP1]]
; IF-EVL-INLOOP-NEXT: [[N_VEC:%.*]] = sub i32 [[N_RND_UP]], [[N_MOD_VF]]
+; IF-EVL-INLOOP-NEXT: [[TRIP_COUNT_MINUS_1:%.*]] = sub i32 [[N]], 1
; IF-EVL-INLOOP-NEXT: [[TMP3:%.*]] = call i32 @llvm.vscale.i32()
; IF-EVL-INLOOP-NEXT: [[TMP4:%.*]] = mul i32 [[TMP3]], 8
+; IF-EVL-INLOOP-NEXT: [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[TRIP_COUNT_MINUS_1]], i64 0
+; IF-EVL-INLOOP-NEXT: [[BROADCAST_SPLAT2:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT1]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
; IF-EVL-INLOOP-NEXT: br label [[VECTOR_BODY:%.*]]
; IF-EVL-INLOOP: vector.body:
; IF-EVL-INLOOP-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; IF-EVL-INLOOP-NEXT: [[EVL_BASED_IV:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_EVL_NEXT:%.*]], [[VECTOR_BODY]] ]
-; IF-EVL-INLOOP-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP11:%.*]], [[VECTOR_BODY]] ]
-; IF-EVL-INLOOP-NEXT: [[AVL:%.*]] = sub i32 [[N]], [[EVL_BASED_IV]]
-; IF-EVL-INLOOP-NEXT: [[TMP5:%.*]] = call i32 @llvm.experimental.get.vector.length.i32(i32 [[AVL]], i32 8, i1 true)
-; IF-EVL-INLOOP-NEXT: [[TMP6:%.*]] = add i32 [[EVL_BASED_IV]], 0
-; IF-EVL-INLOOP-NEXT: [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP6]]
-; IF-EVL-INLOOP-NEXT: [[TMP8:%.*]] = getelementptr inbounds i16, ptr [[TMP7]], i32 0
-; IF-EVL-INLOOP-NEXT: [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP8]], <vscale x 8 x i1> shufflevector (<vscale x 8 x i1> insertelement (<vscale x 8 x i1> poison, i1 true, i64 0), <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer), i32 [[TMP5]])
-; IF-EVL-INLOOP-NEXT: [[TMP9:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
-; IF-EVL-INLOOP-NEXT: [[TMP10:%.*]] = call i32 @llvm.vp.reduce.add.nxv8i32(i32 0, <vscale x 8 x i32> [[TMP9]], <vscale x 8 x i1> shufflevector (<vscale x 8 x i1> insertelement (<vscale x 8 x i1> poison, i1 true, i64 0), <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer), i32 [[TMP5]])
-; IF-EVL-INLOOP-NEXT: [[TMP11]] = add i32 [[TMP10]], [[VEC_PHI]]
-; IF-EVL-INLOOP-NEXT: [[INDEX_EVL_NEXT]] = add i32 [[TMP5]], [[EVL_BASED_IV]]
+; IF-EVL-INLOOP-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ]
+; IF-EVL-INLOOP-NEXT: [[TMP5:%.*]] = sub i32 [[N]], [[EVL_BASED_IV]]
+; IF-EVL-INLOOP-NEXT: [[TMP6:%.*]] = call i32 @llvm.experimental.get.vector.length.i32(i32 [[TMP5]], i32 8, i1 true)
+; IF-EVL-INLOOP-NEXT: [[TMP7:%.*]] = add i32 [[EVL_BASED_IV]], 0
+; IF-EVL-INLOOP-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[EVL_BASED_IV]], i64 0
+; IF-EVL-INLOOP-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
+; IF-EVL-INLOOP-NEXT: [[TMP15:%.*]] = call <vscale x 8 x i32> @llvm.stepvector.nxv8i32()
+; IF-EVL-INLOOP-NEXT: [[TMP16:%.*]] = add <vscale x 8 x i32> zeroinitializer, [[TMP15]]
+; IF-EVL-INLOOP-NEXT: [[VEC_IV:%.*]] = add <vscale x 8 x i32> [[BROADCAST_SPLAT]], [[TMP16]]
+; IF-EVL-INLOOP-NEXT: [[TMP17:%.*]] = icmp ule <vscale x 8 x i32> [[VEC_IV]], [[BROADCAST_SPLAT2]]
+; IF-EVL-INLOOP-NEXT: [[TMP8:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP7]]
+; IF-EVL-INLOOP-NEXT: [[TMP9:%.*]] = getelementptr inbounds i16, ptr [[TMP8]], i32 0
+; IF-EVL-INLOOP-NEXT: [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP9]], <vscale x 8 x i1> shufflevector (<vscale x 8 x i1> insertelement (<vscale x 8 x i1> poison, i1 true, i64 0), <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer), i32 [[TMP6]])
+; IF-EVL-INLOOP-NEXT: [[TMP10:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
+; IF-EVL-INLOOP-NEXT: [[TMP18:%.*]] = select <vscale x 8 x i1> [[TMP17]], <vscale x 8 x i32> [[TMP10]], <vscale x 8 x i32> zeroinitializer
+; IF-EVL-INLOOP-NEXT: [[TMP11:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP18]])
+; IF-EVL-INLOOP-NEXT: [[TMP12]] = add i32 [[TMP11]], [[VEC_PHI]]
+; IF-EVL-INLOOP-NEXT: [[INDEX_EVL_NEXT]] = add i32 [[TMP6]], [[EVL_BASED_IV]]
; IF-EVL-INLOOP-NEXT: [[INDEX_NEXT]] = add i32 [[INDEX]], [[TMP4]]
; IF-EVL-INLOOP-NEXT: [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; IF-EVL-INLOOP-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
index 6771f561913130..cb025608882e6a 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
@@ -424,62 +424,62 @@ define i32 @reduction_mix(ptr noalias nocapture %A, ptr noalias nocapture %B) {
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i1> [[TMP0]], i64 0
; CHECK-NEXT: br i1 [[TMP1]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
; CHECK: pred.load.if:
-; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
-; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[TMP2]], align 4
-; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x i32> poison, i32 [[TMP3]], i64 0
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[INDEX]]
; CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr [[TMP5]], align 4
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x i32> poison, i32 [[TMP6]], i64 0
+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[B1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT: [[TMP12:%.*]] = load i32, ptr [[TMP8]], align 4
+; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x i32> poison, i32 [[TMP12]], i64 0
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE]]
; CHECK: pred.load.continue:
-; CHECK-NEXT: [[TMP8:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP4]], [[PRED_LOAD_IF]] ]
; CHECK-NEXT: [[TMP9:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP7]], [[PRED_LOAD_IF]] ]
+; CHECK-NEXT: [[TMP14:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP13]], [[PRED_LOAD_IF]] ]
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <4 x i1> [[TMP0]], i64 1
; CHECK-NEXT: br i1 [[TMP10]], label [[PRED_LOAD_IF3:%.*]], label [[PRED_LOAD_CONTINUE4:%.*]]
; CHECK: pred.load.if3:
; CHECK-NEXT: [[TMP11:%.*]] = or disjoint i64 [[INDEX]], 1
-; CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP11]]
-; CHECK-NEXT: [[TMP13:%.*]] = load i32, ptr [[TMP12]], align 4
-; CHECK-NEXT: [[TMP14:%.*]] = insertelement <4 x i32> [[TMP8]], i32 [[TMP13]], i64 1
; CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP11]]
; CHECK-NEXT: [[TMP16:%.*]] = load i32, ptr [[TMP15]], align 4
; CHECK-NEXT: [[TMP17:%.*]] = insertelement <4 x i32> [[TMP9]], i32 [[TMP16]], i64 1
+; CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP11]]
+; CHECK-NEXT: [[TMP22:%.*]] = load i32, ptr [[TMP18]], align 4
+; CHECK-NEXT: [[TMP23:%.*]] = insertelement <4 x i32> [[TMP14]], i32 [[TMP22]], i64 1
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE4]]
; CHECK: pred.load.continue4:
-; CHECK-NEXT: [[TMP18:%.*]] = phi <4 x i32> [ [[TMP8]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP14]], [[PRED_LOAD_IF3]] ]
; CHECK-NEXT: [[TMP19:%.*]] = phi <4 x i32> [ [[TMP9]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP17]], [[PRED_LOAD_IF3]] ]
+; CHECK-NEXT: [[TMP24:%.*]] = phi <4 x i32> [ [[TMP14]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP23]], [[PRED_LOAD_IF3]] ]
; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x i1> [[TMP0]], i64 2
; CHECK-NEXT: br i1 [[TMP20]], label [[PRED_LOAD_IF5:%.*]], label [[PRED_LOAD_CONTINUE6:%.*]]
; CHECK: pred.load.if5:
; CHECK-NEXT: [[TMP21:%.*]] = or disjoint i64 [[INDEX]], 2
-; CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP21]]
-; CHECK-NEXT: [[TMP23:%.*]] = load i32, ptr [[TMP22]], align 4
-; CHECK-NEXT: [[TMP24:%.*]] = insertelement <4 x i32> [[TMP18]], i32 [[TMP23]], i64 2
; CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP21]]
; CHECK-NEXT: [[TMP26:%.*]] = load i32, ptr [[TMP25]], align 4
; CHECK-NEXT: [[TMP27:%.*]] = insertelement <4 x i32> [[TMP19]], i32 [[TMP26]], i64 2
+; CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP21]]
+; CHECK-NEXT: [[TMP32:%.*]] = load i32, ptr [[TMP28]], align 4
+; CHECK-NEXT: [[TMP33:%.*]] = insertelement <4 x i32> [[TMP24]], i32 [[TMP32]], i64 2
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE6]]
; CHECK: pred.load.continue6:
-; CHECK-NEXT: [[TMP28:%.*]] = phi <4 x i32> [ [[TMP18]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP24]], [[PRED_LOAD_IF5]] ]
; CHECK-NEXT: [[TMP29:%.*]] = phi <4 x i32> [ [[TMP19]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP27]], [[PRED_LOAD_IF5]] ]
+; CHECK-NEXT: [[TMP34:%.*]] = phi <4 x i32> [ [[TMP24]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP33]], [[PRED_LOAD_IF5]] ]
; CHECK-NEXT: [[TMP30:%.*]] = extractelement <4 x i1> [[TMP0]], i64 3
; CHECK-NEXT: br i1 [[TMP30]], label [[PRED_LOAD_IF7:%.*]], label [[PRED_LOAD_CONTINUE8]]
; CHECK: pred.load.if7:
; CHECK-NEXT: [[TMP31:%.*]] = or disjoint i64 [[INDEX]], 3
-; CHECK-NEXT: [[TMP32:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP31]]
-; CHECK-NEXT: [[TMP33:%.*]] = load i32, ptr [[TMP32]], align 4
-; CHECK-NEXT: [[TMP34:%.*]] = insertelement <4 x i32> [[TMP28]], i32 [[TMP33]], i64 3
; CHECK-NEXT: [[TMP35:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP31]]
; CHECK-NEXT: [[TMP36:%.*]] = load i32, ptr [[TMP35]], align 4
; CHECK-NEXT: [[TMP37:%.*]] = insertelement <4 x i32> [[TMP29]], i32 [[TMP36]], i64 3
+; CHECK-NEXT: [[TMP38:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP31]]
+; CHECK-NEXT: [[TMP48:%.*]] = load i32, ptr [[TMP38]], align 4
+; CHECK-NEXT: [[TMP49:%.*]] = insertelement <4 x i32> [[TMP34]], i32 [[TMP48]], i64 3
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE8]]
; CHECK: pred.load.continue8:
-; CHECK-NEXT: [[TMP38:%.*]] = phi <4 x i32> [ [[TMP28]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP34]], [[PRED_LOAD_IF7]] ]
; CHECK-NEXT: [[TMP39:%.*]] = phi <4 x i32> [ [[TMP29]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP37]], [[PRED_LOAD_IF7]] ]
-; CHECK-NEXT: [[TMP40:%.*]] = mul nsw <4 x i32> [[TMP39]], [[TMP38]]
+; CHECK-NEXT: [[TMP50:%.*]] = phi <4 x i32> [ [[TMP34]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP49]], [[PRED_LOAD_IF7]] ]
; CHECK-NEXT: [[TMP41:%.*]] = select <4 x i1> [[TMP0]], <4 x i32> [[VEC_IND1]], <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP42:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP41]])
; CHECK-NEXT: [[TMP43:%.*]] = add i32 [[TMP42]], [[VEC_PHI]]
+; CHECK-NEXT: [[TMP40:%.*]] = mul nsw <4 x i32> [[TMP50]], [[TMP39]]
; CHECK-NEXT: [[TMP44:%.*]] = select <4 x i1> [[TMP0]], <4 x i32> [[TMP40]], <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP45:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP44]])
; CHECK-NEXT: [[TMP46]] = add i32 [[TMP45]], [[TMP43]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index fe74a7c3a9b27c..b578e61d85dfa1 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -221,13 +221,13 @@ define i32 @reduction_mix(ptr noalias nocapture %A, ptr noalias nocapture %B) {
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP6:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[VEC_IND:%.*]] = phi <4 x i32> [ <i32 0, i32 1, i32 2, i32 3>, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
-; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP0]], align 4
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <4 x i32>, ptr [[TMP1]], align 4
-; CHECK-NEXT: [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[B1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <4 x i32>, ptr [[TMP8]], align 4
; CHECK-NEXT: [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[VEC_IND]])
; CHECK-NEXT: [[TMP4:%.*]] = add i32 [[TMP3]], [[VEC_PHI]]
+; CHECK-NEXT: [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD2]], [[WIDE_LOAD1]]
; CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP2]])
; CHECK-NEXT: [[TMP6]] = add i32 [[TMP5]], [[TMP4]]
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
@@ -329,11 +329,11 @@ define i32 @start_at_non_zero(ptr nocapture %in, ptr nocapture %coeff, ptr nocap
; CHECK: vector.body:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 120, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[IN:%.*]], i64 [[INDEX]]
-; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP0]], align 4
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[COEFF:%.*]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <4 x i32>, ptr [[TMP1]], align 4
-; CHECK-NEXT: [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
+; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds i32, ptr [[COEFF1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <4 x i32>, ptr [[TMP6]], align 4
+; CHECK-NEXT: [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD2]], [[WIDE_LOAD1]]
; CHECK-NEXT: [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP2]])
; CHECK-NEXT: [[TMP4]] = add i32 [[TMP3]], [[VEC_PHI]]
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
>From b7b2586eedd4a4ec3cb788fafd1ab6a0520218d2 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 5 Nov 2024 17:49:26 -0800
Subject: [PATCH 05/14] Refactors
Using lamda function to early return when pattern matched.
Leave some assertions.
---
.../Transforms/Vectorize/LoopVectorize.cpp | 121 ++++++++---------
llvm/lib/Transforms/Vectorize/VPlan.h | 55 ++++----
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 123 +-----------------
.../Transforms/Vectorize/VPlanTransforms.cpp | 19 +--
.../LoopVectorize/ARM/mve-reductions.ll | 3 +-
.../LoopVectorize/RISCV/inloop-reduction.ll | 8 +-
6 files changed, 113 insertions(+), 216 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index ea2e6b96f7711c..ee429f9d3fb97d 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7397,7 +7397,7 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
// VPExtendedReductionRecipe contains a folded extend instruction.
if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
SeenInstrs.insert(ExtendedRed->getExtInstr());
- // VPMulAccRecupe constians a mul and otional extend instructions.
+ // VPMulAccRecipe constians a mul and otional extend instructions.
else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
SeenInstrs.insert(MulAcc->getMulInstr());
if (MulAcc->isExtended()) {
@@ -9390,77 +9390,82 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
if (CM.blockNeedsPredicationForAnyReason(BB))
CondOp = RecipeBuilder.getBlockInMask(BB);
- VPValue *A, *B;
- VPSingleDefRecipe *RedRecipe;
- // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
- if (RdxDesc.getOpcode() == Instruction::Add &&
- match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
- VPRecipeBase *RecipeA = A->getDefiningRecipe();
- VPRecipeBase *RecipeB = B->getDefiningRecipe();
- if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
- match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
- cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
- cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
- !A->hasMoreThanOneUniqueUser() && !B->hasMoreThanOneUniqueUser()) {
- RedRecipe = new VPMulAccRecipe(
- RdxDesc, CurrentLinkI, PreviousLink, CondOp,
- CM.useOrderedReductions(RdxDesc),
- cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
- cast<VPWidenCastRecipe>(RecipeA),
- cast<VPWidenCastRecipe>(RecipeB));
- } else {
- RedRecipe = new VPMulAccRecipe(
- RdxDesc, CurrentLinkI, PreviousLink, CondOp,
- CM.useOrderedReductions(RdxDesc),
- cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
- }
- } else if (RdxDesc.getOpcode() == Instruction::Add &&
- match(VecOp,
- m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue(A)),
- m_ZExtOrSExt(m_VPValue(B)))))) {
- VPWidenCastRecipe *Ext =
- dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
- VPWidenRecipe *Mul =
- dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
- if (Mul && match(Mul, m_Mul(m_ZExtOrSExt(m_VPValue()),
- m_ZExtOrSExt(m_VPValue())))) {
+ auto TryToMatchMulAcc = [&]() -> VPSingleDefRecipe * {
+ VPValue *A, *B;
+ if (RdxDesc.getOpcode() != Instruction::Add)
+ return nullptr;
+ // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+ if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
+ !VecOp->hasMoreThanOneUniqueUser()) {
+ VPRecipeBase *RecipeA = A->getDefiningRecipe();
+ VPRecipeBase *RecipeB = B->getDefiningRecipe();
+ if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+ match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+ cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
+ cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
+ !A->hasMoreThanOneUniqueUser() &&
+ !B->hasMoreThanOneUniqueUser()) {
+ return new VPMulAccRecipe(
+ RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+ CM.useOrderedReductions(RdxDesc),
+ cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
+ cast<VPWidenCastRecipe>(RecipeA),
+ cast<VPWidenCastRecipe>(RecipeB));
+ } else {
+ // Matched reduce.add(mul(...))
+ return new VPMulAccRecipe(
+ RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+ CM.useOrderedReductions(RdxDesc),
+ cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+ }
+ // Matched reduce.add(ext(mul(ext, ext)))
+ // Note that 3 extend instructions must have same opcode.
+ } else if (match(VecOp,
+ m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
+ m_ZExtOrSExt(m_VPValue())))) &&
+ !VecOp->hasMoreThanOneUniqueUser()) {
+ VPWidenCastRecipe *Ext =
+ dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
VPWidenRecipe *Mul =
- cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+ dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
VPWidenCastRecipe *Ext0 =
cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
VPWidenCastRecipe *Ext1 =
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
if (Ext->getOpcode() == Ext0->getOpcode() &&
- Ext0->getOpcode() == Ext1->getOpcode()) {
- RedRecipe = new VPMulAccRecipe(
+ Ext0->getOpcode() == Ext1->getOpcode() &&
+ !Mul->hasMoreThanOneUniqueUser() &&
+ !Ext0->hasMoreThanOneUniqueUser() &&
+ !Ext1->hasMoreThanOneUniqueUser()) {
+ return new VPMulAccRecipe(
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
CM.useOrderedReductions(RdxDesc),
cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
- } else
- RedRecipe = new VPExtendedReductionRecipe(
- RdxDesc, CurrentLinkI,
- cast<CastInst>(
- cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
- PreviousLink, cast<VPWidenCastRecipe>(VecOp)->getOperand(0),
- CondOp, CM.useOrderedReductions(RdxDesc),
- cast<VPWidenCastRecipe>(VecOp)->getResultType());
+ }
}
- }
- // VPWidenCastRecipes can folded into VPReductionRecipe
- else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
- !VecOp->hasMoreThanOneUniqueUser()) {
- RedRecipe = new VPExtendedReductionRecipe(
- RdxDesc, CurrentLinkI,
- cast<CastInst>(
- cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
- PreviousLink, A, CondOp, CM.useOrderedReductions(RdxDesc),
- cast<VPWidenCastRecipe>(VecOp)->getResultType());
- } else {
+ return nullptr;
+ };
+ auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
+ VPValue *A;
+ if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+ !VecOp->hasMoreThanOneUniqueUser()) {
+ return new VPExtendedReductionRecipe(
+ RdxDesc, CurrentLinkI, PreviousLink,
+ cast<VPWidenCastRecipe>(VecOp), CondOp,
+ CM.useOrderedReductions(RdxDesc));
+ }
+ return nullptr;
+ };
+ VPSingleDefRecipe *RedRecipe;
+ if (auto *MulAcc = TryToMatchMulAcc())
+ RedRecipe = MulAcc;
+ else if (auto *ExtendedRed = TryToMatchExtendedReduction())
+ RedRecipe = ExtendedRed;
+ else
RedRecipe =
new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
CondOp, CM.useOrderedReductions(RdxDesc));
- }
// Append the recipe to the end of the VPBasicBlock because we need to
// ensure that it comes after all of it's inputs, including CondOp.
// Note that this transformation may leave over dead recipes (including
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 0103686be422d6..c7117551bf5dda 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2670,18 +2670,19 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
bool IsConditional = false;
/// Type after extend.
Type *ResultTy;
+ /// Opcode for the extend instruction.
Instruction::CastOps ExtOp;
- CastInst *CastInstr;
+ CastInst *ExtInstr;
bool IsZExt;
protected:
VPExtendedReductionRecipe(const unsigned char SC,
const RecurrenceDescriptor &R, Instruction *RedI,
- Instruction::CastOps ExtOp, CastInst *CastI,
+ Instruction::CastOps ExtOp, CastInst *ExtI,
ArrayRef<VPValue *> Operands, VPValue *CondOp,
bool IsOrdered, Type *ResultTy)
: VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
- ResultTy(ResultTy), ExtOp(ExtOp), CastInstr(CastI) {
+ ResultTy(ResultTy), ExtOp(ExtOp), ExtInstr(ExtI) {
if (CondOp) {
IsConditional = true;
addOperand(CondOp);
@@ -2691,20 +2692,13 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
public:
VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
- CastInst *CastI, VPValue *ChainOp, VPValue *VecOp,
- VPValue *CondOp, bool IsOrdered, Type *ResultTy)
- : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
- CastI->getOpcode(), CastI,
- ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
- IsOrdered, ResultTy) {}
-
- VPExtendedReductionRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *Ext)
+ VPValue *ChainOp, VPWidenCastRecipe *Ext,
+ VPValue *CondOp, bool IsOrdered)
: VPExtendedReductionRecipe(
- VPDef::VPExtendedReductionSC, Red->getRecurrenceDescriptor(),
- Red->getUnderlyingInstr(), Ext->getOpcode(),
+ VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
cast<CastInst>(Ext->getUnderlyingInstr()),
- ArrayRef<VPValue *>({Red->getChainOp(), Ext->getOperand(0)}),
- Red->getCondOp(), Red->isOrdered(), Ext->getResultType()) {}
+ ArrayRef<VPValue *>({ChainOp, Ext->getOperand(0)}), CondOp,
+ IsOrdered, Ext->getResultType()) {}
~VPExtendedReductionRecipe() override = default;
@@ -2721,7 +2715,6 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
return R && classof(R);
}
- /// Generate the reduction in the loop
void execute(VPTransformState &State) override {
llvm_unreachable("VPExtendedReductionRecipe should be transform to "
"VPExtendedRecipe + VPReductionRecipe before execution.");
@@ -2753,9 +2746,12 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
VPValue *getCondOp() const {
return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
}
+ /// The Type after extended.
Type *getResultType() const { return ResultTy; };
+ /// The Opcode of extend instruction.
Instruction::CastOps getExtOpcode() const { return ExtOp; };
- CastInst *getExtInstr() const { return CastInstr; };
+ /// The CastInst of the extend instruction.
+ CastInst *getExtInstr() const { return ExtInstr; };
};
/// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2771,16 +2767,17 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
bool IsConditional = false;
/// Type after extend.
Type *ResultType;
- /// reduce.add(ext((mul(Ext(), Ext())))
+ // Note that all extend instruction must have the same opcode in MulAcc.
Instruction::CastOps ExtOp;
+ /// reduce.add(ext(mul(ext0(), ext1())))
Instruction *MulInstr;
CastInst *ExtInstr = nullptr;
- CastInst *Ext0Instr;
- CastInst *Ext1Instr;
+ CastInst *Ext0Instr = nullptr;
+ CastInst *Ext1Instr = nullptr;
+ /// Is this MulAcc recipe contains extend recipes?
bool IsExtended;
- bool IsOuterExtended = false;
protected:
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2794,6 +2791,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
ExtInstr(cast_if_present<CastInst>(ExtInstr)),
Ext0Instr(cast<CastInst>(Ext0Instr)),
Ext1Instr(cast<CastInst>(Ext1Instr)) {
+ assert(MulInstr->getOpcode() == Instruction::Mul);
if (CondOp) {
IsConditional = true;
addOperand(CondOp);
@@ -2806,6 +2804,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
: VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
MulInstr(MulInstr) {
+ assert(MulInstr->getOpcode() == Instruction::Mul);
if (CondOp) {
IsConditional = true;
addOperand(CondOp);
@@ -2857,13 +2856,12 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
return R && classof(R);
}
- /// Generate the reduction in the loop
void execute(VPTransformState &State) override {
llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
"VPWidenRecipe + VPReductionRecipe before execution");
}
- /// Return the cost of VPExtendedReductionRecipe.
+ /// Return the cost of VPMulAccRecipe.
InstructionCost computeCost(ElementCount VF,
VPCostContext &Ctx) const override;
@@ -2890,13 +2888,24 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
VPValue *getCondOp() const {
return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
}
+ /// Return the type after inner extended, which must equal to the type of mul
+ /// instruction. If the ResultType != recurrenceType, than it must have a
+ /// extend recipe after mul recipe.
Type *getResultType() const { return ResultType; };
+ /// The opcode of the extend instructions.
Instruction::CastOps getExtOpcode() const { return ExtOp; };
+ /// The underlying instruction for VPWidenRecipe.
Instruction *getMulInstr() const { return MulInstr; };
+ /// The underlying Instruction for outer VPWidenCastRecipe.
CastInst *getExtInstr() const { return ExtInstr; };
+ /// The underlying Instruction for inner VPWidenCastRecipe.
CastInst *getExt0Instr() const { return Ext0Instr; };
+ /// The underlying Instruction for inner VPWidenCastRecipe.
CastInst *getExt1Instr() const { return Ext1Instr; };
+ /// Return if this MulAcc recipe contains extend instructions.
bool isExtended() const { return IsExtended; };
+ /// Return if the operands of mul instruction come from same extend.
+ bool isSameExtend() const { return Ext0Instr == Ext1Instr; };
};
/// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index d28b2d93ee5630..ab40bed209167d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2209,122 +2209,6 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
}
- /*
- using namespace llvm::VPlanPatternMatch;
- auto GetMulAccReductionCost =
- [&](const VPReductionRecipe *Red) -> InstructionCost {
- VPValue *A, *B;
- InstructionCost InnerExt0Cost = 0;
- InstructionCost InnerExt1Cost = 0;
- InstructionCost ExtCost = 0;
- InstructionCost MulCost = 0;
-
- VectorType *SrcVecTy = VectorTy;
- Type *InnerExt0Ty;
- Type *InnerExt1Ty;
- Type *MaxInnerExtTy;
- bool IsUnsigned = true;
- bool HasOuterExt = false;
-
- auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
- Red->getVecOp()->getDefiningRecipe());
- VPRecipeBase *Mul;
- // Try to match outer extend reduce.add(ext(...))
- if (Ext && match(Ext, m_ZExtOrSExt(m_VPValue())) &&
- cast<VPWidenCastRecipe>(Ext)->getNumUsers() == 1) {
- IsUnsigned =
- Ext->getOpcode() == Instruction::CastOps::ZExt ? true : false;
- ExtCost = Ext->computeCost(VF, Ctx);
- Mul = Ext->getOperand(0)->getDefiningRecipe();
- HasOuterExt = true;
- } else {
- Mul = Red->getVecOp()->getDefiningRecipe();
- }
-
- // Match reduce.add(mul())
- if (Mul && match(Mul, m_Mul(m_VPValue(A), m_VPValue(B))) &&
- cast<VPWidenRecipe>(Mul)->getNumUsers() == 1) {
- MulCost = cast<VPWidenRecipe>(Mul)->computeCost(VF, Ctx);
- auto *InnerExt0 =
- dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
- auto *InnerExt1 =
- dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
- bool HasInnerExt = false;
- // Try to match inner extends.
- if (InnerExt0 && InnerExt1 &&
- match(InnerExt0, m_ZExtOrSExt(m_VPValue())) &&
- match(InnerExt1, m_ZExtOrSExt(m_VPValue())) &&
- InnerExt0->getOpcode() == InnerExt1->getOpcode() &&
- (InnerExt0->getNumUsers() > 0 &&
- !InnerExt0->hasMoreThanOneUniqueUser()) &&
- (InnerExt1->getNumUsers() > 0 &&
- !InnerExt1->hasMoreThanOneUniqueUser())) {
- InnerExt0Cost = InnerExt0->computeCost(VF, Ctx);
- InnerExt1Cost = InnerExt1->computeCost(VF, Ctx);
- Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
- Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
- Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth() >
- InnerExt1Ty->getIntegerBitWidth()
- ? InnerExt0Ty
- : InnerExt1Ty;
- SrcVecTy = cast<VectorType>(ToVectorTy(MaxInnerExtTy, VF));
- IsUnsigned = true;
- HasInnerExt = true;
- }
- InstructionCost MulAccRedCost = Ctx.TTI.getMulAccReductionCost(
- IsUnsigned, ElementTy, SrcVecTy, CostKind);
- // Check if folding ext/mul into MulAccReduction is profitable.
- if (MulAccRedCost.isValid() &&
- MulAccRedCost <
- ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
- if (HasInnerExt) {
- Ctx.FoldedRecipes[VF].insert(InnerExt0);
- Ctx.FoldedRecipes[VF].insert(InnerExt1);
- }
- Ctx.FoldedRecipes[VF].insert(Mul);
- if (HasOuterExt)
- Ctx.FoldedRecipes[VF].insert(Ext);
- return MulAccRedCost;
- }
- }
- return InstructionCost::getInvalid();
- };
-
- // Match reduce(ext(...))
- auto GetExtendedReductionCost =
- [&](const VPReductionRecipe *Red) -> InstructionCost {
- VPValue *VecOp = Red->getVecOp();
- VPValue *A;
- if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && VecOp->getNumUsers() == 1) {
- VPWidenCastRecipe *Ext =
- cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
- bool IsUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
- InstructionCost ExtCost = Ext->computeCost(VF, Ctx);
- auto *ExtVecTy =
- cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(A), VF));
- InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
- Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags(),
- CostKind);
- // Check if folding ext into ExtendedReduction is profitable.
- if (ExtendedRedCost.isValid() && ExtendedRedCost < ExtCost + BaseCost) {
- Ctx.FoldedRecipes[VF].insert(Ext);
- return ExtendedRedCost;
- }
- }
- return InstructionCost::getInvalid();
- };
-
- // Match MulAccReduction patterns.
- InstructionCost MulAccCost = GetMulAccReductionCost(this);
- if (MulAccCost.isValid())
- return MulAccCost;
-
- // Match ExtendedReduction patterns.
- InstructionCost ExtendedCost = GetExtendedReductionCost(this);
- if (ExtendedCost.isValid())
- return ExtendedCost;
- */
-
// Default cost.
return BaseCost;
}
@@ -2338,9 +2222,6 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
unsigned Opcode = RdxDesc.getOpcode();
- assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
- "Inferred type and recurrence type mismatch.");
-
// BaseCost = Reduction cost + BinOp cost
InstructionCost ReductionCost =
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
@@ -2382,8 +2263,8 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
unsigned Opcode = RdxDesc.getOpcode();
- assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
- "Inferred type and recurrence type mismatch.");
+ assert(Opcode == Instruction::Add &&
+ "Reduction opcode must be add in the VPMulAccRecipe.");
// BaseCost = Reduction cost + BinOp cost
InstructionCost ReductionCost =
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index b7cc945747df13..e52991acd20562 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -545,11 +545,15 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
Op0 = new VPWidenCastRecipe(
MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
MulAcc->getResultType(), *MulAcc->getExt0Instr());
- Op1 = new VPWidenCastRecipe(
- MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
- MulAcc->getResultType(), *MulAcc->getExt1Instr());
Op0->getDefiningRecipe()->insertBefore(MulAcc);
- Op1->getDefiningRecipe()->insertBefore(MulAcc);
+ if (!MulAcc->isSameExtend()) {
+ Op1 = new VPWidenCastRecipe(
+ MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+ MulAcc->getResultType(), *MulAcc->getExt1Instr());
+ Op1->getDefiningRecipe()->insertBefore(MulAcc);
+ } else {
+ Op1 = Op0;
+ }
} else {
Op0 = MulAcc->getVecOp0();
Op1 = MulAcc->getVecOp1();
@@ -559,14 +563,13 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
auto *Mul = new VPWidenRecipe(*MulInstr,
make_range(MulOps.begin(), MulOps.end()));
- if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
- // dbgs() <<"\n!!!"<< *OuterExtInstr << " " << MulAcc->getExtOpcode()
- // << "\n";
+ // Outer extend.
+ if (auto *OuterExtInstr = MulAcc->getExtInstr())
VecOp = new VPWidenCastRecipe(
MulAcc->getExtOpcode(), Mul,
MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
*OuterExtInstr);
- } else
+ else
VecOp = Mul;
auto *Red = new VPReductionRecipe(
MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index 93a030c2bf1556..f1d29646617b7e 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1412,9 +1412,8 @@ define i32 @mla_i8_i32_multiuse(ptr nocapture readonly %x, ptr nocapture readonl
; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT: [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP7:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
-; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP7]]
+; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP7]], [[TMP7]]
; CHECK-NEXT: [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP2]], <16 x i32> zeroinitializer
; CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP3]])
; CHECK-NEXT: [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
index 9f1a61ebb5efef..9af31c9a4762b3 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
@@ -215,13 +215,13 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
; IF-EVL-INLOOP-NEXT: [[TMP12]] = add i32 [[TMP11]], [[VEC_PHI]]
; IF-EVL-INLOOP-NEXT: [[INDEX_EVL_NEXT]] = add i32 [[TMP6]], [[EVL_BASED_IV]]
; IF-EVL-INLOOP-NEXT: [[INDEX_NEXT]] = add i32 [[INDEX]], [[TMP4]]
-; IF-EVL-INLOOP-NEXT: [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; IF-EVL-INLOOP-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; IF-EVL-INLOOP-NEXT: [[TMP19:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
+; IF-EVL-INLOOP-NEXT: br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; IF-EVL-INLOOP: middle.block:
; IF-EVL-INLOOP-NEXT: br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
; IF-EVL-INLOOP: scalar.ph:
; IF-EVL-INLOOP-NEXT: [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[FOR_BODY_PREHEADER]] ]
-; IF-EVL-INLOOP-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP11]], [[MIDDLE_BLOCK]] ], [ 0, [[FOR_BODY_PREHEADER]] ]
+; IF-EVL-INLOOP-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP12]], [[MIDDLE_BLOCK]] ], [ 0, [[FOR_BODY_PREHEADER]] ]
; IF-EVL-INLOOP-NEXT: br label [[FOR_BODY:%.*]]
; IF-EVL-INLOOP: for.body:
; IF-EVL-INLOOP-NEXT: [[I_08:%.*]] = phi i32 [ [[INC:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
@@ -234,7 +234,7 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
; IF-EVL-INLOOP-NEXT: [[EXITCOND:%.*]] = icmp eq i32 [[INC]], [[N]]
; IF-EVL-INLOOP-NEXT: br i1 [[EXITCOND]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
; IF-EVL-INLOOP: for.cond.cleanup.loopexit:
-; IF-EVL-INLOOP-NEXT: [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD]], [[FOR_BODY]] ], [ [[TMP11]], [[MIDDLE_BLOCK]] ]
+; IF-EVL-INLOOP-NEXT: [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD]], [[FOR_BODY]] ], [ [[TMP12]], [[MIDDLE_BLOCK]] ]
; IF-EVL-INLOOP-NEXT: br label [[FOR_COND_CLEANUP]]
; IF-EVL-INLOOP: for.cond.cleanup:
; IF-EVL-INLOOP-NEXT: [[R_0_LCSSA:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[ADD_LCSSA]], [[FOR_COND_CLEANUP_LOOPEXIT]] ]
>From 432b172f2c145937b0a1bf522fcf6fc54f114355 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Wed, 6 Nov 2024 21:07:04 -0800
Subject: [PATCH 06/14] Fix typos and update printing test
---
.../Transforms/Vectorize/LoopVectorize.cpp | 3 +-
llvm/lib/Transforms/Vectorize/VPlan.h | 7 +-
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 18 +-
.../Transforms/Vectorize/VPlanTransforms.h | 2 +-
.../LoopVectorize/vplan-printing.ll | 236 ++++++++++++++++++
5 files changed, 254 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index ee429f9d3fb97d..b50f8d9c4cf438 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7675,7 +7675,7 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
ILV.getOrCreateVectorTripCount(nullptr),
CanonicalIVStartValue, State);
- // TODO: Rebase to fhahn's implementation.
+ // TODO: Replace with upstream implementation.
VPlanTransforms::prepareExecute(BestVPlan);
BestVPlan.execute(&State);
@@ -9271,7 +9271,6 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
// Adjust AnyOf reductions; replace the reduction phi for the selected value
// with a boolean reduction phi node to check if the condition is true in any
// iteration. The final value is selected by the final ComputeReductionResult.
-// TODO: Implement VPMulAccHere.
void LoopVectorizationPlanner::adjustRecipesForReductions(
VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) {
using namespace VPlanPatternMatch;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c7117551bf5dda..d48e75fe9fb64c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2757,8 +2757,8 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
/// A recipe to represent inloop MulAccreduction operations, performing a
/// reduction on a vector operand into a scalar value, and adding the result to
/// a chain. This recipe is high level abstract which will generate
-/// VPReductionRecipe VPWidenRecipe(mul)and VPWidenCastRecipe before execution.
-/// The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+/// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
+/// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
class VPMulAccRecipe : public VPSingleDefRecipe {
/// The recurrence decriptor for the reduction in question.
const RecurrenceDescriptor &RdxDesc;
@@ -2778,6 +2778,8 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
/// Is this MulAcc recipe contains extend recipes?
bool IsExtended;
+ /// Is this reciep contains outer extend instuction?
+ bool IsOuterExtended = false;
protected:
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2797,6 +2799,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
addOperand(CondOp);
}
IsExtended = true;
+ IsOuterExtended = ExtInstr != nullptr;
}
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ab40bed209167d..8a2ecd46ab672d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,7 +270,7 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
UI = &WidenMem->getIngredient();
InstructionCost RecipeCost;
- if ((UI && Ctx.skipCostComputation(UI, VF.isVector()))) {
+ if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
RecipeCost = 0;
} else {
RecipeCost = computeCost(VF, Ctx);
@@ -2408,18 +2408,20 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
printAsOperand(O, SlotTracker);
O << " = ";
getChainOp()->printAsOperand(O, SlotTracker);
- O << " +";
+ O << " + ";
if (isa<FPMathOperator>(getUnderlyingInstr()))
O << getUnderlyingInstr()->getFastMathFlags();
- O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
- O << " mul ";
+ if (IsOuterExtended)
+ O << " (";
+ O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+ O << "mul ";
if (IsExtended)
O << "(";
getVecOp0()->printAsOperand(O, SlotTracker);
if (IsExtended)
- O << " extended to " << *getResultType() << ")";
- if (IsExtended)
- O << "(";
+ O << " extended to " << *getResultType() << "), (";
+ else
+ O << ", ";
getVecOp1()->printAsOperand(O, SlotTracker);
if (IsExtended)
O << " extended to " << *getResultType() << ")";
@@ -2428,6 +2430,8 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
getCondOp()->printAsOperand(O, SlotTracker);
}
O << ")";
+ if (IsOuterExtended)
+ O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
if (RdxDesc.IntermediateStore)
O << " (with final reduction value stored in invariant address sank "
"outside of loop)";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 6310c23b605da3..40ed91104d566e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -124,7 +124,7 @@ struct VPlanTransforms {
/// Remove dead recipes from \p Plan.
static void removeDeadRecipes(VPlan &Plan);
- /// TODO: Rebase to fhahn's implementation.
+ /// TODO: Rebase to upstream implementation.
static void prepareExecute(VPlan &Plan);
};
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 6bb20a301e0ade..a8fb374a4b162d 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -1143,6 +1143,242 @@ exit:
ret i16 %for.1
}
+define i64 @print_extended_reduction(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_extended_reduction'
+; CHECK: VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT: vector.body:
+; CHECK-NEXT: EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT: vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT: CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT: vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT: WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT: EXTENDED-REDUCE ir<%add> = ir<%r.09> + reduce.add (ir<%load0> extended to i64)
+; CHECK-NEXT: EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT: EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT: EMIT vp<%6> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT: EMIT vp<%7> = extract-from-end vp<%6>, ir<1>
+; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT: IR %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%7>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT: EMIT vp<%bc.merge.rdx> = resume-phi vp<%6>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT: IR %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT: IR %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT: IR %arrayidx = getelementptr inbounds i32, ptr %x, i32 %i.010
+; CHECK-NEXT: IR %load0 = load i32, ptr %arrayidx, align 4
+; CHECK-NEXT: IR %conv0 = zext i32 %load0 to i64
+; CHECK-NEXT: IR %add = add nsw i64 %r.09, %conv0
+; CHECK-NEXT: IR %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT: IR %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+ %cmp8 = icmp sgt i32 %n, 0
+ br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body: ; preds = %entry, %for.body
+ %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+ %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+ %arrayidx = getelementptr inbounds i32, ptr %x, i32 %i.010
+ %load0 = load i32, ptr %arrayidx, align 4
+ %conv0 = zext i32 %load0 to i64
+ %add = add nsw i64 %r.09, %conv0
+ %inc = add nuw nsw i32 %i.010, 1
+ %exitcond = icmp eq i32 %inc, %n
+ br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup: ; preds = %for.body, %entry
+ %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+ ret i64 %r.0.lcssa
+}
+
+define i64 @print_mulacc(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_mulacc'
+; CHECK: VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT: vector.body:
+; CHECK-NEXT: EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT: vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT: CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT: vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT: WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT: CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
+; CHECK-NEXT: vp<%5> = vector-pointer ir<%arrayidx1>
+; CHECK-NEXT: WIDEN ir<%load1> = load vp<%5>
+; CHECK-NEXT: MULACC-REDUCE ir<%add> = ir<%r.09> + reduce.add (mul ir<%load0>, ir<%load1>)
+; CHECK-NEXT: EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT: EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT: EMIT vp<%7> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT: EMIT vp<%8> = extract-from-end vp<%7>, ir<1>
+; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT: IR %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT: EMIT vp<%bc.merge.rdx> = resume-phi vp<%7>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT: IR %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT: IR %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT: IR %arrayidx = getelementptr inbounds i64, ptr %x, i32 %i.010
+; CHECK-NEXT: IR %load0 = load i64, ptr %arrayidx, align 4
+; CHECK-NEXT: IR %arrayidx1 = getelementptr inbounds i64, ptr %y, i32 %i.010
+; CHECK-NEXT: IR %load1 = load i64, ptr %arrayidx1, align 4
+; CHECK-NEXT: IR %mul = mul nsw i64 %load0, %load1
+; CHECK-NEXT: IR %add = add nsw i64 %r.09, %mul
+; CHECK-NEXT: IR %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT: IR %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+ %cmp8 = icmp sgt i32 %n, 0
+ br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body: ; preds = %entry, %for.body
+ %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+ %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+ %arrayidx = getelementptr inbounds i64, ptr %x, i32 %i.010
+ %load0 = load i64, ptr %arrayidx, align 4
+ %arrayidx1 = getelementptr inbounds i64, ptr %y, i32 %i.010
+ %load1 = load i64, ptr %arrayidx1, align 4
+ %mul = mul nsw i64 %load0, %load1
+ %add = add nsw i64 %r.09, %mul
+ %inc = add nuw nsw i32 %i.010, 1
+ %exitcond = icmp eq i32 %inc, %n
+ br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup: ; preds = %for.body, %entry
+ %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+ ret i64 %r.0.lcssa
+}
+
+define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_mulacc_extended'
+; CHECK: VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT: vector.body:
+; CHECK-NEXT: EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT: vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT: CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT: vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT: WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT: CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
+; CHECK-NEXT: vp<%5> = vector-pointer ir<%arrayidx1>
+; CHECK-NEXT: WIDEN ir<%load1> = load vp<%5>
+; CHECK-NEXT: MULACC-REDUCE ir<%add> = ir<%r.09> + (reduce.add (mul (ir<%load0> extended to i32), (ir<%load1> extended to i32)) extended to i64)
+; CHECK-NEXT: EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT: EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT: EMIT vp<%7> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT: EMIT vp<%8> = extract-from-end vp<%7>, ir<1>
+; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT: IR %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT: EMIT vp<%bc.merge.rdx> = resume-phi vp<%7>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT: IR %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT: IR %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT: IR %arrayidx = getelementptr inbounds i16, ptr %x, i32 %i.010
+; CHECK-NEXT: IR %load0 = load i16, ptr %arrayidx, align 4
+; CHECK-NEXT: IR %arrayidx1 = getelementptr inbounds i16, ptr %y, i32 %i.010
+; CHECK-NEXT: IR %load1 = load i16, ptr %arrayidx1, align 4
+; CHECK-NEXT: IR %conv0 = sext i16 %load0 to i32
+; CHECK-NEXT: IR %conv1 = sext i16 %load1 to i32
+; CHECK-NEXT: IR %mul = mul nsw i32 %conv0, %conv1
+; CHECK-NEXT: IR %conv = sext i32 %mul to i64
+; CHECK-NEXT: IR %add = add nsw i64 %r.09, %conv
+; CHECK-NEXT: IR %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT: IR %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+ %cmp8 = icmp sgt i32 %n, 0
+ br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body: ; preds = %entry, %for.body
+ %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+ %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+ %arrayidx = getelementptr inbounds i16, ptr %x, i32 %i.010
+ %load0 = load i16, ptr %arrayidx, align 4
+ %arrayidx1 = getelementptr inbounds i16, ptr %y, i32 %i.010
+ %load1 = load i16, ptr %arrayidx1, align 4
+ %conv0 = sext i16 %load0 to i32
+ %conv1 = sext i16 %load1 to i32
+ %mul = mul nsw i32 %conv0, %conv1
+ %conv = sext i32 %mul to i64
+ %add = add nsw i64 %r.09, %conv
+ %inc = add nuw nsw i32 %i.010, 1
+ %exitcond = icmp eq i32 %inc, %n
+ br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup: ; preds = %for.body, %entry
+ %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+ ret i64 %r.0.lcssa
+}
+
!llvm.dbg.cu = !{!0}
!llvm.module.flags = !{!3, !4}
>From 500f16ecf3f4982354a584a9a7de3537a9564bbd Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 10 Nov 2024 18:00:52 -0800
Subject: [PATCH 07/14] Fold reduce.add(zext(mul(sext(A), sext(B)))) into
MulAccRecipe when A == B
For the future refactor of avoiding reference underlying instructions
and mismatched opcode and the entend instruction in the new added
pattern, removed passing UI when creating VPWidenCastRecipe.
This removed will lead to dupicate extend instruction created after loop
vectorizer when there are two reduction patterns exist in the same loop.
This redundant instruction might be removed after LV.
---
.../Transforms/Vectorize/LoopVectorize.cpp | 31 ++++++++-----------
.../Transforms/Vectorize/VPlanTransforms.cpp | 15 +++++----
.../LoopVectorize/ARM/mve-reductions.ll | 3 +-
.../LoopVectorize/reduction-inloop.ll | 6 ++--
4 files changed, 24 insertions(+), 31 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index b50f8d9c4cf438..999b4157142a0e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9396,20 +9396,18 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
// reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
!VecOp->hasMoreThanOneUniqueUser()) {
- VPRecipeBase *RecipeA = A->getDefiningRecipe();
- VPRecipeBase *RecipeB = B->getDefiningRecipe();
+ VPWidenCastRecipe *RecipeA =
+ dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+ VPWidenCastRecipe *RecipeB =
+ dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
- cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
- cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
- !A->hasMoreThanOneUniqueUser() &&
- !B->hasMoreThanOneUniqueUser()) {
+ (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
return new VPMulAccRecipe(
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
CM.useOrderedReductions(RdxDesc),
- cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
- cast<VPWidenCastRecipe>(RecipeA),
- cast<VPWidenCastRecipe>(RecipeB));
+ cast<VPWidenRecipe>(VecOp->getDefiningRecipe()), RecipeA,
+ RecipeB);
} else {
// Matched reduce.add(mul(...))
return new VPMulAccRecipe(
@@ -9417,8 +9415,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
CM.useOrderedReductions(RdxDesc),
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
}
- // Matched reduce.add(ext(mul(ext, ext)))
- // Note that 3 extend instructions must have same opcode.
+ // Matched reduce.add(ext(mul(ext(A), ext(B))))
+ // Note that 3 extend instructions must have same opcode or A == B
+ // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
} else if (match(VecOp,
m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
m_ZExtOrSExt(m_VPValue())))) &&
@@ -9431,11 +9430,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
VPWidenCastRecipe *Ext1 =
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
- if (Ext->getOpcode() == Ext0->getOpcode() &&
- Ext0->getOpcode() == Ext1->getOpcode() &&
- !Mul->hasMoreThanOneUniqueUser() &&
- !Ext0->hasMoreThanOneUniqueUser() &&
- !Ext1->hasMoreThanOneUniqueUser()) {
+ if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
+ Ext0->getOpcode() == Ext1->getOpcode()) {
return new VPMulAccRecipe(
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
CM.useOrderedReductions(RdxDesc),
@@ -9447,8 +9443,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
};
auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
VPValue *A;
- if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
- !VecOp->hasMoreThanOneUniqueUser()) {
+ if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
return new VPExtendedReductionRecipe(
RdxDesc, CurrentLinkI, PreviousLink,
cast<VPWidenCastRecipe>(VecOp), CondOp,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index e52991acd20562..af287577fcb879 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -542,14 +542,14 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
auto *MulAcc = cast<VPMulAccRecipe>(&R);
VPValue *Op0, *Op1;
if (MulAcc->isExtended()) {
- Op0 = new VPWidenCastRecipe(
- MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
- MulAcc->getResultType(), *MulAcc->getExt0Instr());
+ Op0 =
+ new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+ MulAcc->getResultType());
Op0->getDefiningRecipe()->insertBefore(MulAcc);
if (!MulAcc->isSameExtend()) {
- Op1 = new VPWidenCastRecipe(
- MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
- MulAcc->getResultType(), *MulAcc->getExt1Instr());
+ Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+ MulAcc->getVecOp1(),
+ MulAcc->getResultType());
Op1->getDefiningRecipe()->insertBefore(MulAcc);
} else {
Op1 = Op0;
@@ -567,8 +567,7 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
if (auto *OuterExtInstr = MulAcc->getExtInstr())
VecOp = new VPWidenCastRecipe(
MulAcc->getExtOpcode(), Mul,
- MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
- *OuterExtInstr);
+ MulAcc->getRecurrenceDescriptor().getRecurrenceType());
else
VecOp = Mul;
auto *Red = new VPReductionRecipe(
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index f1d29646617b7e..6a48c330775972 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1535,7 +1535,8 @@ define i64 @mla_and_add_together_16_64(ptr nocapture noundef readonly %x, i32 no
; CHECK-NEXT: [[TMP3:%.*]] = zext nneg <8 x i32> [[TMP2]] to <8 x i64>
; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP3]])
; CHECK-NEXT: [[TMP5]] = add i64 [[TMP4]], [[VEC_PHI]]
-; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP1]])
+; CHECK-NEXT: [[TMP10:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP10]])
; CHECK-NEXT: [[TMP7]] = add i32 [[TMP6]], [[VEC_PHI1]]
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
; CHECK-NEXT: [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index b578e61d85dfa1..d6dbb74f26d4ab 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -1206,15 +1206,13 @@ define i32 @predicated_not_dominates_reduction_twoadd(ptr nocapture noundef read
; CHECK-NEXT: [[TMP0:%.*]] = sext i32 [[INDEX]] to i64
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[H:%.*]], i64 [[TMP0]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
-; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq <4 x i8> [[WIDE_LOAD]], zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = udiv <4 x i8> [[WIDE_LOAD]], splat (i8 31)
; CHECK-NEXT: [[TMP3:%.*]] = shl nuw nsw <4 x i8> [[TMP2]], splat (i8 3)
; CHECK-NEXT: [[TMP4:%.*]] = udiv <4 x i8> [[TMP3]], splat (i8 31)
; CHECK-NEXT: [[TMP5:%.*]] = zext nneg <4 x i8> [[TMP4]] to <4 x i32>
-; CHECK-NEXT: [[TMP6:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i32> zeroinitializer, <4 x i32> [[TMP5]]
-; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
+; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP5]])
; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[TMP7]], [[VEC_PHI]]
-; CHECK-NEXT: [[TMP9:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i32> zeroinitializer, <4 x i32> [[TMP5]]
+; CHECK-NEXT: [[TMP9:%.*]] = zext nneg <4 x i8> [[TMP4]] to <4 x i32>
; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP9]])
; CHECK-NEXT: [[TMP11]] = add i32 [[TMP10]], [[TMP8]]
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
>From 91e1f4518bf76ae4b88f56dfc645400e9eeba586 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 10 Nov 2024 23:34:05 -0800
Subject: [PATCH 08/14] Refactor! Reuse functions from VPReductionRecipe.
---
llvm/lib/Transforms/Vectorize/VPlan.h | 114 +++++-------------
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 6 +-
.../Transforms/Vectorize/VPlanTransforms.cpp | 20 +--
3 files changed, 47 insertions(+), 93 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index d48e75fe9fb64c..bb42d840ded90f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -682,8 +682,6 @@ struct VPCostContext {
LLVMContext &LLVMCtx;
LoopVectorizationCostModel &CM;
SmallPtrSet<Instruction *, 8> SkipCostComputation;
- /// Contains recipes that are folded into other recipes.
- SmallDenseMap<ElementCount, SmallPtrSet<VPRecipeBase *, 4>, 4> FoldedRecipes;
VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,
Type *CanIVTy, LoopVectorizationCostModel &CM)
@@ -2575,6 +2573,8 @@ class VPReductionRecipe : public VPSingleDefRecipe {
getVecOp(), getCondOp(), IsOrdered);
}
+ // TODO: Support VPExtendedReductionRecipe and VPMulAccRecipe after EVL
+ // support.
static inline bool classof(const VPRecipeBase *R) {
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
@@ -2662,33 +2662,20 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
/// a chain. This recipe is high level abstract which will generate
/// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
/// {ChainOp, VecOp, [Condition]}.
-class VPExtendedReductionRecipe : public VPSingleDefRecipe {
- /// The recurrence decriptor for the reduction in question.
- const RecurrenceDescriptor &RdxDesc;
- bool IsOrdered;
- /// Whether the reduction is conditional.
- bool IsConditional = false;
+class VPExtendedReductionRecipe : public VPReductionRecipe {
/// Type after extend.
Type *ResultTy;
- /// Opcode for the extend instruction.
- Instruction::CastOps ExtOp;
CastInst *ExtInstr;
- bool IsZExt;
protected:
VPExtendedReductionRecipe(const unsigned char SC,
const RecurrenceDescriptor &R, Instruction *RedI,
- Instruction::CastOps ExtOp, CastInst *ExtI,
- ArrayRef<VPValue *> Operands, VPValue *CondOp,
+ Instruction::CastOps ExtOp, CastInst *ExtInstr,
+ VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
bool IsOrdered, Type *ResultTy)
- : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
- ResultTy(ResultTy), ExtOp(ExtOp), ExtInstr(ExtI) {
- if (CondOp) {
- IsConditional = true;
- addOperand(CondOp);
- }
- IsZExt = ExtOp == Instruction::CastOps::ZExt;
- }
+ : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
+ CondOp, IsOrdered),
+ ResultTy(ResultTy), ExtInstr(ExtInstr) {}
public:
VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
@@ -2696,9 +2683,8 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
VPValue *CondOp, bool IsOrdered)
: VPExtendedReductionRecipe(
VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
- cast<CastInst>(Ext->getUnderlyingInstr()),
- ArrayRef<VPValue *>({ChainOp, Ext->getOperand(0)}), CondOp,
- IsOrdered, Ext->getResultType()) {}
+ cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
+ Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
~VPExtendedReductionRecipe() override = default;
@@ -2730,26 +2716,11 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
VPSlotTracker &SlotTracker) const override;
#endif
- /// Return the recurrence decriptor for the in-loop reduction.
- const RecurrenceDescriptor &getRecurrenceDescriptor() const {
- return RdxDesc;
- }
- /// Return true if the in-loop reduction is ordered.
- bool isOrdered() const { return IsOrdered; };
- /// Return true if the in-loop reduction is conditional.
- bool isConditional() const { return IsConditional; };
- /// The VPValue of the scalar Chain being accumulated.
- VPValue *getChainOp() const { return getOperand(0); }
- /// The VPValue of the vector value to be extended and reduced.
- VPValue *getVecOp() const { return getOperand(1); }
- /// The VPValue of the condition for the block.
- VPValue *getCondOp() const {
- return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
- }
/// The Type after extended.
Type *getResultType() const { return ResultTy; };
+ bool isZExt() const { return getExtOpcode() == Instruction::ZExt; };
/// The Opcode of extend instruction.
- Instruction::CastOps getExtOpcode() const { return ExtOp; };
+ Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); };
/// The CastInst of the extend instruction.
CastInst *getExtInstr() const { return ExtInstr; };
};
@@ -2759,12 +2730,7 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
/// a chain. This recipe is high level abstract which will generate
/// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
/// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
-class VPMulAccRecipe : public VPSingleDefRecipe {
- /// The recurrence decriptor for the reduction in question.
- const RecurrenceDescriptor &RdxDesc;
- bool IsOrdered;
- /// Whether the reduction is conditional.
- bool IsConditional = false;
+class VPMulAccRecipe : public VPReductionRecipe {
/// Type after extend.
Type *ResultType;
// Note that all extend instruction must have the same opcode in MulAcc.
@@ -2786,32 +2752,29 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
Instruction *RedI, Instruction *ExtInstr,
Instruction *MulInstr, Instruction::CastOps ExtOp,
Instruction *Ext0Instr, Instruction *Ext1Instr,
- ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered,
- Type *ResultType)
- : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+ VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
+ VPValue *CondOp, bool IsOrdered, Type *ResultType)
+ : VPReductionRecipe(SC, R, RedI,
+ ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
+ CondOp, IsOrdered),
ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
ExtInstr(cast_if_present<CastInst>(ExtInstr)),
Ext0Instr(cast<CastInst>(Ext0Instr)),
Ext1Instr(cast<CastInst>(Ext1Instr)) {
assert(MulInstr->getOpcode() == Instruction::Mul);
- if (CondOp) {
- IsConditional = true;
- addOperand(CondOp);
- }
IsExtended = true;
IsOuterExtended = ExtInstr != nullptr;
}
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
- Instruction *RedI, Instruction *MulInstr,
- ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
- : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+ Instruction *RedI, Instruction *MulInstr, VPValue *ChainOp,
+ VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp,
+ bool IsOrdered)
+ : VPReductionRecipe(SC, R, RedI,
+ ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
+ CondOp, IsOrdered),
MulInstr(MulInstr) {
assert(MulInstr->getOpcode() == Instruction::Mul);
- if (CondOp) {
- IsConditional = true;
- addOperand(CondOp);
- }
IsExtended = false;
}
@@ -2823,17 +2786,15 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
Mul->getUnderlyingInstr(), Ext0->getOpcode(),
Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
- ArrayRef<VPValue *>(
- {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+ ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
CondOp, IsOrdered, Ext0->getResultType()) {}
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
VPWidenRecipe *Mul)
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
- ArrayRef<VPValue *>(
- {ChainOp, Mul->getOperand(0), Mul->getOperand(1)}),
- CondOp, IsOrdered) {}
+ ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
+ IsOrdered) {}
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
@@ -2842,8 +2803,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
Mul->getUnderlyingInstr(), Ext0->getOpcode(),
Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
- ArrayRef<VPValue *>(
- {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+ ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
CondOp, IsOrdered, Ext0->getResultType()) {}
~VPMulAccRecipe() override = default;
@@ -2874,37 +2834,27 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
VPSlotTracker &SlotTracker) const override;
#endif
- /// Return the recurrence decriptor for the in-loop reduction.
- const RecurrenceDescriptor &getRecurrenceDescriptor() const {
- return RdxDesc;
- }
- /// Return true if the in-loop reduction is ordered.
- bool isOrdered() const { return IsOrdered; };
- /// Return true if the in-loop reduction is conditional.
- bool isConditional() const { return IsConditional; };
- /// The VPValue of the scalar Chain being accumulated.
- VPValue *getChainOp() const { return getOperand(0); }
/// The VPValue of the vector value to be extended and reduced.
VPValue *getVecOp0() const { return getOperand(1); }
VPValue *getVecOp1() const { return getOperand(2); }
- /// The VPValue of the condition for the block.
- VPValue *getCondOp() const {
- return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
- }
+
/// Return the type after inner extended, which must equal to the type of mul
/// instruction. If the ResultType != recurrenceType, than it must have a
/// extend recipe after mul recipe.
Type *getResultType() const { return ResultType; };
+
/// The opcode of the extend instructions.
Instruction::CastOps getExtOpcode() const { return ExtOp; };
/// The underlying instruction for VPWidenRecipe.
Instruction *getMulInstr() const { return MulInstr; };
+
/// The underlying Instruction for outer VPWidenCastRecipe.
CastInst *getExtInstr() const { return ExtInstr; };
/// The underlying Instruction for inner VPWidenCastRecipe.
CastInst *getExt0Instr() const { return Ext0Instr; };
/// The underlying Instruction for inner VPWidenCastRecipe.
CastInst *getExt1Instr() const { return Ext1Instr; };
+
/// Return if this MulAcc recipe contains extend instructions.
bool isExtended() const { return IsExtended; };
/// Return if the operands of mul instruction come from same extend.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 8a2ecd46ab672d..2ecf4e763c39b4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2216,6 +2216,7 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
InstructionCost
VPExtendedReductionRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
+ const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
RecurKind RdxKind = RdxDesc.getRecurrenceKind();
Type *ElementTy = getResultType();
auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
@@ -2246,7 +2247,7 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
// ExtendedReduction Cost
InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
- Opcode, IsZExt, ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
+ Opcode, isZExt(), ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
// Check if folding ext into ExtendedReduction is profitable.
if (ExtendedRedCost.isValid() &&
ExtendedRedCost < ExtendedCost + ReductionCost) {
@@ -2257,6 +2258,7 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
+ const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
: Ctx.Types.inferScalarType(getVecOp0());
auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
@@ -2382,6 +2384,7 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
+ const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
O << Indent << "EXTENDED-REDUCE ";
printAsOperand(O, SlotTracker);
O << " = ";
@@ -2404,6 +2407,7 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
+ const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
O << Indent << "MULACC-REDUCE ";
printAsOperand(O, SlotTracker);
O << " = ";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index af287577fcb879..55e5788e8f9131 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -525,8 +525,7 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
vp_depth_first_deep(Plan.getEntry()))) {
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
- if (isa<VPExtendedReductionRecipe>(&R)) {
- auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+ if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
auto *Ext = new VPWidenCastRecipe(
ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
*ExtRed->getExtInstr());
@@ -542,14 +541,14 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
auto *MulAcc = cast<VPMulAccRecipe>(&R);
VPValue *Op0, *Op1;
if (MulAcc->isExtended()) {
- Op0 =
- new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
- MulAcc->getResultType());
+ CastInst *Ext0 = MulAcc->getExt0Instr();
+ Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
+ MulAcc->getResultType(), *Ext0);
Op0->getDefiningRecipe()->insertBefore(MulAcc);
if (!MulAcc->isSameExtend()) {
- Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
- MulAcc->getVecOp1(),
- MulAcc->getResultType());
+ CastInst *Ext1 = MulAcc->getExt1Instr();
+ Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
+ MulAcc->getResultType(), *Ext1);
Op1->getDefiningRecipe()->insertBefore(MulAcc);
} else {
Op1 = Op0;
@@ -566,8 +565,9 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
// Outer extend.
if (auto *OuterExtInstr = MulAcc->getExtInstr())
VecOp = new VPWidenCastRecipe(
- MulAcc->getExtOpcode(), Mul,
- MulAcc->getRecurrenceDescriptor().getRecurrenceType());
+ OuterExtInstr->getOpcode(), Mul,
+ MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
+ *OuterExtInstr);
else
VecOp = Mul;
auto *Red = new VPReductionRecipe(
>From 7829bf22469fe95536f8b40c27dedb5d68eeee57 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 11 Nov 2024 23:17:46 -0800
Subject: [PATCH 09/14] Refactor! Add comments and refine new recipes.
---
.../Transforms/Vectorize/LoopVectorize.cpp | 20 +++---
llvm/lib/Transforms/Vectorize/VPlan.h | 72 +++++++++----------
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 28 ++++----
.../Transforms/Vectorize/VPlanTransforms.cpp | 35 ++++++---
4 files changed, 85 insertions(+), 70 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 999b4157142a0e..fef9956068f1fd 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9389,17 +9389,18 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
if (CM.blockNeedsPredicationForAnyReason(BB))
CondOp = RecipeBuilder.getBlockInMask(BB);
- auto TryToMatchMulAcc = [&]() -> VPSingleDefRecipe * {
+ auto TryToMatchMulAcc = [&]() -> VPReductionRecipe * {
VPValue *A, *B;
if (RdxDesc.getOpcode() != Instruction::Add)
return nullptr;
- // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+ // Try to match reduce.add(mul(...))
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
!VecOp->hasMoreThanOneUniqueUser()) {
VPWidenCastRecipe *RecipeA =
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
VPWidenCastRecipe *RecipeB =
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+ // Matched reduce.add(mul(ext, ext))
if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
(RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
@@ -9409,23 +9410,23 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()), RecipeA,
RecipeB);
} else {
- // Matched reduce.add(mul(...))
+ // Matched reduce.add(mul)
return new VPMulAccRecipe(
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
CM.useOrderedReductions(RdxDesc),
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
}
// Matched reduce.add(ext(mul(ext(A), ext(B))))
- // Note that 3 extend instructions must have same opcode or A == B
+ // Note that all extend instructions must have same opcode or A == B
// which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
} else if (match(VecOp,
m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
m_ZExtOrSExt(m_VPValue())))) &&
!VecOp->hasMoreThanOneUniqueUser()) {
VPWidenCastRecipe *Ext =
- dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+ cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
VPWidenRecipe *Mul =
- dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+ cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
VPWidenCastRecipe *Ext0 =
cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
VPWidenCastRecipe *Ext1 =
@@ -9441,8 +9442,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
}
return nullptr;
};
- auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
+
+ auto TryToMatchExtendedReduction = [&]() -> VPReductionRecipe * {
VPValue *A;
+ // Matched reduce(ext)).
if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
return new VPExtendedReductionRecipe(
RdxDesc, CurrentLinkI, PreviousLink,
@@ -9451,7 +9454,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
}
return nullptr;
};
- VPSingleDefRecipe *RedRecipe;
+
+ VPReductionRecipe *RedRecipe;
if (auto *MulAcc = TryToMatchMulAcc())
RedRecipe = MulAcc;
else if (auto *ExtendedRed = TryToMatchExtendedReduction())
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index bb42d840ded90f..cd7d0efe8dda45 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2717,12 +2717,12 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
#endif
/// The Type after extended.
- Type *getResultType() const { return ResultTy; };
- bool isZExt() const { return getExtOpcode() == Instruction::ZExt; };
+ Type *getResultType() const { return ResultTy; }
+ bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
/// The Opcode of extend instruction.
- Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); };
+ Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
/// The CastInst of the extend instruction.
- CastInst *getExtInstr() const { return ExtInstr; };
+ CastInst *getExtInstr() const { return ExtInstr; }
};
/// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2733,8 +2733,6 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
class VPMulAccRecipe : public VPReductionRecipe {
/// Type after extend.
Type *ResultType;
- // Note that all extend instruction must have the same opcode in MulAcc.
- Instruction::CastOps ExtOp;
/// reduce.add(ext(mul(ext0(), ext1())))
Instruction *MulInstr;
@@ -2742,28 +2740,21 @@ class VPMulAccRecipe : public VPReductionRecipe {
CastInst *Ext0Instr = nullptr;
CastInst *Ext1Instr = nullptr;
- /// Is this MulAcc recipe contains extend recipes?
- bool IsExtended;
- /// Is this reciep contains outer extend instuction?
- bool IsOuterExtended = false;
-
protected:
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
Instruction *RedI, Instruction *ExtInstr,
- Instruction *MulInstr, Instruction::CastOps ExtOp,
- Instruction *Ext0Instr, Instruction *Ext1Instr,
- VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
- VPValue *CondOp, bool IsOrdered, Type *ResultType)
+ Instruction *MulInstr, Instruction *Ext0Instr,
+ Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
+ VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
+ Type *ResultType)
: VPReductionRecipe(SC, R, RedI,
ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
CondOp, IsOrdered),
- ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+ ResultType(ResultType), MulInstr(MulInstr),
ExtInstr(cast_if_present<CastInst>(ExtInstr)),
Ext0Instr(cast<CastInst>(Ext0Instr)),
Ext1Instr(cast<CastInst>(Ext1Instr)) {
assert(MulInstr->getOpcode() == Instruction::Mul);
- IsExtended = true;
- IsOuterExtended = ExtInstr != nullptr;
}
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2775,7 +2766,6 @@ class VPMulAccRecipe : public VPReductionRecipe {
CondOp, IsOrdered),
MulInstr(MulInstr) {
assert(MulInstr->getOpcode() == Instruction::Mul);
- IsExtended = false;
}
public:
@@ -2784,10 +2774,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
VPWidenCastRecipe *Ext1)
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
- Mul->getUnderlyingInstr(), Ext0->getOpcode(),
- Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
- ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
- CondOp, IsOrdered, Ext0->getResultType()) {}
+ Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+ Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+ Ext1->getOperand(0), CondOp, IsOrdered,
+ Ext0->getResultType()) {}
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
@@ -2801,10 +2791,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
- Mul->getUnderlyingInstr(), Ext0->getOpcode(),
- Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
- ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
- CondOp, IsOrdered, Ext0->getResultType()) {}
+ Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+ Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+ Ext1->getOperand(0), CondOp, IsOrdered,
+ Ext0->getResultType()) {}
~VPMulAccRecipe() override = default;
@@ -2841,24 +2831,34 @@ class VPMulAccRecipe : public VPReductionRecipe {
/// Return the type after inner extended, which must equal to the type of mul
/// instruction. If the ResultType != recurrenceType, than it must have a
/// extend recipe after mul recipe.
- Type *getResultType() const { return ResultType; };
+ Type *getResultType() const { return ResultType; }
- /// The opcode of the extend instructions.
- Instruction::CastOps getExtOpcode() const { return ExtOp; };
/// The underlying instruction for VPWidenRecipe.
- Instruction *getMulInstr() const { return MulInstr; };
+ Instruction *getMulInstr() const { return MulInstr; }
/// The underlying Instruction for outer VPWidenCastRecipe.
- CastInst *getExtInstr() const { return ExtInstr; };
+ CastInst *getExtInstr() const { return ExtInstr; }
/// The underlying Instruction for inner VPWidenCastRecipe.
- CastInst *getExt0Instr() const { return Ext0Instr; };
+ CastInst *getExt0Instr() const { return Ext0Instr; }
/// The underlying Instruction for inner VPWidenCastRecipe.
- CastInst *getExt1Instr() const { return Ext1Instr; };
+ CastInst *getExt1Instr() const { return Ext1Instr; }
/// Return if this MulAcc recipe contains extend instructions.
- bool isExtended() const { return IsExtended; };
+ bool isExtended() const { return Ext0Instr && Ext1Instr; }
/// Return if the operands of mul instruction come from same extend.
- bool isSameExtend() const { return Ext0Instr == Ext1Instr; };
+ bool isSameExtend() const { return Ext0Instr == Ext1Instr; }
+ /// Return if the MulAcc recipes contains extend after mul.
+ bool isOuterExtended() const { return ExtInstr != nullptr; }
+ /// Return if the extend opcode is ZExt.
+ bool isZExt() const {
+ if (!isExtended())
+ return true;
+ // reduce.add(sext(mul(zext(A), zext(A)))) can be transform to
+ // reduce.add(zext(mul(sext(A), sext(A))))
+ if (ExtInstr && ExtInstr->getOpcode() != Ext0Instr->getOpcode())
+ return true;
+ return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
+ }
};
/// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 2ecf4e763c39b4..fc6b90191efa43 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2259,8 +2259,8 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
- Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
- : Ctx.Types.inferScalarType(getVecOp0());
+ Type *ElementTy = isExtended() ? RdxDesc.getRecurrenceType()
+ : Ctx.Types.inferScalarType(getVecOp0());
auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
unsigned Opcode = RdxDesc.getOpcode();
@@ -2276,20 +2276,19 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
// Extended cost
InstructionCost ExtendedCost = 0;
- if (IsExtended) {
+ if (isExtended()) {
auto *SrcTy = cast<VectorType>(
ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
TTI::CastContextHint CCH0 =
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
- // Arm TTI will use the underlying instruction to determine the cost.
ExtendedCost = Ctx.TTI.getCastInstrCost(
- ExtOp, DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+ Ext0Instr->getOpcode(), DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
dyn_cast_if_present<Instruction>(getExt0Instr()));
TTI::CastContextHint CCH1 =
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
ExtendedCost += Ctx.TTI.getCastInstrCost(
- ExtOp, DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+ Ext1Instr->getOpcode(), DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
dyn_cast_if_present<Instruction>(getExt1Instr()));
}
@@ -2297,7 +2296,7 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
InstructionCost MulCost;
SmallVector<const Value *, 4> Operands;
Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
- if (IsExtended)
+ if (isExtended())
MulCost = Ctx.TTI.getArithmeticInstrCost(
Instruction::Mul, VectorTy, CostKind,
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
@@ -2324,9 +2323,8 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
// MulAccReduction Cost
VectorType *SrcVecTy =
cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
- InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
- getExtOpcode() == Instruction::CastOps::ZExt, ElementTy, SrcVecTy,
- CostKind);
+ InstructionCost MulAccCost =
+ Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
// Check if folding ext into ExtendedReduction is profitable.
if (MulAccCost.isValid() &&
@@ -2415,26 +2413,26 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
O << " + ";
if (isa<FPMathOperator>(getUnderlyingInstr()))
O << getUnderlyingInstr()->getFastMathFlags();
- if (IsOuterExtended)
+ if (isOuterExtended())
O << " (";
O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
O << "mul ";
- if (IsExtended)
+ if (isExtended())
O << "(";
getVecOp0()->printAsOperand(O, SlotTracker);
- if (IsExtended)
+ if (isExtended())
O << " extended to " << *getResultType() << "), (";
else
O << ", ";
getVecOp1()->printAsOperand(O, SlotTracker);
- if (IsExtended)
+ if (isExtended())
O << " extended to " << *getResultType() << ")";
if (isConditional()) {
O << ", ";
getCondOp()->printAsOperand(O, SlotTracker);
}
O << ")";
- if (IsOuterExtended)
+ if (isOuterExtended())
O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
if (RdxDesc.IntermediateStore)
O << " (with final reduction value stored in invariant address sank "
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 55e5788e8f9131..bbfb2540cc6e3d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -526,9 +526,12 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
vp_depth_first_deep(Plan.getEntry()))) {
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
+ // Genearte VPWidenCastRecipe.
auto *Ext = new VPWidenCastRecipe(
ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
*ExtRed->getExtInstr());
+
+ // Generate VPreductionRecipe.
auto *Red = new VPReductionRecipe(
ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
ExtRed->getChainOp(), Ext, ExtRed->getCondOp(),
@@ -539,45 +542,55 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
ExtRed->eraseFromParent();
} else if (isa<VPMulAccRecipe>(&R)) {
auto *MulAcc = cast<VPMulAccRecipe>(&R);
+
+ // Generate inner VPWidenCastRecipes if necessary.
VPValue *Op0, *Op1;
if (MulAcc->isExtended()) {
CastInst *Ext0 = MulAcc->getExt0Instr();
Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
MulAcc->getResultType(), *Ext0);
Op0->getDefiningRecipe()->insertBefore(MulAcc);
- if (!MulAcc->isSameExtend()) {
+ // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
+ // VPWidenCastRecipe.
+ if (MulAcc->isSameExtend()) {
+ Op1 = Op0;
+ } else {
CastInst *Ext1 = MulAcc->getExt1Instr();
Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
MulAcc->getResultType(), *Ext1);
Op1->getDefiningRecipe()->insertBefore(MulAcc);
- } else {
- Op1 = Op0;
}
+ // Not contains extend instruction in this MulAccRecipe.
} else {
Op0 = MulAcc->getVecOp0();
Op1 = MulAcc->getVecOp1();
}
+
+ // Generate VPWidenRecipe.
VPSingleDefRecipe *VecOp;
- Instruction *MulInstr = MulAcc->getMulInstr();
SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
- auto *Mul = new VPWidenRecipe(*MulInstr,
+ auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
make_range(MulOps.begin(), MulOps.end()));
- // Outer extend.
- if (auto *OuterExtInstr = MulAcc->getExtInstr())
+ Mul->insertBefore(MulAcc);
+
+ // Generate outer VPWidenCastRecipe if necessary.
+ if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
VecOp = new VPWidenCastRecipe(
OuterExtInstr->getOpcode(), Mul,
MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
*OuterExtInstr);
- else
+ VecOp->insertBefore(MulAcc);
+ } else {
VecOp = Mul;
+ }
+
+ // Generate VPReductionRecipe.
auto *Red = new VPReductionRecipe(
MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
MulAcc->isOrdered());
- Mul->insertBefore(MulAcc);
- if (VecOp != Mul)
- VecOp->insertBefore(MulAcc);
Red->insertBefore(MulAcc);
+
MulAcc->replaceAllUsesWith(Red);
MulAcc->eraseFromParent();
}
>From e4e510802f44c9936f5a892eb16ba2b19651ee15 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 14 Nov 2024 04:23:02 -0800
Subject: [PATCH 10/14] Remove underying instruction dependency.
---
.../Transforms/Vectorize/LoopVectorize.cpp | 24 ++--
llvm/lib/Transforms/Vectorize/VPlan.h | 116 +++++++-----------
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 31 ++---
.../Transforms/Vectorize/VPlanTransforms.cpp | 33 ++---
.../LoopVectorize/ARM/mve-reductions.ll | 89 +++++++++-----
5 files changed, 137 insertions(+), 156 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index fef9956068f1fd..da01dbe1beb99a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7395,18 +7395,18 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
if (Instruction *UI = GetInstructionForCost(&R))
SeenInstrs.insert(UI);
// VPExtendedReductionRecipe contains a folded extend instruction.
- if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
- SeenInstrs.insert(ExtendedRed->getExtInstr());
- // VPMulAccRecipe constians a mul and otional extend instructions.
- else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
- SeenInstrs.insert(MulAcc->getMulInstr());
- if (MulAcc->isExtended()) {
- SeenInstrs.insert(MulAcc->getExt0Instr());
- SeenInstrs.insert(MulAcc->getExt1Instr());
- if (auto *Ext = MulAcc->getExtInstr())
- SeenInstrs.insert(Ext);
- }
- }
+ // if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
+ // SeenInstrs.insert(ExtendedRed->getExtInstr());
+ // // VPMulAccRecipe constians a mul and otional extend instructions.
+ // else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+ // SeenInstrs.insert(MulAcc->getMulInstr());
+ // if (MulAcc->isExtended()) {
+ // SeenInstrs.insert(MulAcc->getExt0Instr());
+ // SeenInstrs.insert(MulAcc->getExt1Instr());
+ // if (auto *Ext = MulAcc->getExtInstr())
+ // SeenInstrs.insert(Ext);
+ // }
+ // }
}
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index cd7d0efe8dda45..bdc1823d8e6f68 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1446,11 +1446,20 @@ class VPWidenRecipe : public VPRecipeWithIRFlags {
iterator_range<IterT> Operands)
: VPRecipeWithIRFlags(VPDefOpcode, Operands, I), Opcode(I.getOpcode()) {}
+ template <typename IterT>
+ VPWidenRecipe(unsigned VPDefOpcode, unsigned InstrOpcode,
+ iterator_range<IterT> Operands)
+ : VPRecipeWithIRFlags(VPDefOpcode, Operands), Opcode(InstrOpcode) {}
+
public:
template <typename IterT>
VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands)
: VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
+ template <typename IterT>
+ VPWidenRecipe(unsigned Opcode, iterator_range<IterT> Operands)
+ : VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands) {}
+
~VPWidenRecipe() override = default;
VPWidenRecipe *clone() override {
@@ -2665,26 +2674,25 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
class VPExtendedReductionRecipe : public VPReductionRecipe {
/// Type after extend.
Type *ResultTy;
- CastInst *ExtInstr;
+ Instruction::CastOps ExtOp;
protected:
VPExtendedReductionRecipe(const unsigned char SC,
const RecurrenceDescriptor &R, Instruction *RedI,
- Instruction::CastOps ExtOp, CastInst *ExtInstr,
- VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
- bool IsOrdered, Type *ResultTy)
+ Instruction::CastOps ExtOp, VPValue *ChainOp,
+ VPValue *VecOp, VPValue *CondOp, bool IsOrdered,
+ Type *ResultTy)
: VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
CondOp, IsOrdered),
- ResultTy(ResultTy), ExtInstr(ExtInstr) {}
+ ResultTy(ResultTy), ExtOp(ExtOp) {}
public:
VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
VPValue *ChainOp, VPWidenCastRecipe *Ext,
VPValue *CondOp, bool IsOrdered)
- : VPExtendedReductionRecipe(
- VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
- cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
- Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
+ : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
+ Ext->getOpcode(), ChainOp, Ext->getOperand(0),
+ CondOp, IsOrdered, Ext->getResultType()) {}
~VPExtendedReductionRecipe() override = default;
@@ -2720,9 +2728,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
Type *getResultType() const { return ResultTy; }
bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
/// The Opcode of extend instruction.
- Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
- /// The CastInst of the extend instruction.
- CastInst *getExtInstr() const { return ExtInstr; }
+ Instruction::CastOps getExtOpcode() const { return ExtOp; }
};
/// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2731,70 +2737,52 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
/// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
/// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
class VPMulAccRecipe : public VPReductionRecipe {
- /// Type after extend.
- Type *ResultType;
-
- /// reduce.add(ext(mul(ext0(), ext1())))
- Instruction *MulInstr;
- CastInst *ExtInstr = nullptr;
- CastInst *Ext0Instr = nullptr;
- CastInst *Ext1Instr = nullptr;
+ Instruction::CastOps ExtOp;
+ bool IsExtended = false;
protected:
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
- Instruction *RedI, Instruction *ExtInstr,
- Instruction *MulInstr, Instruction *Ext0Instr,
- Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
- VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
- Type *ResultType)
+ Instruction *RedI, Instruction::CastOps ExtOp,
+ VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
+ VPValue *CondOp, bool IsOrdered)
: VPReductionRecipe(SC, R, RedI,
ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
CondOp, IsOrdered),
- ResultType(ResultType), MulInstr(MulInstr),
- ExtInstr(cast_if_present<CastInst>(ExtInstr)),
- Ext0Instr(cast<CastInst>(Ext0Instr)),
- Ext1Instr(cast<CastInst>(Ext1Instr)) {
- assert(MulInstr->getOpcode() == Instruction::Mul);
+ ExtOp(ExtOp) {
+ IsExtended = true;
}
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
- Instruction *RedI, Instruction *MulInstr, VPValue *ChainOp,
- VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp,
- bool IsOrdered)
+ Instruction *RedI, VPValue *ChainOp, VPValue *VecOp0,
+ VPValue *VecOp1, VPValue *CondOp, bool IsOrdered)
: VPReductionRecipe(SC, R, RedI,
ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
- CondOp, IsOrdered),
- MulInstr(MulInstr) {
- assert(MulInstr->getOpcode() == Instruction::Mul);
- }
+ CondOp, IsOrdered) {}
public:
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
VPWidenCastRecipe *Ext1)
- : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
- Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
- Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
- Ext1->getOperand(0), CondOp, IsOrdered,
- Ext0->getResultType()) {}
+ : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI,
+ cast<CastInst>(Ext0->getUnderlyingInstr())->getOpcode(),
+ ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
+ CondOp, IsOrdered) {}
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
VPWidenRecipe *Mul)
- : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
- ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
- IsOrdered) {}
+ : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, ChainOp, Mul->getOperand(0),
+ Mul->getOperand(1), CondOp, IsOrdered) {}
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
- : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
- Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
- Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
- Ext1->getOperand(0), CondOp, IsOrdered,
- Ext0->getResultType()) {}
+ : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI,
+ cast<CastInst>(Ext0->getUnderlyingInstr())->getOpcode(),
+ ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
+ CondOp, IsOrdered) {}
~VPMulAccRecipe() override = default;
@@ -2828,37 +2816,17 @@ class VPMulAccRecipe : public VPReductionRecipe {
VPValue *getVecOp0() const { return getOperand(1); }
VPValue *getVecOp1() const { return getOperand(2); }
- /// Return the type after inner extended, which must equal to the type of mul
- /// instruction. If the ResultType != recurrenceType, than it must have a
- /// extend recipe after mul recipe.
- Type *getResultType() const { return ResultType; }
-
- /// The underlying instruction for VPWidenRecipe.
- Instruction *getMulInstr() const { return MulInstr; }
-
- /// The underlying Instruction for outer VPWidenCastRecipe.
- CastInst *getExtInstr() const { return ExtInstr; }
- /// The underlying Instruction for inner VPWidenCastRecipe.
- CastInst *getExt0Instr() const { return Ext0Instr; }
- /// The underlying Instruction for inner VPWidenCastRecipe.
- CastInst *getExt1Instr() const { return Ext1Instr; }
-
/// Return if this MulAcc recipe contains extend instructions.
- bool isExtended() const { return Ext0Instr && Ext1Instr; }
+ bool isExtended() const { return IsExtended; }
/// Return if the operands of mul instruction come from same extend.
- bool isSameExtend() const { return Ext0Instr == Ext1Instr; }
- /// Return if the MulAcc recipes contains extend after mul.
- bool isOuterExtended() const { return ExtInstr != nullptr; }
+ bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
/// Return if the extend opcode is ZExt.
bool isZExt() const {
if (!isExtended())
return true;
- // reduce.add(sext(mul(zext(A), zext(A)))) can be transform to
- // reduce.add(zext(mul(sext(A), sext(A))))
- if (ExtInstr && ExtInstr->getOpcode() != Ext0Instr->getOpcode())
- return true;
- return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
+ return ExtOp == Instruction::CastOps::ZExt;
}
+ Instruction::CastOps getExtOpcode() const { return ExtOp; }
};
/// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index fc6b90191efa43..cd38d4ffe8c87e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2262,6 +2262,8 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
Type *ElementTy = isExtended() ? RdxDesc.getRecurrenceType()
: Ctx.Types.inferScalarType(getVecOp0());
auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+ auto *SrcVecTy =
+ cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
unsigned Opcode = RdxDesc.getOpcode();
@@ -2279,29 +2281,25 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
if (isExtended()) {
auto *SrcTy = cast<VectorType>(
ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
- auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
TTI::CastContextHint CCH0 =
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
- ExtendedCost = Ctx.TTI.getCastInstrCost(
- Ext0Instr->getOpcode(), DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
- dyn_cast_if_present<Instruction>(getExt0Instr()));
+ ExtendedCost = Ctx.TTI.getCastInstrCost(getExtOpcode(), VectorTy, SrcVecTy,
+ CCH0, TTI::TCK_RecipThroughput);
TTI::CastContextHint CCH1 =
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
- ExtendedCost += Ctx.TTI.getCastInstrCost(
- Ext1Instr->getOpcode(), DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
- dyn_cast_if_present<Instruction>(getExt1Instr()));
+ ExtendedCost += Ctx.TTI.getCastInstrCost(getExtOpcode(), VectorTy, SrcVecTy,
+ CCH1, TTI::TCK_RecipThroughput);
}
// Mul cost
InstructionCost MulCost;
SmallVector<const Value *, 4> Operands;
- Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
if (isExtended())
MulCost = Ctx.TTI.getArithmeticInstrCost(
Instruction::Mul, VectorTy, CostKind,
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
- Operands, MulInstr, &Ctx.TLI);
+ Operands, nullptr, &Ctx.TLI);
else {
VPValue *RHS = getVecOp1();
// Certain instructions can be cheaper to vectorize if they have a constant
@@ -2314,15 +2312,15 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
RHS->isDefinedOutsideLoopRegions())
RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
+ Operands.append(
+ {getVecOp0()->getUnderlyingValue(), RHS->getUnderlyingValue()});
MulCost = Ctx.TTI.getArithmeticInstrCost(
Instruction::Mul, VectorTy, CostKind,
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
- RHSInfo, Operands, MulInstr, &Ctx.TLI);
+ RHSInfo, Operands, nullptr, &Ctx.TLI);
}
// MulAccReduction Cost
- VectorType *SrcVecTy =
- cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
InstructionCost MulAccCost =
Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
@@ -2406,6 +2404,7 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+ Type *ElementTy = RdxDesc.getRecurrenceType();
O << Indent << "MULACC-REDUCE ";
printAsOperand(O, SlotTracker);
O << " = ";
@@ -2413,27 +2412,23 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
O << " + ";
if (isa<FPMathOperator>(getUnderlyingInstr()))
O << getUnderlyingInstr()->getFastMathFlags();
- if (isOuterExtended())
- O << " (";
O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
O << "mul ";
if (isExtended())
O << "(";
getVecOp0()->printAsOperand(O, SlotTracker);
if (isExtended())
- O << " extended to " << *getResultType() << "), (";
+ O << " extended to " << *ElementTy << "), (";
else
O << ", ";
getVecOp1()->printAsOperand(O, SlotTracker);
if (isExtended())
- O << " extended to " << *getResultType() << ")";
+ O << " extended to " << *ElementTy << ")";
if (isConditional()) {
O << ", ";
getCondOp()->printAsOperand(O, SlotTracker);
}
O << ")";
- if (isOuterExtended())
- O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
if (RdxDesc.IntermediateStore)
O << " (with final reduction value stored in invariant address sank "
"outside of loop)";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index bbfb2540cc6e3d..813e0edfdb33ed 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -527,9 +527,9 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
// Genearte VPWidenCastRecipe.
- auto *Ext = new VPWidenCastRecipe(
- ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
- *ExtRed->getExtInstr());
+ auto *Ext =
+ new VPWidenCastRecipe(ExtRed->getExtOpcode(), ExtRed->getVecOp(),
+ ExtRed->getResultType());
// Generate VPreductionRecipe.
auto *Red = new VPReductionRecipe(
@@ -542,22 +542,21 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
ExtRed->eraseFromParent();
} else if (isa<VPMulAccRecipe>(&R)) {
auto *MulAcc = cast<VPMulAccRecipe>(&R);
+ Type *RedType = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
// Generate inner VPWidenCastRecipes if necessary.
VPValue *Op0, *Op1;
if (MulAcc->isExtended()) {
- CastInst *Ext0 = MulAcc->getExt0Instr();
- Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
- MulAcc->getResultType(), *Ext0);
+ Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+ MulAcc->getVecOp0(), RedType);
Op0->getDefiningRecipe()->insertBefore(MulAcc);
// Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
// VPWidenCastRecipe.
if (MulAcc->isSameExtend()) {
Op1 = Op0;
} else {
- CastInst *Ext1 = MulAcc->getExt1Instr();
- Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
- MulAcc->getResultType(), *Ext1);
+ Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+ MulAcc->getVecOp1(), RedType);
Op1->getDefiningRecipe()->insertBefore(MulAcc);
}
// Not contains extend instruction in this MulAccRecipe.
@@ -567,27 +566,15 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
}
// Generate VPWidenRecipe.
- VPSingleDefRecipe *VecOp;
SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
- auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
+ auto *Mul = new VPWidenRecipe(Instruction::Mul,
make_range(MulOps.begin(), MulOps.end()));
Mul->insertBefore(MulAcc);
- // Generate outer VPWidenCastRecipe if necessary.
- if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
- VecOp = new VPWidenCastRecipe(
- OuterExtInstr->getOpcode(), Mul,
- MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
- *OuterExtInstr);
- VecOp->insertBefore(MulAcc);
- } else {
- VecOp = Mul;
- }
-
// Generate VPReductionRecipe.
auto *Red = new VPReductionRecipe(
MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
- MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
+ MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
MulAcc->isOrdered());
Red->insertBefore(MulAcc);
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index 6a48c330775972..b6e5d7484fafea 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1670,24 +1670,55 @@ define i64 @test_std_q31(ptr %x, i32 %n) #0 {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[CMP11:%.*]] = icmp sgt i32 [[N:%.*]], 0
; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP11]])
+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp samesign ult i32 [[N]], 4
+; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK: vector.ph:
+; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[N]], 2147483644
; CHECK-NEXT: br label [[FOR_BODY1:%.*]]
+; CHECK: vector.body:
+; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP10]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = ashr <4 x i32> [[WIDE_LOAD]], splat (i32 8)
+; CHECK-NEXT: [[TMP2:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP2]])
+; CHECK-NEXT: [[TMP4]] = add i64 [[TMP3]], [[VEC_PHI]]
+; CHECK-NEXT: [[TMP5:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT: [[TMP6:%.*]] = mul nsw <4 x i64> [[TMP5]], [[TMP5]]
+; CHECK-NEXT: [[TMP7:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP6]])
+; CHECK-NEXT: [[TMP8]] = add i64 [[TMP7]], [[VEC_PHI1]]
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT: [[TMP9:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[FOR_BODY1]], !llvm.loop [[LOOP37:![0-9]+]]
+; CHECK: middle.block:
+; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK: scalar.ph:
+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i64 [ [[TMP4]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX2:%.*]] = phi i64 [ [[TMP8]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT: br label [[FOR_BODY:%.*]]
; CHECK: for.cond.cleanup:
-; CHECK-NEXT: [[DIV:%.*]] = sdiv i64 [[ADD3:%.*]], [[ADD:%.*]]
+; CHECK-NEXT: [[ADD:%.*]] = phi i64 [ [[ADD1:%.*]], [[FOR_BODY]] ], [ [[TMP4]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: [[ADD3:%.*]] = phi i64 [ [[ADD5:%.*]], [[FOR_BODY]] ], [ [[TMP8]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: [[DIV:%.*]] = sdiv i64 [[ADD3]], [[ADD]]
; CHECK-NEXT: ret i64 [[DIV]]
; CHECK: for.body:
-; CHECK-NEXT: [[S_014:%.*]] = phi i64 [ [[ADD]], [[FOR_BODY1]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT: [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT: [[T_012:%.*]] = phi i64 [ [[ADD3]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[I_013]]
+; CHECK-NEXT: [[S_014:%.*]] = phi i64 [ [[ADD1]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
+; CHECK-NEXT: [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
+; CHECK-NEXT: [[T_012:%.*]] = phi i64 [ [[ADD5]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX2]], [[SCALAR_PH]] ]
+; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X]], i32 [[I_013]]
; CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
; CHECK-NEXT: [[SHR:%.*]] = ashr i32 [[TMP0]], 8
; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[SHR]] to i64
-; CHECK-NEXT: [[ADD]] = add nsw i64 [[S_014]], [[CONV]]
+; CHECK-NEXT: [[ADD1]] = add nsw i64 [[S_014]], [[CONV]]
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i64 [[CONV]], [[CONV]]
-; CHECK-NEXT: [[ADD3]] = add nuw nsw i64 [[MUL]], [[T_012]]
+; CHECK-NEXT: [[ADD5]] = add nuw nsw i64 [[MUL]], [[T_012]]
; CHECK-NEXT: [[ADD4]] = add nuw nsw i32 [[I_013]], 1
; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i32 [[ADD4]], [[N]]
-; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY1]]
+; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP38:![0-9]+]]
;
entry:
%cmp11 = icmp sgt i32 %n, 0
@@ -1723,10 +1754,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
; CHECK-NEXT: [[TMP0:%.*]] = add nsw i32 [[N]], -1
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 1
; CHECK-NEXT: [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1
-; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7
+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK: vector.ph:
-; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[TMP2]], -4
+; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[TMP2]], -8
; CHECK-NEXT: [[IND_END:%.*]] = shl i32 [[N_VEC]], 1
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -1734,28 +1765,28 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2
-; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT: [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32>
+; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2
+; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2
-; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT: [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32>
-; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]]
-; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64>
-; CHECK-NEXT: [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32>
-; CHECK-NEXT: [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32>
-; CHECK-NEXT: [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]]
-; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64>
-; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]])
+; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2
+; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT: [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i32>
+; CHECK-NEXT: [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i32>
+; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <8 x i32> [[TMP5]], [[TMP6]]
+; CHECK-NEXT: [[TMP8:%.*]] = sext <8 x i32> [[TMP7]] to <8 x i64>
+; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP8]])
; CHECK-NEXT: [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]]
-; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]])
+; CHECK-NEXT: [[TMP11:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i32>
+; CHECK-NEXT: [[TMP12:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i32>
+; CHECK-NEXT: [[TMP13:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP12]]
+; CHECK-NEXT: [[TMP14:%.*]] = sext <8 x i32> [[TMP13]] to <8 x i64>
+; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP14]])
; CHECK-NEXT: [[TMP16]] = add i64 [[TMP15]], [[TMP10]]
-; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
; CHECK-NEXT: [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT: br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]]
+; CHECK-NEXT: br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP39:![0-9]+]]
; CHECK: middle.block:
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[TMP2]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP]], label [[SCALAR_PH]]
@@ -1790,7 +1821,7 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
; CHECK-NEXT: [[ADD12]] = add nsw i64 [[ADD]], [[CONV11]]
; CHECK-NEXT: [[ADD13]] = add nuw nsw i32 [[I_025]], 2
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[ADD13]], [[N]]
-; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP38:![0-9]+]]
+; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP40:![0-9]+]]
;
entry:
%cmp23 = icmp sgt i32 %n, 0
>From 0a7d4937267e92f2ba92bd1560cb2fc211688cf9 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 14 Nov 2024 05:21:18 -0800
Subject: [PATCH 11/14] Revert "Remove underying instruction dependency."
This reverts commit e4e510802f44c9936f5a892eb16ba2b19651ee15.
We need underlying instruction for accurate TTI query and
the metadata from to attach on the generating vector instructions.
---
.../Transforms/Vectorize/LoopVectorize.cpp | 24 ++--
llvm/lib/Transforms/Vectorize/VPlan.h | 116 +++++++++++-------
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 31 +++--
.../Transforms/Vectorize/VPlanTransforms.cpp | 33 +++--
.../LoopVectorize/ARM/mve-reductions.ll | 89 +++++---------
5 files changed, 156 insertions(+), 137 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index da01dbe1beb99a..fef9956068f1fd 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7395,18 +7395,18 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
if (Instruction *UI = GetInstructionForCost(&R))
SeenInstrs.insert(UI);
// VPExtendedReductionRecipe contains a folded extend instruction.
- // if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
- // SeenInstrs.insert(ExtendedRed->getExtInstr());
- // // VPMulAccRecipe constians a mul and otional extend instructions.
- // else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
- // SeenInstrs.insert(MulAcc->getMulInstr());
- // if (MulAcc->isExtended()) {
- // SeenInstrs.insert(MulAcc->getExt0Instr());
- // SeenInstrs.insert(MulAcc->getExt1Instr());
- // if (auto *Ext = MulAcc->getExtInstr())
- // SeenInstrs.insert(Ext);
- // }
- // }
+ if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
+ SeenInstrs.insert(ExtendedRed->getExtInstr());
+ // VPMulAccRecipe constians a mul and otional extend instructions.
+ else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+ SeenInstrs.insert(MulAcc->getMulInstr());
+ if (MulAcc->isExtended()) {
+ SeenInstrs.insert(MulAcc->getExt0Instr());
+ SeenInstrs.insert(MulAcc->getExt1Instr());
+ if (auto *Ext = MulAcc->getExtInstr())
+ SeenInstrs.insert(Ext);
+ }
+ }
}
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index bdc1823d8e6f68..cd7d0efe8dda45 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1446,20 +1446,11 @@ class VPWidenRecipe : public VPRecipeWithIRFlags {
iterator_range<IterT> Operands)
: VPRecipeWithIRFlags(VPDefOpcode, Operands, I), Opcode(I.getOpcode()) {}
- template <typename IterT>
- VPWidenRecipe(unsigned VPDefOpcode, unsigned InstrOpcode,
- iterator_range<IterT> Operands)
- : VPRecipeWithIRFlags(VPDefOpcode, Operands), Opcode(InstrOpcode) {}
-
public:
template <typename IterT>
VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands)
: VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
- template <typename IterT>
- VPWidenRecipe(unsigned Opcode, iterator_range<IterT> Operands)
- : VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands) {}
-
~VPWidenRecipe() override = default;
VPWidenRecipe *clone() override {
@@ -2674,25 +2665,26 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
class VPExtendedReductionRecipe : public VPReductionRecipe {
/// Type after extend.
Type *ResultTy;
- Instruction::CastOps ExtOp;
+ CastInst *ExtInstr;
protected:
VPExtendedReductionRecipe(const unsigned char SC,
const RecurrenceDescriptor &R, Instruction *RedI,
- Instruction::CastOps ExtOp, VPValue *ChainOp,
- VPValue *VecOp, VPValue *CondOp, bool IsOrdered,
- Type *ResultTy)
+ Instruction::CastOps ExtOp, CastInst *ExtInstr,
+ VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
+ bool IsOrdered, Type *ResultTy)
: VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
CondOp, IsOrdered),
- ResultTy(ResultTy), ExtOp(ExtOp) {}
+ ResultTy(ResultTy), ExtInstr(ExtInstr) {}
public:
VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
VPValue *ChainOp, VPWidenCastRecipe *Ext,
VPValue *CondOp, bool IsOrdered)
- : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
- Ext->getOpcode(), ChainOp, Ext->getOperand(0),
- CondOp, IsOrdered, Ext->getResultType()) {}
+ : VPExtendedReductionRecipe(
+ VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
+ cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
+ Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
~VPExtendedReductionRecipe() override = default;
@@ -2728,7 +2720,9 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
Type *getResultType() const { return ResultTy; }
bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
/// The Opcode of extend instruction.
- Instruction::CastOps getExtOpcode() const { return ExtOp; }
+ Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
+ /// The CastInst of the extend instruction.
+ CastInst *getExtInstr() const { return ExtInstr; }
};
/// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2737,52 +2731,70 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
/// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
/// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
class VPMulAccRecipe : public VPReductionRecipe {
- Instruction::CastOps ExtOp;
- bool IsExtended = false;
+ /// Type after extend.
+ Type *ResultType;
+
+ /// reduce.add(ext(mul(ext0(), ext1())))
+ Instruction *MulInstr;
+ CastInst *ExtInstr = nullptr;
+ CastInst *Ext0Instr = nullptr;
+ CastInst *Ext1Instr = nullptr;
protected:
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
- Instruction *RedI, Instruction::CastOps ExtOp,
- VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
- VPValue *CondOp, bool IsOrdered)
+ Instruction *RedI, Instruction *ExtInstr,
+ Instruction *MulInstr, Instruction *Ext0Instr,
+ Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
+ VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
+ Type *ResultType)
: VPReductionRecipe(SC, R, RedI,
ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
CondOp, IsOrdered),
- ExtOp(ExtOp) {
- IsExtended = true;
+ ResultType(ResultType), MulInstr(MulInstr),
+ ExtInstr(cast_if_present<CastInst>(ExtInstr)),
+ Ext0Instr(cast<CastInst>(Ext0Instr)),
+ Ext1Instr(cast<CastInst>(Ext1Instr)) {
+ assert(MulInstr->getOpcode() == Instruction::Mul);
}
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
- Instruction *RedI, VPValue *ChainOp, VPValue *VecOp0,
- VPValue *VecOp1, VPValue *CondOp, bool IsOrdered)
+ Instruction *RedI, Instruction *MulInstr, VPValue *ChainOp,
+ VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp,
+ bool IsOrdered)
: VPReductionRecipe(SC, R, RedI,
ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
- CondOp, IsOrdered) {}
+ CondOp, IsOrdered),
+ MulInstr(MulInstr) {
+ assert(MulInstr->getOpcode() == Instruction::Mul);
+ }
public:
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
VPWidenCastRecipe *Ext1)
- : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI,
- cast<CastInst>(Ext0->getUnderlyingInstr())->getOpcode(),
- ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
- CondOp, IsOrdered) {}
+ : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
+ Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+ Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+ Ext1->getOperand(0), CondOp, IsOrdered,
+ Ext0->getResultType()) {}
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
VPWidenRecipe *Mul)
- : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, ChainOp, Mul->getOperand(0),
- Mul->getOperand(1), CondOp, IsOrdered) {}
+ : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+ ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
+ IsOrdered) {}
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
- : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI,
- cast<CastInst>(Ext0->getUnderlyingInstr())->getOpcode(),
- ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
- CondOp, IsOrdered) {}
+ : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
+ Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+ Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+ Ext1->getOperand(0), CondOp, IsOrdered,
+ Ext0->getResultType()) {}
~VPMulAccRecipe() override = default;
@@ -2816,17 +2828,37 @@ class VPMulAccRecipe : public VPReductionRecipe {
VPValue *getVecOp0() const { return getOperand(1); }
VPValue *getVecOp1() const { return getOperand(2); }
+ /// Return the type after inner extended, which must equal to the type of mul
+ /// instruction. If the ResultType != recurrenceType, than it must have a
+ /// extend recipe after mul recipe.
+ Type *getResultType() const { return ResultType; }
+
+ /// The underlying instruction for VPWidenRecipe.
+ Instruction *getMulInstr() const { return MulInstr; }
+
+ /// The underlying Instruction for outer VPWidenCastRecipe.
+ CastInst *getExtInstr() const { return ExtInstr; }
+ /// The underlying Instruction for inner VPWidenCastRecipe.
+ CastInst *getExt0Instr() const { return Ext0Instr; }
+ /// The underlying Instruction for inner VPWidenCastRecipe.
+ CastInst *getExt1Instr() const { return Ext1Instr; }
+
/// Return if this MulAcc recipe contains extend instructions.
- bool isExtended() const { return IsExtended; }
+ bool isExtended() const { return Ext0Instr && Ext1Instr; }
/// Return if the operands of mul instruction come from same extend.
- bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
+ bool isSameExtend() const { return Ext0Instr == Ext1Instr; }
+ /// Return if the MulAcc recipes contains extend after mul.
+ bool isOuterExtended() const { return ExtInstr != nullptr; }
/// Return if the extend opcode is ZExt.
bool isZExt() const {
if (!isExtended())
return true;
- return ExtOp == Instruction::CastOps::ZExt;
+ // reduce.add(sext(mul(zext(A), zext(A)))) can be transform to
+ // reduce.add(zext(mul(sext(A), sext(A))))
+ if (ExtInstr && ExtInstr->getOpcode() != Ext0Instr->getOpcode())
+ return true;
+ return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
}
- Instruction::CastOps getExtOpcode() const { return ExtOp; }
};
/// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index cd38d4ffe8c87e..fc6b90191efa43 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2262,8 +2262,6 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
Type *ElementTy = isExtended() ? RdxDesc.getRecurrenceType()
: Ctx.Types.inferScalarType(getVecOp0());
auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
- auto *SrcVecTy =
- cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
unsigned Opcode = RdxDesc.getOpcode();
@@ -2281,25 +2279,29 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
if (isExtended()) {
auto *SrcTy = cast<VectorType>(
ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+ auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
TTI::CastContextHint CCH0 =
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
- ExtendedCost = Ctx.TTI.getCastInstrCost(getExtOpcode(), VectorTy, SrcVecTy,
- CCH0, TTI::TCK_RecipThroughput);
+ ExtendedCost = Ctx.TTI.getCastInstrCost(
+ Ext0Instr->getOpcode(), DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+ dyn_cast_if_present<Instruction>(getExt0Instr()));
TTI::CastContextHint CCH1 =
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
- ExtendedCost += Ctx.TTI.getCastInstrCost(getExtOpcode(), VectorTy, SrcVecTy,
- CCH1, TTI::TCK_RecipThroughput);
+ ExtendedCost += Ctx.TTI.getCastInstrCost(
+ Ext1Instr->getOpcode(), DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+ dyn_cast_if_present<Instruction>(getExt1Instr()));
}
// Mul cost
InstructionCost MulCost;
SmallVector<const Value *, 4> Operands;
+ Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
if (isExtended())
MulCost = Ctx.TTI.getArithmeticInstrCost(
Instruction::Mul, VectorTy, CostKind,
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
- Operands, nullptr, &Ctx.TLI);
+ Operands, MulInstr, &Ctx.TLI);
else {
VPValue *RHS = getVecOp1();
// Certain instructions can be cheaper to vectorize if they have a constant
@@ -2312,15 +2314,15 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
RHS->isDefinedOutsideLoopRegions())
RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
- Operands.append(
- {getVecOp0()->getUnderlyingValue(), RHS->getUnderlyingValue()});
MulCost = Ctx.TTI.getArithmeticInstrCost(
Instruction::Mul, VectorTy, CostKind,
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
- RHSInfo, Operands, nullptr, &Ctx.TLI);
+ RHSInfo, Operands, MulInstr, &Ctx.TLI);
}
// MulAccReduction Cost
+ VectorType *SrcVecTy =
+ cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
InstructionCost MulAccCost =
Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
@@ -2404,7 +2406,6 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
- Type *ElementTy = RdxDesc.getRecurrenceType();
O << Indent << "MULACC-REDUCE ";
printAsOperand(O, SlotTracker);
O << " = ";
@@ -2412,23 +2413,27 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
O << " + ";
if (isa<FPMathOperator>(getUnderlyingInstr()))
O << getUnderlyingInstr()->getFastMathFlags();
+ if (isOuterExtended())
+ O << " (";
O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
O << "mul ";
if (isExtended())
O << "(";
getVecOp0()->printAsOperand(O, SlotTracker);
if (isExtended())
- O << " extended to " << *ElementTy << "), (";
+ O << " extended to " << *getResultType() << "), (";
else
O << ", ";
getVecOp1()->printAsOperand(O, SlotTracker);
if (isExtended())
- O << " extended to " << *ElementTy << ")";
+ O << " extended to " << *getResultType() << ")";
if (isConditional()) {
O << ", ";
getCondOp()->printAsOperand(O, SlotTracker);
}
O << ")";
+ if (isOuterExtended())
+ O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
if (RdxDesc.IntermediateStore)
O << " (with final reduction value stored in invariant address sank "
"outside of loop)";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 813e0edfdb33ed..bbfb2540cc6e3d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -527,9 +527,9 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
// Genearte VPWidenCastRecipe.
- auto *Ext =
- new VPWidenCastRecipe(ExtRed->getExtOpcode(), ExtRed->getVecOp(),
- ExtRed->getResultType());
+ auto *Ext = new VPWidenCastRecipe(
+ ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+ *ExtRed->getExtInstr());
// Generate VPreductionRecipe.
auto *Red = new VPReductionRecipe(
@@ -542,21 +542,22 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
ExtRed->eraseFromParent();
} else if (isa<VPMulAccRecipe>(&R)) {
auto *MulAcc = cast<VPMulAccRecipe>(&R);
- Type *RedType = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
// Generate inner VPWidenCastRecipes if necessary.
VPValue *Op0, *Op1;
if (MulAcc->isExtended()) {
- Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
- MulAcc->getVecOp0(), RedType);
+ CastInst *Ext0 = MulAcc->getExt0Instr();
+ Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
+ MulAcc->getResultType(), *Ext0);
Op0->getDefiningRecipe()->insertBefore(MulAcc);
// Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
// VPWidenCastRecipe.
if (MulAcc->isSameExtend()) {
Op1 = Op0;
} else {
- Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
- MulAcc->getVecOp1(), RedType);
+ CastInst *Ext1 = MulAcc->getExt1Instr();
+ Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
+ MulAcc->getResultType(), *Ext1);
Op1->getDefiningRecipe()->insertBefore(MulAcc);
}
// Not contains extend instruction in this MulAccRecipe.
@@ -566,15 +567,27 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
}
// Generate VPWidenRecipe.
+ VPSingleDefRecipe *VecOp;
SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
- auto *Mul = new VPWidenRecipe(Instruction::Mul,
+ auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
make_range(MulOps.begin(), MulOps.end()));
Mul->insertBefore(MulAcc);
+ // Generate outer VPWidenCastRecipe if necessary.
+ if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
+ VecOp = new VPWidenCastRecipe(
+ OuterExtInstr->getOpcode(), Mul,
+ MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
+ *OuterExtInstr);
+ VecOp->insertBefore(MulAcc);
+ } else {
+ VecOp = Mul;
+ }
+
// Generate VPReductionRecipe.
auto *Red = new VPReductionRecipe(
MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
- MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+ MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
MulAcc->isOrdered());
Red->insertBefore(MulAcc);
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index b6e5d7484fafea..6a48c330775972 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1670,55 +1670,24 @@ define i64 @test_std_q31(ptr %x, i32 %n) #0 {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[CMP11:%.*]] = icmp sgt i32 [[N:%.*]], 0
; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP11]])
-; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp samesign ult i32 [[N]], 4
-; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
-; CHECK: vector.ph:
-; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[N]], 2147483644
; CHECK-NEXT: br label [[FOR_BODY1:%.*]]
-; CHECK: vector.body:
-; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[FOR_BODY1]] ]
-; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[FOR_BODY1]] ]
-; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[FOR_BODY1]] ]
-; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP10]], align 4
-; CHECK-NEXT: [[TMP1:%.*]] = ashr <4 x i32> [[WIDE_LOAD]], splat (i32 8)
-; CHECK-NEXT: [[TMP2:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
-; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP2]])
-; CHECK-NEXT: [[TMP4]] = add i64 [[TMP3]], [[VEC_PHI]]
-; CHECK-NEXT: [[TMP5:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
-; CHECK-NEXT: [[TMP6:%.*]] = mul nsw <4 x i64> [[TMP5]], [[TMP5]]
-; CHECK-NEXT: [[TMP7:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP6]])
-; CHECK-NEXT: [[TMP8]] = add i64 [[TMP7]], [[VEC_PHI1]]
-; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
-; CHECK-NEXT: [[TMP9:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT: br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[FOR_BODY1]], !llvm.loop [[LOOP37:![0-9]+]]
-; CHECK: middle.block:
-; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
-; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
-; CHECK: scalar.ph:
-; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i64 [ [[TMP4]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT: [[BC_MERGE_RDX2:%.*]] = phi i64 [ [[TMP8]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT: br label [[FOR_BODY:%.*]]
; CHECK: for.cond.cleanup:
-; CHECK-NEXT: [[ADD:%.*]] = phi i64 [ [[ADD1:%.*]], [[FOR_BODY]] ], [ [[TMP4]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT: [[ADD3:%.*]] = phi i64 [ [[ADD5:%.*]], [[FOR_BODY]] ], [ [[TMP8]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT: [[DIV:%.*]] = sdiv i64 [[ADD3]], [[ADD]]
+; CHECK-NEXT: [[DIV:%.*]] = sdiv i64 [[ADD3:%.*]], [[ADD:%.*]]
; CHECK-NEXT: ret i64 [[DIV]]
; CHECK: for.body:
-; CHECK-NEXT: [[S_014:%.*]] = phi i64 [ [[ADD1]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
-; CHECK-NEXT: [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
-; CHECK-NEXT: [[T_012:%.*]] = phi i64 [ [[ADD5]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX2]], [[SCALAR_PH]] ]
-; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X]], i32 [[I_013]]
+; CHECK-NEXT: [[S_014:%.*]] = phi i64 [ [[ADD]], [[FOR_BODY1]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT: [[T_012:%.*]] = phi i64 [ [[ADD3]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[I_013]]
; CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
; CHECK-NEXT: [[SHR:%.*]] = ashr i32 [[TMP0]], 8
; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[SHR]] to i64
-; CHECK-NEXT: [[ADD1]] = add nsw i64 [[S_014]], [[CONV]]
+; CHECK-NEXT: [[ADD]] = add nsw i64 [[S_014]], [[CONV]]
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i64 [[CONV]], [[CONV]]
-; CHECK-NEXT: [[ADD5]] = add nuw nsw i64 [[MUL]], [[T_012]]
+; CHECK-NEXT: [[ADD3]] = add nuw nsw i64 [[MUL]], [[T_012]]
; CHECK-NEXT: [[ADD4]] = add nuw nsw i32 [[I_013]], 1
; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i32 [[ADD4]], [[N]]
-; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP38:![0-9]+]]
+; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY1]]
;
entry:
%cmp11 = icmp sgt i32 %n, 0
@@ -1754,10 +1723,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
; CHECK-NEXT: [[TMP0:%.*]] = add nsw i32 [[N]], -1
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 1
; CHECK-NEXT: [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1
-; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15
+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK: vector.ph:
-; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[TMP2]], -8
+; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[TMP2]], -4
; CHECK-NEXT: [[IND_END:%.*]] = shl i32 [[N_VEC]], 1
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -1765,28 +1734,28 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2
-; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
-; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2
+; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+; CHECK-NEXT: [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2
-; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
-; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
-; CHECK-NEXT: [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i32>
-; CHECK-NEXT: [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i32>
-; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <8 x i32> [[TMP5]], [[TMP6]]
-; CHECK-NEXT: [[TMP8:%.*]] = sext <8 x i32> [[TMP7]] to <8 x i64>
-; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP8]])
+; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2
+; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+; CHECK-NEXT: [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32>
+; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]]
+; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64>
+; CHECK-NEXT: [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32>
+; CHECK-NEXT: [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32>
+; CHECK-NEXT: [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]]
+; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64>
+; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]])
; CHECK-NEXT: [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]]
-; CHECK-NEXT: [[TMP11:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i32>
-; CHECK-NEXT: [[TMP12:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i32>
-; CHECK-NEXT: [[TMP13:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP12]]
-; CHECK-NEXT: [[TMP14:%.*]] = sext <8 x i32> [[TMP13]] to <8 x i64>
-; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP14]])
+; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]])
; CHECK-NEXT: [[TMP16]] = add i64 [[TMP15]], [[TMP10]]
-; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT: [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT: br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP39:![0-9]+]]
+; CHECK-NEXT: br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]]
; CHECK: middle.block:
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[TMP2]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP]], label [[SCALAR_PH]]
@@ -1821,7 +1790,7 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
; CHECK-NEXT: [[ADD12]] = add nsw i64 [[ADD]], [[CONV11]]
; CHECK-NEXT: [[ADD13]] = add nuw nsw i32 [[I_025]], 2
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[ADD13]], [[N]]
-; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP40:![0-9]+]]
+; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP38:![0-9]+]]
;
entry:
%cmp23 = icmp sgt i32 %n, 0
>From 7d40b4cd814f4c7d8e854690767acfa1c7341c6c Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 14 Nov 2024 16:48:33 -0800
Subject: [PATCH 12/14] Remove extended instruction after mul in MulAccRecipe.
Note that reduce.add(ext(mul(ext(A), ext(A)))) is mathmetical equal
to reduce.add(mul(zext(A), zext(A))).
---
.../Transforms/Vectorize/LoopVectorize.cpp | 13 --
llvm/lib/Transforms/Vectorize/VPlan.h | 58 ++++-----
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 15 +--
.../Transforms/Vectorize/VPlanTransforms.cpp | 25 ++--
.../LoopVectorize/ARM/mve-reductions.ll | 111 +++++++++++-------
.../LoopVectorize/vplan-printing.ll | 2 +-
6 files changed, 104 insertions(+), 120 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index fef9956068f1fd..2afe2f349b6284 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7394,19 +7394,6 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
}
if (Instruction *UI = GetInstructionForCost(&R))
SeenInstrs.insert(UI);
- // VPExtendedReductionRecipe contains a folded extend instruction.
- if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
- SeenInstrs.insert(ExtendedRed->getExtInstr());
- // VPMulAccRecipe constians a mul and otional extend instructions.
- else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
- SeenInstrs.insert(MulAcc->getMulInstr());
- if (MulAcc->isExtended()) {
- SeenInstrs.insert(MulAcc->getExt0Instr());
- SeenInstrs.insert(MulAcc->getExt1Instr());
- if (auto *Ext = MulAcc->getExtInstr())
- SeenInstrs.insert(Ext);
- }
- }
}
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index cd7d0efe8dda45..d3e4a436e97cb3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2573,8 +2573,6 @@ class VPReductionRecipe : public VPSingleDefRecipe {
getVecOp(), getCondOp(), IsOrdered);
}
- // TODO: Support VPExtendedReductionRecipe and VPMulAccRecipe after EVL
- // support.
static inline bool classof(const VPRecipeBase *R) {
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
@@ -2731,30 +2729,26 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
/// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
/// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
class VPMulAccRecipe : public VPReductionRecipe {
- /// Type after extend.
- Type *ResultType;
/// reduce.add(ext(mul(ext0(), ext1())))
Instruction *MulInstr;
- CastInst *ExtInstr = nullptr;
CastInst *Ext0Instr = nullptr;
CastInst *Ext1Instr = nullptr;
protected:
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
- Instruction *RedI, Instruction *ExtInstr,
- Instruction *MulInstr, Instruction *Ext0Instr,
- Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
- VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
- Type *ResultType)
+ Instruction *RedI, Instruction *MulInstr,
+ Instruction *Ext0Instr, Instruction *Ext1Instr,
+ VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
+ VPValue *CondOp, bool IsOrdered)
: VPReductionRecipe(SC, R, RedI,
ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
CondOp, IsOrdered),
- ResultType(ResultType), MulInstr(MulInstr),
- ExtInstr(cast_if_present<CastInst>(ExtInstr)),
- Ext0Instr(cast<CastInst>(Ext0Instr)),
+ MulInstr(MulInstr), Ext0Instr(cast<CastInst>(Ext0Instr)),
Ext1Instr(cast<CastInst>(Ext1Instr)) {
assert(MulInstr->getOpcode() == Instruction::Mul);
+ assert(R.getOpcode() == Instruction::Add);
+ assert(Ext0Instr->getOpcode() == Ext1Instr->getOpcode());
}
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2766,6 +2760,7 @@ class VPMulAccRecipe : public VPReductionRecipe {
CondOp, IsOrdered),
MulInstr(MulInstr) {
assert(MulInstr->getOpcode() == Instruction::Mul);
+ assert(R.getOpcode() == Instruction::Add);
}
public:
@@ -2773,11 +2768,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
VPWidenCastRecipe *Ext1)
- : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
- Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
- Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
- Ext1->getOperand(0), CondOp, IsOrdered,
- Ext0->getResultType()) {}
+ : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+ Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
+ ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
+ CondOp, IsOrdered) {}
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
@@ -2790,11 +2784,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
- : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
- Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
- Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
- Ext1->getOperand(0), CondOp, IsOrdered,
- Ext0->getResultType()) {}
+ : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+ Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
+ ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
+ CondOp, IsOrdered) {}
~VPMulAccRecipe() override = default;
@@ -2828,35 +2821,26 @@ class VPMulAccRecipe : public VPReductionRecipe {
VPValue *getVecOp0() const { return getOperand(1); }
VPValue *getVecOp1() const { return getOperand(2); }
- /// Return the type after inner extended, which must equal to the type of mul
- /// instruction. If the ResultType != recurrenceType, than it must have a
- /// extend recipe after mul recipe.
- Type *getResultType() const { return ResultType; }
-
/// The underlying instruction for VPWidenRecipe.
Instruction *getMulInstr() const { return MulInstr; }
- /// The underlying Instruction for outer VPWidenCastRecipe.
- CastInst *getExtInstr() const { return ExtInstr; }
/// The underlying Instruction for inner VPWidenCastRecipe.
CastInst *getExt0Instr() const { return Ext0Instr; }
+
/// The underlying Instruction for inner VPWidenCastRecipe.
CastInst *getExt1Instr() const { return Ext1Instr; }
/// Return if this MulAcc recipe contains extend instructions.
bool isExtended() const { return Ext0Instr && Ext1Instr; }
+
/// Return if the operands of mul instruction come from same extend.
- bool isSameExtend() const { return Ext0Instr == Ext1Instr; }
- /// Return if the MulAcc recipes contains extend after mul.
- bool isOuterExtended() const { return ExtInstr != nullptr; }
+ bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
+ Instruction::CastOps getExtOpcode() const { return Ext0Instr->getOpcode(); }
+
/// Return if the extend opcode is ZExt.
bool isZExt() const {
if (!isExtended())
return true;
- // reduce.add(sext(mul(zext(A), zext(A)))) can be transform to
- // reduce.add(zext(mul(sext(A), sext(A))))
- if (ExtInstr && ExtInstr->getOpcode() != Ext0Instr->getOpcode())
- return true;
return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
}
};
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index fc6b90191efa43..9e53bfaa753022 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2279,16 +2279,15 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
if (isExtended()) {
auto *SrcTy = cast<VectorType>(
ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
- auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
TTI::CastContextHint CCH0 =
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
ExtendedCost = Ctx.TTI.getCastInstrCost(
- Ext0Instr->getOpcode(), DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+ Ext0Instr->getOpcode(), VectorTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
dyn_cast_if_present<Instruction>(getExt0Instr()));
TTI::CastContextHint CCH1 =
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
ExtendedCost += Ctx.TTI.getCastInstrCost(
- Ext1Instr->getOpcode(), DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+ Ext1Instr->getOpcode(), VectorTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
dyn_cast_if_present<Instruction>(getExt1Instr()));
}
@@ -2406,6 +2405,8 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+ Type *RedTy = RdxDesc.getRecurrenceType();
+
O << Indent << "MULACC-REDUCE ";
printAsOperand(O, SlotTracker);
O << " = ";
@@ -2413,27 +2414,23 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
O << " + ";
if (isa<FPMathOperator>(getUnderlyingInstr()))
O << getUnderlyingInstr()->getFastMathFlags();
- if (isOuterExtended())
- O << " (";
O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
O << "mul ";
if (isExtended())
O << "(";
getVecOp0()->printAsOperand(O, SlotTracker);
if (isExtended())
- O << " extended to " << *getResultType() << "), (";
+ O << " extended to " << *RedTy << "), (";
else
O << ", ";
getVecOp1()->printAsOperand(O, SlotTracker);
if (isExtended())
- O << " extended to " << *getResultType() << ")";
+ O << " extended to " << *RedTy << ")";
if (isConditional()) {
O << ", ";
getCondOp()->printAsOperand(O, SlotTracker);
}
O << ")";
- if (isOuterExtended())
- O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
if (RdxDesc.IntermediateStore)
O << " (with final reduction value stored in invariant address sank "
"outside of loop)";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index bbfb2540cc6e3d..ff63dfe7c1eecd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -542,52 +542,43 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
ExtRed->eraseFromParent();
} else if (isa<VPMulAccRecipe>(&R)) {
auto *MulAcc = cast<VPMulAccRecipe>(&R);
+ Type *RedTy = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
// Generate inner VPWidenCastRecipes if necessary.
+ // Note that we will drop the extend after of mul which transform
+ // reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
VPValue *Op0, *Op1;
if (MulAcc->isExtended()) {
CastInst *Ext0 = MulAcc->getExt0Instr();
Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
- MulAcc->getResultType(), *Ext0);
+ RedTy, *Ext0);
Op0->getDefiningRecipe()->insertBefore(MulAcc);
// Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
// VPWidenCastRecipe.
- if (MulAcc->isSameExtend()) {
+ if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
Op1 = Op0;
} else {
CastInst *Ext1 = MulAcc->getExt1Instr();
Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
- MulAcc->getResultType(), *Ext1);
+ RedTy, *Ext1);
Op1->getDefiningRecipe()->insertBefore(MulAcc);
}
- // Not contains extend instruction in this MulAccRecipe.
+ // No extends in this MulAccRecipe.
} else {
Op0 = MulAcc->getVecOp0();
Op1 = MulAcc->getVecOp1();
}
// Generate VPWidenRecipe.
- VPSingleDefRecipe *VecOp;
SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
make_range(MulOps.begin(), MulOps.end()));
Mul->insertBefore(MulAcc);
- // Generate outer VPWidenCastRecipe if necessary.
- if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
- VecOp = new VPWidenCastRecipe(
- OuterExtInstr->getOpcode(), Mul,
- MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
- *OuterExtInstr);
- VecOp->insertBefore(MulAcc);
- } else {
- VecOp = Mul;
- }
-
// Generate VPReductionRecipe.
auto *Red = new VPReductionRecipe(
MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
- MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
+ MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
MulAcc->isOrdered());
Red->insertBefore(MulAcc);
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index 6a48c330775972..447fa887bf0c20 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -648,10 +648,9 @@ define i64 @mla_i16_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <8 x i16>, ptr [[TMP2]], align 2
-; CHECK-NEXT: [[TMP11:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
-; CHECK-NEXT: [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP3]]
-; CHECK-NEXT: [[TMP5:%.*]] = sext <8 x i32> [[TMP4]] to <8 x i64>
+; CHECK-NEXT: [[TMP4:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i64>
+; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i64>
+; CHECK-NEXT: [[TMP5:%.*]] = mul nsw <8 x i64> [[TMP4]], [[TMP3]]
; CHECK-NEXT: [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
; CHECK-NEXT: [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
@@ -728,10 +727,9 @@ define i64 @mla_i8_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i32
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i8>, ptr [[TMP0]], align 1
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <8 x i8>, ptr [[TMP2]], align 1
-; CHECK-NEXT: [[TMP11:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
-; CHECK-NEXT: [[TMP4:%.*]] = mul nuw nsw <8 x i32> [[TMP11]], [[TMP3]]
-; CHECK-NEXT: [[TMP5:%.*]] = zext nneg <8 x i32> [[TMP4]] to <8 x i64>
+; CHECK-NEXT: [[TMP4:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i64>
+; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i64>
+; CHECK-NEXT: [[TMP5:%.*]] = mul nuw nsw <8 x i64> [[TMP4]], [[TMP3]]
; CHECK-NEXT: [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
; CHECK-NEXT: [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
@@ -1461,9 +1459,8 @@ define i64 @mla_xx_sext_zext(ptr nocapture noundef readonly %x, i32 %n) #0 {
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
-; CHECK-NEXT: [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
-; CHECK-NEXT: [[TMP2:%.*]] = mul nsw <8 x i32> [[TMP1]], [[TMP1]]
-; CHECK-NEXT: [[TMP3:%.*]] = zext nneg <8 x i32> [[TMP2]] to <8 x i64>
+; CHECK-NEXT: [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i64>
+; CHECK-NEXT: [[TMP3:%.*]] = mul nsw <8 x i64> [[TMP1]], [[TMP1]]
; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP3]])
; CHECK-NEXT: [[TMP5]] = add i64 [[TMP4]], [[VEC_PHI]]
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
@@ -1530,9 +1527,8 @@ define i64 @mla_and_add_together_16_64(ptr nocapture noundef readonly %x, i32 no
; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
-; CHECK-NEXT: [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
-; CHECK-NEXT: [[TMP2:%.*]] = mul nsw <8 x i32> [[TMP1]], [[TMP1]]
-; CHECK-NEXT: [[TMP3:%.*]] = zext nneg <8 x i32> [[TMP2]] to <8 x i64>
+; CHECK-NEXT: [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i64>
+; CHECK-NEXT: [[TMP3:%.*]] = mul nsw <8 x i64> [[TMP1]], [[TMP1]]
; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP3]])
; CHECK-NEXT: [[TMP5]] = add i64 [[TMP4]], [[VEC_PHI]]
; CHECK-NEXT: [[TMP10:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
@@ -1670,24 +1666,55 @@ define i64 @test_std_q31(ptr %x, i32 %n) #0 {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[CMP11:%.*]] = icmp sgt i32 [[N:%.*]], 0
; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP11]])
+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp samesign ult i32 [[N]], 4
+; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK: vector.ph:
+; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[N]], 2147483644
; CHECK-NEXT: br label [[FOR_BODY1:%.*]]
+; CHECK: vector.body:
+; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP10]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = ashr <4 x i32> [[WIDE_LOAD]], splat (i32 8)
+; CHECK-NEXT: [[TMP2:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP2]])
+; CHECK-NEXT: [[TMP4]] = add i64 [[TMP3]], [[VEC_PHI]]
+; CHECK-NEXT: [[TMP5:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT: [[TMP6:%.*]] = mul nsw <4 x i64> [[TMP5]], [[TMP5]]
+; CHECK-NEXT: [[TMP7:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP6]])
+; CHECK-NEXT: [[TMP8]] = add i64 [[TMP7]], [[VEC_PHI1]]
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT: [[TMP9:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[FOR_BODY1]], !llvm.loop [[LOOP37:![0-9]+]]
+; CHECK: middle.block:
+; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK: scalar.ph:
+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i64 [ [[TMP4]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX2:%.*]] = phi i64 [ [[TMP8]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT: br label [[FOR_BODY:%.*]]
; CHECK: for.cond.cleanup:
-; CHECK-NEXT: [[DIV:%.*]] = sdiv i64 [[ADD3:%.*]], [[ADD:%.*]]
+; CHECK-NEXT: [[ADD:%.*]] = phi i64 [ [[ADD1:%.*]], [[FOR_BODY]] ], [ [[TMP4]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: [[ADD3:%.*]] = phi i64 [ [[ADD5:%.*]], [[FOR_BODY]] ], [ [[TMP8]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: [[DIV:%.*]] = sdiv i64 [[ADD3]], [[ADD]]
; CHECK-NEXT: ret i64 [[DIV]]
; CHECK: for.body:
-; CHECK-NEXT: [[S_014:%.*]] = phi i64 [ [[ADD]], [[FOR_BODY1]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT: [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT: [[T_012:%.*]] = phi i64 [ [[ADD3]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[I_013]]
+; CHECK-NEXT: [[S_014:%.*]] = phi i64 [ [[ADD1]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
+; CHECK-NEXT: [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
+; CHECK-NEXT: [[T_012:%.*]] = phi i64 [ [[ADD5]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX2]], [[SCALAR_PH]] ]
+; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X]], i32 [[I_013]]
; CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
; CHECK-NEXT: [[SHR:%.*]] = ashr i32 [[TMP0]], 8
; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[SHR]] to i64
-; CHECK-NEXT: [[ADD]] = add nsw i64 [[S_014]], [[CONV]]
+; CHECK-NEXT: [[ADD1]] = add nsw i64 [[S_014]], [[CONV]]
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i64 [[CONV]], [[CONV]]
-; CHECK-NEXT: [[ADD3]] = add nuw nsw i64 [[MUL]], [[T_012]]
+; CHECK-NEXT: [[ADD5]] = add nuw nsw i64 [[MUL]], [[T_012]]
; CHECK-NEXT: [[ADD4]] = add nuw nsw i32 [[I_013]], 1
; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i32 [[ADD4]], [[N]]
-; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY1]]
+; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP38:![0-9]+]]
;
entry:
%cmp11 = icmp sgt i32 %n, 0
@@ -1723,10 +1750,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
; CHECK-NEXT: [[TMP0:%.*]] = add nsw i32 [[N]], -1
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 1
; CHECK-NEXT: [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1
-; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7
+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK: vector.ph:
-; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[TMP2]], -4
+; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[TMP2]], -8
; CHECK-NEXT: [[IND_END:%.*]] = shl i32 [[N_VEC]], 1
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -1734,28 +1761,26 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2
-; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT: [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32>
+; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2
+; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2
-; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT: [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32>
-; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]]
-; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64>
-; CHECK-NEXT: [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32>
-; CHECK-NEXT: [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32>
-; CHECK-NEXT: [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]]
-; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64>
-; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]])
+; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2
+; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT: [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i64>
+; CHECK-NEXT: [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i64>
+; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <8 x i64> [[TMP5]], [[TMP6]]
+; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP7]])
; CHECK-NEXT: [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]]
-; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]])
+; CHECK-NEXT: [[TMP13:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i64>
+; CHECK-NEXT: [[TMP11:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i64>
+; CHECK-NEXT: [[TMP12:%.*]] = mul nsw <8 x i64> [[TMP13]], [[TMP11]]
+; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP12]])
; CHECK-NEXT: [[TMP16]] = add i64 [[TMP15]], [[TMP10]]
-; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
; CHECK-NEXT: [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT: br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]]
+; CHECK-NEXT: br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP39:![0-9]+]]
; CHECK: middle.block:
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[TMP2]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP]], label [[SCALAR_PH]]
@@ -1790,7 +1815,7 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
; CHECK-NEXT: [[ADD12]] = add nsw i64 [[ADD]], [[CONV11]]
; CHECK-NEXT: [[ADD13]] = add nuw nsw i32 [[I_025]], 2
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[ADD13]], [[N]]
-; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP38:![0-9]+]]
+; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP40:![0-9]+]]
;
entry:
%cmp23 = icmp sgt i32 %n, 0
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index a8fb374a4b162d..c92fa2d4bd66a1 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -1315,7 +1315,7 @@ define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture reado
; CHECK-NEXT: CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
; CHECK-NEXT: vp<%5> = vector-pointer ir<%arrayidx1>
; CHECK-NEXT: WIDEN ir<%load1> = load vp<%5>
-; CHECK-NEXT: MULACC-REDUCE ir<%add> = ir<%r.09> + (reduce.add (mul (ir<%load0> extended to i32), (ir<%load1> extended to i32)) extended to i64)
+; CHECK-NEXT: MULACC-REDUCE ir<%add> = ir<%r.09> + reduce.add (mul (ir<%load0> extended to i64), (ir<%load1> extended to i64))
; CHECK-NEXT: EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
; CHECK-NEXT: EMIT branch-on-count vp<%index.next>, vp<%1>
; CHECK-NEXT: No successors
>From 550868d9b9da0c1ab8a1adcd3bf68cea82625f65 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 14 Nov 2024 18:07:19 -0800
Subject: [PATCH 13/14] Refactor.
- Move tryToMatch{MulAcc|ExtendedReduction} out from
adjustRecipesForReduction().
- Remove `ResultTy` which is same as the recurrence type.
- Use VP_CLASSOF_IMPL.
---
.../Transforms/Vectorize/LoopVectorize.cpp | 153 ++++++++++--------
llvm/lib/Transforms/Vectorize/VPlan.h | 47 ++----
.../Transforms/Vectorize/VPlanTransforms.cpp | 2 +-
3 files changed, 100 insertions(+), 102 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 2afe2f349b6284..63f42e9ca74121 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9245,6 +9245,87 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
return Plan;
}
+/// Try to match the extended-reduction and create VPExtendedReductionRecipe.
+///
+/// This function try to match following pattern which will generate
+/// extended-reduction instruction.
+/// reduce(ext(...)).
+static VPExtendedReductionRecipe *
+tryToMatchExtendedReduction(const RecurrenceDescriptor &RdxDesc,
+ Instruction *CurrentLinkI, VPValue *PreviousLink,
+ VPValue *VecOp, VPValue *CondOp,
+ LoopVectorizationCostModel &CM) {
+ using namespace VPlanPatternMatch;
+ VPValue *A;
+ // Matched reduce(ext)).
+ if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
+ return new VPExtendedReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
+ cast<VPWidenCastRecipe>(VecOp), CondOp,
+ CM.useOrderedReductions(RdxDesc));
+ }
+ return nullptr;
+}
+
+/// Try to match the mul-acc-reduction and create VPMulAccRecipe.
+///
+/// This function try to match following patterns which will generate mul-acc
+/// instructions.
+/// reduce.add(mul(...)),
+/// reduce.add(mul(ext(A), ext(B))),
+/// reduce.add(ext(mul(ext(A), ext(B)))).
+static VPMulAccRecipe *tryToMatchMulAcc(const RecurrenceDescriptor &RdxDesc,
+ Instruction *CurrentLinkI,
+ VPValue *PreviousLink, VPValue *VecOp,
+ VPValue *CondOp,
+ LoopVectorizationCostModel &CM) {
+ using namespace VPlanPatternMatch;
+ VPValue *A, *B;
+ if (RdxDesc.getOpcode() != Instruction::Add)
+ return nullptr;
+ // Try to match reduce.add(mul(...))
+ if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+ VPWidenCastRecipe *RecipeA =
+ dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+ VPWidenCastRecipe *RecipeB =
+ dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+ // Matched reduce.add(mul(ext, ext))
+ if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+ match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+ (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
+ return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+ CM.useOrderedReductions(RdxDesc),
+ cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
+ RecipeA, RecipeB);
+ } else {
+ // Matched reduce.add(mul)
+ return new VPMulAccRecipe(
+ RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+ CM.useOrderedReductions(RdxDesc),
+ cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+ }
+ // Matched reduce.add(ext(mul(ext(A), ext(B))))
+ // All extend instructions must have same opcode or A == B
+ // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
+ } else if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
+ m_ZExtOrSExt(m_VPValue()))))) {
+ VPWidenCastRecipe *Ext =
+ cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+ VPWidenRecipe *Mul =
+ cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+ VPWidenCastRecipe *Ext0 =
+ cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
+ VPWidenCastRecipe *Ext1 =
+ cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
+ if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
+ Ext0->getOpcode() == Ext1->getOpcode()) {
+ return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+ CM.useOrderedReductions(RdxDesc), Mul, Ext0,
+ Ext1);
+ }
+ }
+ return nullptr;
+}
+
// Adjust the recipes for reductions. For in-loop reductions the chain of
// instructions leading from the loop exit instr to the phi need to be converted
// to reductions, with one operand being vector and the other being the scalar
@@ -9376,76 +9457,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
if (CM.blockNeedsPredicationForAnyReason(BB))
CondOp = RecipeBuilder.getBlockInMask(BB);
- auto TryToMatchMulAcc = [&]() -> VPReductionRecipe * {
- VPValue *A, *B;
- if (RdxDesc.getOpcode() != Instruction::Add)
- return nullptr;
- // Try to match reduce.add(mul(...))
- if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
- !VecOp->hasMoreThanOneUniqueUser()) {
- VPWidenCastRecipe *RecipeA =
- dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
- VPWidenCastRecipe *RecipeB =
- dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
- // Matched reduce.add(mul(ext, ext))
- if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
- match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
- (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
- return new VPMulAccRecipe(
- RdxDesc, CurrentLinkI, PreviousLink, CondOp,
- CM.useOrderedReductions(RdxDesc),
- cast<VPWidenRecipe>(VecOp->getDefiningRecipe()), RecipeA,
- RecipeB);
- } else {
- // Matched reduce.add(mul)
- return new VPMulAccRecipe(
- RdxDesc, CurrentLinkI, PreviousLink, CondOp,
- CM.useOrderedReductions(RdxDesc),
- cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
- }
- // Matched reduce.add(ext(mul(ext(A), ext(B))))
- // Note that all extend instructions must have same opcode or A == B
- // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
- } else if (match(VecOp,
- m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
- m_ZExtOrSExt(m_VPValue())))) &&
- !VecOp->hasMoreThanOneUniqueUser()) {
- VPWidenCastRecipe *Ext =
- cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
- VPWidenRecipe *Mul =
- cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
- VPWidenCastRecipe *Ext0 =
- cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
- VPWidenCastRecipe *Ext1 =
- cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
- if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
- Ext0->getOpcode() == Ext1->getOpcode()) {
- return new VPMulAccRecipe(
- RdxDesc, CurrentLinkI, PreviousLink, CondOp,
- CM.useOrderedReductions(RdxDesc),
- cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
- cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
- }
- }
- return nullptr;
- };
-
- auto TryToMatchExtendedReduction = [&]() -> VPReductionRecipe * {
- VPValue *A;
- // Matched reduce(ext)).
- if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
- return new VPExtendedReductionRecipe(
- RdxDesc, CurrentLinkI, PreviousLink,
- cast<VPWidenCastRecipe>(VecOp), CondOp,
- CM.useOrderedReductions(RdxDesc));
- }
- return nullptr;
- };
-
VPReductionRecipe *RedRecipe;
- if (auto *MulAcc = TryToMatchMulAcc())
+ if (auto *MulAcc = tryToMatchMulAcc(RdxDesc, CurrentLinkI, PreviousLink,
+ VecOp, CondOp, CM))
RedRecipe = MulAcc;
- else if (auto *ExtendedRed = TryToMatchExtendedReduction())
+ else if (auto *ExtendedRed = tryToMatchExtendedReduction(
+ RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp, CM))
RedRecipe = ExtendedRed;
else
RedRecipe =
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index d3e4a436e97cb3..467f5ec01b1ef1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2661,8 +2661,6 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
/// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
/// {ChainOp, VecOp, [Condition]}.
class VPExtendedReductionRecipe : public VPReductionRecipe {
- /// Type after extend.
- Type *ResultTy;
CastInst *ExtInstr;
protected:
@@ -2670,10 +2668,10 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
const RecurrenceDescriptor &R, Instruction *RedI,
Instruction::CastOps ExtOp, CastInst *ExtInstr,
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
- bool IsOrdered, Type *ResultTy)
+ bool IsOrdered)
: VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
CondOp, IsOrdered),
- ResultTy(ResultTy), ExtInstr(ExtInstr) {}
+ ExtInstr(ExtInstr) {}
public:
VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
@@ -2682,7 +2680,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
: VPExtendedReductionRecipe(
VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
- Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
+ Ext->getOperand(0), CondOp, IsOrdered) {}
~VPExtendedReductionRecipe() override = default;
@@ -2690,14 +2688,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
llvm_unreachable("Not implement yet");
}
- static inline bool classof(const VPRecipeBase *R) {
- return R->getVPDefID() == VPDef::VPExtendedReductionSC;
- }
-
- static inline bool classof(const VPUser *U) {
- auto *R = dyn_cast<VPRecipeBase>(U);
- return R && classof(R);
- }
+ VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC);
void execute(VPTransformState &State) override {
llvm_unreachable("VPExtendedReductionRecipe should be transform to "
@@ -2714,11 +2705,16 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
VPSlotTracker &SlotTracker) const override;
#endif
- /// The Type after extended.
- Type *getResultType() const { return ResultTy; }
+ /// The scalar type after extended.
+ Type *getResultType() const {
+ return getRecurrenceDescriptor().getRecurrenceType();
+ }
+
bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
+
/// The Opcode of extend instruction.
Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
+
/// The CastInst of the extend instruction.
CastInst *getExtInstr() const { return ExtInstr; }
};
@@ -2730,7 +2726,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
/// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
class VPMulAccRecipe : public VPReductionRecipe {
- /// reduce.add(ext(mul(ext0(), ext1())))
+ // reduce.add(mul(ext0, ext1)).
Instruction *MulInstr;
CastInst *Ext0Instr = nullptr;
CastInst *Ext1Instr = nullptr;
@@ -2780,27 +2776,11 @@ class VPMulAccRecipe : public VPReductionRecipe {
ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
IsOrdered) {}
- VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
- VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
- VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
- VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
- : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
- Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
- ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
- CondOp, IsOrdered) {}
-
~VPMulAccRecipe() override = default;
VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
- static inline bool classof(const VPRecipeBase *R) {
- return R->getVPDefID() == VPRecipeBase::VPMulAccSC;
- }
-
- static inline bool classof(const VPUser *U) {
- auto *R = dyn_cast<VPRecipeBase>(U);
- return R && classof(R);
- }
+ VP_CLASSOF_IMPL(VPDef::VPMulAccSC);
void execute(VPTransformState &State) override {
llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
@@ -2835,6 +2815,7 @@ class VPMulAccRecipe : public VPReductionRecipe {
/// Return if the operands of mul instruction come from same extend.
bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
+
Instruction::CastOps getExtOpcode() const { return Ext0Instr->getOpcode(); }
/// Return if the extend opcode is ZExt.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index ff63dfe7c1eecd..ecf41a6b2092d1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -545,7 +545,7 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
Type *RedTy = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
// Generate inner VPWidenCastRecipes if necessary.
- // Note that we will drop the extend after of mul which transform
+ // Note that we will drop the extend after mul which transform
// reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
VPValue *Op0, *Op1;
if (MulAcc->isExtended()) {
>From e4aa4f3a4460e3a7a017465fea4e145a9486dfa8 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Fri, 15 Nov 2024 01:03:18 -0800
Subject: [PATCH 14/14] Clamp the range when the ExtendedReduction or MulAcc
cost is invalid.
---
.../Vectorize/LoopVectorizationPlanner.h | 2 +-
.../Transforms/Vectorize/LoopVectorize.cpp | 97 +++++++++++++++----
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 42 ++++----
3 files changed, 106 insertions(+), 35 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 7787f58683b2a4..3e38a80e9d8007 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -509,7 +509,7 @@ class LoopVectorizationPlanner {
// between the phi and users outside the vector region when folding the tail.
void adjustRecipesForReductions(VPlanPtr &Plan,
VPRecipeBuilder &RecipeBuilder,
- ElementCount MinVF);
+ VFRange &Range);
#ifndef NDEBUG
/// \return The most profitable vectorization factor for the available VPlans
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 63f42e9ca74121..35aa0bfba0454e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9137,7 +9137,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
// ---------------------------------------------------------------------------
// Adjust the recipes for any inloop reductions.
- adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start);
+ adjustRecipesForReductions(Plan, RecipeBuilder, Range);
// Interleave memory: for each Interleave Group we marked earlier as relevant
// for this VPlan, replace the Recipes widening its memory instructions with a
@@ -9250,15 +9250,39 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
/// This function try to match following pattern which will generate
/// extended-reduction instruction.
/// reduce(ext(...)).
-static VPExtendedReductionRecipe *
-tryToMatchExtendedReduction(const RecurrenceDescriptor &RdxDesc,
- Instruction *CurrentLinkI, VPValue *PreviousLink,
- VPValue *VecOp, VPValue *CondOp,
- LoopVectorizationCostModel &CM) {
+static VPExtendedReductionRecipe *tryToMatchAndCreateExtendedReduction(
+ const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+ VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp,
+ LoopVectorizationCostModel &CM, VPCostContext &Ctx, VFRange &Range) {
using namespace VPlanPatternMatch;
+
VPValue *A;
+ Type *RedTy = RdxDesc.getRecurrenceType();
+
+ // Test if the cost of extended-reduction is valid and clamp the range.
+ // Note that reduction-extended is not always valid for all VF and types.
+ auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
+ Type *SrcTy) -> bool {
+ return LoopVectorizationPlanner::getDecisionAndClampRange(
+ [&](ElementCount VF) {
+ VectorType *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+ return Ctx.TTI
+ .getExtendedReductionCost(Opcode, isZExt, RedTy, SrcVecTy,
+ RdxDesc.getFastMathFlags(),
+ TTI::TCK_RecipThroughput)
+ .isValid();
+ },
+ Range);
+ };
+
// Matched reduce(ext)).
if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
+ if (!IsExtendedRedValidAndClampRange(
+ RdxDesc.getOpcode(),
+ cast<VPWidenCastRecipe>(VecOp)->getOpcode() ==
+ Instruction::CastOps::ZExt,
+ Ctx.Types.inferScalarType(A)))
+ return nullptr;
return new VPExtendedReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
cast<VPWidenCastRecipe>(VecOp), CondOp,
CM.useOrderedReductions(RdxDesc));
@@ -9273,31 +9297,59 @@ tryToMatchExtendedReduction(const RecurrenceDescriptor &RdxDesc,
/// reduce.add(mul(...)),
/// reduce.add(mul(ext(A), ext(B))),
/// reduce.add(ext(mul(ext(A), ext(B)))).
-static VPMulAccRecipe *tryToMatchMulAcc(const RecurrenceDescriptor &RdxDesc,
- Instruction *CurrentLinkI,
- VPValue *PreviousLink, VPValue *VecOp,
- VPValue *CondOp,
- LoopVectorizationCostModel &CM) {
+static VPMulAccRecipe *tryToMatchAndCreateMulAcc(
+ const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+ VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp,
+ LoopVectorizationCostModel &CM, VPCostContext &Ctx, VFRange &Range) {
using namespace VPlanPatternMatch;
+
VPValue *A, *B;
+ Type *RedTy = RdxDesc.getRecurrenceType();
+
+ // Test if the cost of MulAcc is valid and clamp the range.
+ // Note that mul-acc is not always valid for all VF and types.
+ auto IsMulAccValidAndClampRange = [&](bool isZExt, Type *SrcTy) -> bool {
+ return LoopVectorizationPlanner::getDecisionAndClampRange(
+ [&](ElementCount VF) {
+ VectorType *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+ return Ctx.TTI
+ .getMulAccReductionCost(isZExt, RedTy, SrcVecTy,
+ TTI::TCK_RecipThroughput)
+ .isValid();
+ },
+ Range);
+ };
+
if (RdxDesc.getOpcode() != Instruction::Add)
return nullptr;
+
// Try to match reduce.add(mul(...))
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
VPWidenCastRecipe *RecipeA =
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
VPWidenCastRecipe *RecipeB =
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+
// Matched reduce.add(mul(ext, ext))
if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
(RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
+
+ // Only create MulAccRecipe if the cost is valid.
+ if (!IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
+ Instruction::CastOps::ZExt,
+ Ctx.Types.inferScalarType(RecipeA)))
+ return nullptr;
+
return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
CM.useOrderedReductions(RdxDesc),
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
RecipeA, RecipeB);
} else {
// Matched reduce.add(mul)
+ if (!IsMulAccValidAndClampRange(true, RedTy))
+ return nullptr;
+
return new VPMulAccRecipe(
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
CM.useOrderedReductions(RdxDesc),
@@ -9318,6 +9370,12 @@ static VPMulAccRecipe *tryToMatchMulAcc(const RecurrenceDescriptor &RdxDesc,
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
Ext0->getOpcode() == Ext1->getOpcode()) {
+ // Only create MulAcc recipe if the cost if valid.
+ if (!IsMulAccValidAndClampRange(Ext0->getOpcode() ==
+ Instruction::CastOps::ZExt,
+ Ctx.Types.inferScalarType(Ext0)))
+ return nullptr;
+
return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
CM.useOrderedReductions(RdxDesc), Mul, Ext0,
Ext1);
@@ -9340,8 +9398,9 @@ static VPMulAccRecipe *tryToMatchMulAcc(const RecurrenceDescriptor &RdxDesc,
// with a boolean reduction phi node to check if the condition is true in any
// iteration. The final value is selected by the final ComputeReductionResult.
void LoopVectorizationPlanner::adjustRecipesForReductions(
- VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) {
+ VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, VFRange &Range) {
using namespace VPlanPatternMatch;
+ ElementCount MinVF = Range.Start;
VPRegionBlock *VectorLoopRegion = Plan->getVectorLoopRegion();
VPBasicBlock *Header = VectorLoopRegion->getEntryBasicBlock();
VPBasicBlock *MiddleVPBB = Plan->getMiddleBlock();
@@ -9458,12 +9517,16 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
CondOp = RecipeBuilder.getBlockInMask(BB);
VPReductionRecipe *RedRecipe;
- if (auto *MulAcc = tryToMatchMulAcc(RdxDesc, CurrentLinkI, PreviousLink,
- VecOp, CondOp, CM))
+ VPCostContext CostCtx(CM.TTI, *CM.TLI, Legal->getWidestInductionType(),
+ CM);
+ if (auto *MulAcc =
+ tryToMatchAndCreateMulAcc(RdxDesc, CurrentLinkI, PreviousLink,
+ VecOp, CondOp, CM, CostCtx, Range))
RedRecipe = MulAcc;
- else if (auto *ExtendedRed = tryToMatchExtendedReduction(
- RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp, CM))
- RedRecipe = ExtendedRed;
+ else if (auto *ExtRed = tryToMatchAndCreateExtendedReduction(
+ RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp, CM,
+ CostCtx, Range))
+ RedRecipe = ExtRed;
else
RedRecipe =
new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 9e53bfaa753022..c58985b6094fcb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2220,9 +2220,19 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
RecurKind RdxKind = RdxDesc.getRecurrenceKind();
Type *ElementTy = getResultType();
auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+ auto *SrcVecTy =
+ cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
unsigned Opcode = RdxDesc.getOpcode();
+ // ExtendedReduction Cost
+ InstructionCost ExtendedRedCost =
+ Ctx.TTI.getExtendedReductionCost(Opcode, isZExt(), ElementTy, SrcVecTy,
+ RdxDesc.getFastMathFlags(), CostKind);
+
+ assert(ExtendedRedCost.isValid() && "VPExtendedReductionRecipe should not be "
+ "created if the cost is invalid.");
+
// BaseCost = Reduction cost + BinOp cost
InstructionCost ReductionCost =
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
@@ -2236,23 +2246,18 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
}
// Extended cost
- auto *SrcTy =
- cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
- auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
TTI::CastContextHint CCH = computeCCH(getVecOp()->getDefiningRecipe(), VF);
// Arm TTI will use the underlying instruction to determine the cost.
InstructionCost ExtendedCost = Ctx.TTI.getCastInstrCost(
- Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
+ Opcode, VectorTy, SrcVecTy, CCH, TTI::TCK_RecipThroughput,
dyn_cast_if_present<Instruction>(getUnderlyingValue()));
- // ExtendedReduction Cost
- InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
- Opcode, isZExt(), ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
// Check if folding ext into ExtendedReduction is profitable.
if (ExtendedRedCost.isValid() &&
ExtendedRedCost < ExtendedCost + ReductionCost) {
return ExtendedRedCost;
}
+
return ExtendedCost + ReductionCost;
}
@@ -2267,6 +2272,14 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
assert(Opcode == Instruction::Add &&
"Reduction opcode must be add in the VPMulAccRecipe.");
+ // MulAccReduction Cost
+ VectorType *SrcVecTy =
+ cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+ InstructionCost MulAccCost =
+ Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
+
+ assert(MulAccCost.isValid() && "VPMulAccRecipe should not be "
+ "created if the cost is invalid.");
// BaseCost = Reduction cost + BinOp cost
InstructionCost ReductionCost =
@@ -2277,17 +2290,17 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
// Extended cost
InstructionCost ExtendedCost = 0;
if (isExtended()) {
- auto *SrcTy = cast<VectorType>(
- ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
TTI::CastContextHint CCH0 =
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
ExtendedCost = Ctx.TTI.getCastInstrCost(
- Ext0Instr->getOpcode(), VectorTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+ Ext0Instr->getOpcode(), VectorTy, SrcVecTy, CCH0,
+ TTI::TCK_RecipThroughput,
dyn_cast_if_present<Instruction>(getExt0Instr()));
TTI::CastContextHint CCH1 =
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
ExtendedCost += Ctx.TTI.getCastInstrCost(
- Ext1Instr->getOpcode(), VectorTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+ Ext1Instr->getOpcode(), VectorTy, SrcVecTy, CCH1,
+ TTI::TCK_RecipThroughput,
dyn_cast_if_present<Instruction>(getExt1Instr()));
}
@@ -2319,17 +2332,12 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
RHSInfo, Operands, MulInstr, &Ctx.TLI);
}
- // MulAccReduction Cost
- VectorType *SrcVecTy =
- cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
- InstructionCost MulAccCost =
- Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
-
// Check if folding ext into ExtendedReduction is profitable.
if (MulAccCost.isValid() &&
MulAccCost < ExtendedCost + ReductionCost + MulCost) {
return MulAccCost;
}
+
return ExtendedCost + ReductionCost + MulCost;
}
More information about the llvm-commits
mailing list