[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)

Elvis Wang via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 23 01:38:01 PST 2024


https://github.com/ElvisWang123 updated https://github.com/llvm/llvm-project/pull/113903

>From 33b1f601f5677b8159b0c2ef8cfa8e6c1fa9d541 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 28 Oct 2024 05:39:35 -0700
Subject: [PATCH 01/26] [VPlan] Impl VPlan-based pattern match for ExtendedRed
 and MulAccRed. NFCI

This patch implement the VPlan-based pattern match for extendedReduction
and MulAccReduction. In above reduction patterns, extened instructions
and mul instruction can fold into reduction instruction and the cost is
free.

We add `FoldedRecipes` in the `VPCostContext` to put recipes that can be
folded into other recipes.

ExtendedReductionPatterns:
    reduce(ext(...))
MulAccReductionPatterns:
    reduce.add(mul(...))
    reduce.add(mul(ext(...), ext(...)))
    reduce.add(ext(mul(...)))
    reduce.add(ext(mul(ext(...), ext(...))))

Ref: Original instruction based implementation:
https://reviews.llvm.org/D93476
---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  57 -------
 llvm/lib/Transforms/Vectorize/VPlan.h         |   2 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 139 ++++++++++++++++--
 3 files changed, 129 insertions(+), 69 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 3c7c044a042719..74f463821addf9 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7300,63 +7300,6 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
     }
   }
 
-  // The legacy cost model has special logic to compute the cost of in-loop
-  // reductions, which may be smaller than the sum of all instructions involved
-  // in the reduction.
-  // TODO: Switch to costing based on VPlan once the logic has been ported.
-  for (const auto &[RedPhi, RdxDesc] : Legal->getReductionVars()) {
-    if (ForceTargetInstructionCost.getNumOccurrences())
-      continue;
-
-    if (!CM.isInLoopReduction(RedPhi))
-      continue;
-
-    const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop);
-    SetVector<Instruction *> ChainOpsAndOperands(ChainOps.begin(),
-                                                 ChainOps.end());
-    auto IsZExtOrSExt = [](const unsigned Opcode) -> bool {
-      return Opcode == Instruction::ZExt || Opcode == Instruction::SExt;
-    };
-    // Also include the operands of instructions in the chain, as the cost-model
-    // may mark extends as free.
-    //
-    // For ARM, some of the instruction can folded into the reducion
-    // instruction. So we need to mark all folded instructions free.
-    // For example: We can fold reduce(mul(ext(A), ext(B))) into one
-    // instruction.
-    for (auto *ChainOp : ChainOps) {
-      for (Value *Op : ChainOp->operands()) {
-        if (auto *I = dyn_cast<Instruction>(Op)) {
-          ChainOpsAndOperands.insert(I);
-          if (I->getOpcode() == Instruction::Mul) {
-            auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
-            auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
-            if (Ext0 && IsZExtOrSExt(Ext0->getOpcode()) && Ext1 &&
-                Ext0->getOpcode() == Ext1->getOpcode()) {
-              ChainOpsAndOperands.insert(Ext0);
-              ChainOpsAndOperands.insert(Ext1);
-            }
-          }
-        }
-      }
-    }
-
-    // Pre-compute the cost for I, if it has a reduction pattern cost.
-    for (Instruction *I : ChainOpsAndOperands) {
-      auto ReductionCost = CM.getReductionPatternCost(
-          I, VF, ToVectorTy(I->getType(), VF), TTI::TCK_RecipThroughput);
-      if (!ReductionCost)
-        continue;
-
-      assert(!CostCtx.SkipCostComputation.contains(I) &&
-             "reduction op visited multiple times");
-      CostCtx.SkipCostComputation.insert(I);
-      LLVM_DEBUG(dbgs() << "Cost of " << ReductionCost << " for VF " << VF
-                        << ":\n in-loop reduction " << *I << "\n");
-      Cost += *ReductionCost;
-    }
-  }
-
   // Pre-compute the costs for branches except for the backedge, as the number
   // of replicate regions in a VPlan may not directly match the number of
   // branches, which would lead to different decisions.
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e1d828f038f9a2..b87e042a0d1c50 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -684,6 +684,8 @@ struct VPCostContext {
   LLVMContext &LLVMCtx;
   LoopVectorizationCostModel &CM;
   SmallPtrSet<Instruction *, 8> SkipCostComputation;
+  /// Contains recipes that are folded into other recipes.
+  SmallDenseMap<ElementCount, SmallPtrSet<VPRecipeBase *, 4>, 4> FoldedRecipes;
 
   VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,
                 Type *CanIVTy, LoopVectorizationCostModel &CM)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ef5f6e22f82206..ac78164b09e17b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,7 +270,9 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = &WidenMem->getIngredient();
 
   InstructionCost RecipeCost;
-  if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
+  if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
+      (Ctx.FoldedRecipes.contains(VF) &&
+       Ctx.FoldedRecipes.at(VF).contains(this))) {
     RecipeCost = 0;
   } else {
     RecipeCost = computeCost(VF, Ctx);
@@ -2192,30 +2194,143 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  // TODO: Support any-of and in-loop reductions.
+  // TODO: Support any-of reductions.
   assert(
       (!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
        ForceTargetInstructionCost.getNumOccurrences() > 0) &&
       "Any-of reduction not implemented in VPlan-based cost model currently.");
-  assert(
-      (!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
-       ForceTargetInstructionCost.getNumOccurrences() > 0) &&
-      "In-loop reduction not implemented in VPlan-based cost model currently.");
 
   assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
          "Inferred type and recurrence type mismatch.");
 
-  // Cost = Reduction cost + BinOp cost
-  InstructionCost Cost =
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost BaseCost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    return Cost + Ctx.TTI.getMinMaxReductionCost(
-                      Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+    BaseCost += Ctx.TTI.getMinMaxReductionCost(
+        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    BaseCost += Ctx.TTI.getArithmeticReductionCost(
+        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
-  return Cost + Ctx.TTI.getArithmeticReductionCost(
-                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  using namespace llvm::VPlanPatternMatch;
+  auto GetMulAccReductionCost =
+      [&](const VPReductionRecipe *Red) -> InstructionCost {
+    VPValue *A, *B;
+    InstructionCost InnerExt0Cost = 0;
+    InstructionCost InnerExt1Cost = 0;
+    InstructionCost ExtCost = 0;
+    InstructionCost MulCost = 0;
+
+    VectorType *SrcVecTy = VectorTy;
+    Type *InnerExt0Ty;
+    Type *InnerExt1Ty;
+    Type *MaxInnerExtTy;
+    bool IsUnsigned = true;
+    bool HasOuterExt = false;
+
+    auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
+        Red->getVecOp()->getDefiningRecipe());
+    VPRecipeBase *Mul;
+    // Try to match outer extend reduce.add(ext(...))
+    if (Ext && match(Ext, m_ZExtOrSExt(m_VPValue())) &&
+        cast<VPWidenCastRecipe>(Ext)->getNumUsers() == 1) {
+      IsUnsigned =
+          Ext->getOpcode() == Instruction::CastOps::ZExt ? true : false;
+      ExtCost = Ext->computeCost(VF, Ctx);
+      Mul = Ext->getOperand(0)->getDefiningRecipe();
+      HasOuterExt = true;
+    } else {
+      Mul = Red->getVecOp()->getDefiningRecipe();
+    }
+
+    // Match reduce.add(mul())
+    if (Mul && match(Mul, m_Mul(m_VPValue(A), m_VPValue(B))) &&
+        cast<VPWidenRecipe>(Mul)->getNumUsers() == 1) {
+      MulCost = cast<VPWidenRecipe>(Mul)->computeCost(VF, Ctx);
+      auto *InnerExt0 =
+          dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+      auto *InnerExt1 =
+          dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+      bool HasInnerExt = false;
+      // Try to match inner extends.
+      if (InnerExt0 && InnerExt1 &&
+          match(InnerExt0, m_ZExtOrSExt(m_VPValue())) &&
+          match(InnerExt1, m_ZExtOrSExt(m_VPValue())) &&
+          InnerExt0->getOpcode() == InnerExt1->getOpcode() &&
+          (InnerExt0->getNumUsers() > 0 &&
+           !InnerExt0->hasMoreThanOneUniqueUser()) &&
+          (InnerExt1->getNumUsers() > 0 &&
+           !InnerExt1->hasMoreThanOneUniqueUser())) {
+        InnerExt0Cost = InnerExt0->computeCost(VF, Ctx);
+        InnerExt1Cost = InnerExt1->computeCost(VF, Ctx);
+        Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
+        Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
+        Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth() >
+                                      InnerExt1Ty->getIntegerBitWidth()
+                                  ? InnerExt0Ty
+                                  : InnerExt1Ty;
+        SrcVecTy = cast<VectorType>(ToVectorTy(MaxInnerExtTy, VF));
+        IsUnsigned = true;
+        HasInnerExt = true;
+      }
+      InstructionCost MulAccRedCost = Ctx.TTI.getMulAccReductionCost(
+          IsUnsigned, ElementTy, SrcVecTy, CostKind);
+      // Check if folding ext/mul into MulAccReduction is profitable.
+      if (MulAccRedCost.isValid() &&
+          MulAccRedCost <
+              ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
+        if (HasInnerExt) {
+          Ctx.FoldedRecipes[VF].insert(InnerExt0);
+          Ctx.FoldedRecipes[VF].insert(InnerExt1);
+        }
+        Ctx.FoldedRecipes[VF].insert(Mul);
+        if (HasOuterExt)
+          Ctx.FoldedRecipes[VF].insert(Ext);
+        return MulAccRedCost;
+      }
+    }
+    return InstructionCost::getInvalid();
+  };
+
+  // Match reduce(ext(...))
+  auto GetExtendedReductionCost =
+      [&](const VPReductionRecipe *Red) -> InstructionCost {
+    VPValue *VecOp = Red->getVecOp();
+    VPValue *A;
+    if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && VecOp->getNumUsers() == 1) {
+      VPWidenCastRecipe *Ext =
+          cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+      bool IsUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
+      InstructionCost ExtCost = Ext->computeCost(VF, Ctx);
+      auto *ExtVecTy =
+          cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(A), VF));
+      InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
+          Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags(),
+          CostKind);
+      // Check if folding ext into ExtendedReduction is profitable.
+      if (ExtendedRedCost.isValid() && ExtendedRedCost < ExtCost + BaseCost) {
+        Ctx.FoldedRecipes[VF].insert(Ext);
+        return ExtendedRedCost;
+      }
+    }
+    return InstructionCost::getInvalid();
+  };
+
+  // Match MulAccReduction patterns.
+  InstructionCost MulAccCost = GetMulAccReductionCost(this);
+  if (MulAccCost.isValid())
+    return MulAccCost;
+
+  // Match ExtendedReduction patterns.
+  InstructionCost ExtendedCost = GetExtendedReductionCost(this);
+  if (ExtendedCost.isValid())
+    return ExtendedCost;
+
+  // Default cost.
+  return BaseCost;
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)

>From 68fbd7047ba6b0aa5bf42a1755d217970a74b0ec Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 3 Nov 2024 18:55:55 -0800
Subject: [PATCH 02/26] Partially support Extended-reduction.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  24 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         | 217 ++++++++++++++++++
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 114 +++++++--
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  24 ++
 llvm/lib/Transforms/Vectorize/VPlanValue.h    |   2 +
 5 files changed, 356 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 74f463821addf9..c41c2a84afa6cd 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7666,6 +7666,10 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
                              CanonicalIVStartValue, State);
   VPlanTransforms::prepareToExecute(BestVPlan);
 
+  // TODO: Rebase to fhahn's implementation.
+  VPlanTransforms::prepareExecute(BestVPlan);
+  dbgs() << "\n\n print plan\n";
+  BestVPlan.print(dbgs());
   BestVPlan.execute(&State);
 
   // 2.5 Collect reduction resume values.
@@ -9264,6 +9268,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
 // Adjust AnyOf reductions; replace the reduction phi for the selected value
 // with a boolean reduction phi node to check if the condition is true in any
 // iteration. The final value is selected by the final ComputeReductionResult.
+// TODO: Implement VPMulAccHere.
 void LoopVectorizationPlanner::adjustRecipesForReductions(
     VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) {
   using namespace VPlanPatternMatch;
@@ -9382,9 +9387,22 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      VPReductionRecipe *RedRecipe =
-          new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
-                                CondOp, CM.useOrderedReductions(RdxDesc));
+      // VPWidenCastRecipes can folded into VPReductionRecipe
+      VPValue *A;
+      VPSingleDefRecipe *RedRecipe;
+      if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+          !VecOp->hasMoreThanOneUniqueUser()) {
+        RedRecipe = new VPExtendedReductionRecipe(
+            RdxDesc, CurrentLinkI,
+            cast<CastInst>(
+                cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
+            PreviousLink, A, CondOp, CM.useOrderedReductions(RdxDesc),
+            cast<VPWidenCastRecipe>(VecOp)->getResultType());
+      } else {
+        RedRecipe =
+            new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
+                                  CondOp, CM.useOrderedReductions(RdxDesc));
+      }
       // Append the recipe to the end of the VPBasicBlock because we need to
       // ensure that it comes after all of it's inputs, including CondOp.
       // Note that this transformation may leave over dead recipes (including
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index b87e042a0d1c50..d0c2fb2e48d466 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -861,6 +861,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPInstructionSC:
     case VPRecipeBase::VPReductionEVLSC:
     case VPRecipeBase::VPReductionSC:
+    case VPRecipeBase::VPMulAccSC:
+    case VPRecipeBase::VPExtendedReductionSC:
     case VPRecipeBase::VPReplicateSC:
     case VPRecipeBase::VPScalarIVStepsSC:
     case VPRecipeBase::VPVectorPointerSC:
@@ -2696,6 +2698,221 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
   }
 };
 
+/// A recipe to represent inloop extended reduction operations, performing a
+/// reduction on a vector operand into a scalar value, and adding the result to
+/// a chain. This recipe is high level abstract which will generate
+/// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
+/// {ChainOp, VecOp, [Condition]}.
+class VPExtendedReductionRecipe : public VPSingleDefRecipe {
+  /// The recurrence decriptor for the reduction in question.
+  const RecurrenceDescriptor &RdxDesc;
+  bool IsOrdered;
+  /// Whether the reduction is conditional.
+  bool IsConditional = false;
+  /// Type after extend.
+  Type *ResultTy;
+  Instruction::CastOps ExtOp;
+  CastInst *CastInstr;
+  bool IsZExt;
+
+protected:
+  VPExtendedReductionRecipe(const unsigned char SC,
+                            const RecurrenceDescriptor &R, Instruction *RedI,
+                            Instruction::CastOps ExtOp, CastInst *CastI,
+                            ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                            bool IsOrdered, Type *ResultTy)
+      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+        ResultTy(ResultTy), ExtOp(ExtOp), CastInstr(CastI) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+    IsZExt = ExtOp == Instruction::CastOps::ZExt;
+  }
+
+public:
+  VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                            CastInst *CastI, VPValue *ChainOp, VPValue *VecOp,
+                            VPValue *CondOp, bool IsOrdered, Type *ResultTy)
+      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
+                                  CastI->getOpcode(), CastI,
+                                  ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
+                                  IsOrdered, ResultTy) {}
+
+  VPExtendedReductionRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *Ext)
+      : VPExtendedReductionRecipe(
+            VPDef::VPExtendedReductionSC, Red->getRecurrenceDescriptor(),
+            Red->getUnderlyingInstr(), Ext->getOpcode(),
+            cast<CastInst>(Ext->getUnderlyingInstr()),
+            ArrayRef<VPValue *>({Red->getChainOp(), Ext->getOperand(0)}),
+            Red->getCondOp(), Red->isOrdered(), Ext->getResultType()) {}
+
+  ~VPExtendedReductionRecipe() override = default;
+
+  VPExtendedReductionRecipe *clone() override {
+    llvm_unreachable("Not implement yet");
+  }
+
+  static inline bool classof(const VPRecipeBase *R) {
+    return R->getVPDefID() == VPDef::VPExtendedReductionSC;
+  }
+
+  static inline bool classof(const VPUser *U) {
+    auto *R = dyn_cast<VPRecipeBase>(U);
+    return R && classof(R);
+  }
+
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPExtendedReductionRecipe should be transform to "
+                     "VPExtendedRecipe + VPReductionRecipe before execution.");
+  };
+
+  /// Return the cost of VPExtendedReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// Return the recurrence decriptor for the in-loop reduction.
+  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
+    return RdxDesc;
+  }
+  /// Return true if the in-loop reduction is ordered.
+  bool isOrdered() const { return IsOrdered; };
+  /// Return true if the in-loop reduction is conditional.
+  bool isConditional() const { return IsConditional; };
+  /// The VPValue of the scalar Chain being accumulated.
+  VPValue *getChainOp() const { return getOperand(0); }
+  /// The VPValue of the vector value to be extended and reduced.
+  VPValue *getVecOp() const { return getOperand(1); }
+  /// The VPValue of the condition for the block.
+  VPValue *getCondOp() const {
+    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
+  }
+  Type *getResultType() const { return ResultTy; };
+  Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  CastInst *getExtInstr() const { return CastInstr; };
+};
+
+/// A recipe to represent inloop MulAccreduction operations, performing a
+/// reduction on a vector operand into a scalar value, and adding the result to
+/// a chain. This recipe is high level abstract which will generate
+/// VPReductionRecipe VPWidenRecipe(mul)and VPWidenCastRecipe before execution.
+/// The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+class VPMulAccRecipe : public VPSingleDefRecipe {
+  /// The recurrence decriptor for the reduction in question.
+  const RecurrenceDescriptor &RdxDesc;
+  bool IsOrdered;
+  /// Whether the reduction is conditional.
+  bool IsConditional = false;
+  /// Type after extend.
+  Type *ResultTy;
+  /// Type for mul.
+  Type *MulTy;
+  /// reduce.add(OuterExt(mul(InnerExt(), InnerExt())))
+  Instruction::CastOps OuterExtOp;
+  Instruction::CastOps InnerExtOp;
+
+  Instruction *MulI;
+  Instruction *OuterExtI;
+  Instruction *InnerExt0I;
+  Instruction *InnerExt1I;
+
+protected:
+  VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+                 Instruction *RedI, Instruction::CastOps OuterExtOp,
+                 Instruction *OuterExtI, Instruction *MulI,
+                 Instruction::CastOps InnerExtOp, Instruction *InnerExt0I,
+                 Instruction *InnerExt1I, ArrayRef<VPValue *> Operands,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+        ResultTy(ResultTy), MulTy(MulTy), OuterExtOp(OuterExtOp),
+        InnerExtOp(InnerExtOp), MulI(MulI), OuterExtI(OuterExtI),
+        InnerExt0I(InnerExt0I), InnerExt1I(InnerExt1I) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+  }
+
+public:
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 Instruction *OuterExt, Instruction *Mul,
+                 Instruction *InnerExt0, Instruction *InnerExt1,
+                 VPValue *ChainOp, VPValue *InnerExt0Op, VPValue *InnerExt1Op,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+      : VPMulAccRecipe(
+            VPDef::VPMulAccSC, R, RedI, cast<CastInst>(OuterExt)->getOpcode(),
+            OuterExt, Mul, cast<CastInst>(InnerExt0)->getOpcode(), InnerExt0,
+            InnerExt1, ArrayRef<VPValue *>({ChainOp, InnerExt0Op, InnerExt1Op}),
+            CondOp, IsOrdered, ResultTy, MulTy) {}
+
+  VPMulAccRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *OuterExt,
+                 VPWidenRecipe *Mul, VPWidenCastRecipe *InnerExt0,
+                 VPWidenCastRecipe *InnerExt1)
+      : VPMulAccRecipe(
+            VPDef::VPMulAccSC, Red->getRecurrenceDescriptor(),
+            Red->getUnderlyingInstr(), OuterExt->getOpcode(),
+            OuterExt->getUnderlyingInstr(), Mul->getUnderlyingInstr(),
+            InnerExt0->getOpcode(), InnerExt0->getUnderlyingInstr(),
+            InnerExt1->getUnderlyingInstr(),
+            ArrayRef<VPValue *>({Red->getChainOp(), InnerExt0->getOperand(0),
+                                 InnerExt1->getOperand(0)}),
+            Red->getCondOp(), Red->isOrdered(), OuterExt->getResultType(),
+            InnerExt0->getResultType()) {}
+
+  ~VPMulAccRecipe() override = default;
+
+  VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
+
+  static inline bool classof(const VPRecipeBase *R) {
+    return R->getVPDefID() == VPRecipeBase::VPMulAccSC;
+  }
+
+  static inline bool classof(const VPUser *U) {
+    auto *R = dyn_cast<VPRecipeBase>(U);
+    return R && classof(R);
+  }
+
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override;
+
+  /// Return the cost of VPExtendedReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// Return the recurrence decriptor for the in-loop reduction.
+  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
+    return RdxDesc;
+  }
+  /// Return true if the in-loop reduction is ordered.
+  bool isOrdered() const { return IsOrdered; };
+  /// Return true if the in-loop reduction is conditional.
+  bool isConditional() const { return IsConditional; };
+  /// The VPValue of the scalar Chain being accumulated.
+  VPValue *getChainOp() const { return getOperand(0); }
+  /// The VPValue of the vector value to be extended and reduced.
+  VPValue *getVecOp() const { return getOperand(1); }
+  /// The VPValue of the condition for the block.
+  VPValue *getCondOp() const {
+    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
+  }
+  Type *getResultTy() const { return ResultTy; };
+  Instruction::CastOps getOuterExtOpcode() const { return OuterExtOp; };
+  Instruction::CastOps getInnerExtOpcode() const { return InnerExtOp; };
+};
+
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
 /// copies of the original scalar type, one per lane, instead of producing a
 /// single copy of widened type for all lanes. If the instruction is known to be
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ac78164b09e17b..69577a5b78000a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1502,6 +1502,27 @@ void VPWidenCastRecipe::execute(VPTransformState &State) {
     setFlags(CastOp);
 }
 
+// Computes the CastContextHint from a recipes that may access memory.
+static TTI::CastContextHint computeCCH(const VPRecipeBase *R, ElementCount VF) {
+  if (VF.isScalar())
+    return TTI::CastContextHint::Normal;
+  if (isa<VPInterleaveRecipe>(R))
+    return TTI::CastContextHint::Interleave;
+  if (const auto *ReplicateRecipe = dyn_cast<VPReplicateRecipe>(R))
+    return ReplicateRecipe->isPredicated() ? TTI::CastContextHint::Masked
+                                           : TTI::CastContextHint::Normal;
+  const auto *WidenMemoryRecipe = dyn_cast<VPWidenMemoryRecipe>(R);
+  if (WidenMemoryRecipe == nullptr)
+    return TTI::CastContextHint::None;
+  if (!WidenMemoryRecipe->isConsecutive())
+    return TTI::CastContextHint::GatherScatter;
+  if (WidenMemoryRecipe->isReverse())
+    return TTI::CastContextHint::Reversed;
+  if (WidenMemoryRecipe->isMasked())
+    return TTI::CastContextHint::Masked;
+  return TTI::CastContextHint::Normal;
+}
+
 InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
                                                VPCostContext &Ctx) const {
   // TODO: In some cases, VPWidenCastRecipes are created but not considered in
@@ -1509,26 +1530,6 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
   // reduction in a smaller type.
   if (!getUnderlyingValue())
     return 0;
-  // Computes the CastContextHint from a recipes that may access memory.
-  auto ComputeCCH = [&](const VPRecipeBase *R) -> TTI::CastContextHint {
-    if (VF.isScalar())
-      return TTI::CastContextHint::Normal;
-    if (isa<VPInterleaveRecipe>(R))
-      return TTI::CastContextHint::Interleave;
-    if (const auto *ReplicateRecipe = dyn_cast<VPReplicateRecipe>(R))
-      return ReplicateRecipe->isPredicated() ? TTI::CastContextHint::Masked
-                                             : TTI::CastContextHint::Normal;
-    const auto *WidenMemoryRecipe = dyn_cast<VPWidenMemoryRecipe>(R);
-    if (WidenMemoryRecipe == nullptr)
-      return TTI::CastContextHint::None;
-    if (!WidenMemoryRecipe->isConsecutive())
-      return TTI::CastContextHint::GatherScatter;
-    if (WidenMemoryRecipe->isReverse())
-      return TTI::CastContextHint::Reversed;
-    if (WidenMemoryRecipe->isMasked())
-      return TTI::CastContextHint::Masked;
-    return TTI::CastContextHint::Normal;
-  };
 
   VPValue *Operand = getOperand(0);
   TTI::CastContextHint CCH = TTI::CastContextHint::None;
@@ -1536,7 +1537,7 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
   if ((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) &&
       !hasMoreThanOneUniqueUser() && getNumUsers() > 0) {
     if (auto *StoreRecipe = dyn_cast<VPRecipeBase>(*user_begin()))
-      CCH = ComputeCCH(StoreRecipe);
+      CCH = computeCCH(StoreRecipe, VF);
   }
   // For Z/Sext, get the context from the operand.
   else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt ||
@@ -1544,7 +1545,7 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
     if (Operand->isLiveIn())
       CCH = TTI::CastContextHint::Normal;
     else if (Operand->getDefiningRecipe())
-      CCH = ComputeCCH(Operand->getDefiningRecipe());
+      CCH = computeCCH(Operand->getDefiningRecipe(), VF);
   }
 
   auto *SrcTy =
@@ -2215,6 +2216,7 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
         Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
+  /*
   using namespace llvm::VPlanPatternMatch;
   auto GetMulAccReductionCost =
       [&](const VPReductionRecipe *Red) -> InstructionCost {
@@ -2328,11 +2330,57 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   InstructionCost ExtendedCost = GetExtendedReductionCost(this);
   if (ExtendedCost.isValid())
     return ExtendedCost;
+  */
 
   // Default cost.
   return BaseCost;
 }
 
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+                                       VPCostContext &Ctx) const {
+  RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+  Type *ElementTy = getResultType();
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
+         "Inferred type and recurrence type mismatch.");
+
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost ReductionCost =
+      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+    Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+    ReductionCost += Ctx.TTI.getMinMaxReductionCost(
+        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  }
+
+  // Extended cost
+  auto *SrcTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
+  auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+  TTI::CastContextHint CCH = computeCCH(getVecOp()->getDefiningRecipe(), VF);
+  // Arm TTI will use the underlying instruction to determine the cost.
+  InstructionCost ExtendedCost = Ctx.TTI.getCastInstrCost(
+      Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
+      dyn_cast_if_present<Instruction>(getUnderlyingValue()));
+
+  // ExtendedReduction Cost
+  InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
+      Opcode, IsZExt, ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
+  // Check if folding ext into ExtendedReduction is profitable.
+  if (ExtendedRedCost.isValid() &&
+      ExtendedRedCost < ExtendedCost + ReductionCost) {
+    return ExtendedRedCost;
+  }
+  return ExtendedCost + ReductionCost;
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {
@@ -2378,6 +2426,28 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
 }
+
+void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+                                      VPSlotTracker &SlotTracker) const {
+  O << Indent << "EXTENDED-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = ";
+  getChainOp()->printAsOperand(O, SlotTracker);
+  O << " +";
+  if (isa<FPMathOperator>(getUnderlyingInstr()))
+    O << getUnderlyingInstr()->getFastMathFlags();
+  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  getVecOp()->printAsOperand(O, SlotTracker);
+  O << " extended to " << *getResultType();
+  if (isConditional()) {
+    O << ", ";
+    getCondOp()->printAsOperand(O, SlotTracker);
+  }
+  O << ")";
+  if (RdxDesc.IntermediateStore)
+    O << " (with final reduction value stored in invariant address sank "
+         "outside of loop)";
+}
 #endif
 
 bool VPReplicateRecipe::shouldPack() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index cee83d1015b536..d56a1f74014759 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -519,6 +519,30 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
   }
 }
 
+void VPlanTransforms::prepareExecute(VPlan &Plan) {
+  errs() << "\n\n\n!!Prepare to execute\n";
+  ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
+      Plan.getVectorLoopRegion());
+  for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
+           vp_depth_first_deep(Plan.getEntry()))) {
+    for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
+      if (!isa<VPExtendedReductionRecipe>(&R))
+        continue;
+      auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+      auto *Ext = new VPWidenCastRecipe(
+          ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+          *ExtRed->getExtInstr());
+      auto *Red = new VPReductionRecipe(
+          ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
+          ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
+      Ext->insertBefore(ExtRed);
+      Red->insertBefore(ExtRed);
+      ExtRed->replaceAllUsesWith(Red);
+      ExtRed->eraseFromParent();
+    }
+  }
+}
+
 static VPScalarIVStepsRecipe *
 createScalarIVSteps(VPlan &Plan, InductionDescriptor::InductionKind Kind,
                     Instruction::BinaryOps InductionOpcode,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 957a602091c733..2c922d8d39c927 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -329,6 +329,8 @@ class VPDef {
     VPInterleaveSC,
     VPReductionEVLSC,
     VPReductionSC,
+    VPMulAccSC,
+    VPExtendedReductionSC,
     VPReplicateSC,
     VPScalarCastSC,
     VPScalarIVStepsSC,

>From c8c9d56419b763cfcf49be0826aedb79769c6e1b Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 4 Nov 2024 22:02:22 -0800
Subject: [PATCH 03/26] Support MulAccRecipe

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  33 ++++-
 llvm/lib/Transforms/Vectorize/VPlan.h         | 103 +++++++++-------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 114 +++++++++++++++++-
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  56 ++++++---
 4 files changed, 237 insertions(+), 69 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index c41c2a84afa6cd..e8c62fac5ee8c2 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7668,8 +7668,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
 
   // TODO: Rebase to fhahn's implementation.
   VPlanTransforms::prepareExecute(BestVPlan);
-  dbgs() << "\n\n print plan\n";
-  BestVPlan.print(dbgs());
   BestVPlan.execute(&State);
 
   // 2.5 Collect reduction resume values.
@@ -9387,11 +9385,34 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      // VPWidenCastRecipes can folded into VPReductionRecipe
-      VPValue *A;
+      VPValue *A, *B;
       VPSingleDefRecipe *RedRecipe;
-      if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
-          !VecOp->hasMoreThanOneUniqueUser()) {
+      // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+      if (RdxDesc.getOpcode() == Instruction::Add &&
+          match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+        VPRecipeBase *RecipeA = A->getDefiningRecipe();
+        VPRecipeBase *RecipeB = B->getDefiningRecipe();
+        if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+            match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+            cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
+                cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
+            !A->hasMoreThanOneUniqueUser() && !B->hasMoreThanOneUniqueUser()) {
+          RedRecipe = new VPMulAccRecipe(
+              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+              CM.useOrderedReductions(RdxDesc),
+              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
+              cast<VPWidenCastRecipe>(RecipeA),
+              cast<VPWidenCastRecipe>(RecipeB));
+        } else {
+          RedRecipe = new VPMulAccRecipe(
+              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+              CM.useOrderedReductions(RdxDesc),
+              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+        }
+      }
+      // VPWidenCastRecipes can folded into VPReductionRecipe
+      else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+               !VecOp->hasMoreThanOneUniqueUser()) {
         RedRecipe = new VPExtendedReductionRecipe(
             RdxDesc, CurrentLinkI,
             cast<CastInst>(
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index d0c2fb2e48d466..19d98674db9963 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2811,60 +2811,64 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   /// Whether the reduction is conditional.
   bool IsConditional = false;
   /// Type after extend.
-  Type *ResultTy;
-  /// Type for mul.
-  Type *MulTy;
-  /// reduce.add(OuterExt(mul(InnerExt(), InnerExt())))
-  Instruction::CastOps OuterExtOp;
-  Instruction::CastOps InnerExtOp;
+  Type *ResultType;
+  /// reduce.add(mul(Ext(), Ext()))
+  Instruction::CastOps ExtOp;
+
+  Instruction *MulInstr;
+  CastInst *Ext0Instr;
+  CastInst *Ext1Instr;
 
-  Instruction *MulI;
-  Instruction *OuterExtI;
-  Instruction *InnerExt0I;
-  Instruction *InnerExt1I;
+  bool IsExtended;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction::CastOps OuterExtOp,
-                 Instruction *OuterExtI, Instruction *MulI,
-                 Instruction::CastOps InnerExtOp, Instruction *InnerExt0I,
-                 Instruction *InnerExt1I, ArrayRef<VPValue *> Operands,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+                 Instruction *RedI, Instruction *MulInstr,
+                 Instruction::CastOps ExtOp, Instruction *Ext0Instr,
+                 Instruction *Ext1Instr, ArrayRef<VPValue *> Operands,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+        ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+        Ext0Instr(cast<CastInst>(Ext0Instr)),
+        Ext1Instr(cast<CastInst>(Ext1Instr)) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+    IsExtended = true;
+  }
+
+  VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+                 Instruction *RedI, Instruction *MulInstr,
+                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
-        ResultTy(ResultTy), MulTy(MulTy), OuterExtOp(OuterExtOp),
-        InnerExtOp(InnerExtOp), MulI(MulI), OuterExtI(OuterExtI),
-        InnerExt0I(InnerExt0I), InnerExt1I(InnerExt1I) {
+        MulInstr(MulInstr) {
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
     }
+    IsExtended = false;
   }
 
 public:
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
-                 Instruction *OuterExt, Instruction *Mul,
-                 Instruction *InnerExt0, Instruction *InnerExt1,
-                 VPValue *ChainOp, VPValue *InnerExt0Op, VPValue *InnerExt1Op,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
-      : VPMulAccRecipe(
-            VPDef::VPMulAccSC, R, RedI, cast<CastInst>(OuterExt)->getOpcode(),
-            OuterExt, Mul, cast<CastInst>(InnerExt0)->getOpcode(), InnerExt0,
-            InnerExt1, ArrayRef<VPValue *>({ChainOp, InnerExt0Op, InnerExt1Op}),
-            CondOp, IsOrdered, ResultTy, MulTy) {}
-
-  VPMulAccRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *OuterExt,
-                 VPWidenRecipe *Mul, VPWidenCastRecipe *InnerExt0,
-                 VPWidenCastRecipe *InnerExt1)
-      : VPMulAccRecipe(
-            VPDef::VPMulAccSC, Red->getRecurrenceDescriptor(),
-            Red->getUnderlyingInstr(), OuterExt->getOpcode(),
-            OuterExt->getUnderlyingInstr(), Mul->getUnderlyingInstr(),
-            InnerExt0->getOpcode(), InnerExt0->getUnderlyingInstr(),
-            InnerExt1->getUnderlyingInstr(),
-            ArrayRef<VPValue *>({Red->getChainOp(), InnerExt0->getOperand(0),
-                                 InnerExt1->getOperand(0)}),
-            Red->getCondOp(), Red->isOrdered(), OuterExt->getResultType(),
-            InnerExt0->getResultType()) {}
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+                 VPWidenCastRecipe *Ext1)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       Ext0->getOpcode(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(),
+                       ArrayRef<VPValue *>(
+                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       CondOp, IsOrdered, Ext0->getResultType()) {}
+
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenRecipe *Mul)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       ArrayRef<VPValue *>(
+                           {ChainOp, Mul->getOperand(0), Mul->getOperand(0)}),
+                       CondOp, IsOrdered) {}
 
   ~VPMulAccRecipe() override = default;
 
@@ -2880,7 +2884,10 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   }
 
   /// Generate the reduction in the loop
-  void execute(VPTransformState &State) override;
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
+                     "VPWidenRecipe + VPReductionRecipe before execution");
+  }
 
   /// Return the cost of VPExtendedReductionRecipe.
   InstructionCost computeCost(ElementCount VF,
@@ -2903,14 +2910,18 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   /// The VPValue of the scalar Chain being accumulated.
   VPValue *getChainOp() const { return getOperand(0); }
   /// The VPValue of the vector value to be extended and reduced.
-  VPValue *getVecOp() const { return getOperand(1); }
+  VPValue *getVecOp0() const { return getOperand(1); }
+  VPValue *getVecOp1() const { return getOperand(2); }
   /// The VPValue of the condition for the block.
   VPValue *getCondOp() const {
     return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
   }
-  Type *getResultTy() const { return ResultTy; };
-  Instruction::CastOps getOuterExtOpcode() const { return OuterExtOp; };
-  Instruction::CastOps getInnerExtOpcode() const { return InnerExtOp; };
+  Type *getResultType() const { return ResultType; };
+  Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  Instruction *getMulInstr() const { return MulInstr; };
+  CastInst *getExt0Instr() const { return Ext0Instr; };
+  CastInst *getExt1Instr() const { return Ext1Instr; };
+  bool isExtended() const { return IsExtended; };
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 69577a5b78000a..a7d74a78e06317 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,9 +270,7 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = &WidenMem->getIngredient();
 
   InstructionCost RecipeCost;
-  if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
-      (Ctx.FoldedRecipes.contains(VF) &&
-       Ctx.FoldedRecipes.at(VF).contains(this))) {
+  if ((UI && Ctx.skipCostComputation(UI, VF.isVector()))) {
     RecipeCost = 0;
   } else {
     RecipeCost = computeCost(VF, Ctx);
@@ -2381,6 +2379,85 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   return ExtendedCost + ReductionCost;
 }
 
+InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
+                                            VPCostContext &Ctx) const {
+  Type *ElementTy =
+      IsExtended ? getResultType() : Ctx.Types.inferScalarType(getVecOp0());
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
+         "Inferred type and recurrence type mismatch.");
+
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost ReductionCost =
+      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+      Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+
+  // Extended cost
+  InstructionCost ExtendedCost = 0;
+  if (IsExtended) {
+    auto *SrcTy = cast<VectorType>(
+        ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+    auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+    TTI::CastContextHint CCH0 =
+        computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+    // Arm TTI will use the underlying instruction to determine the cost.
+    ExtendedCost = Ctx.TTI.getCastInstrCost(
+        ExtOp, DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt0Instr()));
+    TTI::CastContextHint CCH1 =
+        computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+    ExtendedCost += Ctx.TTI.getCastInstrCost(
+        ExtOp, DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt1Instr()));
+  }
+
+  // Mul cost
+  InstructionCost MulCost;
+  SmallVector<const Value *, 4> Operands;
+  Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
+  if (IsExtended)
+    MulCost = Ctx.TTI.getArithmeticInstrCost(
+        Instruction::Mul, VectorTy, CostKind,
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        Operands, MulInstr, &Ctx.TLI);
+  else {
+    VPValue *RHS = getVecOp1();
+    // Certain instructions can be cheaper to vectorize if they have a constant
+    // second vector operand. One example of this are shifts on x86.
+    TargetTransformInfo::OperandValueInfo RHSInfo = {
+        TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None};
+    if (RHS->isLiveIn())
+      RHSInfo = Ctx.TTI.getOperandInfo(RHS->getLiveInIRValue());
+
+    if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
+        RHS->isDefinedOutsideLoopRegions())
+      RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
+    MulCost = Ctx.TTI.getArithmeticInstrCost(
+        Instruction::Mul, VectorTy, CostKind,
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        RHSInfo, Operands, MulInstr, &Ctx.TLI);
+  }
+
+  // ExtendedReduction Cost
+  VectorType *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+  InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+      getExtOpcode() == Instruction::CastOps::ZExt, ElementTy, SrcVecTy,
+      CostKind);
+
+  // Check if folding ext into ExtendedReduction is profitable.
+  if (MulAccCost.isValid() &&
+      MulAccCost < ExtendedCost + ReductionCost + MulCost) {
+    return MulAccCost;
+  }
+  return ExtendedCost + ReductionCost + MulCost;
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {
@@ -2448,6 +2525,37 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
 }
+
+void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
+                           VPSlotTracker &SlotTracker) const {
+  O << Indent << "MULACC-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = ";
+  getChainOp()->printAsOperand(O, SlotTracker);
+  O << " +";
+  if (isa<FPMathOperator>(getUnderlyingInstr()))
+    O << getUnderlyingInstr()->getFastMathFlags();
+  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  O << " mul ";
+  if (IsExtended)
+    O << "(";
+  getVecOp0()->printAsOperand(O, SlotTracker);
+  if (IsExtended)
+    O << " extended to " << *getResultType() << ")";
+  if (IsExtended)
+    O << "(";
+  getVecOp1()->printAsOperand(O, SlotTracker);
+  if (IsExtended)
+    O << " extended to " << *getResultType() << ")";
+  if (isConditional()) {
+    O << ", ";
+    getCondOp()->printAsOperand(O, SlotTracker);
+  }
+  O << ")";
+  if (RdxDesc.IntermediateStore)
+    O << " (with final reduction value stored in invariant address sank "
+         "outside of loop)";
+}
 #endif
 
 bool VPReplicateRecipe::shouldPack() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index d56a1f74014759..1f82228d4d787e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -520,25 +520,53 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
 }
 
 void VPlanTransforms::prepareExecute(VPlan &Plan) {
-  errs() << "\n\n\n!!Prepare to execute\n";
   ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
       Plan.getVectorLoopRegion());
   for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
            vp_depth_first_deep(Plan.getEntry()))) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
-      if (!isa<VPExtendedReductionRecipe>(&R))
-        continue;
-      auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
-      auto *Ext = new VPWidenCastRecipe(
-          ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
-          *ExtRed->getExtInstr());
-      auto *Red = new VPReductionRecipe(
-          ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
-          ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
-      Ext->insertBefore(ExtRed);
-      Red->insertBefore(ExtRed);
-      ExtRed->replaceAllUsesWith(Red);
-      ExtRed->eraseFromParent();
+      if (isa<VPExtendedReductionRecipe>(&R)) {
+        auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+        auto *Ext = new VPWidenCastRecipe(
+            ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+            *ExtRed->getExtInstr());
+        auto *Red = new VPReductionRecipe(
+            ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
+            ExtRed->getChainOp(), Ext, ExtRed->getCondOp(),
+            ExtRed->isOrdered());
+        Ext->insertBefore(ExtRed);
+        Red->insertBefore(ExtRed);
+        ExtRed->replaceAllUsesWith(Red);
+        ExtRed->eraseFromParent();
+      } else if (isa<VPMulAccRecipe>(&R)) {
+        auto *MulAcc = cast<VPMulAccRecipe>(&R);
+        VPValue *Op0, *Op1;
+        if (MulAcc->isExtended()) {
+          Op0 = new VPWidenCastRecipe(
+              MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+              MulAcc->getResultType(), *MulAcc->getExt0Instr());
+          Op1 = new VPWidenCastRecipe(
+              MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+              MulAcc->getResultType(), *MulAcc->getExt1Instr());
+          Op0->getDefiningRecipe()->insertBefore(MulAcc);
+          Op1->getDefiningRecipe()->insertBefore(MulAcc);
+        } else {
+          Op0 = MulAcc->getVecOp0();
+          Op1 = MulAcc->getVecOp1();
+        }
+        Instruction *MulInstr = MulAcc->getMulInstr();
+        SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
+        auto *Mul = new VPWidenRecipe(*MulInstr,
+                                      make_range(MulOps.begin(), MulOps.end()));
+        auto *Red = new VPReductionRecipe(
+            MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
+            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+            MulAcc->isOrdered());
+        Mul->insertBefore(MulAcc);
+        Red->insertBefore(MulAcc);
+        MulAcc->replaceAllUsesWith(Red);
+        MulAcc->eraseFromParent();
+      }
     }
   }
 }

>From d29a1189194b9190ff6ae945587b34aff2ecfd49 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 5 Nov 2024 16:52:31 -0800
Subject: [PATCH 04/26] Fix servel errors and update tests.

We need to update tests since the generated vector IR will be reordered.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 45 ++++++++++++++
 llvm/lib/Transforms/Vectorize/VPlan.h         | 34 ++++++++---
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  6 +-
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 14 ++++-
 .../LoopVectorize/ARM/mve-reduction-types.ll  |  4 +-
 .../LoopVectorize/ARM/mve-reductions.ll       | 61 ++++++++++---------
 .../LoopVectorize/RISCV/inloop-reduction.ll   | 22 +++++--
 .../LoopVectorize/reduction-inloop-pred.ll    | 34 +++++------
 .../LoopVectorize/reduction-inloop.ll         | 12 ++--
 9 files changed, 158 insertions(+), 74 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e8c62fac5ee8c2..2bf3b63961791c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7396,6 +7396,19 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       }
       if (Instruction *UI = GetInstructionForCost(&R))
         SeenInstrs.insert(UI);
+      // VPExtendedReductionRecipe contains a folded extend instruction.
+      if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
+        SeenInstrs.insert(ExtendedRed->getExtInstr());
+      // VPMulAccRecupe constians a mul and otional extend instructions.
+      else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+        SeenInstrs.insert(MulAcc->getMulInstr());
+        if (MulAcc->isExtended()) {
+          SeenInstrs.insert(MulAcc->getExt0Instr());
+          SeenInstrs.insert(MulAcc->getExt1Instr());
+          if (auto *Ext = MulAcc->getExtInstr())
+            SeenInstrs.insert(Ext);
+        }
+      }
     }
   }
 
@@ -9409,6 +9422,38 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
               CM.useOrderedReductions(RdxDesc),
               cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
         }
+      } else if (RdxDesc.getOpcode() == Instruction::Add &&
+                 match(VecOp,
+                       m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue(A)),
+                                          m_ZExtOrSExt(m_VPValue(B)))))) {
+        VPWidenCastRecipe *Ext =
+            dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+        VPWidenRecipe *Mul =
+            dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+        if (Mul && match(Mul, m_Mul(m_ZExtOrSExt(m_VPValue()),
+                                    m_ZExtOrSExt(m_VPValue())))) {
+          VPWidenRecipe *Mul =
+              cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+          VPWidenCastRecipe *Ext0 =
+              cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
+          VPWidenCastRecipe *Ext1 =
+              cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
+          if (Ext->getOpcode() == Ext0->getOpcode() &&
+              Ext0->getOpcode() == Ext1->getOpcode()) {
+            RedRecipe = new VPMulAccRecipe(
+                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
+                cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
+          } else
+            RedRecipe = new VPExtendedReductionRecipe(
+                RdxDesc, CurrentLinkI,
+                cast<CastInst>(
+                    cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
+                PreviousLink, cast<VPWidenCastRecipe>(VecOp)->getOperand(0),
+                CondOp, CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenCastRecipe>(VecOp)->getResultType());
+        }
       }
       // VPWidenCastRecipes can folded into VPReductionRecipe
       else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 19d98674db9963..c33e6081669f8b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2812,23 +2812,27 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   bool IsConditional = false;
   /// Type after extend.
   Type *ResultType;
-  /// reduce.add(mul(Ext(), Ext()))
+  /// reduce.add(ext((mul(Ext(), Ext())))
   Instruction::CastOps ExtOp;
 
   Instruction *MulInstr;
+  CastInst *ExtInstr = nullptr;
   CastInst *Ext0Instr;
   CastInst *Ext1Instr;
 
   bool IsExtended;
+  bool IsOuterExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *MulInstr,
-                 Instruction::CastOps ExtOp, Instruction *Ext0Instr,
-                 Instruction *Ext1Instr, ArrayRef<VPValue *> Operands,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+                 Instruction *RedI, Instruction *ExtInstr,
+                 Instruction *MulInstr, Instruction::CastOps ExtOp,
+                 Instruction *Ext0Instr, Instruction *Ext1Instr,
+                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered,
+                 Type *ResultType)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
         ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+        ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
     if (CondOp) {
@@ -2855,9 +2859,9 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       Ext0->getOpcode(), Ext0->getUnderlyingInstr(),
-                       Ext1->getUnderlyingInstr(),
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
+                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
+                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
                        ArrayRef<VPValue *>(
                            {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
                        CondOp, IsOrdered, Ext0->getResultType()) {}
@@ -2867,9 +2871,20 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  VPWidenRecipe *Mul)
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
                        ArrayRef<VPValue *>(
-                           {ChainOp, Mul->getOperand(0), Mul->getOperand(0)}),
+                           {ChainOp, Mul->getOperand(0), Mul->getOperand(1)}),
                        CondOp, IsOrdered) {}
 
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
+                 VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
+                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
+                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
+                       ArrayRef<VPValue *>(
+                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       CondOp, IsOrdered, Ext0->getResultType()) {}
+
   ~VPMulAccRecipe() override = default;
 
   VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
@@ -2919,6 +2934,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   Type *getResultType() const { return ResultType; };
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
   Instruction *getMulInstr() const { return MulInstr; };
+  CastInst *getExtInstr() const { return ExtInstr; };
   CastInst *getExt0Instr() const { return Ext0Instr; };
   CastInst *getExt1Instr() const { return Ext1Instr; };
   bool isExtended() const { return IsExtended; };
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index a7d74a78e06317..ff1e31033fd9cf 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2381,8 +2381,8 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 
 InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
-  Type *ElementTy =
-      IsExtended ? getResultType() : Ctx.Types.inferScalarType(getVecOp0());
+  Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
+                               : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
@@ -2443,7 +2443,7 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
         RHSInfo, Operands, MulInstr, &Ctx.TLI);
   }
 
-  // ExtendedReduction Cost
+  // MulAccReduction Cost
   VectorType *SrcVecTy =
       cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
   InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 1f82228d4d787e..c28953e7314f57 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -554,15 +554,27 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
           Op0 = MulAcc->getVecOp0();
           Op1 = MulAcc->getVecOp1();
         }
+        VPSingleDefRecipe *VecOp;
         Instruction *MulInstr = MulAcc->getMulInstr();
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
         auto *Mul = new VPWidenRecipe(*MulInstr,
                                       make_range(MulOps.begin(), MulOps.end()));
+        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
+          // dbgs() <<"\n!!!"<< *OuterExtInstr << " " << MulAcc->getExtOpcode()
+          // << "\n";
+          VecOp = new VPWidenCastRecipe(
+              MulAcc->getExtOpcode(), Mul,
+              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
+              *OuterExtInstr);
+        } else
+          VecOp = Mul;
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
-            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+            MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
             MulAcc->isOrdered());
         Mul->insertBefore(MulAcc);
+        if (VecOp != Mul)
+          VecOp->insertBefore(MulAcc);
         Red->insertBefore(MulAcc);
         MulAcc->replaceAllUsesWith(Red);
         MulAcc->eraseFromParent();
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
index 6f78982d7ab029..0067b5ff5243cb 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
@@ -24,11 +24,11 @@ define i32 @mla_i32(ptr noalias nocapture readonly %A, ptr noalias nocapture rea
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[TMP1]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP5]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP6:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <16 x i32> [[TMP6]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP8:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP7]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
@@ -107,11 +107,11 @@ define i32 @mla_i8(ptr noalias nocapture readonly %A, ptr noalias nocapture read
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[TMP1]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP5]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP6:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <16 x i32> [[TMP6]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP8:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP7]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index baad7a84a891a6..ffdd8e66983d85 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -646,11 +646,11 @@ define i64 @mla_i16_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i16>, ptr [[TMP2]], align 2
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
+; CHECK-NEXT:    [[TMP11:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i32> [[TMP4]] to <8 x i64>
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
 ; CHECK-NEXT:    [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
@@ -726,11 +726,11 @@ define i64 @mla_i8_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i32
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i8>, ptr [[TMP0]], align 1
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i8>, ptr [[TMP2]], align 1
-; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <8 x i32> [[TMP3]], [[TMP1]]
+; CHECK-NEXT:    [[TMP11:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <8 x i32> [[TMP11]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = zext nneg <8 x i32> [[TMP4]] to <8 x i64>
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
 ; CHECK-NEXT:    [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
@@ -802,11 +802,11 @@ define i32 @mla_i32_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP0]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP1]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i32, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP7]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP2]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
@@ -855,10 +855,10 @@ define i32 @mla_i16_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP0]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP2]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP4]], <8 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP5]])
@@ -910,10 +910,10 @@ define i32 @mla_i8_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i32
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <16 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP4]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
@@ -963,11 +963,11 @@ define signext i16 @mla_i16_i16(ptr nocapture readonly %x, ptr nocapture readonl
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i16 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP0]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP1]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = mul <8 x i16> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP7]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = mul <8 x i16> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> [[TMP2]], <8 x i16> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i16 [[TMP4]], [[VEC_PHI]]
@@ -1016,10 +1016,10 @@ define signext i16 @mla_i8_i16(ptr nocapture readonly %x, ptr nocapture readonly
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i16>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i16>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i16>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw <16 x i16> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i16> [[TMP4]], <16 x i16> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[TMP5]])
@@ -1069,11 +1069,11 @@ define zeroext i8 @mla_i8_i8(ptr nocapture readonly %x, ptr nocapture readonly %
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i8 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP1]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = mul <16 x i8> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP7]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = mul <16 x i8> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> [[TMP2]], <16 x i8> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i8 @llvm.vector.reduce.add.v16i8(<16 x i8> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i8 [[TMP4]], [[VEC_PHI]]
@@ -1122,10 +1122,10 @@ define i32 @red_mla_ext_s8_s16_s32(ptr noalias nocapture readonly %A, ptr noalia
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i8> @llvm.masked.load.v8i8.p0(ptr [[TMP0]], i32 1, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[B:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP2]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP4]], <8 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP5]])
@@ -1183,11 +1183,11 @@ define i64 @red_mla_ext_s16_u16_s64(ptr noalias nocapture readonly %A, ptr noali
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, ptr [[TMP0]], align 1
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <4 x i16> [[WIDE_LOAD]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[B:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i16>, ptr [[TMP2]], align 2
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, ptr [[TMP2]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <4 x i16> [[WIDE_LOAD]] to <4 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds i16, ptr [[B1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i16>, ptr [[TMP11]], align 2
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext <4 x i16> [[WIDE_LOAD1]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <4 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = zext <4 x i32> [[TMP4]] to <4 x i64>
@@ -1206,10 +1206,10 @@ define i64 @red_mla_ext_s16_u16_s64(ptr noalias nocapture readonly %A, ptr noali
 ; CHECK:       for.body:
 ; CHECK-NEXT:    [[I_011:%.*]] = phi i32 [ [[INC:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
 ; CHECK-NEXT:    [[S_010:%.*]] = phi i64 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i16, ptr [[A]], i32 [[I_011]]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i16, ptr [[B]], i32 [[I_011]]
 ; CHECK-NEXT:    [[TMP9:%.*]] = load i16, ptr [[ARRAYIDX]], align 1
 ; CHECK-NEXT:    [[CONV:%.*]] = sext i16 [[TMP9]] to i32
-; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds i16, ptr [[B]], i32 [[I_011]]
+; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds i16, ptr [[B1]], i32 [[I_011]]
 ; CHECK-NEXT:    [[TMP10:%.*]] = load i16, ptr [[ARRAYIDX1]], align 2
 ; CHECK-NEXT:    [[CONV2:%.*]] = zext i16 [[TMP10]] to i32
 ; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[CONV2]], [[CONV]]
@@ -1268,12 +1268,12 @@ define i32 @red_mla_u8_s8_u32(ptr noalias nocapture readonly %A, ptr noalias noc
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP0]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i8> [[WIDE_MASKED_LOAD]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP2]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i8> [[WIDE_MASKED_LOAD1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds i8, ptr [[B1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP9]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD2]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <4 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP4]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP5]])
@@ -1413,7 +1413,8 @@ define i32 @mla_i8_i32_multiuse(ptr nocapture readonly %x, ptr nocapture readonl
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP7]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP2]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
index c09b2de500d8cb..c6438857a7571e 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
@@ -187,8 +187,11 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
 ; IF-EVL-INLOOP-NEXT:    [[N_RND_UP:%.*]] = add i32 [[N]], [[TMP2]]
 ; IF-EVL-INLOOP-NEXT:    [[N_MOD_VF:%.*]] = urem i32 [[N_RND_UP]], [[TMP1]]
 ; IF-EVL-INLOOP-NEXT:    [[N_VEC:%.*]] = sub i32 [[N_RND_UP]], [[N_MOD_VF]]
+; IF-EVL-INLOOP-NEXT:    [[TRIP_COUNT_MINUS_1:%.*]] = sub i32 [[N]], 1
 ; IF-EVL-INLOOP-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vscale.i32()
 ; IF-EVL-INLOOP-NEXT:    [[TMP4:%.*]] = mul i32 [[TMP3]], 8
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[TRIP_COUNT_MINUS_1]], i64 0
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLAT2:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT1]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
 ; IF-EVL-INLOOP-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; IF-EVL-INLOOP:       vector.body:
 ; IF-EVL-INLOOP-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
@@ -197,12 +200,19 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
 ; IF-EVL-INLOOP-NEXT:    [[AVL:%.*]] = sub i32 [[N]], [[EVL_BASED_IV]]
 ; IF-EVL-INLOOP-NEXT:    [[TMP5:%.*]] = call i32 @llvm.experimental.get.vector.length.i32(i32 [[AVL]], i32 8, i1 true)
 ; IF-EVL-INLOOP-NEXT:    [[TMP6:%.*]] = add i32 [[EVL_BASED_IV]], 0
-; IF-EVL-INLOOP-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP6]]
-; IF-EVL-INLOOP-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i16, ptr [[TMP7]], i32 0
-; IF-EVL-INLOOP-NEXT:    [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP8]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP5]])
-; IF-EVL-INLOOP-NEXT:    [[TMP9:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
-; IF-EVL-INLOOP-NEXT:    [[TMP10:%.*]] = call i32 @llvm.vp.reduce.add.nxv8i32(i32 0, <vscale x 8 x i32> [[TMP9]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP5]])
-; IF-EVL-INLOOP-NEXT:    [[TMP11]] = add i32 [[TMP10]], [[VEC_PHI]]
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[EVL_BASED_IV]], i64 0
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
+; IF-EVL-INLOOP-NEXT:    [[TMP7:%.*]] = call <vscale x 8 x i32> @llvm.stepvector.nxv8i32()
+; IF-EVL-INLOOP-NEXT:    [[TMP8:%.*]] = add <vscale x 8 x i32> zeroinitializer, [[TMP7]]
+; IF-EVL-INLOOP-NEXT:    [[VEC_IV:%.*]] = add <vscale x 8 x i32> [[BROADCAST_SPLAT]], [[TMP8]]
+; IF-EVL-INLOOP-NEXT:    [[TMP9:%.*]] = icmp ule <vscale x 8 x i32> [[VEC_IV]], [[BROADCAST_SPLAT2]]
+; IF-EVL-INLOOP-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP6]]
+; IF-EVL-INLOOP-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i16, ptr [[TMP10]], i32 0
+; IF-EVL-INLOOP-NEXT:    [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP15]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP5]])
+; IF-EVL-INLOOP-NEXT:    [[TMP16:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
+; IF-EVL-INLOOP-NEXT:    [[TMP17:%.*]] = select <vscale x 8 x i1> [[TMP9]], <vscale x 8 x i32> [[TMP16]], <vscale x 8 x i32> zeroinitializer
+; IF-EVL-INLOOP-NEXT:    [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP17]])
+; IF-EVL-INLOOP-NEXT:    [[TMP11]] = add i32 [[TMP14]], [[VEC_PHI]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_EVL_NEXT]] = add nuw i32 [[TMP5]], [[EVL_BASED_IV]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], [[TMP4]]
 ; IF-EVL-INLOOP-NEXT:    [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
index 8e132ed8399cd6..5b171adec78f6b 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
@@ -424,62 +424,62 @@ define i32 @reduction_mix(ptr noalias nocapture %A, ptr noalias nocapture %B) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x i1> [[TMP0]], i64 0
 ; CHECK-NEXT:    br i1 [[TMP1]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
 ; CHECK:       pred.load.if:
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr [[TMP2]], align 4
-; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <4 x i32> poison, i32 [[TMP3]], i64 0
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = load i32, ptr [[TMP5]], align 4
 ; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x i32> poison, i32 [[TMP6]], i64 0
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[B1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[TMP12:%.*]] = load i32, ptr [[TMP8]], align 4
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <4 x i32> poison, i32 [[TMP12]], i64 0
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE]]
 ; CHECK:       pred.load.continue:
-; CHECK-NEXT:    [[TMP8:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP4]], [[PRED_LOAD_IF]] ]
 ; CHECK-NEXT:    [[TMP9:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP7]], [[PRED_LOAD_IF]] ]
+; CHECK-NEXT:    [[TMP14:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP13]], [[PRED_LOAD_IF]] ]
 ; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x i1> [[TMP0]], i64 1
 ; CHECK-NEXT:    br i1 [[TMP10]], label [[PRED_LOAD_IF3:%.*]], label [[PRED_LOAD_CONTINUE4:%.*]]
 ; CHECK:       pred.load.if3:
 ; CHECK-NEXT:    [[TMP11:%.*]] = or disjoint i64 [[INDEX]], 1
-; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP11]]
-; CHECK-NEXT:    [[TMP13:%.*]] = load i32, ptr [[TMP12]], align 4
-; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <4 x i32> [[TMP8]], i32 [[TMP13]], i64 1
 ; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP11]]
 ; CHECK-NEXT:    [[TMP16:%.*]] = load i32, ptr [[TMP15]], align 4
 ; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <4 x i32> [[TMP9]], i32 [[TMP16]], i64 1
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP22:%.*]] = load i32, ptr [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP23:%.*]] = insertelement <4 x i32> [[TMP14]], i32 [[TMP22]], i64 1
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE4]]
 ; CHECK:       pred.load.continue4:
-; CHECK-NEXT:    [[TMP18:%.*]] = phi <4 x i32> [ [[TMP8]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP14]], [[PRED_LOAD_IF3]] ]
 ; CHECK-NEXT:    [[TMP19:%.*]] = phi <4 x i32> [ [[TMP9]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP17]], [[PRED_LOAD_IF3]] ]
+; CHECK-NEXT:    [[TMP24:%.*]] = phi <4 x i32> [ [[TMP14]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP23]], [[PRED_LOAD_IF3]] ]
 ; CHECK-NEXT:    [[TMP20:%.*]] = extractelement <4 x i1> [[TMP0]], i64 2
 ; CHECK-NEXT:    br i1 [[TMP20]], label [[PRED_LOAD_IF5:%.*]], label [[PRED_LOAD_CONTINUE6:%.*]]
 ; CHECK:       pred.load.if5:
 ; CHECK-NEXT:    [[TMP21:%.*]] = or disjoint i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP22:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP21]]
-; CHECK-NEXT:    [[TMP23:%.*]] = load i32, ptr [[TMP22]], align 4
-; CHECK-NEXT:    [[TMP24:%.*]] = insertelement <4 x i32> [[TMP18]], i32 [[TMP23]], i64 2
 ; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP21]]
 ; CHECK-NEXT:    [[TMP26:%.*]] = load i32, ptr [[TMP25]], align 4
 ; CHECK-NEXT:    [[TMP27:%.*]] = insertelement <4 x i32> [[TMP19]], i32 [[TMP26]], i64 2
+; CHECK-NEXT:    [[TMP28:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP21]]
+; CHECK-NEXT:    [[TMP32:%.*]] = load i32, ptr [[TMP28]], align 4
+; CHECK-NEXT:    [[TMP33:%.*]] = insertelement <4 x i32> [[TMP24]], i32 [[TMP32]], i64 2
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE6]]
 ; CHECK:       pred.load.continue6:
-; CHECK-NEXT:    [[TMP28:%.*]] = phi <4 x i32> [ [[TMP18]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP24]], [[PRED_LOAD_IF5]] ]
 ; CHECK-NEXT:    [[TMP29:%.*]] = phi <4 x i32> [ [[TMP19]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP27]], [[PRED_LOAD_IF5]] ]
+; CHECK-NEXT:    [[TMP34:%.*]] = phi <4 x i32> [ [[TMP24]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP33]], [[PRED_LOAD_IF5]] ]
 ; CHECK-NEXT:    [[TMP30:%.*]] = extractelement <4 x i1> [[TMP0]], i64 3
 ; CHECK-NEXT:    br i1 [[TMP30]], label [[PRED_LOAD_IF7:%.*]], label [[PRED_LOAD_CONTINUE8]]
 ; CHECK:       pred.load.if7:
 ; CHECK-NEXT:    [[TMP31:%.*]] = or disjoint i64 [[INDEX]], 3
-; CHECK-NEXT:    [[TMP32:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP31]]
-; CHECK-NEXT:    [[TMP33:%.*]] = load i32, ptr [[TMP32]], align 4
-; CHECK-NEXT:    [[TMP34:%.*]] = insertelement <4 x i32> [[TMP28]], i32 [[TMP33]], i64 3
 ; CHECK-NEXT:    [[TMP35:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP31]]
 ; CHECK-NEXT:    [[TMP36:%.*]] = load i32, ptr [[TMP35]], align 4
 ; CHECK-NEXT:    [[TMP37:%.*]] = insertelement <4 x i32> [[TMP29]], i32 [[TMP36]], i64 3
+; CHECK-NEXT:    [[TMP38:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP31]]
+; CHECK-NEXT:    [[TMP48:%.*]] = load i32, ptr [[TMP38]], align 4
+; CHECK-NEXT:    [[TMP49:%.*]] = insertelement <4 x i32> [[TMP34]], i32 [[TMP48]], i64 3
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE8]]
 ; CHECK:       pred.load.continue8:
-; CHECK-NEXT:    [[TMP38:%.*]] = phi <4 x i32> [ [[TMP28]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP34]], [[PRED_LOAD_IF7]] ]
 ; CHECK-NEXT:    [[TMP39:%.*]] = phi <4 x i32> [ [[TMP29]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP37]], [[PRED_LOAD_IF7]] ]
-; CHECK-NEXT:    [[TMP40:%.*]] = mul nsw <4 x i32> [[TMP39]], [[TMP38]]
+; CHECK-NEXT:    [[TMP50:%.*]] = phi <4 x i32> [ [[TMP34]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP49]], [[PRED_LOAD_IF7]] ]
 ; CHECK-NEXT:    [[TMP41:%.*]] = select <4 x i1> [[TMP0]], <4 x i32> [[VEC_IND1]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP42:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP41]])
 ; CHECK-NEXT:    [[TMP43:%.*]] = add i32 [[TMP42]], [[VEC_PHI]]
+; CHECK-NEXT:    [[TMP40:%.*]] = mul nsw <4 x i32> [[TMP50]], [[TMP39]]
 ; CHECK-NEXT:    [[TMP44:%.*]] = select <4 x i1> [[TMP0]], <4 x i32> [[TMP40]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP45:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP44]])
 ; CHECK-NEXT:    [[TMP46]] = add i32 [[TMP45]], [[TMP43]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index fe74a7c3a9b27c..b578e61d85dfa1 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -221,13 +221,13 @@ define i32 @reduction_mix(ptr noalias nocapture %A, ptr noalias nocapture %B) {
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP6:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_IND:%.*]] = phi <4 x i32> [ <i32 0, i32 1, i32 2, i32 3>, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i32>, ptr [[TMP1]], align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[B1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <4 x i32>, ptr [[TMP8]], align 4
 ; CHECK-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[VEC_IND]])
 ; CHECK-NEXT:    [[TMP4:%.*]] = add i32 [[TMP3]], [[VEC_PHI]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD2]], [[WIDE_LOAD1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP2]])
 ; CHECK-NEXT:    [[TMP6]] = add i32 [[TMP5]], [[TMP4]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
@@ -329,11 +329,11 @@ define i32 @start_at_non_zero(ptr nocapture %in, ptr nocapture %coeff, ptr nocap
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 120, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[IN:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[COEFF:%.*]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i32>, ptr [[TMP1]], align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds i32, ptr [[COEFF1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <4 x i32>, ptr [[TMP6]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD2]], [[WIDE_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP2]])
 ; CHECK-NEXT:    [[TMP4]] = add i32 [[TMP3]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4

>From e5b50f7883cd07c5eccf4a7025d1926da0c6b4db Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 5 Nov 2024 17:49:26 -0800
Subject: [PATCH 05/26] Refactors

Using lamda function to early return when pattern matched.
Leave some assertions.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 121 ++++++++---------
 llvm/lib/Transforms/Vectorize/VPlan.h         |  55 ++++----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 123 +-----------------
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  19 +--
 .../LoopVectorize/ARM/mve-reductions.ll       |   3 +-
 .../LoopVectorize/RISCV/inloop-reduction.ll   |   8 +-
 6 files changed, 113 insertions(+), 216 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 2bf3b63961791c..f7bca1383cdad1 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7399,7 +7399,7 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       // VPExtendedReductionRecipe contains a folded extend instruction.
       if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
         SeenInstrs.insert(ExtendedRed->getExtInstr());
-      // VPMulAccRecupe constians a mul and otional extend instructions.
+      // VPMulAccRecipe constians a mul and otional extend instructions.
       else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
         SeenInstrs.insert(MulAcc->getMulInstr());
         if (MulAcc->isExtended()) {
@@ -9398,77 +9398,82 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      VPValue *A, *B;
-      VPSingleDefRecipe *RedRecipe;
-      // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
-      if (RdxDesc.getOpcode() == Instruction::Add &&
-          match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
-        VPRecipeBase *RecipeA = A->getDefiningRecipe();
-        VPRecipeBase *RecipeB = B->getDefiningRecipe();
-        if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
-            match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
-            cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
-                cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
-            !A->hasMoreThanOneUniqueUser() && !B->hasMoreThanOneUniqueUser()) {
-          RedRecipe = new VPMulAccRecipe(
-              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-              CM.useOrderedReductions(RdxDesc),
-              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
-              cast<VPWidenCastRecipe>(RecipeA),
-              cast<VPWidenCastRecipe>(RecipeB));
-        } else {
-          RedRecipe = new VPMulAccRecipe(
-              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-              CM.useOrderedReductions(RdxDesc),
-              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
-        }
-      } else if (RdxDesc.getOpcode() == Instruction::Add &&
-                 match(VecOp,
-                       m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue(A)),
-                                          m_ZExtOrSExt(m_VPValue(B)))))) {
-        VPWidenCastRecipe *Ext =
-            dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
-        VPWidenRecipe *Mul =
-            dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
-        if (Mul && match(Mul, m_Mul(m_ZExtOrSExt(m_VPValue()),
-                                    m_ZExtOrSExt(m_VPValue())))) {
+      auto TryToMatchMulAcc = [&]() -> VPSingleDefRecipe * {
+        VPValue *A, *B;
+        if (RdxDesc.getOpcode() != Instruction::Add)
+          return nullptr;
+        // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+        if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
+            !VecOp->hasMoreThanOneUniqueUser()) {
+          VPRecipeBase *RecipeA = A->getDefiningRecipe();
+          VPRecipeBase *RecipeB = B->getDefiningRecipe();
+          if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+              match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+              cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
+                  cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
+              !A->hasMoreThanOneUniqueUser() &&
+              !B->hasMoreThanOneUniqueUser()) {
+            return new VPMulAccRecipe(
+                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
+                cast<VPWidenCastRecipe>(RecipeA),
+                cast<VPWidenCastRecipe>(RecipeB));
+          } else {
+            // Matched reduce.add(mul(...))
+            return new VPMulAccRecipe(
+                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+          }
+          // Matched reduce.add(ext(mul(ext, ext)))
+          // Note that 3 extend instructions must have same opcode.
+        } else if (match(VecOp,
+                         m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
+                                            m_ZExtOrSExt(m_VPValue())))) &&
+                   !VecOp->hasMoreThanOneUniqueUser()) {
+          VPWidenCastRecipe *Ext =
+              dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
           VPWidenRecipe *Mul =
-              cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+              dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext0 =
               cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext1 =
               cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
           if (Ext->getOpcode() == Ext0->getOpcode() &&
-              Ext0->getOpcode() == Ext1->getOpcode()) {
-            RedRecipe = new VPMulAccRecipe(
+              Ext0->getOpcode() == Ext1->getOpcode() &&
+              !Mul->hasMoreThanOneUniqueUser() &&
+              !Ext0->hasMoreThanOneUniqueUser() &&
+              !Ext1->hasMoreThanOneUniqueUser()) {
+            return new VPMulAccRecipe(
                 RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                 CM.useOrderedReductions(RdxDesc),
                 cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
                 cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
-          } else
-            RedRecipe = new VPExtendedReductionRecipe(
-                RdxDesc, CurrentLinkI,
-                cast<CastInst>(
-                    cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
-                PreviousLink, cast<VPWidenCastRecipe>(VecOp)->getOperand(0),
-                CondOp, CM.useOrderedReductions(RdxDesc),
-                cast<VPWidenCastRecipe>(VecOp)->getResultType());
+          }
         }
-      }
-      // VPWidenCastRecipes can folded into VPReductionRecipe
-      else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
-               !VecOp->hasMoreThanOneUniqueUser()) {
-        RedRecipe = new VPExtendedReductionRecipe(
-            RdxDesc, CurrentLinkI,
-            cast<CastInst>(
-                cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
-            PreviousLink, A, CondOp, CM.useOrderedReductions(RdxDesc),
-            cast<VPWidenCastRecipe>(VecOp)->getResultType());
-      } else {
+        return nullptr;
+      };
+      auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
+        VPValue *A;
+        if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+            !VecOp->hasMoreThanOneUniqueUser()) {
+          return new VPExtendedReductionRecipe(
+              RdxDesc, CurrentLinkI, PreviousLink,
+              cast<VPWidenCastRecipe>(VecOp), CondOp,
+              CM.useOrderedReductions(RdxDesc));
+        }
+        return nullptr;
+      };
+      VPSingleDefRecipe *RedRecipe;
+      if (auto *MulAcc = TryToMatchMulAcc())
+        RedRecipe = MulAcc;
+      else if (auto *ExtendedRed = TryToMatchExtendedReduction())
+        RedRecipe = ExtendedRed;
+      else
         RedRecipe =
             new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
                                   CondOp, CM.useOrderedReductions(RdxDesc));
-      }
       // Append the recipe to the end of the VPBasicBlock because we need to
       // ensure that it comes after all of it's inputs, including CondOp.
       // Note that this transformation may leave over dead recipes (including
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c33e6081669f8b..e5b317dfd6f365 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2711,18 +2711,19 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
   bool IsConditional = false;
   /// Type after extend.
   Type *ResultTy;
+  /// Opcode for the extend instruction.
   Instruction::CastOps ExtOp;
-  CastInst *CastInstr;
+  CastInst *ExtInstr;
   bool IsZExt;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
                             const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, CastInst *CastI,
+                            Instruction::CastOps ExtOp, CastInst *ExtI,
                             ArrayRef<VPValue *> Operands, VPValue *CondOp,
                             bool IsOrdered, Type *ResultTy)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
-        ResultTy(ResultTy), ExtOp(ExtOp), CastInstr(CastI) {
+        ResultTy(ResultTy), ExtOp(ExtOp), ExtInstr(ExtI) {
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2732,20 +2733,13 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
-                            CastInst *CastI, VPValue *ChainOp, VPValue *VecOp,
-                            VPValue *CondOp, bool IsOrdered, Type *ResultTy)
-      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
-                                  CastI->getOpcode(), CastI,
-                                  ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
-                                  IsOrdered, ResultTy) {}
-
-  VPExtendedReductionRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *Ext)
+                            VPValue *ChainOp, VPWidenCastRecipe *Ext,
+                            VPValue *CondOp, bool IsOrdered)
       : VPExtendedReductionRecipe(
-            VPDef::VPExtendedReductionSC, Red->getRecurrenceDescriptor(),
-            Red->getUnderlyingInstr(), Ext->getOpcode(),
+            VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
             cast<CastInst>(Ext->getUnderlyingInstr()),
-            ArrayRef<VPValue *>({Red->getChainOp(), Ext->getOperand(0)}),
-            Red->getCondOp(), Red->isOrdered(), Ext->getResultType()) {}
+            ArrayRef<VPValue *>({ChainOp, Ext->getOperand(0)}), CondOp,
+            IsOrdered, Ext->getResultType()) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2762,7 +2756,6 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
     return R && classof(R);
   }
 
-  /// Generate the reduction in the loop
   void execute(VPTransformState &State) override {
     llvm_unreachable("VPExtendedReductionRecipe should be transform to "
                      "VPExtendedRecipe + VPReductionRecipe before execution.");
@@ -2794,9 +2787,12 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
   VPValue *getCondOp() const {
     return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
   }
+  /// The Type after extended.
   Type *getResultType() const { return ResultTy; };
+  /// The Opcode of extend instruction.
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
-  CastInst *getExtInstr() const { return CastInstr; };
+  /// The CastInst of the extend instruction.
+  CastInst *getExtInstr() const { return ExtInstr; };
 };
 
 /// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2812,16 +2808,17 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   bool IsConditional = false;
   /// Type after extend.
   Type *ResultType;
-  /// reduce.add(ext((mul(Ext(), Ext())))
+  // Note that all extend instruction must have the same opcode in MulAcc.
   Instruction::CastOps ExtOp;
 
+  /// reduce.add(ext(mul(ext0(), ext1())))
   Instruction *MulInstr;
   CastInst *ExtInstr = nullptr;
-  CastInst *Ext0Instr;
-  CastInst *Ext1Instr;
+  CastInst *Ext0Instr = nullptr;
+  CastInst *Ext1Instr = nullptr;
 
+  /// Is this MulAcc recipe contains extend recipes?
   bool IsExtended;
-  bool IsOuterExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2835,6 +2832,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
         ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2847,6 +2845,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
         MulInstr(MulInstr) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2898,13 +2897,12 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
     return R && classof(R);
   }
 
-  /// Generate the reduction in the loop
   void execute(VPTransformState &State) override {
     llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
                      "VPWidenRecipe + VPReductionRecipe before execution");
   }
 
-  /// Return the cost of VPExtendedReductionRecipe.
+  /// Return the cost of VPMulAccRecipe.
   InstructionCost computeCost(ElementCount VF,
                               VPCostContext &Ctx) const override;
 
@@ -2931,13 +2929,24 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   VPValue *getCondOp() const {
     return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
   }
+  /// Return the type after inner extended, which must equal to the type of mul
+  /// instruction. If the ResultType != recurrenceType, than it must have a
+  /// extend recipe after mul recipe.
   Type *getResultType() const { return ResultType; };
+  /// The opcode of the extend instructions.
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  /// The underlying instruction for VPWidenRecipe.
   Instruction *getMulInstr() const { return MulInstr; };
+  /// The underlying Instruction for outer VPWidenCastRecipe.
   CastInst *getExtInstr() const { return ExtInstr; };
+  /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt0Instr() const { return Ext0Instr; };
+  /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt1Instr() const { return Ext1Instr; };
+  /// Return if this MulAcc recipe contains extend instructions.
   bool isExtended() const { return IsExtended; };
+  /// Return if the operands of mul instruction come from same extend.
+  bool isSameExtend() const { return Ext0Instr == Ext1Instr; };
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ff1e31033fd9cf..33febe56a77294 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2214,122 +2214,6 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
         Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
-  /*
-  using namespace llvm::VPlanPatternMatch;
-  auto GetMulAccReductionCost =
-      [&](const VPReductionRecipe *Red) -> InstructionCost {
-    VPValue *A, *B;
-    InstructionCost InnerExt0Cost = 0;
-    InstructionCost InnerExt1Cost = 0;
-    InstructionCost ExtCost = 0;
-    InstructionCost MulCost = 0;
-
-    VectorType *SrcVecTy = VectorTy;
-    Type *InnerExt0Ty;
-    Type *InnerExt1Ty;
-    Type *MaxInnerExtTy;
-    bool IsUnsigned = true;
-    bool HasOuterExt = false;
-
-    auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
-        Red->getVecOp()->getDefiningRecipe());
-    VPRecipeBase *Mul;
-    // Try to match outer extend reduce.add(ext(...))
-    if (Ext && match(Ext, m_ZExtOrSExt(m_VPValue())) &&
-        cast<VPWidenCastRecipe>(Ext)->getNumUsers() == 1) {
-      IsUnsigned =
-          Ext->getOpcode() == Instruction::CastOps::ZExt ? true : false;
-      ExtCost = Ext->computeCost(VF, Ctx);
-      Mul = Ext->getOperand(0)->getDefiningRecipe();
-      HasOuterExt = true;
-    } else {
-      Mul = Red->getVecOp()->getDefiningRecipe();
-    }
-
-    // Match reduce.add(mul())
-    if (Mul && match(Mul, m_Mul(m_VPValue(A), m_VPValue(B))) &&
-        cast<VPWidenRecipe>(Mul)->getNumUsers() == 1) {
-      MulCost = cast<VPWidenRecipe>(Mul)->computeCost(VF, Ctx);
-      auto *InnerExt0 =
-          dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
-      auto *InnerExt1 =
-          dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
-      bool HasInnerExt = false;
-      // Try to match inner extends.
-      if (InnerExt0 && InnerExt1 &&
-          match(InnerExt0, m_ZExtOrSExt(m_VPValue())) &&
-          match(InnerExt1, m_ZExtOrSExt(m_VPValue())) &&
-          InnerExt0->getOpcode() == InnerExt1->getOpcode() &&
-          (InnerExt0->getNumUsers() > 0 &&
-           !InnerExt0->hasMoreThanOneUniqueUser()) &&
-          (InnerExt1->getNumUsers() > 0 &&
-           !InnerExt1->hasMoreThanOneUniqueUser())) {
-        InnerExt0Cost = InnerExt0->computeCost(VF, Ctx);
-        InnerExt1Cost = InnerExt1->computeCost(VF, Ctx);
-        Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
-        Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
-        Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth() >
-                                      InnerExt1Ty->getIntegerBitWidth()
-                                  ? InnerExt0Ty
-                                  : InnerExt1Ty;
-        SrcVecTy = cast<VectorType>(ToVectorTy(MaxInnerExtTy, VF));
-        IsUnsigned = true;
-        HasInnerExt = true;
-      }
-      InstructionCost MulAccRedCost = Ctx.TTI.getMulAccReductionCost(
-          IsUnsigned, ElementTy, SrcVecTy, CostKind);
-      // Check if folding ext/mul into MulAccReduction is profitable.
-      if (MulAccRedCost.isValid() &&
-          MulAccRedCost <
-              ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
-        if (HasInnerExt) {
-          Ctx.FoldedRecipes[VF].insert(InnerExt0);
-          Ctx.FoldedRecipes[VF].insert(InnerExt1);
-        }
-        Ctx.FoldedRecipes[VF].insert(Mul);
-        if (HasOuterExt)
-          Ctx.FoldedRecipes[VF].insert(Ext);
-        return MulAccRedCost;
-      }
-    }
-    return InstructionCost::getInvalid();
-  };
-
-  // Match reduce(ext(...))
-  auto GetExtendedReductionCost =
-      [&](const VPReductionRecipe *Red) -> InstructionCost {
-    VPValue *VecOp = Red->getVecOp();
-    VPValue *A;
-    if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && VecOp->getNumUsers() == 1) {
-      VPWidenCastRecipe *Ext =
-          cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
-      bool IsUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
-      InstructionCost ExtCost = Ext->computeCost(VF, Ctx);
-      auto *ExtVecTy =
-          cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(A), VF));
-      InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
-          Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags(),
-          CostKind);
-      // Check if folding ext into ExtendedReduction is profitable.
-      if (ExtendedRedCost.isValid() && ExtendedRedCost < ExtCost + BaseCost) {
-        Ctx.FoldedRecipes[VF].insert(Ext);
-        return ExtendedRedCost;
-      }
-    }
-    return InstructionCost::getInvalid();
-  };
-
-  // Match MulAccReduction patterns.
-  InstructionCost MulAccCost = GetMulAccReductionCost(this);
-  if (MulAccCost.isValid())
-    return MulAccCost;
-
-  // Match ExtendedReduction patterns.
-  InstructionCost ExtendedCost = GetExtendedReductionCost(this);
-  if (ExtendedCost.isValid())
-    return ExtendedCost;
-  */
-
   // Default cost.
   return BaseCost;
 }
@@ -2343,9 +2227,6 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
-         "Inferred type and recurrence type mismatch.");
-
   // BaseCost = Reduction cost + BinOp cost
   InstructionCost ReductionCost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
@@ -2387,8 +2268,8 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
-         "Inferred type and recurrence type mismatch.");
+  assert(Opcode == Instruction::Add &&
+         "Reduction opcode must be add in the VPMulAccRecipe.");
 
   // BaseCost = Reduction cost + BinOp cost
   InstructionCost ReductionCost =
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index c28953e7314f57..423694cdd84e05 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -545,11 +545,15 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
           Op0 = new VPWidenCastRecipe(
               MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
               MulAcc->getResultType(), *MulAcc->getExt0Instr());
-          Op1 = new VPWidenCastRecipe(
-              MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
-              MulAcc->getResultType(), *MulAcc->getExt1Instr());
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
-          Op1->getDefiningRecipe()->insertBefore(MulAcc);
+          if (!MulAcc->isSameExtend()) {
+            Op1 = new VPWidenCastRecipe(
+                MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+                MulAcc->getResultType(), *MulAcc->getExt1Instr());
+            Op1->getDefiningRecipe()->insertBefore(MulAcc);
+          } else {
+            Op1 = Op0;
+          }
         } else {
           Op0 = MulAcc->getVecOp0();
           Op1 = MulAcc->getVecOp1();
@@ -559,14 +563,13 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
         auto *Mul = new VPWidenRecipe(*MulInstr,
                                       make_range(MulOps.begin(), MulOps.end()));
-        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
-          // dbgs() <<"\n!!!"<< *OuterExtInstr << " " << MulAcc->getExtOpcode()
-          // << "\n";
+        // Outer extend.
+        if (auto *OuterExtInstr = MulAcc->getExtInstr())
           VecOp = new VPWidenCastRecipe(
               MulAcc->getExtOpcode(), Mul,
               MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
               *OuterExtInstr);
-        } else
+        else
           VecOp = Mul;
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index ffdd8e66983d85..9b7c345076480a 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1412,9 +1412,8 @@ define i32 @mla_i8_i32_multiuse(ptr nocapture readonly %x, ptr nocapture readonl
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP7]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP7]], [[TMP7]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP2]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
index c6438857a7571e..c20b9051750622 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
@@ -209,14 +209,14 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
 ; IF-EVL-INLOOP-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP6]]
 ; IF-EVL-INLOOP-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i16, ptr [[TMP10]], i32 0
 ; IF-EVL-INLOOP-NEXT:    [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP15]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP5]])
-; IF-EVL-INLOOP-NEXT:    [[TMP16:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
-; IF-EVL-INLOOP-NEXT:    [[TMP17:%.*]] = select <vscale x 8 x i1> [[TMP9]], <vscale x 8 x i32> [[TMP16]], <vscale x 8 x i32> zeroinitializer
+; IF-EVL-INLOOP-NEXT:    [[TMP12:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
+; IF-EVL-INLOOP-NEXT:    [[TMP17:%.*]] = select <vscale x 8 x i1> [[TMP9]], <vscale x 8 x i32> [[TMP12]], <vscale x 8 x i32> zeroinitializer
 ; IF-EVL-INLOOP-NEXT:    [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP17]])
 ; IF-EVL-INLOOP-NEXT:    [[TMP11]] = add i32 [[TMP14]], [[VEC_PHI]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_EVL_NEXT]] = add nuw i32 [[TMP5]], [[EVL_BASED_IV]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], [[TMP4]]
-; IF-EVL-INLOOP-NEXT:    [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; IF-EVL-INLOOP-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; IF-EVL-INLOOP-NEXT:    [[TMP16:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
+; IF-EVL-INLOOP-NEXT:    br i1 [[TMP16]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; IF-EVL-INLOOP:       middle.block:
 ; IF-EVL-INLOOP-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ; IF-EVL-INLOOP:       scalar.ph:

>From cc004ffd4f94c071c78eda07f327a20f85696be1 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Wed, 6 Nov 2024 21:07:04 -0800
Subject: [PATCH 06/26] Fix typos and update printing test

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |   3 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |   7 +-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  18 +-
 .../LoopVectorize/vplan-printing.ll           | 236 ++++++++++++++++++
 4 files changed, 253 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f7bca1383cdad1..b70f7f9dba2611 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7679,7 +7679,7 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
                              CanonicalIVStartValue, State);
   VPlanTransforms::prepareToExecute(BestVPlan);
 
-  // TODO: Rebase to fhahn's implementation.
+  // TODO: Replace with upstream implementation.
   VPlanTransforms::prepareExecute(BestVPlan);
   BestVPlan.execute(&State);
 
@@ -9279,7 +9279,6 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
 // Adjust AnyOf reductions; replace the reduction phi for the selected value
 // with a boolean reduction phi node to check if the condition is true in any
 // iteration. The final value is selected by the final ComputeReductionResult.
-// TODO: Implement VPMulAccHere.
 void LoopVectorizationPlanner::adjustRecipesForReductions(
     VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) {
   using namespace VPlanPatternMatch;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e5b317dfd6f365..504021d0f621d3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2798,8 +2798,8 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
 /// A recipe to represent inloop MulAccreduction operations, performing a
 /// reduction on a vector operand into a scalar value, and adding the result to
 /// a chain. This recipe is high level abstract which will generate
-/// VPReductionRecipe VPWidenRecipe(mul)and VPWidenCastRecipe before execution.
-/// The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+/// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
+/// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPSingleDefRecipe {
   /// The recurrence decriptor for the reduction in question.
   const RecurrenceDescriptor &RdxDesc;
@@ -2819,6 +2819,8 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
 
   /// Is this MulAcc recipe contains extend recipes?
   bool IsExtended;
+  /// Is this reciep contains outer extend instuction?
+  bool IsOuterExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2838,6 +2840,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
       addOperand(CondOp);
     }
     IsExtended = true;
+    IsOuterExtended = ExtInstr != nullptr;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 33febe56a77294..a85a4bc81c2854 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,7 +270,7 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = &WidenMem->getIngredient();
 
   InstructionCost RecipeCost;
-  if ((UI && Ctx.skipCostComputation(UI, VF.isVector()))) {
+  if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
     RecipeCost = 0;
   } else {
     RecipeCost = computeCost(VF, Ctx);
@@ -2413,18 +2413,20 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   printAsOperand(O, SlotTracker);
   O << " = ";
   getChainOp()->printAsOperand(O, SlotTracker);
-  O << " +";
+  O << " + ";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
-  O << " mul ";
+  if (IsOuterExtended)
+    O << " (";
+  O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  O << "mul ";
   if (IsExtended)
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
   if (IsExtended)
-    O << " extended to " << *getResultType() << ")";
-  if (IsExtended)
-    O << "(";
+    O << " extended to " << *getResultType() << "), (";
+  else
+    O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
   if (IsExtended)
     O << " extended to " << *getResultType() << ")";
@@ -2433,6 +2435,8 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
+  if (IsOuterExtended)
+    O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
   if (RdxDesc.IntermediateStore)
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 195f6a48640e54..b1d7388005ad4e 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -1143,6 +1143,242 @@ exit:
   ret i16 %for.1
 }
 
+define i64 @print_extended_reduction(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_extended_reduction'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:     WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT:     EXTENDED-REDUCE ir<%add> = ir<%r.09> + reduce.add (ir<%load0> extended to i64)
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%6> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT:   EMIT vp<%7> = extract-from-end vp<%6>, ir<1>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%7>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%6>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i32, ptr %x, i32 %i.010
+; CHECK-NEXT:   IR   %load0 = load i32, ptr %arrayidx, align 4
+; CHECK-NEXT:   IR   %conv0 = zext i32 %load0 to i64
+; CHECK-NEXT:   IR   %add = add nsw i64 %r.09, %conv0
+; CHECK-NEXT:   IR   %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:   IR   %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+  %cmp8 = icmp sgt i32 %n, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i32, ptr %x, i32 %i.010
+  %load0 = load i32, ptr %arrayidx, align 4
+  %conv0 = zext i32 %load0 to i64
+  %add = add nsw i64 %r.09, %conv0
+  %inc = add nuw nsw i32 %i.010, 1
+  %exitcond = icmp eq i32 %inc, %n
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %r.0.lcssa
+}
+
+define i64 @print_mulacc(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_mulacc'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:     WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT:     CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
+; CHECK-NEXT:     vp<%5> = vector-pointer ir<%arrayidx1>
+; CHECK-NEXT:     WIDEN ir<%load1> = load vp<%5>
+; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> + reduce.add (mul ir<%load0>, ir<%load1>)
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%7> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT:   EMIT vp<%8> = extract-from-end vp<%7>, ir<1>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%7>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i64, ptr %x, i32 %i.010
+; CHECK-NEXT:   IR   %load0 = load i64, ptr %arrayidx, align 4
+; CHECK-NEXT:   IR   %arrayidx1 = getelementptr inbounds i64, ptr %y, i32 %i.010
+; CHECK-NEXT:   IR   %load1 = load i64, ptr %arrayidx1, align 4
+; CHECK-NEXT:   IR   %mul = mul nsw i64 %load0, %load1
+; CHECK-NEXT:   IR   %add = add nsw i64 %r.09, %mul
+; CHECK-NEXT:   IR   %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:   IR   %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+  %cmp8 = icmp sgt i32 %n, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i64, ptr %x, i32 %i.010
+  %load0 = load i64, ptr %arrayidx, align 4
+  %arrayidx1 = getelementptr inbounds i64, ptr %y, i32 %i.010
+  %load1 = load i64, ptr %arrayidx1, align 4
+  %mul = mul nsw i64 %load0, %load1
+  %add = add nsw i64 %r.09, %mul
+  %inc = add nuw nsw i32 %i.010, 1
+  %exitcond = icmp eq i32 %inc, %n
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %r.0.lcssa
+}
+
+define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_mulacc_extended'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:     WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT:     CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
+; CHECK-NEXT:     vp<%5> = vector-pointer ir<%arrayidx1>
+; CHECK-NEXT:     WIDEN ir<%load1> = load vp<%5>
+; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> +  (reduce.add (mul (ir<%load0> extended to i32), (ir<%load1> extended to i32)) extended to i64)
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%7> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT:   EMIT vp<%8> = extract-from-end vp<%7>, ir<1>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%7>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i16, ptr %x, i32 %i.010
+; CHECK-NEXT:   IR   %load0 = load i16, ptr %arrayidx, align 4
+; CHECK-NEXT:   IR   %arrayidx1 = getelementptr inbounds i16, ptr %y, i32 %i.010
+; CHECK-NEXT:   IR   %load1 = load i16, ptr %arrayidx1, align 4
+; CHECK-NEXT:   IR   %conv0 = sext i16 %load0 to i32
+; CHECK-NEXT:   IR   %conv1 = sext i16 %load1 to i32
+; CHECK-NEXT:   IR   %mul = mul nsw i32 %conv0, %conv1
+; CHECK-NEXT:   IR   %conv = sext i32 %mul to i64
+; CHECK-NEXT:   IR   %add = add nsw i64 %r.09, %conv
+; CHECK-NEXT:   IR   %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:   IR   %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+  %cmp8 = icmp sgt i32 %n, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i16, ptr %x, i32 %i.010
+  %load0 = load i16, ptr %arrayidx, align 4
+  %arrayidx1 = getelementptr inbounds i16, ptr %y, i32 %i.010
+  %load1 = load i16, ptr %arrayidx1, align 4
+  %conv0 = sext i16 %load0 to i32
+  %conv1 = sext i16 %load1 to i32
+  %mul = mul nsw i32 %conv0, %conv1
+  %conv = sext i32 %mul to i64
+  %add = add nsw i64 %r.09, %conv
+  %inc = add nuw nsw i32 %i.010, 1
+  %exitcond = icmp eq i32 %inc, %n
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %r.0.lcssa
+}
+
 !llvm.dbg.cu = !{!0}
 !llvm.module.flags = !{!3, !4}
 

>From b5445ca7e2b33b3485dee9a4e265f876ef6a2470 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 10 Nov 2024 18:00:52 -0800
Subject: [PATCH 07/26] Fold reduce.add(zext(mul(sext(A), sext(B)))) into
 MulAccRecipe when A == B

For the future refactor of avoiding reference underlying instructions
and mismatched opcode and the entend instruction in the new added
pattern, removed passing UI when creating VPWidenCastRecipe.
This removed will lead to dupicate extend instruction created after loop
vectorizer when there are two reduction patterns exist in the same loop.
This redundant instruction might be removed after LV.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 31 ++++++++-----------
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 15 +++++----
 .../LoopVectorize/ARM/mve-reductions.ll       |  3 +-
 .../LoopVectorize/reduction-inloop.ll         |  6 ++--
 4 files changed, 24 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index b70f7f9dba2611..b59fe696c393ac 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9404,20 +9404,18 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
         if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
             !VecOp->hasMoreThanOneUniqueUser()) {
-          VPRecipeBase *RecipeA = A->getDefiningRecipe();
-          VPRecipeBase *RecipeB = B->getDefiningRecipe();
+          VPWidenCastRecipe *RecipeA =
+              dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+          VPWidenCastRecipe *RecipeB =
+              dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
           if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
               match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
-              cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
-                  cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
-              !A->hasMoreThanOneUniqueUser() &&
-              !B->hasMoreThanOneUniqueUser()) {
+              (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
             return new VPMulAccRecipe(
                 RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                 CM.useOrderedReductions(RdxDesc),
-                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
-                cast<VPWidenCastRecipe>(RecipeA),
-                cast<VPWidenCastRecipe>(RecipeB));
+                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()), RecipeA,
+                RecipeB);
           } else {
             // Matched reduce.add(mul(...))
             return new VPMulAccRecipe(
@@ -9425,8 +9423,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
                 CM.useOrderedReductions(RdxDesc),
                 cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
           }
-          // Matched reduce.add(ext(mul(ext, ext)))
-          // Note that 3 extend instructions must have same opcode.
+          // Matched reduce.add(ext(mul(ext(A), ext(B))))
+          // Note that 3 extend instructions must have same opcode or A == B
+          // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
         } else if (match(VecOp,
                          m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
                                             m_ZExtOrSExt(m_VPValue())))) &&
@@ -9439,11 +9438,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
               cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext1 =
               cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
-          if (Ext->getOpcode() == Ext0->getOpcode() &&
-              Ext0->getOpcode() == Ext1->getOpcode() &&
-              !Mul->hasMoreThanOneUniqueUser() &&
-              !Ext0->hasMoreThanOneUniqueUser() &&
-              !Ext1->hasMoreThanOneUniqueUser()) {
+          if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
+              Ext0->getOpcode() == Ext1->getOpcode()) {
             return new VPMulAccRecipe(
                 RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                 CM.useOrderedReductions(RdxDesc),
@@ -9455,8 +9451,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       };
       auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
         VPValue *A;
-        if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
-            !VecOp->hasMoreThanOneUniqueUser()) {
+        if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
           return new VPExtendedReductionRecipe(
               RdxDesc, CurrentLinkI, PreviousLink,
               cast<VPWidenCastRecipe>(VecOp), CondOp,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 423694cdd84e05..d3f0b5c9e5ba6b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -542,14 +542,14 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
-          Op0 = new VPWidenCastRecipe(
-              MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
-              MulAcc->getResultType(), *MulAcc->getExt0Instr());
+          Op0 =
+              new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+                                    MulAcc->getResultType());
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           if (!MulAcc->isSameExtend()) {
-            Op1 = new VPWidenCastRecipe(
-                MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
-                MulAcc->getResultType(), *MulAcc->getExt1Instr());
+            Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+                                        MulAcc->getVecOp1(),
+                                        MulAcc->getResultType());
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           } else {
             Op1 = Op0;
@@ -567,8 +567,7 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         if (auto *OuterExtInstr = MulAcc->getExtInstr())
           VecOp = new VPWidenCastRecipe(
               MulAcc->getExtOpcode(), Mul,
-              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
-              *OuterExtInstr);
+              MulAcc->getRecurrenceDescriptor().getRecurrenceType());
         else
           VecOp = Mul;
         auto *Red = new VPReductionRecipe(
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index 9b7c345076480a..b138666ba92f9b 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1535,7 +1535,8 @@ define i64 @mla_and_add_together_16_64(ptr nocapture noundef readonly %x, i32 no
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext nneg <8 x i32> [[TMP2]] to <8 x i64>
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i64 [[TMP4]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP1]])
+; CHECK-NEXT:    [[TMP10:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP10]])
 ; CHECK-NEXT:    [[TMP7]] = add i32 [[TMP6]], [[VEC_PHI1]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index b578e61d85dfa1..d6dbb74f26d4ab 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -1206,15 +1206,13 @@ define i32 @predicated_not_dominates_reduction_twoadd(ptr nocapture noundef read
 ; CHECK-NEXT:    [[TMP0:%.*]] = sext i32 [[INDEX]] to i64
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[H:%.*]], i64 [[TMP0]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
-; CHECK-NEXT:    [[DOTNOT:%.*]] = icmp eq <4 x i8> [[WIDE_LOAD]], zeroinitializer
 ; CHECK-NEXT:    [[TMP2:%.*]] = udiv <4 x i8> [[WIDE_LOAD]], splat (i8 31)
 ; CHECK-NEXT:    [[TMP3:%.*]] = shl nuw nsw <4 x i8> [[TMP2]], splat (i8 3)
 ; CHECK-NEXT:    [[TMP4:%.*]] = udiv <4 x i8> [[TMP3]], splat (i8 31)
 ; CHECK-NEXT:    [[TMP5:%.*]] = zext nneg <4 x i8> [[TMP4]] to <4 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i32> zeroinitializer, <4 x i32> [[TMP5]]
-; CHECK-NEXT:    [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
+; CHECK-NEXT:    [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP5]])
 ; CHECK-NEXT:    [[TMP8:%.*]] = add i32 [[TMP7]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP9:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i32> zeroinitializer, <4 x i32> [[TMP5]]
+; CHECK-NEXT:    [[TMP9:%.*]] = zext nneg <4 x i8> [[TMP4]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP9]])
 ; CHECK-NEXT:    [[TMP11]] = add i32 [[TMP10]], [[TMP8]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4

>From 1df91d490fef8cf29a724110bd59f17b1075ab84 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 10 Nov 2024 23:34:05 -0800
Subject: [PATCH 08/26] Refactor! Reuse functions from VPReductionRecipe.

---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 114 +++++-------------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |   6 +-
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  20 +--
 3 files changed, 47 insertions(+), 93 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 504021d0f621d3..2e8786cd569205 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -684,8 +684,6 @@ struct VPCostContext {
   LLVMContext &LLVMCtx;
   LoopVectorizationCostModel &CM;
   SmallPtrSet<Instruction *, 8> SkipCostComputation;
-  /// Contains recipes that are folded into other recipes.
-  SmallDenseMap<ElementCount, SmallPtrSet<VPRecipeBase *, 4>, 4> FoldedRecipes;
 
   VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,
                 Type *CanIVTy, LoopVectorizationCostModel &CM)
@@ -2616,6 +2614,8 @@ class VPReductionRecipe : public VPSingleDefRecipe {
                                  getVecOp(), getCondOp(), IsOrdered);
   }
 
+  // TODO: Support VPExtendedReductionRecipe and VPMulAccRecipe after EVL
+  // support.
   static inline bool classof(const VPRecipeBase *R) {
     return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
            R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
@@ -2703,33 +2703,20 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 /// a chain. This recipe is high level abstract which will generate
 /// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
 /// {ChainOp, VecOp, [Condition]}.
-class VPExtendedReductionRecipe : public VPSingleDefRecipe {
-  /// The recurrence decriptor for the reduction in question.
-  const RecurrenceDescriptor &RdxDesc;
-  bool IsOrdered;
-  /// Whether the reduction is conditional.
-  bool IsConditional = false;
+class VPExtendedReductionRecipe : public VPReductionRecipe {
   /// Type after extend.
   Type *ResultTy;
-  /// Opcode for the extend instruction.
-  Instruction::CastOps ExtOp;
   CastInst *ExtInstr;
-  bool IsZExt;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
                             const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, CastInst *ExtI,
-                            ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                            Instruction::CastOps ExtOp, CastInst *ExtInstr,
+                            VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
                             bool IsOrdered, Type *ResultTy)
-      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
-        ResultTy(ResultTy), ExtOp(ExtOp), ExtInstr(ExtI) {
-    if (CondOp) {
-      IsConditional = true;
-      addOperand(CondOp);
-    }
-    IsZExt = ExtOp == Instruction::CastOps::ZExt;
-  }
+      : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
+                          CondOp, IsOrdered),
+        ResultTy(ResultTy), ExtInstr(ExtInstr) {}
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
@@ -2737,9 +2724,8 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
                             VPValue *CondOp, bool IsOrdered)
       : VPExtendedReductionRecipe(
             VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
-            cast<CastInst>(Ext->getUnderlyingInstr()),
-            ArrayRef<VPValue *>({ChainOp, Ext->getOperand(0)}), CondOp,
-            IsOrdered, Ext->getResultType()) {}
+            cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
+            Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2771,26 +2757,11 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
              VPSlotTracker &SlotTracker) const override;
 #endif
 
-  /// Return the recurrence decriptor for the in-loop reduction.
-  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
-    return RdxDesc;
-  }
-  /// Return true if the in-loop reduction is ordered.
-  bool isOrdered() const { return IsOrdered; };
-  /// Return true if the in-loop reduction is conditional.
-  bool isConditional() const { return IsConditional; };
-  /// The VPValue of the scalar Chain being accumulated.
-  VPValue *getChainOp() const { return getOperand(0); }
-  /// The VPValue of the vector value to be extended and reduced.
-  VPValue *getVecOp() const { return getOperand(1); }
-  /// The VPValue of the condition for the block.
-  VPValue *getCondOp() const {
-    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
-  }
   /// The Type after extended.
   Type *getResultType() const { return ResultTy; };
+  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; };
   /// The Opcode of extend instruction.
-  Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); };
   /// The CastInst of the extend instruction.
   CastInst *getExtInstr() const { return ExtInstr; };
 };
@@ -2800,12 +2771,7 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
 /// a chain. This recipe is high level abstract which will generate
 /// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
 /// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
-class VPMulAccRecipe : public VPSingleDefRecipe {
-  /// The recurrence decriptor for the reduction in question.
-  const RecurrenceDescriptor &RdxDesc;
-  bool IsOrdered;
-  /// Whether the reduction is conditional.
-  bool IsConditional = false;
+class VPMulAccRecipe : public VPReductionRecipe {
   /// Type after extend.
   Type *ResultType;
   // Note that all extend instruction must have the same opcode in MulAcc.
@@ -2827,32 +2793,29 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  Instruction *RedI, Instruction *ExtInstr,
                  Instruction *MulInstr, Instruction::CastOps ExtOp,
                  Instruction *Ext0Instr, Instruction *Ext1Instr,
-                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered,
-                 Type *ResultType)
-      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+      : VPReductionRecipe(SC, R, RedI,
+                          ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
+                          CondOp, IsOrdered),
         ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
         ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
-    if (CondOp) {
-      IsConditional = true;
-      addOperand(CondOp);
-    }
     IsExtended = true;
     IsOuterExtended = ExtInstr != nullptr;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *MulInstr,
-                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
-      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+                 Instruction *RedI, Instruction *MulInstr, VPValue *ChainOp,
+                 VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp,
+                 bool IsOrdered)
+      : VPReductionRecipe(SC, R, RedI,
+                          ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
+                          CondOp, IsOrdered),
         MulInstr(MulInstr) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
-    if (CondOp) {
-      IsConditional = true;
-      addOperand(CondOp);
-    }
     IsExtended = false;
   }
 
@@ -2864,17 +2827,15 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
                        Mul->getUnderlyingInstr(), Ext0->getOpcode(),
                        Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ArrayRef<VPValue *>(
-                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
                        CondOp, IsOrdered, Ext0->getResultType()) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul)
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       ArrayRef<VPValue *>(
-                           {ChainOp, Mul->getOperand(0), Mul->getOperand(1)}),
-                       CondOp, IsOrdered) {}
+                       ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
+                       IsOrdered) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
@@ -2883,8 +2844,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
                        Mul->getUnderlyingInstr(), Ext0->getOpcode(),
                        Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ArrayRef<VPValue *>(
-                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
                        CondOp, IsOrdered, Ext0->getResultType()) {}
 
   ~VPMulAccRecipe() override = default;
@@ -2915,37 +2875,27 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
              VPSlotTracker &SlotTracker) const override;
 #endif
 
-  /// Return the recurrence decriptor for the in-loop reduction.
-  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
-    return RdxDesc;
-  }
-  /// Return true if the in-loop reduction is ordered.
-  bool isOrdered() const { return IsOrdered; };
-  /// Return true if the in-loop reduction is conditional.
-  bool isConditional() const { return IsConditional; };
-  /// The VPValue of the scalar Chain being accumulated.
-  VPValue *getChainOp() const { return getOperand(0); }
   /// The VPValue of the vector value to be extended and reduced.
   VPValue *getVecOp0() const { return getOperand(1); }
   VPValue *getVecOp1() const { return getOperand(2); }
-  /// The VPValue of the condition for the block.
-  VPValue *getCondOp() const {
-    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
-  }
+
   /// Return the type after inner extended, which must equal to the type of mul
   /// instruction. If the ResultType != recurrenceType, than it must have a
   /// extend recipe after mul recipe.
   Type *getResultType() const { return ResultType; };
+
   /// The opcode of the extend instructions.
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
   /// The underlying instruction for VPWidenRecipe.
   Instruction *getMulInstr() const { return MulInstr; };
+
   /// The underlying Instruction for outer VPWidenCastRecipe.
   CastInst *getExtInstr() const { return ExtInstr; };
   /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt0Instr() const { return Ext0Instr; };
   /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt1Instr() const { return Ext1Instr; };
+
   /// Return if this MulAcc recipe contains extend instructions.
   bool isExtended() const { return IsExtended; };
   /// Return if the operands of mul instruction come from same extend.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index a85a4bc81c2854..c7c50702f90b44 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2221,6 +2221,7 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
 InstructionCost
 VPExtendedReductionRecipe::computeCost(ElementCount VF,
                                        VPCostContext &Ctx) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
   RecurKind RdxKind = RdxDesc.getRecurrenceKind();
   Type *ElementTy = getResultType();
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
@@ -2251,7 +2252,7 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 
   // ExtendedReduction Cost
   InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
-      Opcode, IsZExt, ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
+      Opcode, isZExt(), ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
   // Check if folding ext into ExtendedReduction is profitable.
   if (ExtendedRedCost.isValid() &&
       ExtendedRedCost < ExtendedCost + ReductionCost) {
@@ -2262,6 +2263,7 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 
 InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
   Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
                                : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
@@ -2387,6 +2389,7 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
 
 void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                                       VPSlotTracker &SlotTracker) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
   O << Indent << "EXTENDED-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
@@ -2409,6 +2412,7 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 
 void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
                            VPSlotTracker &SlotTracker) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
   O << Indent << "MULACC-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index d3f0b5c9e5ba6b..5e9c30911e34e5 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -525,8 +525,7 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
   for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
            vp_depth_first_deep(Plan.getEntry()))) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
-      if (isa<VPExtendedReductionRecipe>(&R)) {
-        auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+      if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
         auto *Ext = new VPWidenCastRecipe(
             ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
             *ExtRed->getExtInstr());
@@ -542,14 +541,14 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
-          Op0 =
-              new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
-                                    MulAcc->getResultType());
+          CastInst *Ext0 = MulAcc->getExt0Instr();
+          Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
+                                      MulAcc->getResultType(), *Ext0);
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           if (!MulAcc->isSameExtend()) {
-            Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
-                                        MulAcc->getVecOp1(),
-                                        MulAcc->getResultType());
+            CastInst *Ext1 = MulAcc->getExt1Instr();
+            Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
+                                        MulAcc->getResultType(), *Ext1);
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           } else {
             Op1 = Op0;
@@ -566,8 +565,9 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         // Outer extend.
         if (auto *OuterExtInstr = MulAcc->getExtInstr())
           VecOp = new VPWidenCastRecipe(
-              MulAcc->getExtOpcode(), Mul,
-              MulAcc->getRecurrenceDescriptor().getRecurrenceType());
+              OuterExtInstr->getOpcode(), Mul,
+              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
+              *OuterExtInstr);
         else
           VecOp = Mul;
         auto *Red = new VPReductionRecipe(

>From a0b2f30cd6596e687e68e7ffe5a3407e35a77f05 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 11 Nov 2024 23:17:46 -0800
Subject: [PATCH 09/26] Refactor! Add comments and refine new recipes.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 20 +++---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 72 +++++++++----------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 28 ++++----
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 35 ++++++---
 4 files changed, 85 insertions(+), 70 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index b59fe696c393ac..5cee4b7e2ce70e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9397,17 +9397,18 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      auto TryToMatchMulAcc = [&]() -> VPSingleDefRecipe * {
+      auto TryToMatchMulAcc = [&]() -> VPReductionRecipe * {
         VPValue *A, *B;
         if (RdxDesc.getOpcode() != Instruction::Add)
           return nullptr;
-        // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+        // Try to match reduce.add(mul(...))
         if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
             !VecOp->hasMoreThanOneUniqueUser()) {
           VPWidenCastRecipe *RecipeA =
               dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
           VPWidenCastRecipe *RecipeB =
               dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+          // Matched reduce.add(mul(ext, ext))
           if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
               match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
               (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
@@ -9417,23 +9418,23 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
                 cast<VPWidenRecipe>(VecOp->getDefiningRecipe()), RecipeA,
                 RecipeB);
           } else {
-            // Matched reduce.add(mul(...))
+            // Matched reduce.add(mul)
             return new VPMulAccRecipe(
                 RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                 CM.useOrderedReductions(RdxDesc),
                 cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
           }
           // Matched reduce.add(ext(mul(ext(A), ext(B))))
-          // Note that 3 extend instructions must have same opcode or A == B
+          // Note that all extend instructions must have same opcode or A == B
           // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
         } else if (match(VecOp,
                          m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
                                             m_ZExtOrSExt(m_VPValue())))) &&
                    !VecOp->hasMoreThanOneUniqueUser()) {
           VPWidenCastRecipe *Ext =
-              dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+              cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
           VPWidenRecipe *Mul =
-              dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+              cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext0 =
               cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext1 =
@@ -9449,8 +9450,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         }
         return nullptr;
       };
-      auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
+
+      auto TryToMatchExtendedReduction = [&]() -> VPReductionRecipe * {
         VPValue *A;
+        // Matched reduce(ext)).
         if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
           return new VPExtendedReductionRecipe(
               RdxDesc, CurrentLinkI, PreviousLink,
@@ -9459,7 +9462,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         }
         return nullptr;
       };
-      VPSingleDefRecipe *RedRecipe;
+
+      VPReductionRecipe *RedRecipe;
       if (auto *MulAcc = TryToMatchMulAcc())
         RedRecipe = MulAcc;
       else if (auto *ExtendedRed = TryToMatchExtendedReduction())
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 2e8786cd569205..ab9e6f0393602f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2758,12 +2758,12 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 #endif
 
   /// The Type after extended.
-  Type *getResultType() const { return ResultTy; };
-  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; };
+  Type *getResultType() const { return ResultTy; }
+  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
   /// The Opcode of extend instruction.
-  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); };
+  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
   /// The CastInst of the extend instruction.
-  CastInst *getExtInstr() const { return ExtInstr; };
+  CastInst *getExtInstr() const { return ExtInstr; }
 };
 
 /// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2774,8 +2774,6 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 class VPMulAccRecipe : public VPReductionRecipe {
   /// Type after extend.
   Type *ResultType;
-  // Note that all extend instruction must have the same opcode in MulAcc.
-  Instruction::CastOps ExtOp;
 
   /// reduce.add(ext(mul(ext0(), ext1())))
   Instruction *MulInstr;
@@ -2783,28 +2781,21 @@ class VPMulAccRecipe : public VPReductionRecipe {
   CastInst *Ext0Instr = nullptr;
   CastInst *Ext1Instr = nullptr;
 
-  /// Is this MulAcc recipe contains extend recipes?
-  bool IsExtended;
-  /// Is this reciep contains outer extend instuction?
-  bool IsOuterExtended = false;
-
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
                  Instruction *RedI, Instruction *ExtInstr,
-                 Instruction *MulInstr, Instruction::CastOps ExtOp,
-                 Instruction *Ext0Instr, Instruction *Ext1Instr,
-                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+                 Instruction *MulInstr, Instruction *Ext0Instr,
+                 Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
+                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
+                 Type *ResultType)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered),
-        ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+        ResultType(ResultType), MulInstr(MulInstr),
         ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
-    IsExtended = true;
-    IsOuterExtended = ExtInstr != nullptr;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2816,7 +2807,6 @@ class VPMulAccRecipe : public VPReductionRecipe {
                           CondOp, IsOrdered),
         MulInstr(MulInstr) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
-    IsExtended = false;
   }
 
 public:
@@ -2825,10 +2815,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
-                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
-                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered, Ext0->getResultType()) {}
+                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered,
+                       Ext0->getResultType()) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
@@ -2842,10 +2832,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
                  VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
-                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
-                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered, Ext0->getResultType()) {}
+                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered,
+                       Ext0->getResultType()) {}
 
   ~VPMulAccRecipe() override = default;
 
@@ -2882,24 +2872,34 @@ class VPMulAccRecipe : public VPReductionRecipe {
   /// Return the type after inner extended, which must equal to the type of mul
   /// instruction. If the ResultType != recurrenceType, than it must have a
   /// extend recipe after mul recipe.
-  Type *getResultType() const { return ResultType; };
+  Type *getResultType() const { return ResultType; }
 
-  /// The opcode of the extend instructions.
-  Instruction::CastOps getExtOpcode() const { return ExtOp; };
   /// The underlying instruction for VPWidenRecipe.
-  Instruction *getMulInstr() const { return MulInstr; };
+  Instruction *getMulInstr() const { return MulInstr; }
 
   /// The underlying Instruction for outer VPWidenCastRecipe.
-  CastInst *getExtInstr() const { return ExtInstr; };
+  CastInst *getExtInstr() const { return ExtInstr; }
   /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt0Instr() const { return Ext0Instr; };
+  CastInst *getExt0Instr() const { return Ext0Instr; }
   /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt1Instr() const { return Ext1Instr; };
+  CastInst *getExt1Instr() const { return Ext1Instr; }
 
   /// Return if this MulAcc recipe contains extend instructions.
-  bool isExtended() const { return IsExtended; };
+  bool isExtended() const { return Ext0Instr && Ext1Instr; }
   /// Return if the operands of mul instruction come from same extend.
-  bool isSameExtend() const { return Ext0Instr == Ext1Instr; };
+  bool isSameExtend() const { return Ext0Instr == Ext1Instr; }
+  /// Return if the MulAcc recipes contains extend after mul.
+  bool isOuterExtended() const { return ExtInstr != nullptr; }
+  /// Return if the extend opcode is ZExt.
+  bool isZExt() const {
+    if (!isExtended())
+      return true;
+    // reduce.add(sext(mul(zext(A), zext(A)))) can be transform to
+    // reduce.add(zext(mul(sext(A), sext(A))))
+    if (ExtInstr && ExtInstr->getOpcode() != Ext0Instr->getOpcode())
+      return true;
+    return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
+  }
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index c7c50702f90b44..5aef2089ac37d8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2264,8 +2264,8 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
   const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
-  Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
-                               : Ctx.Types.inferScalarType(getVecOp0());
+  Type *ElementTy = isExtended() ? RdxDesc.getRecurrenceType()
+                                 : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
@@ -2281,20 +2281,19 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
 
   // Extended cost
   InstructionCost ExtendedCost = 0;
-  if (IsExtended) {
+  if (isExtended()) {
     auto *SrcTy = cast<VectorType>(
         ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
     auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
     TTI::CastContextHint CCH0 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    // Arm TTI will use the underlying instruction to determine the cost.
     ExtendedCost = Ctx.TTI.getCastInstrCost(
-        ExtOp, DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        Ext0Instr->getOpcode(), DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt0Instr()));
     TTI::CastContextHint CCH1 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
     ExtendedCost += Ctx.TTI.getCastInstrCost(
-        ExtOp, DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        Ext1Instr->getOpcode(), DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt1Instr()));
   }
 
@@ -2302,7 +2301,7 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   InstructionCost MulCost;
   SmallVector<const Value *, 4> Operands;
   Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
-  if (IsExtended)
+  if (isExtended())
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
@@ -2329,9 +2328,8 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   // MulAccReduction Cost
   VectorType *SrcVecTy =
       cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
-  InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
-      getExtOpcode() == Instruction::CastOps::ZExt, ElementTy, SrcVecTy,
-      CostKind);
+  InstructionCost MulAccCost =
+      Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
 
   // Check if folding ext into ExtendedReduction is profitable.
   if (MulAccCost.isValid() &&
@@ -2420,26 +2418,26 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " + ";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  if (IsOuterExtended)
+  if (isOuterExtended())
     O << " (";
   O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
   O << "mul ";
-  if (IsExtended)
+  if (isExtended())
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
-  if (IsExtended)
+  if (isExtended())
     O << " extended to " << *getResultType() << "), (";
   else
     O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
-  if (IsExtended)
+  if (isExtended())
     O << " extended to " << *getResultType() << ")";
   if (isConditional()) {
     O << ", ";
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
-  if (IsOuterExtended)
+  if (isOuterExtended())
     O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
   if (RdxDesc.IntermediateStore)
     O << " (with final reduction value stored in invariant address sank "
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 5e9c30911e34e5..53e080b5bb01d6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -526,9 +526,12 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
            vp_depth_first_deep(Plan.getEntry()))) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
       if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
+        // Genearte VPWidenCastRecipe.
         auto *Ext = new VPWidenCastRecipe(
             ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
             *ExtRed->getExtInstr());
+
+        // Generate VPreductionRecipe.
         auto *Red = new VPReductionRecipe(
             ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
             ExtRed->getChainOp(), Ext, ExtRed->getCondOp(),
@@ -539,45 +542,55 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         ExtRed->eraseFromParent();
       } else if (isa<VPMulAccRecipe>(&R)) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
+
+        // Generate inner VPWidenCastRecipes if necessary.
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
           CastInst *Ext0 = MulAcc->getExt0Instr();
           Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
                                       MulAcc->getResultType(), *Ext0);
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
-          if (!MulAcc->isSameExtend()) {
+          // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
+          // VPWidenCastRecipe.
+          if (MulAcc->isSameExtend()) {
+            Op1 = Op0;
+          } else {
             CastInst *Ext1 = MulAcc->getExt1Instr();
             Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
                                         MulAcc->getResultType(), *Ext1);
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
-          } else {
-            Op1 = Op0;
           }
+          // Not contains extend instruction in this MulAccRecipe.
         } else {
           Op0 = MulAcc->getVecOp0();
           Op1 = MulAcc->getVecOp1();
         }
+
+        // Generate VPWidenRecipe.
         VPSingleDefRecipe *VecOp;
-        Instruction *MulInstr = MulAcc->getMulInstr();
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
-        auto *Mul = new VPWidenRecipe(*MulInstr,
+        auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
                                       make_range(MulOps.begin(), MulOps.end()));
-        // Outer extend.
-        if (auto *OuterExtInstr = MulAcc->getExtInstr())
+        Mul->insertBefore(MulAcc);
+
+        // Generate outer VPWidenCastRecipe if necessary.
+        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
           VecOp = new VPWidenCastRecipe(
               OuterExtInstr->getOpcode(), Mul,
               MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
               *OuterExtInstr);
-        else
+          VecOp->insertBefore(MulAcc);
+        } else {
           VecOp = Mul;
+        }
+
+        // Generate VPReductionRecipe.
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
             MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
             MulAcc->isOrdered());
-        Mul->insertBefore(MulAcc);
-        if (VecOp != Mul)
-          VecOp->insertBefore(MulAcc);
         Red->insertBefore(MulAcc);
+
         MulAcc->replaceAllUsesWith(Red);
         MulAcc->eraseFromParent();
       }

>From 46928bdc43fe5fc41e4cdaf8e5698967cd076172 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 14 Nov 2024 04:23:02 -0800
Subject: [PATCH 10/26] Remove underying instruction dependency.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  24 ++--
 llvm/lib/Transforms/Vectorize/VPlan.h         | 116 +++++++-----------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  31 ++---
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  33 ++---
 .../LoopVectorize/ARM/mve-reductions.ll       |  89 +++++++++-----
 5 files changed, 137 insertions(+), 156 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 5cee4b7e2ce70e..9c22cbcb055e72 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7397,18 +7397,18 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       if (Instruction *UI = GetInstructionForCost(&R))
         SeenInstrs.insert(UI);
       // VPExtendedReductionRecipe contains a folded extend instruction.
-      if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
-        SeenInstrs.insert(ExtendedRed->getExtInstr());
-      // VPMulAccRecipe constians a mul and otional extend instructions.
-      else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
-        SeenInstrs.insert(MulAcc->getMulInstr());
-        if (MulAcc->isExtended()) {
-          SeenInstrs.insert(MulAcc->getExt0Instr());
-          SeenInstrs.insert(MulAcc->getExt1Instr());
-          if (auto *Ext = MulAcc->getExtInstr())
-            SeenInstrs.insert(Ext);
-        }
-      }
+      // if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
+      //   SeenInstrs.insert(ExtendedRed->getExtInstr());
+      // // VPMulAccRecipe constians a mul and otional extend instructions.
+      // else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+      //   SeenInstrs.insert(MulAcc->getMulInstr());
+      //   if (MulAcc->isExtended()) {
+      //     SeenInstrs.insert(MulAcc->getExt0Instr());
+      //     SeenInstrs.insert(MulAcc->getExt1Instr());
+      //     if (auto *Ext = MulAcc->getExtInstr())
+      //       SeenInstrs.insert(Ext);
+      //   }
+      // }
     }
   }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index ab9e6f0393602f..23a5c2756b6f92 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1448,11 +1448,20 @@ class VPWidenRecipe : public VPRecipeWithIRFlags {
                 iterator_range<IterT> Operands)
       : VPRecipeWithIRFlags(VPDefOpcode, Operands, I), Opcode(I.getOpcode()) {}
 
+  template <typename IterT>
+  VPWidenRecipe(unsigned VPDefOpcode, unsigned InstrOpcode,
+                iterator_range<IterT> Operands)
+      : VPRecipeWithIRFlags(VPDefOpcode, Operands), Opcode(InstrOpcode) {}
+
 public:
   template <typename IterT>
   VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands)
       : VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
 
+  template <typename IterT>
+  VPWidenRecipe(unsigned Opcode, iterator_range<IterT> Operands)
+      : VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands) {}
+
   ~VPWidenRecipe() override = default;
 
   VPWidenRecipe *clone() override {
@@ -2706,26 +2715,25 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 class VPExtendedReductionRecipe : public VPReductionRecipe {
   /// Type after extend.
   Type *ResultTy;
-  CastInst *ExtInstr;
+  Instruction::CastOps ExtOp;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
                             const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, CastInst *ExtInstr,
-                            VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
-                            bool IsOrdered, Type *ResultTy)
+                            Instruction::CastOps ExtOp, VPValue *ChainOp,
+                            VPValue *VecOp, VPValue *CondOp, bool IsOrdered,
+                            Type *ResultTy)
       : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
                           CondOp, IsOrdered),
-        ResultTy(ResultTy), ExtInstr(ExtInstr) {}
+        ResultTy(ResultTy), ExtOp(ExtOp) {}
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                             VPValue *ChainOp, VPWidenCastRecipe *Ext,
                             VPValue *CondOp, bool IsOrdered)
-      : VPExtendedReductionRecipe(
-            VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
-            cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
-            Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
+      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
+                                  Ext->getOpcode(), ChainOp, Ext->getOperand(0),
+                                  CondOp, IsOrdered, Ext->getResultType()) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2761,9 +2769,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
   Type *getResultType() const { return ResultTy; }
   bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
   /// The Opcode of extend instruction.
-  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
-  /// The CastInst of the extend instruction.
-  CastInst *getExtInstr() const { return ExtInstr; }
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
 };
 
 /// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2772,70 +2778,52 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
 /// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPReductionRecipe {
-  /// Type after extend.
-  Type *ResultType;
-
-  /// reduce.add(ext(mul(ext0(), ext1())))
-  Instruction *MulInstr;
-  CastInst *ExtInstr = nullptr;
-  CastInst *Ext0Instr = nullptr;
-  CastInst *Ext1Instr = nullptr;
+  Instruction::CastOps ExtOp;
+  bool IsExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *ExtInstr,
-                 Instruction *MulInstr, Instruction *Ext0Instr,
-                 Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
-                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
-                 Type *ResultType)
+                 Instruction *RedI, Instruction::CastOps ExtOp,
+                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
+                 VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered),
-        ResultType(ResultType), MulInstr(MulInstr),
-        ExtInstr(cast_if_present<CastInst>(ExtInstr)),
-        Ext0Instr(cast<CastInst>(Ext0Instr)),
-        Ext1Instr(cast<CastInst>(Ext1Instr)) {
-    assert(MulInstr->getOpcode() == Instruction::Mul);
+        ExtOp(ExtOp) {
+    IsExtended = true;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *MulInstr, VPValue *ChainOp,
-                 VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp,
-                 bool IsOrdered)
+                 Instruction *RedI, VPValue *ChainOp, VPValue *VecOp0,
+                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
-                          CondOp, IsOrdered),
-        MulInstr(MulInstr) {
-    assert(MulInstr->getOpcode() == Instruction::Mul);
-  }
+                          CondOp, IsOrdered) {}
 
 public:
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
-                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
-                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
-                       Ext1->getOperand(0), CondOp, IsOrdered,
-                       Ext0->getResultType()) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI,
+                       cast<CastInst>(Ext0->getUnderlyingInstr())->getOpcode(),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
+                       CondOp, IsOrdered) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
-                       IsOrdered) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, ChainOp, Mul->getOperand(0),
+                       Mul->getOperand(1), CondOp, IsOrdered) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
                  VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
-                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
-                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
-                       Ext1->getOperand(0), CondOp, IsOrdered,
-                       Ext0->getResultType()) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI,
+                       cast<CastInst>(Ext0->getUnderlyingInstr())->getOpcode(),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
+                       CondOp, IsOrdered) {}
 
   ~VPMulAccRecipe() override = default;
 
@@ -2869,37 +2857,17 @@ class VPMulAccRecipe : public VPReductionRecipe {
   VPValue *getVecOp0() const { return getOperand(1); }
   VPValue *getVecOp1() const { return getOperand(2); }
 
-  /// Return the type after inner extended, which must equal to the type of mul
-  /// instruction. If the ResultType != recurrenceType, than it must have a
-  /// extend recipe after mul recipe.
-  Type *getResultType() const { return ResultType; }
-
-  /// The underlying instruction for VPWidenRecipe.
-  Instruction *getMulInstr() const { return MulInstr; }
-
-  /// The underlying Instruction for outer VPWidenCastRecipe.
-  CastInst *getExtInstr() const { return ExtInstr; }
-  /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt0Instr() const { return Ext0Instr; }
-  /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt1Instr() const { return Ext1Instr; }
-
   /// Return if this MulAcc recipe contains extend instructions.
-  bool isExtended() const { return Ext0Instr && Ext1Instr; }
+  bool isExtended() const { return IsExtended; }
   /// Return if the operands of mul instruction come from same extend.
-  bool isSameExtend() const { return Ext0Instr == Ext1Instr; }
-  /// Return if the MulAcc recipes contains extend after mul.
-  bool isOuterExtended() const { return ExtInstr != nullptr; }
+  bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
   /// Return if the extend opcode is ZExt.
   bool isZExt() const {
     if (!isExtended())
       return true;
-    // reduce.add(sext(mul(zext(A), zext(A)))) can be transform to
-    // reduce.add(zext(mul(sext(A), sext(A))))
-    if (ExtInstr && ExtInstr->getOpcode() != Ext0Instr->getOpcode())
-      return true;
-    return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
+    return ExtOp == Instruction::CastOps::ZExt;
   }
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 5aef2089ac37d8..edcc229cbfd5a8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2267,6 +2267,8 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   Type *ElementTy = isExtended() ? RdxDesc.getRecurrenceType()
                                  : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  auto *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
@@ -2284,29 +2286,25 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   if (isExtended()) {
     auto *SrcTy = cast<VectorType>(
         ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
-    auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
     TTI::CastContextHint CCH0 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost = Ctx.TTI.getCastInstrCost(
-        Ext0Instr->getOpcode(), DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
-        dyn_cast_if_present<Instruction>(getExt0Instr()));
+    ExtendedCost = Ctx.TTI.getCastInstrCost(getExtOpcode(), VectorTy, SrcVecTy,
+                                            CCH0, TTI::TCK_RecipThroughput);
     TTI::CastContextHint CCH1 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost += Ctx.TTI.getCastInstrCost(
-        Ext1Instr->getOpcode(), DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
-        dyn_cast_if_present<Instruction>(getExt1Instr()));
+    ExtendedCost += Ctx.TTI.getCastInstrCost(getExtOpcode(), VectorTy, SrcVecTy,
+                                             CCH1, TTI::TCK_RecipThroughput);
   }
 
   // Mul cost
   InstructionCost MulCost;
   SmallVector<const Value *, 4> Operands;
-  Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
   if (isExtended())
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        Operands, MulInstr, &Ctx.TLI);
+        Operands, nullptr, &Ctx.TLI);
   else {
     VPValue *RHS = getVecOp1();
     // Certain instructions can be cheaper to vectorize if they have a constant
@@ -2319,15 +2317,15 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
     if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
         RHS->isDefinedOutsideLoopRegions())
       RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
+    Operands.append(
+        {getVecOp0()->getUnderlyingValue(), RHS->getUnderlyingValue()});
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        RHSInfo, Operands, MulInstr, &Ctx.TLI);
+        RHSInfo, Operands, nullptr, &Ctx.TLI);
   }
 
   // MulAccReduction Cost
-  VectorType *SrcVecTy =
-      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
   InstructionCost MulAccCost =
       Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
 
@@ -2411,6 +2409,7 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
                            VPSlotTracker &SlotTracker) const {
   const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+  Type *ElementTy = RdxDesc.getRecurrenceType();
   O << Indent << "MULACC-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
@@ -2418,27 +2417,23 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " + ";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  if (isOuterExtended())
-    O << " (";
   O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
   O << "mul ";
   if (isExtended())
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
   if (isExtended())
-    O << " extended to " << *getResultType() << "), (";
+    O << " extended to " << *ElementTy << "), (";
   else
     O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
   if (isExtended())
-    O << " extended to " << *getResultType() << ")";
+    O << " extended to " << *ElementTy << ")";
   if (isConditional()) {
     O << ", ";
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
-  if (isOuterExtended())
-    O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
   if (RdxDesc.IntermediateStore)
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 53e080b5bb01d6..50afdeb3aa78c8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -527,9 +527,9 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
       if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
         // Genearte VPWidenCastRecipe.
-        auto *Ext = new VPWidenCastRecipe(
-            ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
-            *ExtRed->getExtInstr());
+        auto *Ext =
+            new VPWidenCastRecipe(ExtRed->getExtOpcode(), ExtRed->getVecOp(),
+                                  ExtRed->getResultType());
 
         // Generate VPreductionRecipe.
         auto *Red = new VPReductionRecipe(
@@ -542,22 +542,21 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         ExtRed->eraseFromParent();
       } else if (isa<VPMulAccRecipe>(&R)) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
+        Type *RedType = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
 
         // Generate inner VPWidenCastRecipes if necessary.
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
-          CastInst *Ext0 = MulAcc->getExt0Instr();
-          Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
-                                      MulAcc->getResultType(), *Ext0);
+          Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+                                      MulAcc->getVecOp0(), RedType);
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
           // VPWidenCastRecipe.
           if (MulAcc->isSameExtend()) {
             Op1 = Op0;
           } else {
-            CastInst *Ext1 = MulAcc->getExt1Instr();
-            Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
-                                        MulAcc->getResultType(), *Ext1);
+            Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+                                        MulAcc->getVecOp1(), RedType);
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           }
           // Not contains extend instruction in this MulAccRecipe.
@@ -567,27 +566,15 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         }
 
         // Generate VPWidenRecipe.
-        VPSingleDefRecipe *VecOp;
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
-        auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
+        auto *Mul = new VPWidenRecipe(Instruction::Mul,
                                       make_range(MulOps.begin(), MulOps.end()));
         Mul->insertBefore(MulAcc);
 
-        // Generate outer VPWidenCastRecipe if necessary.
-        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
-          VecOp = new VPWidenCastRecipe(
-              OuterExtInstr->getOpcode(), Mul,
-              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
-              *OuterExtInstr);
-          VecOp->insertBefore(MulAcc);
-        } else {
-          VecOp = Mul;
-        }
-
         // Generate VPReductionRecipe.
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
-            MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
+            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
             MulAcc->isOrdered());
         Red->insertBefore(MulAcc);
 
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index b138666ba92f9b..f99f33365fad0f 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1670,24 +1670,55 @@ define i64 @test_std_q31(ptr %x, i32 %n) #0 {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP11:%.*]] = icmp sgt i32 [[N:%.*]], 0
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[CMP11]])
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp samesign ult i32 [[N]], 4
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N]], 2147483644
 ; CHECK-NEXT:    br label [[FOR_BODY1:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT:    [[VEC_PHI1:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP10]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr <4 x i32> [[WIDE_LOAD]], splat (i32 8)
+; CHECK-NEXT:    [[TMP2:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP2]])
+; CHECK-NEXT:    [[TMP4]] = add i64 [[TMP3]], [[VEC_PHI]]
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <4 x i64> [[TMP5]], [[TMP5]]
+; CHECK-NEXT:    [[TMP7:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP6]])
+; CHECK-NEXT:    [[TMP8]] = add i64 [[TMP7]], [[VEC_PHI1]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[FOR_BODY1]], !llvm.loop [[LOOP37:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK:       scalar.ph:
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i64 [ [[TMP4]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX2:%.*]] = phi i64 [ [[TMP8]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.cond.cleanup:
-; CHECK-NEXT:    [[DIV:%.*]] = sdiv i64 [[ADD3:%.*]], [[ADD:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = phi i64 [ [[ADD1:%.*]], [[FOR_BODY]] ], [ [[TMP4]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[ADD3:%.*]] = phi i64 [ [[ADD5:%.*]], [[FOR_BODY]] ], [ [[TMP8]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[DIV:%.*]] = sdiv i64 [[ADD3]], [[ADD]]
 ; CHECK-NEXT:    ret i64 [[DIV]]
 ; CHECK:       for.body:
-; CHECK-NEXT:    [[S_014:%.*]] = phi i64 [ [[ADD]], [[FOR_BODY1]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT:    [[T_012:%.*]] = phi i64 [ [[ADD3]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[I_013]]
+; CHECK-NEXT:    [[S_014:%.*]] = phi i64 [ [[ADD1]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
+; CHECK-NEXT:    [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
+; CHECK-NEXT:    [[T_012:%.*]] = phi i64 [ [[ADD5]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX2]], [[SCALAR_PH]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X]], i32 [[I_013]]
 ; CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
 ; CHECK-NEXT:    [[SHR:%.*]] = ashr i32 [[TMP0]], 8
 ; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[SHR]] to i64
-; CHECK-NEXT:    [[ADD]] = add nsw i64 [[S_014]], [[CONV]]
+; CHECK-NEXT:    [[ADD1]] = add nsw i64 [[S_014]], [[CONV]]
 ; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV]], [[CONV]]
-; CHECK-NEXT:    [[ADD3]] = add nuw nsw i64 [[MUL]], [[T_012]]
+; CHECK-NEXT:    [[ADD5]] = add nuw nsw i64 [[MUL]], [[T_012]]
 ; CHECK-NEXT:    [[ADD4]] = add nuw nsw i32 [[I_013]], 1
 ; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i32 [[ADD4]], [[N]]
-; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY1]]
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP38:![0-9]+]]
 ;
 entry:
   %cmp11 = icmp sgt i32 %n, 0
@@ -1723,10 +1754,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[TMP0:%.*]] = add nsw i32 [[N]], -1
 ; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -4
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -8
 ; CHECK-NEXT:    [[IND_END:%.*]] = shl i32 [[N_VEC]], 1
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
@@ -1734,28 +1765,28 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32>
+; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]]
-; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64>
-; CHECK-NEXT:    [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32>
-; CHECK-NEXT:    [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]]
-; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64>
-; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]])
+; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i32>
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <8 x i32> [[TMP5]], [[TMP6]]
+; CHECK-NEXT:    [[TMP8:%.*]] = sext <8 x i32> [[TMP7]] to <8 x i64>
+; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP8]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]])
+; CHECK-NEXT:    [[TMP11:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i32>
+; CHECK-NEXT:    [[TMP12:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP13:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP12]]
+; CHECK-NEXT:    [[TMP14:%.*]] = sext <8 x i32> [[TMP13]] to <8 x i64>
+; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP14]])
 ; CHECK-NEXT:    [[TMP16]] = add i64 [[TMP15]], [[TMP10]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]]
+; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP39:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[TMP2]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP]], label [[SCALAR_PH]]
@@ -1790,7 +1821,7 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[ADD12]] = add nsw i64 [[ADD]], [[CONV11]]
 ; CHECK-NEXT:    [[ADD13]] = add nuw nsw i32 [[I_025]], 2
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[ADD13]], [[N]]
-; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP38:![0-9]+]]
+; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP40:![0-9]+]]
 ;
 entry:
   %cmp23 = icmp sgt i32 %n, 0

>From 35abf19fdb78efe2ded6cbe283a582f78aa068e6 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 14 Nov 2024 05:21:18 -0800
Subject: [PATCH 11/26] Revert "Remove underying instruction dependency."

This reverts commit e4e510802f44c9936f5a892eb16ba2b19651ee15.

We need underlying instruction for accurate TTI query and
the metadata from to attach on the generating vector instructions.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  24 ++--
 llvm/lib/Transforms/Vectorize/VPlan.h         | 116 +++++++++++-------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  31 +++--
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  33 +++--
 .../LoopVectorize/ARM/mve-reductions.ll       |  89 +++++---------
 5 files changed, 156 insertions(+), 137 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 9c22cbcb055e72..5cee4b7e2ce70e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7397,18 +7397,18 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       if (Instruction *UI = GetInstructionForCost(&R))
         SeenInstrs.insert(UI);
       // VPExtendedReductionRecipe contains a folded extend instruction.
-      // if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
-      //   SeenInstrs.insert(ExtendedRed->getExtInstr());
-      // // VPMulAccRecipe constians a mul and otional extend instructions.
-      // else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
-      //   SeenInstrs.insert(MulAcc->getMulInstr());
-      //   if (MulAcc->isExtended()) {
-      //     SeenInstrs.insert(MulAcc->getExt0Instr());
-      //     SeenInstrs.insert(MulAcc->getExt1Instr());
-      //     if (auto *Ext = MulAcc->getExtInstr())
-      //       SeenInstrs.insert(Ext);
-      //   }
-      // }
+      if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
+        SeenInstrs.insert(ExtendedRed->getExtInstr());
+      // VPMulAccRecipe constians a mul and otional extend instructions.
+      else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+        SeenInstrs.insert(MulAcc->getMulInstr());
+        if (MulAcc->isExtended()) {
+          SeenInstrs.insert(MulAcc->getExt0Instr());
+          SeenInstrs.insert(MulAcc->getExt1Instr());
+          if (auto *Ext = MulAcc->getExtInstr())
+            SeenInstrs.insert(Ext);
+        }
+      }
     }
   }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 23a5c2756b6f92..ab9e6f0393602f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1448,20 +1448,11 @@ class VPWidenRecipe : public VPRecipeWithIRFlags {
                 iterator_range<IterT> Operands)
       : VPRecipeWithIRFlags(VPDefOpcode, Operands, I), Opcode(I.getOpcode()) {}
 
-  template <typename IterT>
-  VPWidenRecipe(unsigned VPDefOpcode, unsigned InstrOpcode,
-                iterator_range<IterT> Operands)
-      : VPRecipeWithIRFlags(VPDefOpcode, Operands), Opcode(InstrOpcode) {}
-
 public:
   template <typename IterT>
   VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands)
       : VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
 
-  template <typename IterT>
-  VPWidenRecipe(unsigned Opcode, iterator_range<IterT> Operands)
-      : VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands) {}
-
   ~VPWidenRecipe() override = default;
 
   VPWidenRecipe *clone() override {
@@ -2715,25 +2706,26 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 class VPExtendedReductionRecipe : public VPReductionRecipe {
   /// Type after extend.
   Type *ResultTy;
-  Instruction::CastOps ExtOp;
+  CastInst *ExtInstr;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
                             const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, VPValue *ChainOp,
-                            VPValue *VecOp, VPValue *CondOp, bool IsOrdered,
-                            Type *ResultTy)
+                            Instruction::CastOps ExtOp, CastInst *ExtInstr,
+                            VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
+                            bool IsOrdered, Type *ResultTy)
       : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
                           CondOp, IsOrdered),
-        ResultTy(ResultTy), ExtOp(ExtOp) {}
+        ResultTy(ResultTy), ExtInstr(ExtInstr) {}
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                             VPValue *ChainOp, VPWidenCastRecipe *Ext,
                             VPValue *CondOp, bool IsOrdered)
-      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
-                                  Ext->getOpcode(), ChainOp, Ext->getOperand(0),
-                                  CondOp, IsOrdered, Ext->getResultType()) {}
+      : VPExtendedReductionRecipe(
+            VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
+            cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
+            Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2769,7 +2761,9 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
   Type *getResultType() const { return ResultTy; }
   bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
   /// The Opcode of extend instruction.
-  Instruction::CastOps getExtOpcode() const { return ExtOp; }
+  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
+  /// The CastInst of the extend instruction.
+  CastInst *getExtInstr() const { return ExtInstr; }
 };
 
 /// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2778,52 +2772,70 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
 /// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPReductionRecipe {
-  Instruction::CastOps ExtOp;
-  bool IsExtended = false;
+  /// Type after extend.
+  Type *ResultType;
+
+  /// reduce.add(ext(mul(ext0(), ext1())))
+  Instruction *MulInstr;
+  CastInst *ExtInstr = nullptr;
+  CastInst *Ext0Instr = nullptr;
+  CastInst *Ext1Instr = nullptr;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction::CastOps ExtOp,
-                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
-                 VPValue *CondOp, bool IsOrdered)
+                 Instruction *RedI, Instruction *ExtInstr,
+                 Instruction *MulInstr, Instruction *Ext0Instr,
+                 Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
+                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
+                 Type *ResultType)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered),
-        ExtOp(ExtOp) {
-    IsExtended = true;
+        ResultType(ResultType), MulInstr(MulInstr),
+        ExtInstr(cast_if_present<CastInst>(ExtInstr)),
+        Ext0Instr(cast<CastInst>(Ext0Instr)),
+        Ext1Instr(cast<CastInst>(Ext1Instr)) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, VPValue *ChainOp, VPValue *VecOp0,
-                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered)
+                 Instruction *RedI, Instruction *MulInstr, VPValue *ChainOp,
+                 VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp,
+                 bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
-                          CondOp, IsOrdered) {}
+                          CondOp, IsOrdered),
+        MulInstr(MulInstr) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
+  }
 
 public:
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI,
-                       cast<CastInst>(Ext0->getUnderlyingInstr())->getOpcode(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
+                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered,
+                       Ext0->getResultType()) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, ChainOp, Mul->getOperand(0),
-                       Mul->getOperand(1), CondOp, IsOrdered) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
+                       IsOrdered) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
                  VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI,
-                       cast<CastInst>(Ext0->getUnderlyingInstr())->getOpcode(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
+                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered,
+                       Ext0->getResultType()) {}
 
   ~VPMulAccRecipe() override = default;
 
@@ -2857,17 +2869,37 @@ class VPMulAccRecipe : public VPReductionRecipe {
   VPValue *getVecOp0() const { return getOperand(1); }
   VPValue *getVecOp1() const { return getOperand(2); }
 
+  /// Return the type after inner extended, which must equal to the type of mul
+  /// instruction. If the ResultType != recurrenceType, than it must have a
+  /// extend recipe after mul recipe.
+  Type *getResultType() const { return ResultType; }
+
+  /// The underlying instruction for VPWidenRecipe.
+  Instruction *getMulInstr() const { return MulInstr; }
+
+  /// The underlying Instruction for outer VPWidenCastRecipe.
+  CastInst *getExtInstr() const { return ExtInstr; }
+  /// The underlying Instruction for inner VPWidenCastRecipe.
+  CastInst *getExt0Instr() const { return Ext0Instr; }
+  /// The underlying Instruction for inner VPWidenCastRecipe.
+  CastInst *getExt1Instr() const { return Ext1Instr; }
+
   /// Return if this MulAcc recipe contains extend instructions.
-  bool isExtended() const { return IsExtended; }
+  bool isExtended() const { return Ext0Instr && Ext1Instr; }
   /// Return if the operands of mul instruction come from same extend.
-  bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
+  bool isSameExtend() const { return Ext0Instr == Ext1Instr; }
+  /// Return if the MulAcc recipes contains extend after mul.
+  bool isOuterExtended() const { return ExtInstr != nullptr; }
   /// Return if the extend opcode is ZExt.
   bool isZExt() const {
     if (!isExtended())
       return true;
-    return ExtOp == Instruction::CastOps::ZExt;
+    // reduce.add(sext(mul(zext(A), zext(A)))) can be transform to
+    // reduce.add(zext(mul(sext(A), sext(A))))
+    if (ExtInstr && ExtInstr->getOpcode() != Ext0Instr->getOpcode())
+      return true;
+    return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
   }
-  Instruction::CastOps getExtOpcode() const { return ExtOp; }
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index edcc229cbfd5a8..5aef2089ac37d8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2267,8 +2267,6 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   Type *ElementTy = isExtended() ? RdxDesc.getRecurrenceType()
                                  : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
-  auto *SrcVecTy =
-      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
@@ -2286,25 +2284,29 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   if (isExtended()) {
     auto *SrcTy = cast<VectorType>(
         ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+    auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
     TTI::CastContextHint CCH0 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost = Ctx.TTI.getCastInstrCost(getExtOpcode(), VectorTy, SrcVecTy,
-                                            CCH0, TTI::TCK_RecipThroughput);
+    ExtendedCost = Ctx.TTI.getCastInstrCost(
+        Ext0Instr->getOpcode(), DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt0Instr()));
     TTI::CastContextHint CCH1 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost += Ctx.TTI.getCastInstrCost(getExtOpcode(), VectorTy, SrcVecTy,
-                                             CCH1, TTI::TCK_RecipThroughput);
+    ExtendedCost += Ctx.TTI.getCastInstrCost(
+        Ext1Instr->getOpcode(), DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt1Instr()));
   }
 
   // Mul cost
   InstructionCost MulCost;
   SmallVector<const Value *, 4> Operands;
+  Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
   if (isExtended())
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        Operands, nullptr, &Ctx.TLI);
+        Operands, MulInstr, &Ctx.TLI);
   else {
     VPValue *RHS = getVecOp1();
     // Certain instructions can be cheaper to vectorize if they have a constant
@@ -2317,15 +2319,15 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
     if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
         RHS->isDefinedOutsideLoopRegions())
       RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
-    Operands.append(
-        {getVecOp0()->getUnderlyingValue(), RHS->getUnderlyingValue()});
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        RHSInfo, Operands, nullptr, &Ctx.TLI);
+        RHSInfo, Operands, MulInstr, &Ctx.TLI);
   }
 
   // MulAccReduction Cost
+  VectorType *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
   InstructionCost MulAccCost =
       Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
 
@@ -2409,7 +2411,6 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
                            VPSlotTracker &SlotTracker) const {
   const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
-  Type *ElementTy = RdxDesc.getRecurrenceType();
   O << Indent << "MULACC-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
@@ -2417,23 +2418,27 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " + ";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
+  if (isOuterExtended())
+    O << " (";
   O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
   O << "mul ";
   if (isExtended())
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
   if (isExtended())
-    O << " extended to " << *ElementTy << "), (";
+    O << " extended to " << *getResultType() << "), (";
   else
     O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
   if (isExtended())
-    O << " extended to " << *ElementTy << ")";
+    O << " extended to " << *getResultType() << ")";
   if (isConditional()) {
     O << ", ";
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
+  if (isOuterExtended())
+    O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
   if (RdxDesc.IntermediateStore)
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 50afdeb3aa78c8..53e080b5bb01d6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -527,9 +527,9 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
       if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
         // Genearte VPWidenCastRecipe.
-        auto *Ext =
-            new VPWidenCastRecipe(ExtRed->getExtOpcode(), ExtRed->getVecOp(),
-                                  ExtRed->getResultType());
+        auto *Ext = new VPWidenCastRecipe(
+            ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+            *ExtRed->getExtInstr());
 
         // Generate VPreductionRecipe.
         auto *Red = new VPReductionRecipe(
@@ -542,21 +542,22 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         ExtRed->eraseFromParent();
       } else if (isa<VPMulAccRecipe>(&R)) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
-        Type *RedType = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
 
         // Generate inner VPWidenCastRecipes if necessary.
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
-          Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
-                                      MulAcc->getVecOp0(), RedType);
+          CastInst *Ext0 = MulAcc->getExt0Instr();
+          Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
+                                      MulAcc->getResultType(), *Ext0);
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
           // VPWidenCastRecipe.
           if (MulAcc->isSameExtend()) {
             Op1 = Op0;
           } else {
-            Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
-                                        MulAcc->getVecOp1(), RedType);
+            CastInst *Ext1 = MulAcc->getExt1Instr();
+            Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
+                                        MulAcc->getResultType(), *Ext1);
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           }
           // Not contains extend instruction in this MulAccRecipe.
@@ -566,15 +567,27 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         }
 
         // Generate VPWidenRecipe.
+        VPSingleDefRecipe *VecOp;
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
-        auto *Mul = new VPWidenRecipe(Instruction::Mul,
+        auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
                                       make_range(MulOps.begin(), MulOps.end()));
         Mul->insertBefore(MulAcc);
 
+        // Generate outer VPWidenCastRecipe if necessary.
+        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
+          VecOp = new VPWidenCastRecipe(
+              OuterExtInstr->getOpcode(), Mul,
+              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
+              *OuterExtInstr);
+          VecOp->insertBefore(MulAcc);
+        } else {
+          VecOp = Mul;
+        }
+
         // Generate VPReductionRecipe.
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
-            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+            MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
             MulAcc->isOrdered());
         Red->insertBefore(MulAcc);
 
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index f99f33365fad0f..b138666ba92f9b 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1670,55 +1670,24 @@ define i64 @test_std_q31(ptr %x, i32 %n) #0 {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP11:%.*]] = icmp sgt i32 [[N:%.*]], 0
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[CMP11]])
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp samesign ult i32 [[N]], 4
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
-; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N]], 2147483644
 ; CHECK-NEXT:    br label [[FOR_BODY1:%.*]]
-; CHECK:       vector.body:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[FOR_BODY1]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[FOR_BODY1]] ]
-; CHECK-NEXT:    [[VEC_PHI1:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[FOR_BODY1]] ]
-; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP10]], align 4
-; CHECK-NEXT:    [[TMP1:%.*]] = ashr <4 x i32> [[WIDE_LOAD]], splat (i32 8)
-; CHECK-NEXT:    [[TMP2:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
-; CHECK-NEXT:    [[TMP3:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP2]])
-; CHECK-NEXT:    [[TMP4]] = add i64 [[TMP3]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
-; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <4 x i64> [[TMP5]], [[TMP5]]
-; CHECK-NEXT:    [[TMP7:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP6]])
-; CHECK-NEXT:    [[TMP8]] = add i64 [[TMP7]], [[VEC_PHI1]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
-; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[FOR_BODY1]], !llvm.loop [[LOOP37:![0-9]+]]
-; CHECK:       middle.block:
-; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
-; CHECK:       scalar.ph:
-; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i64 [ [[TMP4]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT:    [[BC_MERGE_RDX2:%.*]] = phi i64 [ [[TMP8]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.cond.cleanup:
-; CHECK-NEXT:    [[ADD:%.*]] = phi i64 [ [[ADD1:%.*]], [[FOR_BODY]] ], [ [[TMP4]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    [[ADD3:%.*]] = phi i64 [ [[ADD5:%.*]], [[FOR_BODY]] ], [ [[TMP8]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    [[DIV:%.*]] = sdiv i64 [[ADD3]], [[ADD]]
+; CHECK-NEXT:    [[DIV:%.*]] = sdiv i64 [[ADD3:%.*]], [[ADD:%.*]]
 ; CHECK-NEXT:    ret i64 [[DIV]]
 ; CHECK:       for.body:
-; CHECK-NEXT:    [[S_014:%.*]] = phi i64 [ [[ADD1]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
-; CHECK-NEXT:    [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
-; CHECK-NEXT:    [[T_012:%.*]] = phi i64 [ [[ADD5]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX2]], [[SCALAR_PH]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X]], i32 [[I_013]]
+; CHECK-NEXT:    [[S_014:%.*]] = phi i64 [ [[ADD]], [[FOR_BODY1]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    [[T_012:%.*]] = phi i64 [ [[ADD3]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[I_013]]
 ; CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
 ; CHECK-NEXT:    [[SHR:%.*]] = ashr i32 [[TMP0]], 8
 ; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[SHR]] to i64
-; CHECK-NEXT:    [[ADD1]] = add nsw i64 [[S_014]], [[CONV]]
+; CHECK-NEXT:    [[ADD]] = add nsw i64 [[S_014]], [[CONV]]
 ; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV]], [[CONV]]
-; CHECK-NEXT:    [[ADD5]] = add nuw nsw i64 [[MUL]], [[T_012]]
+; CHECK-NEXT:    [[ADD3]] = add nuw nsw i64 [[MUL]], [[T_012]]
 ; CHECK-NEXT:    [[ADD4]] = add nuw nsw i32 [[I_013]], 1
 ; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i32 [[ADD4]], [[N]]
-; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP38:![0-9]+]]
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY1]]
 ;
 entry:
   %cmp11 = icmp sgt i32 %n, 0
@@ -1754,10 +1723,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[TMP0:%.*]] = add nsw i32 [[N]], -1
 ; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -8
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -4
 ; CHECK-NEXT:    [[IND_END:%.*]] = shl i32 [[N_VEC]], 1
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
@@ -1765,28 +1734,28 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
-; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
-; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <8 x i32> [[TMP5]], [[TMP6]]
-; CHECK-NEXT:    [[TMP8:%.*]] = sext <8 x i32> [[TMP7]] to <8 x i64>
-; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP8]])
+; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64>
+; CHECK-NEXT:    [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]]
+; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64>
+; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP11:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i32>
-; CHECK-NEXT:    [[TMP12:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP13:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP12]]
-; CHECK-NEXT:    [[TMP14:%.*]] = sext <8 x i32> [[TMP13]] to <8 x i64>
-; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP14]])
+; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]])
 ; CHECK-NEXT:    [[TMP16]] = add i64 [[TMP15]], [[TMP10]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP39:![0-9]+]]
+; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[TMP2]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP]], label [[SCALAR_PH]]
@@ -1821,7 +1790,7 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[ADD12]] = add nsw i64 [[ADD]], [[CONV11]]
 ; CHECK-NEXT:    [[ADD13]] = add nuw nsw i32 [[I_025]], 2
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[ADD13]], [[N]]
-; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP40:![0-9]+]]
+; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP38:![0-9]+]]
 ;
 entry:
   %cmp23 = icmp sgt i32 %n, 0

>From 453997e4b1a451430491bd1a9cdb851c0927ed23 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 14 Nov 2024 16:48:33 -0800
Subject: [PATCH 12/26] Remove extended instruction after mul in MulAccRecipe.

Note that reduce.add(ext(mul(ext(A), ext(A)))) is mathmetical equal
to reduce.add(mul(zext(A), zext(A))).
---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  13 --
 llvm/lib/Transforms/Vectorize/VPlan.h         |  58 ++++-----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  15 +--
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  25 ++--
 .../LoopVectorize/ARM/mve-reductions.ll       | 111 +++++++++++-------
 .../LoopVectorize/vplan-printing.ll           |   2 +-
 6 files changed, 104 insertions(+), 120 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 5cee4b7e2ce70e..14821362f71550 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7396,19 +7396,6 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       }
       if (Instruction *UI = GetInstructionForCost(&R))
         SeenInstrs.insert(UI);
-      // VPExtendedReductionRecipe contains a folded extend instruction.
-      if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
-        SeenInstrs.insert(ExtendedRed->getExtInstr());
-      // VPMulAccRecipe constians a mul and otional extend instructions.
-      else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
-        SeenInstrs.insert(MulAcc->getMulInstr());
-        if (MulAcc->isExtended()) {
-          SeenInstrs.insert(MulAcc->getExt0Instr());
-          SeenInstrs.insert(MulAcc->getExt1Instr());
-          if (auto *Ext = MulAcc->getExtInstr())
-            SeenInstrs.insert(Ext);
-        }
-      }
     }
   }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index ab9e6f0393602f..b04b694196e298 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2614,8 +2614,6 @@ class VPReductionRecipe : public VPSingleDefRecipe {
                                  getVecOp(), getCondOp(), IsOrdered);
   }
 
-  // TODO: Support VPExtendedReductionRecipe and VPMulAccRecipe after EVL
-  // support.
   static inline bool classof(const VPRecipeBase *R) {
     return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
            R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
@@ -2772,30 +2770,26 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
 /// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPReductionRecipe {
-  /// Type after extend.
-  Type *ResultType;
 
   /// reduce.add(ext(mul(ext0(), ext1())))
   Instruction *MulInstr;
-  CastInst *ExtInstr = nullptr;
   CastInst *Ext0Instr = nullptr;
   CastInst *Ext1Instr = nullptr;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *ExtInstr,
-                 Instruction *MulInstr, Instruction *Ext0Instr,
-                 Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
-                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
-                 Type *ResultType)
+                 Instruction *RedI, Instruction *MulInstr,
+                 Instruction *Ext0Instr, Instruction *Ext1Instr,
+                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
+                 VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered),
-        ResultType(ResultType), MulInstr(MulInstr),
-        ExtInstr(cast_if_present<CastInst>(ExtInstr)),
-        Ext0Instr(cast<CastInst>(Ext0Instr)),
+        MulInstr(MulInstr), Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
+    assert(R.getOpcode() == Instruction::Add);
+    assert(Ext0Instr->getOpcode() == Ext1Instr->getOpcode());
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2807,6 +2801,7 @@ class VPMulAccRecipe : public VPReductionRecipe {
                           CondOp, IsOrdered),
         MulInstr(MulInstr) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
+    assert(R.getOpcode() == Instruction::Add);
   }
 
 public:
@@ -2814,11 +2809,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
-                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
-                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
-                       Ext1->getOperand(0), CondOp, IsOrdered,
-                       Ext0->getResultType()) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
+                       CondOp, IsOrdered) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
@@ -2831,11 +2825,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
                  VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
-                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
-                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
-                       Ext1->getOperand(0), CondOp, IsOrdered,
-                       Ext0->getResultType()) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
+                       CondOp, IsOrdered) {}
 
   ~VPMulAccRecipe() override = default;
 
@@ -2869,35 +2862,26 @@ class VPMulAccRecipe : public VPReductionRecipe {
   VPValue *getVecOp0() const { return getOperand(1); }
   VPValue *getVecOp1() const { return getOperand(2); }
 
-  /// Return the type after inner extended, which must equal to the type of mul
-  /// instruction. If the ResultType != recurrenceType, than it must have a
-  /// extend recipe after mul recipe.
-  Type *getResultType() const { return ResultType; }
-
   /// The underlying instruction for VPWidenRecipe.
   Instruction *getMulInstr() const { return MulInstr; }
 
-  /// The underlying Instruction for outer VPWidenCastRecipe.
-  CastInst *getExtInstr() const { return ExtInstr; }
   /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt0Instr() const { return Ext0Instr; }
+
   /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt1Instr() const { return Ext1Instr; }
 
   /// Return if this MulAcc recipe contains extend instructions.
   bool isExtended() const { return Ext0Instr && Ext1Instr; }
+
   /// Return if the operands of mul instruction come from same extend.
-  bool isSameExtend() const { return Ext0Instr == Ext1Instr; }
-  /// Return if the MulAcc recipes contains extend after mul.
-  bool isOuterExtended() const { return ExtInstr != nullptr; }
+  bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
+  Instruction::CastOps getExtOpcode() const { return Ext0Instr->getOpcode(); }
+
   /// Return if the extend opcode is ZExt.
   bool isZExt() const {
     if (!isExtended())
       return true;
-    // reduce.add(sext(mul(zext(A), zext(A)))) can be transform to
-    // reduce.add(zext(mul(sext(A), sext(A))))
-    if (ExtInstr && ExtInstr->getOpcode() != Ext0Instr->getOpcode())
-      return true;
     return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
   }
 };
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 5aef2089ac37d8..f6e6f601309d55 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2284,16 +2284,15 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   if (isExtended()) {
     auto *SrcTy = cast<VectorType>(
         ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
-    auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
     TTI::CastContextHint CCH0 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
     ExtendedCost = Ctx.TTI.getCastInstrCost(
-        Ext0Instr->getOpcode(), DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        Ext0Instr->getOpcode(), VectorTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt0Instr()));
     TTI::CastContextHint CCH1 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
     ExtendedCost += Ctx.TTI.getCastInstrCost(
-        Ext1Instr->getOpcode(), DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        Ext1Instr->getOpcode(), VectorTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt1Instr()));
   }
 
@@ -2411,6 +2410,8 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
                            VPSlotTracker &SlotTracker) const {
   const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+  Type *RedTy = RdxDesc.getRecurrenceType();
+
   O << Indent << "MULACC-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
@@ -2418,27 +2419,23 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " + ";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  if (isOuterExtended())
-    O << " (";
   O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
   O << "mul ";
   if (isExtended())
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
   if (isExtended())
-    O << " extended to " << *getResultType() << "), (";
+    O << " extended to " << *RedTy << "), (";
   else
     O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
   if (isExtended())
-    O << " extended to " << *getResultType() << ")";
+    O << " extended to " << *RedTy << ")";
   if (isConditional()) {
     O << ", ";
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
-  if (isOuterExtended())
-    O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
   if (RdxDesc.IntermediateStore)
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 53e080b5bb01d6..a41d1a411e899d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -542,52 +542,43 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         ExtRed->eraseFromParent();
       } else if (isa<VPMulAccRecipe>(&R)) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
+        Type *RedTy = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
 
         // Generate inner VPWidenCastRecipes if necessary.
+        // Note that we will drop the extend after of mul which transform
+        // reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
           CastInst *Ext0 = MulAcc->getExt0Instr();
           Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
-                                      MulAcc->getResultType(), *Ext0);
+                                      RedTy, *Ext0);
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
           // VPWidenCastRecipe.
-          if (MulAcc->isSameExtend()) {
+          if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
             Op1 = Op0;
           } else {
             CastInst *Ext1 = MulAcc->getExt1Instr();
             Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
-                                        MulAcc->getResultType(), *Ext1);
+                                        RedTy, *Ext1);
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           }
-          // Not contains extend instruction in this MulAccRecipe.
+          // No extends in this MulAccRecipe.
         } else {
           Op0 = MulAcc->getVecOp0();
           Op1 = MulAcc->getVecOp1();
         }
 
         // Generate VPWidenRecipe.
-        VPSingleDefRecipe *VecOp;
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
         auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
                                       make_range(MulOps.begin(), MulOps.end()));
         Mul->insertBefore(MulAcc);
 
-        // Generate outer VPWidenCastRecipe if necessary.
-        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
-          VecOp = new VPWidenCastRecipe(
-              OuterExtInstr->getOpcode(), Mul,
-              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
-              *OuterExtInstr);
-          VecOp->insertBefore(MulAcc);
-        } else {
-          VecOp = Mul;
-        }
-
         // Generate VPReductionRecipe.
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
-            MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
+            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
             MulAcc->isOrdered());
         Red->insertBefore(MulAcc);
 
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index b138666ba92f9b..f722fd2584ca96 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -648,10 +648,9 @@ define i64 @mla_i16_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i16>, ptr [[TMP2]], align 2
-; CHECK-NEXT:    [[TMP11:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i32> [[TMP4]] to <8 x i64>
+; CHECK-NEXT:    [[TMP4:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i64>
+; CHECK-NEXT:    [[TMP5:%.*]] = mul nsw <8 x i64> [[TMP4]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
 ; CHECK-NEXT:    [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
@@ -728,10 +727,9 @@ define i64 @mla_i8_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i32
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i8>, ptr [[TMP0]], align 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i8>, ptr [[TMP2]], align 1
-; CHECK-NEXT:    [[TMP11:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <8 x i32> [[TMP11]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = zext nneg <8 x i32> [[TMP4]] to <8 x i64>
+; CHECK-NEXT:    [[TMP4:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i64>
+; CHECK-NEXT:    [[TMP5:%.*]] = mul nuw nsw <8 x i64> [[TMP4]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
 ; CHECK-NEXT:    [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
@@ -1461,9 +1459,8 @@ define i64 @mla_xx_sext_zext(ptr nocapture noundef readonly %x, i32 %n) #0 {
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <8 x i32> [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = zext nneg <8 x i32> [[TMP2]] to <8 x i64>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = mul nsw <8 x i64> [[TMP1]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i64 [[TMP4]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
@@ -1530,9 +1527,8 @@ define i64 @mla_and_add_together_16_64(ptr nocapture noundef readonly %x, i32 no
 ; CHECK-NEXT:    [[VEC_PHI1:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <8 x i32> [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = zext nneg <8 x i32> [[TMP2]] to <8 x i64>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = mul nsw <8 x i64> [[TMP1]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i64 [[TMP4]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[TMP10:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
@@ -1670,24 +1666,55 @@ define i64 @test_std_q31(ptr %x, i32 %n) #0 {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP11:%.*]] = icmp sgt i32 [[N:%.*]], 0
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[CMP11]])
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp samesign ult i32 [[N]], 4
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N]], 2147483644
 ; CHECK-NEXT:    br label [[FOR_BODY1:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT:    [[VEC_PHI1:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP10]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr <4 x i32> [[WIDE_LOAD]], splat (i32 8)
+; CHECK-NEXT:    [[TMP2:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP2]])
+; CHECK-NEXT:    [[TMP4]] = add i64 [[TMP3]], [[VEC_PHI]]
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <4 x i64> [[TMP5]], [[TMP5]]
+; CHECK-NEXT:    [[TMP7:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP6]])
+; CHECK-NEXT:    [[TMP8]] = add i64 [[TMP7]], [[VEC_PHI1]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[FOR_BODY1]], !llvm.loop [[LOOP37:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK:       scalar.ph:
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i64 [ [[TMP4]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX2:%.*]] = phi i64 [ [[TMP8]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.cond.cleanup:
-; CHECK-NEXT:    [[DIV:%.*]] = sdiv i64 [[ADD3:%.*]], [[ADD:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = phi i64 [ [[ADD1:%.*]], [[FOR_BODY]] ], [ [[TMP4]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[ADD3:%.*]] = phi i64 [ [[ADD5:%.*]], [[FOR_BODY]] ], [ [[TMP8]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[DIV:%.*]] = sdiv i64 [[ADD3]], [[ADD]]
 ; CHECK-NEXT:    ret i64 [[DIV]]
 ; CHECK:       for.body:
-; CHECK-NEXT:    [[S_014:%.*]] = phi i64 [ [[ADD]], [[FOR_BODY1]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT:    [[T_012:%.*]] = phi i64 [ [[ADD3]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[I_013]]
+; CHECK-NEXT:    [[S_014:%.*]] = phi i64 [ [[ADD1]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
+; CHECK-NEXT:    [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
+; CHECK-NEXT:    [[T_012:%.*]] = phi i64 [ [[ADD5]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX2]], [[SCALAR_PH]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X]], i32 [[I_013]]
 ; CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
 ; CHECK-NEXT:    [[SHR:%.*]] = ashr i32 [[TMP0]], 8
 ; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[SHR]] to i64
-; CHECK-NEXT:    [[ADD]] = add nsw i64 [[S_014]], [[CONV]]
+; CHECK-NEXT:    [[ADD1]] = add nsw i64 [[S_014]], [[CONV]]
 ; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV]], [[CONV]]
-; CHECK-NEXT:    [[ADD3]] = add nuw nsw i64 [[MUL]], [[T_012]]
+; CHECK-NEXT:    [[ADD5]] = add nuw nsw i64 [[MUL]], [[T_012]]
 ; CHECK-NEXT:    [[ADD4]] = add nuw nsw i32 [[I_013]], 1
 ; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i32 [[ADD4]], [[N]]
-; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY1]]
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP38:![0-9]+]]
 ;
 entry:
   %cmp11 = icmp sgt i32 %n, 0
@@ -1723,10 +1750,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[TMP0:%.*]] = add nsw i32 [[N]], -1
 ; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -4
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -8
 ; CHECK-NEXT:    [[IND_END:%.*]] = shl i32 [[N_VEC]], 1
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
@@ -1734,28 +1761,26 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32>
+; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]]
-; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64>
-; CHECK-NEXT:    [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32>
-; CHECK-NEXT:    [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]]
-; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64>
-; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]])
+; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i64>
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i64>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <8 x i64> [[TMP5]], [[TMP6]]
+; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP7]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]])
+; CHECK-NEXT:    [[TMP13:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i64>
+; CHECK-NEXT:    [[TMP11:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i64>
+; CHECK-NEXT:    [[TMP12:%.*]] = mul nsw <8 x i64> [[TMP13]], [[TMP11]]
+; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP12]])
 ; CHECK-NEXT:    [[TMP16]] = add i64 [[TMP15]], [[TMP10]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]]
+; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP39:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[TMP2]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP]], label [[SCALAR_PH]]
@@ -1790,7 +1815,7 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[ADD12]] = add nsw i64 [[ADD]], [[CONV11]]
 ; CHECK-NEXT:    [[ADD13]] = add nuw nsw i32 [[I_025]], 2
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[ADD13]], [[N]]
-; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP38:![0-9]+]]
+; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP40:![0-9]+]]
 ;
 entry:
   %cmp23 = icmp sgt i32 %n, 0
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index b1d7388005ad4e..95ce514d8a92ff 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -1315,7 +1315,7 @@ define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture reado
 ; CHECK-NEXT:     CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
 ; CHECK-NEXT:     vp<%5> = vector-pointer ir<%arrayidx1>
 ; CHECK-NEXT:     WIDEN ir<%load1> = load vp<%5>
-; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> +  (reduce.add (mul (ir<%load0> extended to i32), (ir<%load1> extended to i32)) extended to i64)
+; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> + reduce.add (mul (ir<%load0> extended to i64), (ir<%load1> extended to i64))
 ; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
 ; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
 ; CHECK-NEXT:   No successors

>From fa4f476c811d216cc982931be1ca601dff9cdd2a Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 14 Nov 2024 18:07:19 -0800
Subject: [PATCH 13/26] Refactor.

- Move tryToMatch{MulAcc|ExtendedReduction} out from
adjustRecipesForReduction().
- Remove `ResultTy` which is same as the recurrence type.
- Use VP_CLASSOF_IMPL.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 153 ++++++++++--------
 llvm/lib/Transforms/Vectorize/VPlan.h         |  47 ++----
 .../Transforms/Vectorize/VPlanTransforms.cpp  |   2 +-
 3 files changed, 100 insertions(+), 102 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 14821362f71550..72a9ae76eccb68 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9253,6 +9253,87 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
   return Plan;
 }
 
+/// Try to match the extended-reduction and create VPExtendedReductionRecipe.
+///
+/// This function try to match following pattern which will generate
+/// extended-reduction instruction.
+///    reduce(ext(...)).
+static VPExtendedReductionRecipe *
+tryToMatchExtendedReduction(const RecurrenceDescriptor &RdxDesc,
+                            Instruction *CurrentLinkI, VPValue *PreviousLink,
+                            VPValue *VecOp, VPValue *CondOp,
+                            LoopVectorizationCostModel &CM) {
+  using namespace VPlanPatternMatch;
+  VPValue *A;
+  // Matched reduce(ext)).
+  if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
+    return new VPExtendedReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
+                                         cast<VPWidenCastRecipe>(VecOp), CondOp,
+                                         CM.useOrderedReductions(RdxDesc));
+  }
+  return nullptr;
+}
+
+/// Try to match the mul-acc-reduction and create VPMulAccRecipe.
+///
+/// This function try to match following patterns which will generate mul-acc
+/// instructions.
+///    reduce.add(mul(...)),
+///    reduce.add(mul(ext(A), ext(B))),
+///    reduce.add(ext(mul(ext(A), ext(B)))).
+static VPMulAccRecipe *tryToMatchMulAcc(const RecurrenceDescriptor &RdxDesc,
+                                        Instruction *CurrentLinkI,
+                                        VPValue *PreviousLink, VPValue *VecOp,
+                                        VPValue *CondOp,
+                                        LoopVectorizationCostModel &CM) {
+  using namespace VPlanPatternMatch;
+  VPValue *A, *B;
+  if (RdxDesc.getOpcode() != Instruction::Add)
+    return nullptr;
+  // Try to match reduce.add(mul(...))
+  if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+    VPWidenCastRecipe *RecipeA =
+        dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+    VPWidenCastRecipe *RecipeB =
+        dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+    // Matched reduce.add(mul(ext, ext))
+    if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+        match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+        (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
+      return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                                CM.useOrderedReductions(RdxDesc),
+                                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
+                                RecipeA, RecipeB);
+    } else {
+      // Matched reduce.add(mul)
+      return new VPMulAccRecipe(
+          RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+          CM.useOrderedReductions(RdxDesc),
+          cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+    }
+    // Matched reduce.add(ext(mul(ext(A), ext(B))))
+    // All extend instructions must have same opcode or A == B
+    // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
+  } else if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
+                                             m_ZExtOrSExt(m_VPValue()))))) {
+    VPWidenCastRecipe *Ext =
+        cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+    VPWidenRecipe *Mul =
+        cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+    VPWidenCastRecipe *Ext0 =
+        cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
+    VPWidenCastRecipe *Ext1 =
+        cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
+    if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
+        Ext0->getOpcode() == Ext1->getOpcode()) {
+      return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                                CM.useOrderedReductions(RdxDesc), Mul, Ext0,
+                                Ext1);
+    }
+  }
+  return nullptr;
+}
+
 // Adjust the recipes for reductions. For in-loop reductions the chain of
 // instructions leading from the loop exit instr to the phi need to be converted
 // to reductions, with one operand being vector and the other being the scalar
@@ -9384,76 +9465,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      auto TryToMatchMulAcc = [&]() -> VPReductionRecipe * {
-        VPValue *A, *B;
-        if (RdxDesc.getOpcode() != Instruction::Add)
-          return nullptr;
-        // Try to match reduce.add(mul(...))
-        if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
-            !VecOp->hasMoreThanOneUniqueUser()) {
-          VPWidenCastRecipe *RecipeA =
-              dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
-          VPWidenCastRecipe *RecipeB =
-              dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
-          // Matched reduce.add(mul(ext, ext))
-          if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
-              match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
-              (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
-            return new VPMulAccRecipe(
-                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-                CM.useOrderedReductions(RdxDesc),
-                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()), RecipeA,
-                RecipeB);
-          } else {
-            // Matched reduce.add(mul)
-            return new VPMulAccRecipe(
-                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-                CM.useOrderedReductions(RdxDesc),
-                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
-          }
-          // Matched reduce.add(ext(mul(ext(A), ext(B))))
-          // Note that all extend instructions must have same opcode or A == B
-          // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
-        } else if (match(VecOp,
-                         m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
-                                            m_ZExtOrSExt(m_VPValue())))) &&
-                   !VecOp->hasMoreThanOneUniqueUser()) {
-          VPWidenCastRecipe *Ext =
-              cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
-          VPWidenRecipe *Mul =
-              cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
-          VPWidenCastRecipe *Ext0 =
-              cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
-          VPWidenCastRecipe *Ext1 =
-              cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
-          if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
-              Ext0->getOpcode() == Ext1->getOpcode()) {
-            return new VPMulAccRecipe(
-                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-                CM.useOrderedReductions(RdxDesc),
-                cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
-                cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
-          }
-        }
-        return nullptr;
-      };
-
-      auto TryToMatchExtendedReduction = [&]() -> VPReductionRecipe * {
-        VPValue *A;
-        // Matched reduce(ext)).
-        if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
-          return new VPExtendedReductionRecipe(
-              RdxDesc, CurrentLinkI, PreviousLink,
-              cast<VPWidenCastRecipe>(VecOp), CondOp,
-              CM.useOrderedReductions(RdxDesc));
-        }
-        return nullptr;
-      };
-
       VPReductionRecipe *RedRecipe;
-      if (auto *MulAcc = TryToMatchMulAcc())
+      if (auto *MulAcc = tryToMatchMulAcc(RdxDesc, CurrentLinkI, PreviousLink,
+                                          VecOp, CondOp, CM))
         RedRecipe = MulAcc;
-      else if (auto *ExtendedRed = TryToMatchExtendedReduction())
+      else if (auto *ExtendedRed = tryToMatchExtendedReduction(
+                   RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp, CM))
         RedRecipe = ExtendedRed;
       else
         RedRecipe =
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index b04b694196e298..bc04c5c47f7f75 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2702,8 +2702,6 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 /// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
 /// {ChainOp, VecOp, [Condition]}.
 class VPExtendedReductionRecipe : public VPReductionRecipe {
-  /// Type after extend.
-  Type *ResultTy;
   CastInst *ExtInstr;
 
 protected:
@@ -2711,10 +2709,10 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
                             const RecurrenceDescriptor &R, Instruction *RedI,
                             Instruction::CastOps ExtOp, CastInst *ExtInstr,
                             VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
-                            bool IsOrdered, Type *ResultTy)
+                            bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
                           CondOp, IsOrdered),
-        ResultTy(ResultTy), ExtInstr(ExtInstr) {}
+        ExtInstr(ExtInstr) {}
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
@@ -2723,7 +2721,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
       : VPExtendedReductionRecipe(
             VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
             cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
-            Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
+            Ext->getOperand(0), CondOp, IsOrdered) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2731,14 +2729,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
     llvm_unreachable("Not implement yet");
   }
 
-  static inline bool classof(const VPRecipeBase *R) {
-    return R->getVPDefID() == VPDef::VPExtendedReductionSC;
-  }
-
-  static inline bool classof(const VPUser *U) {
-    auto *R = dyn_cast<VPRecipeBase>(U);
-    return R && classof(R);
-  }
+  VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC);
 
   void execute(VPTransformState &State) override {
     llvm_unreachable("VPExtendedReductionRecipe should be transform to "
@@ -2755,11 +2746,16 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
              VPSlotTracker &SlotTracker) const override;
 #endif
 
-  /// The Type after extended.
-  Type *getResultType() const { return ResultTy; }
+  /// The scalar type after extended.
+  Type *getResultType() const {
+    return getRecurrenceDescriptor().getRecurrenceType();
+  }
+
   bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
+
   /// The Opcode of extend instruction.
   Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
+
   /// The CastInst of the extend instruction.
   CastInst *getExtInstr() const { return ExtInstr; }
 };
@@ -2771,7 +2767,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPReductionRecipe {
 
-  /// reduce.add(ext(mul(ext0(), ext1())))
+  // reduce.add(mul(ext0, ext1)).
   Instruction *MulInstr;
   CastInst *Ext0Instr = nullptr;
   CastInst *Ext1Instr = nullptr;
@@ -2821,27 +2817,11 @@ class VPMulAccRecipe : public VPReductionRecipe {
                        ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
                        IsOrdered) {}
 
-  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
-                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
-                 VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
-                 VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered) {}
-
   ~VPMulAccRecipe() override = default;
 
   VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
 
-  static inline bool classof(const VPRecipeBase *R) {
-    return R->getVPDefID() == VPRecipeBase::VPMulAccSC;
-  }
-
-  static inline bool classof(const VPUser *U) {
-    auto *R = dyn_cast<VPRecipeBase>(U);
-    return R && classof(R);
-  }
+  VP_CLASSOF_IMPL(VPDef::VPMulAccSC);
 
   void execute(VPTransformState &State) override {
     llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
@@ -2876,6 +2856,7 @@ class VPMulAccRecipe : public VPReductionRecipe {
 
   /// Return if the operands of mul instruction come from same extend.
   bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
+
   Instruction::CastOps getExtOpcode() const { return Ext0Instr->getOpcode(); }
 
   /// Return if the extend opcode is ZExt.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index a41d1a411e899d..bae8aabc6f9692 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -545,7 +545,7 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         Type *RedTy = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
 
         // Generate inner VPWidenCastRecipes if necessary.
-        // Note that we will drop the extend after of mul which transform
+        // Note that we will drop the extend after mul which transform
         // reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {

>From 86ad2d8565c5c2abfcfb88fb6155ff08dfbeb7ba Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Fri, 15 Nov 2024 01:03:18 -0800
Subject: [PATCH 14/26] Clamp the range when the ExtendedReduction or MulAcc
 cost is invalid.

---
 .../Vectorize/LoopVectorizationPlanner.h      |  2 +-
 .../Transforms/Vectorize/LoopVectorize.cpp    | 97 +++++++++++++++----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 42 ++++----
 3 files changed, 106 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index fbcf181a45a664..b8920c936f3b6a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -509,7 +509,7 @@ class LoopVectorizationPlanner {
   // between the phi and users outside the vector region when folding the tail.
   void adjustRecipesForReductions(VPlanPtr &Plan,
                                   VPRecipeBuilder &RecipeBuilder,
-                                  ElementCount MinVF);
+                                  VFRange &Range);
 
 #ifndef NDEBUG
   /// \return The most profitable vectorization factor for the available VPlans
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 72a9ae76eccb68..fbb39c659c9175 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9145,7 +9145,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
   // ---------------------------------------------------------------------------
 
   // Adjust the recipes for any inloop reductions.
-  adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start);
+  adjustRecipesForReductions(Plan, RecipeBuilder, Range);
 
   // Interleave memory: for each Interleave Group we marked earlier as relevant
   // for this VPlan, replace the Recipes widening its memory instructions with a
@@ -9258,15 +9258,39 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
 /// This function try to match following pattern which will generate
 /// extended-reduction instruction.
 ///    reduce(ext(...)).
-static VPExtendedReductionRecipe *
-tryToMatchExtendedReduction(const RecurrenceDescriptor &RdxDesc,
-                            Instruction *CurrentLinkI, VPValue *PreviousLink,
-                            VPValue *VecOp, VPValue *CondOp,
-                            LoopVectorizationCostModel &CM) {
+static VPExtendedReductionRecipe *tryToMatchAndCreateExtendedReduction(
+    const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+    VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp,
+    LoopVectorizationCostModel &CM, VPCostContext &Ctx, VFRange &Range) {
   using namespace VPlanPatternMatch;
+
   VPValue *A;
+  Type *RedTy = RdxDesc.getRecurrenceType();
+
+  // Test if the cost of extended-reduction is valid and clamp the range.
+  // Note that reduction-extended is not always valid for all VF and types.
+  auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
+                                             Type *SrcTy) -> bool {
+    return LoopVectorizationPlanner::getDecisionAndClampRange(
+        [&](ElementCount VF) {
+          VectorType *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+          return Ctx.TTI
+              .getExtendedReductionCost(Opcode, isZExt, RedTy, SrcVecTy,
+                                        RdxDesc.getFastMathFlags(),
+                                        TTI::TCK_RecipThroughput)
+              .isValid();
+        },
+        Range);
+  };
+
   // Matched reduce(ext)).
   if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
+    if (!IsExtendedRedValidAndClampRange(
+            RdxDesc.getOpcode(),
+            cast<VPWidenCastRecipe>(VecOp)->getOpcode() ==
+                Instruction::CastOps::ZExt,
+            Ctx.Types.inferScalarType(A)))
+      return nullptr;
     return new VPExtendedReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
                                          cast<VPWidenCastRecipe>(VecOp), CondOp,
                                          CM.useOrderedReductions(RdxDesc));
@@ -9281,31 +9305,59 @@ tryToMatchExtendedReduction(const RecurrenceDescriptor &RdxDesc,
 ///    reduce.add(mul(...)),
 ///    reduce.add(mul(ext(A), ext(B))),
 ///    reduce.add(ext(mul(ext(A), ext(B)))).
-static VPMulAccRecipe *tryToMatchMulAcc(const RecurrenceDescriptor &RdxDesc,
-                                        Instruction *CurrentLinkI,
-                                        VPValue *PreviousLink, VPValue *VecOp,
-                                        VPValue *CondOp,
-                                        LoopVectorizationCostModel &CM) {
+static VPMulAccRecipe *tryToMatchAndCreateMulAcc(
+    const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+    VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp,
+    LoopVectorizationCostModel &CM, VPCostContext &Ctx, VFRange &Range) {
   using namespace VPlanPatternMatch;
+
   VPValue *A, *B;
+  Type *RedTy = RdxDesc.getRecurrenceType();
+
+  // Test if the cost of MulAcc is valid and clamp the range.
+  // Note that mul-acc is not always valid for all VF and types.
+  auto IsMulAccValidAndClampRange = [&](bool isZExt, Type *SrcTy) -> bool {
+    return LoopVectorizationPlanner::getDecisionAndClampRange(
+        [&](ElementCount VF) {
+          VectorType *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+          return Ctx.TTI
+              .getMulAccReductionCost(isZExt, RedTy, SrcVecTy,
+                                      TTI::TCK_RecipThroughput)
+              .isValid();
+        },
+        Range);
+  };
+
   if (RdxDesc.getOpcode() != Instruction::Add)
     return nullptr;
+
   // Try to match reduce.add(mul(...))
   if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
     VPWidenCastRecipe *RecipeA =
         dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
     VPWidenCastRecipe *RecipeB =
         dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+
     // Matched reduce.add(mul(ext, ext))
     if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
         match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
         (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
+
+      // Only create MulAccRecipe if the cost is valid.
+      if (!IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
+                                          Instruction::CastOps::ZExt,
+                                      Ctx.Types.inferScalarType(RecipeA)))
+        return nullptr;
+
       return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                                 CM.useOrderedReductions(RdxDesc),
                                 cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
                                 RecipeA, RecipeB);
     } else {
       // Matched reduce.add(mul)
+      if (!IsMulAccValidAndClampRange(true, RedTy))
+        return nullptr;
+
       return new VPMulAccRecipe(
           RdxDesc, CurrentLinkI, PreviousLink, CondOp,
           CM.useOrderedReductions(RdxDesc),
@@ -9326,6 +9378,12 @@ static VPMulAccRecipe *tryToMatchMulAcc(const RecurrenceDescriptor &RdxDesc,
         cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
     if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
         Ext0->getOpcode() == Ext1->getOpcode()) {
+      // Only create MulAcc recipe if the cost if valid.
+      if (!IsMulAccValidAndClampRange(Ext0->getOpcode() ==
+                                          Instruction::CastOps::ZExt,
+                                      Ctx.Types.inferScalarType(Ext0)))
+        return nullptr;
+
       return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                                 CM.useOrderedReductions(RdxDesc), Mul, Ext0,
                                 Ext1);
@@ -9348,8 +9406,9 @@ static VPMulAccRecipe *tryToMatchMulAcc(const RecurrenceDescriptor &RdxDesc,
 // with a boolean reduction phi node to check if the condition is true in any
 // iteration. The final value is selected by the final ComputeReductionResult.
 void LoopVectorizationPlanner::adjustRecipesForReductions(
-    VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) {
+    VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, VFRange &Range) {
   using namespace VPlanPatternMatch;
+  ElementCount MinVF = Range.Start;
   VPRegionBlock *VectorLoopRegion = Plan->getVectorLoopRegion();
   VPBasicBlock *Header = VectorLoopRegion->getEntryBasicBlock();
   VPBasicBlock *MiddleVPBB = Plan->getMiddleBlock();
@@ -9466,12 +9525,16 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
       VPReductionRecipe *RedRecipe;
-      if (auto *MulAcc = tryToMatchMulAcc(RdxDesc, CurrentLinkI, PreviousLink,
-                                          VecOp, CondOp, CM))
+      VPCostContext CostCtx(CM.TTI, *CM.TLI, Legal->getWidestInductionType(),
+                            CM);
+      if (auto *MulAcc =
+              tryToMatchAndCreateMulAcc(RdxDesc, CurrentLinkI, PreviousLink,
+                                        VecOp, CondOp, CM, CostCtx, Range))
         RedRecipe = MulAcc;
-      else if (auto *ExtendedRed = tryToMatchExtendedReduction(
-                   RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp, CM))
-        RedRecipe = ExtendedRed;
+      else if (auto *ExtRed = tryToMatchAndCreateExtendedReduction(
+                   RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp, CM,
+                   CostCtx, Range))
+        RedRecipe = ExtRed;
       else
         RedRecipe =
             new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index f6e6f601309d55..ad3771fcaa142d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2225,9 +2225,19 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   RecurKind RdxKind = RdxDesc.getRecurrenceKind();
   Type *ElementTy = getResultType();
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  auto *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
+  // ExtendedReduction Cost
+  InstructionCost ExtendedRedCost =
+      Ctx.TTI.getExtendedReductionCost(Opcode, isZExt(), ElementTy, SrcVecTy,
+                                       RdxDesc.getFastMathFlags(), CostKind);
+
+  assert(ExtendedRedCost.isValid() && "VPExtendedReductionRecipe should not be "
+                                      "created if the cost is invalid.");
+
   // BaseCost = Reduction cost + BinOp cost
   InstructionCost ReductionCost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
@@ -2241,23 +2251,18 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   }
 
   // Extended cost
-  auto *SrcTy =
-      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
-  auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
   TTI::CastContextHint CCH = computeCCH(getVecOp()->getDefiningRecipe(), VF);
   // Arm TTI will use the underlying instruction to determine the cost.
   InstructionCost ExtendedCost = Ctx.TTI.getCastInstrCost(
-      Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
+      Opcode, VectorTy, SrcVecTy, CCH, TTI::TCK_RecipThroughput,
       dyn_cast_if_present<Instruction>(getUnderlyingValue()));
 
-  // ExtendedReduction Cost
-  InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
-      Opcode, isZExt(), ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
   // Check if folding ext into ExtendedReduction is profitable.
   if (ExtendedRedCost.isValid() &&
       ExtendedRedCost < ExtendedCost + ReductionCost) {
     return ExtendedRedCost;
   }
+
   return ExtendedCost + ReductionCost;
 }
 
@@ -2272,6 +2277,14 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
 
   assert(Opcode == Instruction::Add &&
          "Reduction opcode must be add in the VPMulAccRecipe.");
+  // MulAccReduction Cost
+  VectorType *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+  InstructionCost MulAccCost =
+      Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
+
+  assert(MulAccCost.isValid() && "VPMulAccRecipe should not be "
+                                 "created if the cost is invalid.");
 
   // BaseCost = Reduction cost + BinOp cost
   InstructionCost ReductionCost =
@@ -2282,17 +2295,17 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   // Extended cost
   InstructionCost ExtendedCost = 0;
   if (isExtended()) {
-    auto *SrcTy = cast<VectorType>(
-        ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
     TTI::CastContextHint CCH0 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
     ExtendedCost = Ctx.TTI.getCastInstrCost(
-        Ext0Instr->getOpcode(), VectorTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        Ext0Instr->getOpcode(), VectorTy, SrcVecTy, CCH0,
+        TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt0Instr()));
     TTI::CastContextHint CCH1 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
     ExtendedCost += Ctx.TTI.getCastInstrCost(
-        Ext1Instr->getOpcode(), VectorTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        Ext1Instr->getOpcode(), VectorTy, SrcVecTy, CCH1,
+        TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt1Instr()));
   }
 
@@ -2324,17 +2337,12 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
         RHSInfo, Operands, MulInstr, &Ctx.TLI);
   }
 
-  // MulAccReduction Cost
-  VectorType *SrcVecTy =
-      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
-  InstructionCost MulAccCost =
-      Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
-
   // Check if folding ext into ExtendedReduction is profitable.
   if (MulAccCost.isValid() &&
       MulAccCost < ExtendedCost + ReductionCost + MulCost) {
     return MulAccCost;
   }
+
   return ExtendedCost + ReductionCost + MulCost;
 }
 

>From 594f9e49425fd8c8a223c7102801aa5f5d4ef341 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 17 Nov 2024 21:26:30 -0800
Subject: [PATCH 15/26] Try to not depend on underlying ext/mul instructions
 and preserve flags/DL.

This commit expand the ctor of VPWidenRecipe and the VPWidenCastRecipe
to generate vector instruction with debug information and flags without
underlying instruction.

The extend instruction (ZExt) will need non-negative flags. and the int mul
instructions will need no-unsigned-wrap/no-signed-wrap flags.
---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 157 ++++++++++++------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  19 +--
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  42 +++--
 3 files changed, 144 insertions(+), 74 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index bc04c5c47f7f75..60ab1d156e8806 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -951,13 +951,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     GEPFlagsTy(bool IsInBounds) : IsInBounds(IsInBounds) {}
   };
 
+  struct NonNegFlagsTy {
+    char NonNeg : 1;
+    NonNegFlagsTy(bool IsNonNeg) : NonNeg(IsNonNeg) {}
+  };
+
 private:
   struct ExactFlagsTy {
     char IsExact : 1;
   };
-  struct NonNegFlagsTy {
-    char NonNeg : 1;
-  };
   struct FastMathFlagsTy {
     char AllowReassoc : 1;
     char NoNaNs : 1;
@@ -1051,6 +1053,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
       : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
         DisjointFlags(DisjointFlags) {}
 
+  template <typename IterT>
+  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+                      NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
+        NonNegFlags(NonNegFlags) {}
+
 protected:
   template <typename IterT>
   VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
@@ -1160,6 +1168,8 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
 
   FastMathFlags getFastMathFlags() const;
 
+  bool isNonNeg() const { return NonNegFlags.NonNeg; }
+
   bool hasNoUnsignedWrap() const {
     assert(OpType == OperationType::OverflowingBinOp &&
            "recipe doesn't have a NUW flag");
@@ -1448,11 +1458,22 @@ class VPWidenRecipe : public VPRecipeWithIRFlags {
                 iterator_range<IterT> Operands)
       : VPRecipeWithIRFlags(VPDefOpcode, Operands, I), Opcode(I.getOpcode()) {}
 
+  template <typename IterT>
+  VPWidenRecipe(unsigned VPDefOpcode, unsigned Opcode,
+                iterator_range<IterT> Operands, bool NUW, bool NSW, DebugLoc DL)
+      : VPRecipeWithIRFlags(VPDefOpcode, Operands, WrapFlagsTy(NUW, NSW), DL),
+        Opcode(Opcode) {}
+
 public:
   template <typename IterT>
   VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands)
       : VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
 
+  template <typename IterT>
+  VPWidenRecipe(unsigned Opcode, iterator_range<IterT> Operands, bool NUW,
+                bool NSW, DebugLoc DL)
+      : VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands, NUW, NSW, DL) {}
+
   ~VPWidenRecipe() override = default;
 
   VPWidenRecipe *clone() override {
@@ -1553,10 +1574,17 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags {
            "opcode of underlying cast doesn't match");
   }
 
-  VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy)
-      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), Opcode(Opcode),
+  VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
+                    DebugLoc DL = {})
+      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), Opcode(Opcode),
         ResultTy(ResultTy) {}
 
+  VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
+                    bool IsNonNeg, DebugLoc DL)
+      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
+                            DL),
+        Opcode(Opcode), ResultTy(ResultTy) {}
+
   ~VPWidenCastRecipe() override = default;
 
   VPWidenCastRecipe *clone() override {
@@ -2702,26 +2730,29 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 /// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
 /// {ChainOp, VecOp, [Condition]}.
 class VPExtendedReductionRecipe : public VPReductionRecipe {
-  CastInst *ExtInstr;
+  Instruction::CastOps ExtOp;
+  DebugLoc ExtDL;
+  /// Non-negative flag for the extended instruction.
+  bool IsNonNeg;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
                             const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, CastInst *ExtInstr,
-                            VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
-                            bool IsOrdered)
+                            Instruction::CastOps ExtOp, DebugLoc ExtDL,
+                            bool IsNonNeg, VPValue *ChainOp, VPValue *VecOp,
+                            VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
                           CondOp, IsOrdered),
-        ExtInstr(ExtInstr) {}
+        ExtOp(ExtOp), ExtDL(ExtDL), IsNonNeg(IsNonNeg) {}
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                             VPValue *ChainOp, VPWidenCastRecipe *Ext,
                             VPValue *CondOp, bool IsOrdered)
-      : VPExtendedReductionRecipe(
-            VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
-            cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
-            Ext->getOperand(0), CondOp, IsOrdered) {}
+      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
+                                  Ext->getOpcode(), Ext->getDebugLoc(),
+                                  Ext->isNonNeg(), ChainOp, Ext->getOperand(0),
+                                  CondOp, IsOrdered) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2754,10 +2785,12 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
   bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
 
   /// The Opcode of extend instruction.
-  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
+
+  bool getNonNegFlags() const { return IsNonNeg; }
 
-  /// The CastInst of the extend instruction.
-  CastInst *getExtInstr() const { return ExtInstr; }
+  /// Return the debug location of the extend instruction.
+  DebugLoc getExtDebugLoc() const { return ExtDL; }
 };
 
 /// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2767,36 +2800,47 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPReductionRecipe {
 
-  // reduce.add(mul(ext0, ext1)).
-  Instruction *MulInstr;
-  CastInst *Ext0Instr = nullptr;
-  CastInst *Ext1Instr = nullptr;
+  // WrapFlags (NoUnsignedWraps, NoSignedWraps)
+  bool NUW;
+  bool NSW;
+
+  // Debug location of the mul instruction.
+  DebugLoc MulDL;
+  Instruction::CastOps ExtOp;
+
+  // Non-neg flag for the extend instruction.
+  bool IsNonNeg;
+
+  // Debug location of extend instruction.
+  DebugLoc Ext0DL;
+  DebugLoc Ext1DL;
+
+  // Is this mul-acc recipe contains extend recipes?
+  bool IsExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *MulInstr,
-                 Instruction *Ext0Instr, Instruction *Ext1Instr,
-                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
-                 VPValue *CondOp, bool IsOrdered)
+                 Instruction *RedI, bool NUW, bool NSW, DebugLoc MulDL,
+                 Instruction::CastOps ExtOp, bool IsNonNeg, DebugLoc Ext0DL,
+                 DebugLoc Ext1DL, VPValue *ChainOp, VPValue *VecOp0,
+                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered),
-        MulInstr(MulInstr), Ext0Instr(cast<CastInst>(Ext0Instr)),
-        Ext1Instr(cast<CastInst>(Ext1Instr)) {
-    assert(MulInstr->getOpcode() == Instruction::Mul);
+        NUW(NUW), NSW(NSW), MulDL(MulDL), ExtOp(ExtOp), IsNonNeg(IsNonNeg),
+        Ext0DL(Ext0DL), Ext1DL(Ext1DL) {
     assert(R.getOpcode() == Instruction::Add);
-    assert(Ext0Instr->getOpcode() == Ext1Instr->getOpcode());
+    IsExtended = true;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *MulInstr, VPValue *ChainOp,
-                 VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp,
-                 bool IsOrdered)
+                 Instruction *RedI, bool NUW, bool NSW, DebugLoc MulDL,
+                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
+                 VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered),
-        MulInstr(MulInstr) {
-    assert(MulInstr->getOpcode() == Instruction::Mul);
+        NUW(NUW), NSW(NSW), MulDL(MulDL) {
     assert(R.getOpcode() == Instruction::Add);
   }
 
@@ -2805,16 +2849,18 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->hasNoUnsignedWrap(),
+                       Mul->hasNoSignedWrap(), Mul->getDebugLoc(),
+                       Ext0->getOpcode(), Ext0->isNonNeg(), Ext0->getDebugLoc(),
+                       Ext1->getDebugLoc(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->hasNoUnsignedWrap(),
+                       Mul->hasNoSignedWrap(), Mul->getDebugLoc(), ChainOp,
+                       Mul->getOperand(0), Mul->getOperand(1), CondOp,
                        IsOrdered) {}
 
   ~VPMulAccRecipe() override = default;
@@ -2842,29 +2888,36 @@ class VPMulAccRecipe : public VPReductionRecipe {
   VPValue *getVecOp0() const { return getOperand(1); }
   VPValue *getVecOp1() const { return getOperand(2); }
 
-  /// The underlying instruction for VPWidenRecipe.
-  Instruction *getMulInstr() const { return MulInstr; }
-
-  /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt0Instr() const { return Ext0Instr; }
-
-  /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt1Instr() const { return Ext1Instr; }
-
   /// Return if this MulAcc recipe contains extend instructions.
-  bool isExtended() const { return Ext0Instr && Ext1Instr; }
+  bool isExtended() const { return IsExtended; }
 
   /// Return if the operands of mul instruction come from same extend.
   bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
 
-  Instruction::CastOps getExtOpcode() const { return Ext0Instr->getOpcode(); }
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
 
   /// Return if the extend opcode is ZExt.
   bool isZExt() const {
     if (!isExtended())
       return true;
-    return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
+    return ExtOp == Instruction::CastOps::ZExt;
   }
+
+  /// Return the non-neg flag.
+  bool getNonNegFlags() const { return IsNonNeg; }
+
+  /// Return no-unsigned-wrap.
+  bool hasNoUnsignedWrap() const { return NUW; }
+
+  /// Return no-signed-warp.
+  bool hasNoSignedWrap() const { return NSW; }
+
+  /// Return debug location of mul instruction.
+  DebugLoc getMulDebugLoc() const { return MulDL; }
+
+  /// Return debug location of extend instructions.
+  DebugLoc getExt0DebugLoc() const { return Ext0DL; }
+  DebugLoc getExt1DebugLoc() const { return Ext1DL; }
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ad3771fcaa142d..39beced3a1e3ba 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2297,28 +2297,23 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   if (isExtended()) {
     TTI::CastContextHint CCH0 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost = Ctx.TTI.getCastInstrCost(
-        Ext0Instr->getOpcode(), VectorTy, SrcVecTy, CCH0,
-        TTI::TCK_RecipThroughput,
-        dyn_cast_if_present<Instruction>(getExt0Instr()));
+    ExtendedCost = Ctx.TTI.getCastInstrCost(ExtOp, VectorTy, SrcVecTy, CCH0,
+                                            TTI::TCK_RecipThroughput);
     TTI::CastContextHint CCH1 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost += Ctx.TTI.getCastInstrCost(
-        Ext1Instr->getOpcode(), VectorTy, SrcVecTy, CCH1,
-        TTI::TCK_RecipThroughput,
-        dyn_cast_if_present<Instruction>(getExt1Instr()));
+    ExtendedCost += Ctx.TTI.getCastInstrCost(ExtOp, VectorTy, SrcVecTy, CCH1,
+                                             TTI::TCK_RecipThroughput);
   }
 
   // Mul cost
   InstructionCost MulCost;
   SmallVector<const Value *, 4> Operands;
-  Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
   if (isExtended())
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        Operands, MulInstr, &Ctx.TLI);
+        Operands, nullptr, &Ctx.TLI);
   else {
     VPValue *RHS = getVecOp1();
     // Certain instructions can be cheaper to vectorize if they have a constant
@@ -2331,10 +2326,12 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
     if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
         RHS->isDefinedOutsideLoopRegions())
       RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
+    Operands.append(
+        {getVecOp0()->getUnderlyingValue(), RHS->getUnderlyingValue()});
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        RHSInfo, Operands, MulInstr, &Ctx.TLI);
+        RHSInfo, Operands, nullptr, &Ctx.TLI);
   }
 
   // Check if folding ext into ExtendedReduction is profitable.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index bae8aabc6f9692..733c144c7aa359 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -527,9 +527,17 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
       if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
         // Genearte VPWidenCastRecipe.
-        auto *Ext = new VPWidenCastRecipe(
-            ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
-            *ExtRed->getExtInstr());
+        VPWidenCastRecipe *Ext;
+        // Only ZExt contiains non-neg flags.
+        if (ExtRed->isZExt())
+          Ext = new VPWidenCastRecipe(
+              ExtRed->getExtOpcode(), ExtRed->getVecOp(),
+              ExtRed->getResultType(), ExtRed->getNonNegFlags(),
+              ExtRed->getExtDebugLoc());
+        else
+          Ext = new VPWidenCastRecipe(
+              ExtRed->getExtOpcode(), ExtRed->getVecOp(),
+              ExtRed->getResultType(), ExtRed->getExtDebugLoc());
 
         // Generate VPreductionRecipe.
         auto *Red = new VPReductionRecipe(
@@ -549,18 +557,28 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         // reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
-          CastInst *Ext0 = MulAcc->getExt0Instr();
-          Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
-                                      RedTy, *Ext0);
+          if (MulAcc->isZExt())
+            Op0 = new VPWidenCastRecipe(
+                MulAcc->getExtOpcode(), MulAcc->getVecOp0(), RedTy,
+                MulAcc->getNonNegFlags(), MulAcc->getExt0DebugLoc());
+          else
+            Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+                                        MulAcc->getVecOp0(), RedTy,
+                                        MulAcc->getExt0DebugLoc());
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
           // VPWidenCastRecipe.
           if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
             Op1 = Op0;
           } else {
-            CastInst *Ext1 = MulAcc->getExt1Instr();
-            Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
-                                        RedTy, *Ext1);
+            if (MulAcc->isZExt())
+              Op1 = new VPWidenCastRecipe(
+                  MulAcc->getExtOpcode(), MulAcc->getVecOp1(), RedTy,
+                  MulAcc->getNonNegFlags(), MulAcc->getExt1DebugLoc());
+            else
+              Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+                                          MulAcc->getVecOp1(), RedTy,
+                                          MulAcc->getExt1DebugLoc());
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           }
           // No extends in this MulAccRecipe.
@@ -571,8 +589,10 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
 
         // Generate VPWidenRecipe.
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
-        auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
-                                      make_range(MulOps.begin(), MulOps.end()));
+        auto *Mul = new VPWidenRecipe(
+            Instruction::Mul, make_range(MulOps.begin(), MulOps.end()),
+            MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap(),
+            MulAcc->getMulDebugLoc());
         Mul->insertBefore(MulAcc);
 
         // Generate VPReductionRecipe.

>From 52369d05b9fb5ff6bed4446f28988060dee08455 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 24 Nov 2024 23:28:17 -0800
Subject: [PATCH 16/26] Update testcase and fix reduction cost.

TTI should model the cost of move result of reduction to the scalar
register and add/sub/cmp... to the scalar value. We don't need to add
the cost of binOp in the vplan-based cost model.
---
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 28 +++++++------------
 .../AArch64/sve-strict-fadd-cost.ll           | 12 ++++----
 .../LoopVectorize/vplan-printing.ll           | 12 ++++----
 3 files changed, 22 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 39beced3a1e3ba..2e35c4e738a156 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2202,20 +2202,16 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
          "Inferred type and recurrence type mismatch.");
 
-  // BaseCost = Reduction cost + BinOp cost
-  InstructionCost BaseCost =
-      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  // Note that TTI should model the cost of moving result to the scalar register
+  // and the binOp cost in the getReductionCost().
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    BaseCost += Ctx.TTI.getMinMaxReductionCost(
-        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
-  } else {
-    BaseCost += Ctx.TTI.getArithmeticReductionCost(
-        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+    return Ctx.TTI.getMinMaxReductionCost(Id, VectorTy,
+                                          RdxDesc.getFastMathFlags(), CostKind);
   }
 
-  // Default cost.
-  return BaseCost;
+  return Ctx.TTI.getArithmeticReductionCost(
+      Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
 }
 
 InstructionCost
@@ -2238,15 +2234,13 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   assert(ExtendedRedCost.isValid() && "VPExtendedReductionRecipe should not be "
                                       "created if the cost is invalid.");
 
-  // BaseCost = Reduction cost + BinOp cost
-  InstructionCost ReductionCost =
-      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  InstructionCost ReductionCost;
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    ReductionCost += Ctx.TTI.getMinMaxReductionCost(
+    ReductionCost = Ctx.TTI.getMinMaxReductionCost(
         Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   } else {
-    ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+    ReductionCost = Ctx.TTI.getArithmeticReductionCost(
         Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
@@ -2287,9 +2281,7 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                  "created if the cost is invalid.");
 
   // BaseCost = Reduction cost + BinOp cost
-  InstructionCost ReductionCost =
-      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
-  ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+  InstructionCost ReductionCost = Ctx.TTI.getArithmeticReductionCost(
       Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
 
   // Extended cost
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-strict-fadd-cost.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-strict-fadd-cost.ll
index 0efdf077dca669..4208773e947349 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-strict-fadd-cost.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-strict-fadd-cost.ll
@@ -12,14 +12,14 @@ target triple="aarch64-unknown-linux-gnu"
 
 ; CHECK-VSCALE2-LABEL: LV: Checking a loop in 'fadd_strict32'
 ; CHECK-VSCALE2: Cost of 4 for VF vscale x 2:
-; CHECK-VSCALE2:  in-loop reduction   %add = fadd float %0, %sum.07
+; CHECK-VSCALE2:  REDUCE ir<%add> = ir<%sum.07> + reduce.fadd (ir<%0>)
 ; CHECK-VSCALE2: Cost of 8 for VF vscale x 4:
-; CHECK-VSCALE2:  in-loop reduction   %add = fadd float %0, %sum.07
+; CHECK-VSCALE2:  REDUCE ir<%add> = ir<%sum.07> + reduce.fadd (ir<%0>)
 ; CHECK-VSCALE1-LABEL: LV: Checking a loop in 'fadd_strict32'
 ; CHECK-VSCALE1: Cost of 2 for VF vscale x 2:
-; CHECK-VSCALE1:  in-loop reduction   %add = fadd float %0, %sum.07
+; CHECK-VSCALE1:  REDUCE ir<%add> = ir<%sum.07> + reduce.fadd (ir<%0>)
 ; CHECK-VSCALE1: Cost of 4 for VF vscale x 4:
-; CHECK-VSCALE1:  in-loop reduction   %add = fadd float %0, %sum.07
+; CHECK-VSCALE1:  REDUCE ir<%add> = ir<%sum.07> + reduce.fadd (ir<%0>)
 
 define float @fadd_strict32(ptr noalias nocapture readonly %a, i64 %n) #0 {
 entry:
@@ -42,10 +42,10 @@ for.end:
 
 ; CHECK-VSCALE2-LABEL: LV: Checking a loop in 'fadd_strict64'
 ; CHECK-VSCALE2: Cost of 4 for VF vscale x 2:
-; CHECK-VSCALE2:  in-loop reduction   %add = fadd double %0, %sum.07
+; CHECK-VSCALE2:  REDUCE ir<%add> = ir<%sum.07> + reduce.fadd (ir<%0>)
 ; CHECK-VSCALE1-LABEL: LV: Checking a loop in 'fadd_strict64'
 ; CHECK-VSCALE1: Cost of 2 for VF vscale x 2:
-; CHECK-VSCALE1:  in-loop reduction   %add = fadd double %0, %sum.07
+; CHECK-VSCALE1:  REDUCE ir<%add> = ir<%sum.07> + reduce.fadd (ir<%0>)
 
 define double @fadd_strict64(ptr noalias nocapture readonly %a, i64 %n) #0 {
 entry:
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 95ce514d8a92ff..a80723bd35b52a 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -1176,7 +1176,7 @@ define i64 @print_extended_reduction(ptr nocapture readonly %x, ptr nocapture re
 ; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
-; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%7>)
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%7> from middle.block)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: scalar.ph:
@@ -1185,7 +1185,7 @@ define i64 @print_extended_reduction(ptr nocapture readonly %x, ptr nocapture re
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.body>:
 ; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
-; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
 ; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i32, ptr %x, i32 %i.010
 ; CHECK-NEXT:   IR   %load0 = load i32, ptr %arrayidx, align 4
 ; CHECK-NEXT:   IR   %conv0 = zext i32 %load0 to i64
@@ -1251,7 +1251,7 @@ define i64 @print_mulacc(ptr nocapture readonly %x, ptr nocapture readonly %y, i
 ; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
-; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8> from middle.block)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: scalar.ph:
@@ -1260,7 +1260,7 @@ define i64 @print_mulacc(ptr nocapture readonly %x, ptr nocapture readonly %y, i
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.body>:
 ; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
-; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
 ; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i64, ptr %x, i32 %i.010
 ; CHECK-NEXT:   IR   %load0 = load i64, ptr %arrayidx, align 4
 ; CHECK-NEXT:   IR   %arrayidx1 = getelementptr inbounds i64, ptr %y, i32 %i.010
@@ -1330,7 +1330,7 @@ define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture reado
 ; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
-; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8> from middle.block)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: scalar.ph:
@@ -1339,7 +1339,7 @@ define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture reado
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.body>:
 ; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
-; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
 ; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i16, ptr %x, i32 %i.010
 ; CHECK-NEXT:   IR   %load0 = load i16, ptr %arrayidx, align 4
 ; CHECK-NEXT:   IR   %arrayidx1 = getelementptr inbounds i16, ptr %y, i32 %i.010

>From abc08f39dedbcd31dd0689380c828b84087e36e8 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Wed, 4 Dec 2024 16:52:14 -0800
Subject: [PATCH 17/26] !fixup. Rebase to upstream `prepareToExecute()`
 implementation.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |   2 -
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 180 +++++++++---------
 2 files changed, 89 insertions(+), 93 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index fbb39c659c9175..02b3f1001fd276 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7666,8 +7666,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
                              CanonicalIVStartValue, State);
   VPlanTransforms::prepareToExecute(BestVPlan);
 
-  // TODO: Replace with upstream implementation.
-  VPlanTransforms::prepareExecute(BestVPlan);
   BestVPlan.execute(&State);
 
   // 2.5 Collect reduction resume values.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 733c144c7aa359..d4e84ecc46818d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -519,96 +519,6 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
   }
 }
 
-void VPlanTransforms::prepareExecute(VPlan &Plan) {
-  ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
-      Plan.getVectorLoopRegion());
-  for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
-           vp_depth_first_deep(Plan.getEntry()))) {
-    for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
-      if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
-        // Genearte VPWidenCastRecipe.
-        VPWidenCastRecipe *Ext;
-        // Only ZExt contiains non-neg flags.
-        if (ExtRed->isZExt())
-          Ext = new VPWidenCastRecipe(
-              ExtRed->getExtOpcode(), ExtRed->getVecOp(),
-              ExtRed->getResultType(), ExtRed->getNonNegFlags(),
-              ExtRed->getExtDebugLoc());
-        else
-          Ext = new VPWidenCastRecipe(
-              ExtRed->getExtOpcode(), ExtRed->getVecOp(),
-              ExtRed->getResultType(), ExtRed->getExtDebugLoc());
-
-        // Generate VPreductionRecipe.
-        auto *Red = new VPReductionRecipe(
-            ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
-            ExtRed->getChainOp(), Ext, ExtRed->getCondOp(),
-            ExtRed->isOrdered());
-        Ext->insertBefore(ExtRed);
-        Red->insertBefore(ExtRed);
-        ExtRed->replaceAllUsesWith(Red);
-        ExtRed->eraseFromParent();
-      } else if (isa<VPMulAccRecipe>(&R)) {
-        auto *MulAcc = cast<VPMulAccRecipe>(&R);
-        Type *RedTy = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
-
-        // Generate inner VPWidenCastRecipes if necessary.
-        // Note that we will drop the extend after mul which transform
-        // reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
-        VPValue *Op0, *Op1;
-        if (MulAcc->isExtended()) {
-          if (MulAcc->isZExt())
-            Op0 = new VPWidenCastRecipe(
-                MulAcc->getExtOpcode(), MulAcc->getVecOp0(), RedTy,
-                MulAcc->getNonNegFlags(), MulAcc->getExt0DebugLoc());
-          else
-            Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
-                                        MulAcc->getVecOp0(), RedTy,
-                                        MulAcc->getExt0DebugLoc());
-          Op0->getDefiningRecipe()->insertBefore(MulAcc);
-          // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
-          // VPWidenCastRecipe.
-          if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
-            Op1 = Op0;
-          } else {
-            if (MulAcc->isZExt())
-              Op1 = new VPWidenCastRecipe(
-                  MulAcc->getExtOpcode(), MulAcc->getVecOp1(), RedTy,
-                  MulAcc->getNonNegFlags(), MulAcc->getExt1DebugLoc());
-            else
-              Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
-                                          MulAcc->getVecOp1(), RedTy,
-                                          MulAcc->getExt1DebugLoc());
-            Op1->getDefiningRecipe()->insertBefore(MulAcc);
-          }
-          // No extends in this MulAccRecipe.
-        } else {
-          Op0 = MulAcc->getVecOp0();
-          Op1 = MulAcc->getVecOp1();
-        }
-
-        // Generate VPWidenRecipe.
-        SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
-        auto *Mul = new VPWidenRecipe(
-            Instruction::Mul, make_range(MulOps.begin(), MulOps.end()),
-            MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap(),
-            MulAcc->getMulDebugLoc());
-        Mul->insertBefore(MulAcc);
-
-        // Generate VPReductionRecipe.
-        auto *Red = new VPReductionRecipe(
-            MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
-            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
-            MulAcc->isOrdered());
-        Red->insertBefore(MulAcc);
-
-        MulAcc->replaceAllUsesWith(Red);
-        MulAcc->eraseFromParent();
-      }
-    }
-  }
-}
-
 static VPScalarIVStepsRecipe *
 createScalarIVSteps(VPlan &Plan, InductionDescriptor::InductionKind Kind,
                     Instruction::BinaryOps InductionOpcode,
@@ -1910,12 +1820,100 @@ void VPlanTransforms::createInterleaveGroups(
   }
 }
 
+// Expand VPExtendedReductionRecipe to VPWidenCastRecipe + VPReductionRecipe.
+static void expandVPExtendedReduction(VPExtendedReductionRecipe *ExtRed) {
+  // Genearte VPWidenCastRecipe.
+  VPWidenCastRecipe *Ext;
+  // Only ZExt contiains non-neg flags.
+  if (ExtRed->isZExt())
+    Ext = new VPWidenCastRecipe(
+        ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+        ExtRed->getNonNegFlags(), ExtRed->getExtDebugLoc());
+  else
+    Ext = new VPWidenCastRecipe(ExtRed->getExtOpcode(), ExtRed->getVecOp(),
+                                ExtRed->getResultType(),
+                                ExtRed->getExtDebugLoc());
+
+  // Generate VPreductionRecipe.
+  auto *Red = new VPReductionRecipe(
+      ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
+      ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
+  Ext->insertBefore(ExtRed);
+  Red->insertBefore(ExtRed);
+  ExtRed->replaceAllUsesWith(Red);
+  ExtRed->eraseFromParent();
+}
+
+// Expand VPMulAccRecipe to VPWidenRecipe (mul) + VPReductionRecipe (reduce.add)
+// + VPWidenCastRecipe (optional).
+static void expandVPMulAcc(VPMulAccRecipe *MulAcc) {
+  Type *RedTy = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
+
+  // Generate inner VPWidenCastRecipes if necessary.
+  // Note that we will drop the extend after mul which transform
+  // reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
+  VPValue *Op0, *Op1;
+  if (MulAcc->isExtended()) {
+    if (MulAcc->isZExt())
+      Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+                                  RedTy, MulAcc->getNonNegFlags(),
+                                  MulAcc->getExt0DebugLoc());
+    else
+      Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+                                  RedTy, MulAcc->getExt0DebugLoc());
+    Op0->getDefiningRecipe()->insertBefore(MulAcc);
+    // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
+    // VPWidenCastRecipe.
+    if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
+      Op1 = Op0;
+    } else {
+      if (MulAcc->isZExt())
+        Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+                                    RedTy, MulAcc->getNonNegFlags(),
+                                    MulAcc->getExt1DebugLoc());
+      else
+        Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+                                    RedTy, MulAcc->getExt1DebugLoc());
+      Op1->getDefiningRecipe()->insertBefore(MulAcc);
+    }
+    // No extends in this MulAccRecipe.
+  } else {
+    Op0 = MulAcc->getVecOp0();
+    Op1 = MulAcc->getVecOp1();
+  }
+
+  // Generate VPWidenRecipe.
+  SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
+  auto *Mul = new VPWidenRecipe(
+      Instruction::Mul, make_range(MulOps.begin(), MulOps.end()),
+      MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap(),
+      MulAcc->getMulDebugLoc());
+  Mul->insertBefore(MulAcc);
+
+  // Generate VPReductionRecipe.
+  auto *Red = new VPReductionRecipe(
+      MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
+      MulAcc->getChainOp(), Mul, MulAcc->getCondOp(), MulAcc->isOrdered());
+  Red->insertBefore(MulAcc);
+
+  MulAcc->replaceAllUsesWith(Red);
+  MulAcc->eraseFromParent();
+}
+
 void VPlanTransforms::prepareToExecute(VPlan &Plan) {
   ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
       Plan.getVectorLoopRegion());
   for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
            vp_depth_first_deep(Plan.getEntry()))) {
-    for (VPRecipeBase &R : make_early_inc_range(VPBB->phis())) {
+    for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
+      if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
+        expandVPExtendedReduction(ExtRed);
+        continue;
+      }
+      if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+        expandVPMulAcc(MulAcc);
+        continue;
+      }
       if (!isa<VPCanonicalIVPHIRecipe, VPEVLBasedIVPHIRecipe>(&R))
         continue;
       auto *PhiR = cast<VPHeaderPHIRecipe>(&R);

>From 729a70ea055a14351d1ae3ae0f49a6e265a67447 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 10 Dec 2024 16:12:15 -0800
Subject: [PATCH 18/26] Move VPReductionRecipe inherite from
 VPRecipeWithIRFlags.

This commit make 2 major changes.

1. VPReductionRecipe inherite form VPRecipeWithIRFlags.
2. VPMulacc/VPExtendedReduction also inherite from VPRecipeWithIRFlags.

Note that we need to disable droping flags for VPReductionRecipe in
`clearReductionWithIRFlags()` to prevent flags of Mul dropped.

Note that we temporary prevent the EVL transformation for
VPExtendedReductionRecipe and VPMulAccRecipe since we have not
implemeted the EVL-version of these recipes.

This is mostly NFC only few vplan-printing change.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    |   2 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         | 144 ++++++++++++------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |   7 +-
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  29 ++--
 .../LoopVectorize/vplan-printing.ll           |  36 ++---
 5 files changed, 136 insertions(+), 82 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 02b3f1001fd276..cce50541729227 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9266,7 +9266,7 @@ static VPExtendedReductionRecipe *tryToMatchAndCreateExtendedReduction(
   Type *RedTy = RdxDesc.getRecurrenceType();
 
   // Test if the cost of extended-reduction is valid and clamp the range.
-  // Note that reduction-extended is not always valid for all VF and types.
+  // Note that extended-reduction is not always valid for all VF and types.
   auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
                                              Type *SrcTy) -> bool {
     return LoopVectorizationPlanner::getDecisionAndClampRange(
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 60ab1d156e8806..f36e5ececfc03e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1075,7 +1075,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
            R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
            R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
            R->getVPDefID() == VPRecipeBase::VPReverseVectorPointerSC ||
-           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
+           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
+           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPMulAccSC;
   }
 
   static inline bool classof(const VPUser *U) {
@@ -2609,7 +2611,7 @@ class VPInterleaveRecipe : public VPRecipeBase {
 /// A recipe to represent inloop reduction operations, performing a reduction on
 /// a vector operand into a scalar value, and adding the result to a chain.
 /// The Operands are {ChainOp, VecOp, [Condition]}.
-class VPReductionRecipe : public VPSingleDefRecipe {
+class VPReductionRecipe : public VPRecipeWithIRFlags {
   /// The recurrence decriptor for the reduction in question.
   const RecurrenceDescriptor &RdxDesc;
   bool IsOrdered;
@@ -2620,7 +2622,45 @@ class VPReductionRecipe : public VPSingleDefRecipe {
   VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
                     Instruction *I, ArrayRef<VPValue *> Operands,
                     VPValue *CondOp, bool IsOrdered)
-      : VPSingleDefRecipe(SC, Operands, I), RdxDesc(R), IsOrdered(IsOrdered) {
+      : VPRecipeWithIRFlags(SC, Operands, *I), RdxDesc(R),
+        IsOrdered(IsOrdered) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+  }
+
+  // For VPExtendedReductionRecipe.
+  // Note that IsNonNeg flag and the debug location are for extend instruction.
+  VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+                    ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                    bool IsOrdered, bool IsNonNeg, DebugLoc DL)
+      : VPRecipeWithIRFlags(SC, Operands, NonNegFlagsTy(IsNonNeg), DL),
+        RdxDesc(R), IsOrdered(IsOrdered) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+  }
+
+  // For VPMulAccRecipe.
+  // Note that the NUW/NSW and DL are for mul instruction.
+  VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+                    ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                    bool IsOrdered, bool NUW, bool NSW, DebugLoc DL)
+      : VPRecipeWithIRFlags(SC, Operands, WrapFlagsTy(NUW, NSW), DL),
+        RdxDesc(R), IsOrdered(IsOrdered) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+  }
+
+  VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+                    ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                    bool IsOrdered, DebugLoc DL)
+      : VPRecipeWithIRFlags(SC, Operands, DL), RdxDesc(R),
+        IsOrdered(IsOrdered) {
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2635,6 +2675,13 @@ class VPReductionRecipe : public VPSingleDefRecipe {
                           ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
                           IsOrdered) {}
 
+  VPReductionRecipe(const RecurrenceDescriptor &R, VPValue *ChainOp,
+                    VPValue *VecOp, VPValue *CondOp, bool IsOrdered,
+                    DebugLoc DL)
+      : VPReductionRecipe(VPDef::VPReductionSC, R,
+                          ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
+                          IsOrdered, DL) {}
+
   ~VPReductionRecipe() override = default;
 
   VPReductionRecipe *clone() override {
@@ -2644,7 +2691,9 @@ class VPReductionRecipe : public VPSingleDefRecipe {
 
   static inline bool classof(const VPRecipeBase *R) {
     return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
-           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
+           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
+           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPMulAccSC;
   }
 
   static inline bool classof(const VPUser *U) {
@@ -2725,15 +2774,13 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 };
 
 /// A recipe to represent inloop extended reduction operations, performing a
-/// reduction on a vector operand into a scalar value, and adding the result to
-/// a chain. This recipe is high level abstract which will generate
-/// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
-/// {ChainOp, VecOp, [Condition]}.
+/// reduction on a extended vector operand into a scalar value, and adding the
+/// result to a chain. This recipe is abstract and needs to be lowered to
+/// concrete recipes before codegen. The Operands are {ChainOp, VecOp,
+/// [Condition]}.
 class VPExtendedReductionRecipe : public VPReductionRecipe {
   Instruction::CastOps ExtOp;
-  DebugLoc ExtDL;
-  /// Non-negative flag for the extended instruction.
-  bool IsNonNeg;
+  DebugLoc RedDL;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
@@ -2741,9 +2788,9 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
                             Instruction::CastOps ExtOp, DebugLoc ExtDL,
                             bool IsNonNeg, VPValue *ChainOp, VPValue *VecOp,
                             VPValue *CondOp, bool IsOrdered)
-      : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
-                          CondOp, IsOrdered),
-        ExtOp(ExtOp), ExtDL(ExtDL), IsNonNeg(IsNonNeg) {}
+      : VPReductionRecipe(SC, R, ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
+                          IsOrdered, IsNonNeg, ExtDL),
+        ExtOp(ExtOp), RedDL(RedI->getDebugLoc()) {}
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
@@ -2787,25 +2834,19 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
   /// The Opcode of extend instruction.
   Instruction::CastOps getExtOpcode() const { return ExtOp; }
 
-  bool getNonNegFlags() const { return IsNonNeg; }
-
   /// Return the debug location of the extend instruction.
-  DebugLoc getExtDebugLoc() const { return ExtDL; }
+  DebugLoc getExtDebugLoc() const { return getDebugLoc(); }
+
+  /// Return the debug location of the reduction instruction.
+  DebugLoc getRedDebugLoc() const { return RedDL; }
 };
 
-/// A recipe to represent inloop MulAccreduction operations, performing a
-/// reduction on a vector operand into a scalar value, and adding the result to
-/// a chain. This recipe is high level abstract which will generate
-/// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
-/// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+/// A recipe to represent inloop MulAccReduction operations, performing a
+/// reduction.add on the result of vector operands (might be extended)
+/// multiplication into a scalar value, and adding the result to a chain. This
+/// recipe is abstract and needs to be lowered to concrete recipes before
+/// codegen. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPReductionRecipe {
-
-  // WrapFlags (NoUnsignedWraps, NoSignedWraps)
-  bool NUW;
-  bool NSW;
-
-  // Debug location of the mul instruction.
-  DebugLoc MulDL;
   Instruction::CastOps ExtOp;
 
   // Non-neg flag for the extend instruction.
@@ -2815,6 +2856,9 @@ class VPMulAccRecipe : public VPReductionRecipe {
   DebugLoc Ext0DL;
   DebugLoc Ext1DL;
 
+  // Debug location of reduction instruction.
+  DebugLoc RedDL;
+
   // Is this mul-acc recipe contains extend recipes?
   bool IsExtended = false;
 
@@ -2824,12 +2868,12 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  Instruction::CastOps ExtOp, bool IsNonNeg, DebugLoc Ext0DL,
                  DebugLoc Ext1DL, VPValue *ChainOp, VPValue *VecOp0,
                  VPValue *VecOp1, VPValue *CondOp, bool IsOrdered)
-      : VPReductionRecipe(SC, R, RedI,
-                          ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
-                          CondOp, IsOrdered),
-        NUW(NUW), NSW(NSW), MulDL(MulDL), ExtOp(ExtOp), IsNonNeg(IsNonNeg),
-        Ext0DL(Ext0DL), Ext1DL(Ext1DL) {
-    assert(R.getOpcode() == Instruction::Add);
+      : VPReductionRecipe(SC, R, ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
+                          CondOp, IsOrdered, NUW, NSW, MulDL),
+        ExtOp(ExtOp), IsNonNeg(IsNonNeg), Ext0DL(Ext0DL), Ext1DL(Ext1DL),
+        RedDL(RedI->getDebugLoc()) {
+    assert(R.getOpcode() == Instruction::Add &&
+           "The reduction instruction in MulAccRecipe must be Add");
     IsExtended = true;
   }
 
@@ -2837,11 +2881,11 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  Instruction *RedI, bool NUW, bool NSW, DebugLoc MulDL,
                  VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
                  VPValue *CondOp, bool IsOrdered)
-      : VPReductionRecipe(SC, R, RedI,
-                          ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
-                          CondOp, IsOrdered),
-        NUW(NUW), NSW(NSW), MulDL(MulDL) {
-    assert(R.getOpcode() == Instruction::Add);
+      : VPReductionRecipe(SC, R, ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
+                          CondOp, IsOrdered, NUW, NSW, MulDL),
+        RedDL(RedI->getDebugLoc()) {
+    assert(R.getOpcode() == Instruction::Add &&
+           "The reduction instruction in MulAccRecipe must be Add");
   }
 
 public:
@@ -2903,21 +2947,25 @@ class VPMulAccRecipe : public VPReductionRecipe {
     return ExtOp == Instruction::CastOps::ZExt;
   }
 
-  /// Return the non-neg flag.
-  bool getNonNegFlags() const { return IsNonNeg; }
+  /// Return the non negative flag for the ext instruction.
+  bool isNonNeg() const { return IsNonNeg; }
 
-  /// Return no-unsigned-wrap.
-  bool hasNoUnsignedWrap() const { return NUW; }
-
-  /// Return no-signed-warp.
-  bool hasNoSignedWrap() const { return NSW; }
+  /// Drop flags in this recipe.
+  void dropPoisonGeneratingFlags() {
+    VPRecipeWithIRFlags::dropPoisonGeneratingFlags();
+    // Also drop the extra flags not in VPRecipeWithIRFlags.
+    this->IsNonNeg = false;
+  }
 
   /// Return debug location of mul instruction.
-  DebugLoc getMulDebugLoc() const { return MulDL; }
+  DebugLoc getMulDebugLoc() const { return getDebugLoc(); }
 
   /// Return debug location of extend instructions.
   DebugLoc getExt0DebugLoc() const { return Ext0DL; }
   DebugLoc getExt1DebugLoc() const { return Ext1DL; }
+
+  /// Return the debug location of reduction instruction.
+  DebugLoc getRedDebugLoc() const { return RedDL; }
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 2e35c4e738a156..07d67af29aaba7 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2389,8 +2389,6 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " = ";
   getChainOp()->printAsOperand(O, SlotTracker);
   O << " +";
-  if (isa<FPMathOperator>(getUnderlyingInstr()))
-    O << getUnderlyingInstr()->getFastMathFlags();
   O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
   getVecOp()->printAsOperand(O, SlotTracker);
   O << " extended to " << *getResultType();
@@ -2414,10 +2412,9 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " = ";
   getChainOp()->printAsOperand(O, SlotTracker);
   O << " + ";
-  if (isa<FPMathOperator>(getUnderlyingInstr()))
-    O << getUnderlyingInstr()->getFastMathFlags();
   O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
-  O << "mul ";
+  O << "mul";
+  printFlags(O);
   if (isExtended())
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index d4e84ecc46818d..eb47a855cba4cc 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -934,10 +934,17 @@ void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) {
     if (RK != RecurKind::Add && RK != RecurKind::Mul)
       continue;
 
-    for (VPUser *U : collectUsersRecursively(PhiR))
+    for (VPUser *U : collectUsersRecursively(PhiR)) {
+      // Flags in reduction recipes are control/store by the recurrence
+      // descriptor. Dropping flags for VPExtendedReductionRecipe and
+      // VPMulaccRecipe may drop flags that we don't expect to drop.
+      if (isa<VPReductionRecipe>(U))
+        continue;
+
       if (auto *RecWithFlags = dyn_cast<VPRecipeWithIRFlags>(U)) {
         RecWithFlags->dropPoisonGeneratingFlags();
       }
+    }
   }
 }
 
@@ -1476,6 +1483,8 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) {
                   return nullptr;
                 return new VPWidenEVLRecipe(*W, EVL);
               })
+              .Case<VPExtendedReductionRecipe, VPMulAccRecipe>(
+                  [](const auto *R) { return nullptr; })
               .Case<VPReductionRecipe>([&](VPReductionRecipe *Red) {
                 VPValue *NewMask = GetNewMask(Red->getCondOp());
                 return new VPReductionEVLRecipe(*Red, EVL, NewMask);
@@ -1826,9 +1835,9 @@ static void expandVPExtendedReduction(VPExtendedReductionRecipe *ExtRed) {
   VPWidenCastRecipe *Ext;
   // Only ZExt contiains non-neg flags.
   if (ExtRed->isZExt())
-    Ext = new VPWidenCastRecipe(
-        ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
-        ExtRed->getNonNegFlags(), ExtRed->getExtDebugLoc());
+    Ext = new VPWidenCastRecipe(ExtRed->getExtOpcode(), ExtRed->getVecOp(),
+                                ExtRed->getResultType(), ExtRed->isNonNeg(),
+                                ExtRed->getExtDebugLoc());
   else
     Ext = new VPWidenCastRecipe(ExtRed->getExtOpcode(), ExtRed->getVecOp(),
                                 ExtRed->getResultType(),
@@ -1836,8 +1845,8 @@ static void expandVPExtendedReduction(VPExtendedReductionRecipe *ExtRed) {
 
   // Generate VPreductionRecipe.
   auto *Red = new VPReductionRecipe(
-      ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
-      ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
+      ExtRed->getRecurrenceDescriptor(), ExtRed->getChainOp(), Ext,
+      ExtRed->getCondOp(), ExtRed->isOrdered(), ExtRed->getRedDebugLoc());
   Ext->insertBefore(ExtRed);
   Red->insertBefore(ExtRed);
   ExtRed->replaceAllUsesWith(Red);
@@ -1856,7 +1865,7 @@ static void expandVPMulAcc(VPMulAccRecipe *MulAcc) {
   if (MulAcc->isExtended()) {
     if (MulAcc->isZExt())
       Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
-                                  RedTy, MulAcc->getNonNegFlags(),
+                                  RedTy, MulAcc->isNonNeg(),
                                   MulAcc->getExt0DebugLoc());
     else
       Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
@@ -1869,7 +1878,7 @@ static void expandVPMulAcc(VPMulAccRecipe *MulAcc) {
     } else {
       if (MulAcc->isZExt())
         Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
-                                    RedTy, MulAcc->getNonNegFlags(),
+                                    RedTy, MulAcc->isNonNeg(),
                                     MulAcc->getExt1DebugLoc());
       else
         Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
@@ -1892,8 +1901,8 @@ static void expandVPMulAcc(VPMulAccRecipe *MulAcc) {
 
   // Generate VPReductionRecipe.
   auto *Red = new VPReductionRecipe(
-      MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
-      MulAcc->getChainOp(), Mul, MulAcc->getCondOp(), MulAcc->isOrdered());
+      MulAcc->getRecurrenceDescriptor(), MulAcc->getChainOp(), Mul,
+      MulAcc->getCondOp(), MulAcc->isOrdered(), MulAcc->getRedDebugLoc());
   Red->insertBefore(MulAcc);
 
   MulAcc->replaceAllUsesWith(Red);
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index a80723bd35b52a..b5e19c6cc3c0b8 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -1156,12 +1156,12 @@ define i64 @print_extended_reduction(ptr nocapture readonly %x, ptr nocapture re
 ; CHECK-NEXT: <x1> vector loop: {
 ; CHECK-NEXT:   vector.body:
 ; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
-; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, vp<%5>
 ; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
 ; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
 ; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
 ; CHECK-NEXT:     WIDEN ir<%load0> = load vp<%4>
-; CHECK-NEXT:     EXTENDED-REDUCE ir<%add> = ir<%r.09> + reduce.add (ir<%load0> extended to i64)
+; CHECK-NEXT:     EXTENDED-REDUCE vp<%5> = ir<%r.09> + reduce.add (ir<%load0> extended to i64)
 ; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
 ; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
 ; CHECK-NEXT:   No successors
@@ -1169,18 +1169,18 @@ define i64 @print_extended_reduction(ptr nocapture readonly %x, ptr nocapture re
 ; CHECK-NEXT: Successor(s): middle.block
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
-; CHECK-NEXT:   EMIT vp<%6> = compute-reduction-result ir<%r.09>, ir<%add>
-; CHECK-NEXT:   EMIT vp<%7> = extract-from-end vp<%6>, ir<1>
+; CHECK-NEXT:   EMIT vp<%7> = compute-reduction-result ir<%r.09>, vp<%5>
+; CHECK-NEXT:   EMIT vp<%8> = extract-from-end vp<%7>, ir<1>
 ; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
 ; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
 ; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
-; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%7> from middle.block)
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8> from middle.block)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: scalar.ph:
-; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%6>, ir<0>
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%7>, ir<0>
 ; CHECK-NEXT: Successor(s): ir-bb<for.body>
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.body>:
@@ -1228,7 +1228,7 @@ define i64 @print_mulacc(ptr nocapture readonly %x, ptr nocapture readonly %y, i
 ; CHECK-NEXT: <x1> vector loop: {
 ; CHECK-NEXT:   vector.body:
 ; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
-; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, vp<%6>
 ; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
 ; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
 ; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
@@ -1236,7 +1236,7 @@ define i64 @print_mulacc(ptr nocapture readonly %x, ptr nocapture readonly %y, i
 ; CHECK-NEXT:     CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
 ; CHECK-NEXT:     vp<%5> = vector-pointer ir<%arrayidx1>
 ; CHECK-NEXT:     WIDEN ir<%load1> = load vp<%5>
-; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> + reduce.add (mul ir<%load0>, ir<%load1>)
+; CHECK-NEXT:     MULACC-REDUCE vp<%6> = ir<%r.09> + reduce.add (mul nsw ir<%load0>, ir<%load1>)
 ; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
 ; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
 ; CHECK-NEXT:   No successors
@@ -1244,18 +1244,18 @@ define i64 @print_mulacc(ptr nocapture readonly %x, ptr nocapture readonly %y, i
 ; CHECK-NEXT: Successor(s): middle.block
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
-; CHECK-NEXT:   EMIT vp<%7> = compute-reduction-result ir<%r.09>, ir<%add>
-; CHECK-NEXT:   EMIT vp<%8> = extract-from-end vp<%7>, ir<1>
+; CHECK-NEXT:   EMIT vp<%8> = compute-reduction-result ir<%r.09>, vp<%6>
+; CHECK-NEXT:   EMIT vp<%9> = extract-from-end vp<%8>, ir<1>
 ; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
 ; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
 ; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
-; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8> from middle.block)
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%9> from middle.block)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: scalar.ph:
-; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%7>, ir<0>
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%8>, ir<0>
 ; CHECK-NEXT: Successor(s): ir-bb<for.body>
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.body>:
@@ -1307,7 +1307,7 @@ define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture reado
 ; CHECK-NEXT: <x1> vector loop: {
 ; CHECK-NEXT:   vector.body:
 ; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
-; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, vp<%6>
 ; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
 ; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
 ; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
@@ -1315,7 +1315,7 @@ define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture reado
 ; CHECK-NEXT:     CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
 ; CHECK-NEXT:     vp<%5> = vector-pointer ir<%arrayidx1>
 ; CHECK-NEXT:     WIDEN ir<%load1> = load vp<%5>
-; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> + reduce.add (mul (ir<%load0> extended to i64), (ir<%load1> extended to i64))
+; CHECK-NEXT:     MULACC-REDUCE vp<%6> = ir<%r.09> + reduce.add (mul nsw (ir<%load0> extended to i64), (ir<%load1> extended to i64))
 ; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
 ; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
 ; CHECK-NEXT:   No successors
@@ -1323,18 +1323,18 @@ define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture reado
 ; CHECK-NEXT: Successor(s): middle.block
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
-; CHECK-NEXT:   EMIT vp<%7> = compute-reduction-result ir<%r.09>, ir<%add>
-; CHECK-NEXT:   EMIT vp<%8> = extract-from-end vp<%7>, ir<1>
+; CHECK-NEXT:   EMIT vp<%8> = compute-reduction-result ir<%r.09>, vp<%6>
+; CHECK-NEXT:   EMIT vp<%9> = extract-from-end vp<%8>, ir<1>
 ; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
 ; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
 ; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
-; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8> from middle.block)
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%9> from middle.block)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: scalar.ph:
-; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%7>, ir<0>
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%8>, ir<0>
 ; CHECK-NEXT: Successor(s): ir-bb<for.body>
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.body>:

>From ea58282b72ee77c5c5f179d8e095d88836157224 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 10 Dec 2024 19:20:54 -0800
Subject: [PATCH 19/26] Only create VPMulAcc/VPExtendedReduction recipe when
 beneficial. NFC

Only create the VPMulaccRecipe and the VPExtendedReductionRecipe when it's beneficial.

Also address comments and cleanup computeCost().
---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  74 ++++++++----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 111 ++----------------
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  20 ++--
 3 files changed, 68 insertions(+), 137 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index cce50541729227..724a98b76fe48d 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9265,18 +9265,30 @@ static VPExtendedReductionRecipe *tryToMatchAndCreateExtendedReduction(
   VPValue *A;
   Type *RedTy = RdxDesc.getRecurrenceType();
 
-  // Test if the cost of extended-reduction is valid and clamp the range.
-  // Note that extended-reduction is not always valid for all VF and types.
+  // Test if using extended-reduction is beneficial and clamp the range.
   auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
                                              Type *SrcTy) -> bool {
     return LoopVectorizationPlanner::getDecisionAndClampRange(
         [&](ElementCount VF) {
           VectorType *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
-          return Ctx.TTI
-              .getExtendedReductionCost(Opcode, isZExt, RedTy, SrcVecTy,
-                                        RdxDesc.getFastMathFlags(),
-                                        TTI::TCK_RecipThroughput)
-              .isValid();
+          VectorType *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
+          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+          InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
+              Opcode, isZExt, RedTy, SrcVecTy, RdxDesc.getFastMathFlags(),
+              CostKind);
+          InstructionCost ExtCost =
+              cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
+          RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+          InstructionCost RedCost;
+          if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+            Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+            RedCost = Ctx.TTI.getMinMaxReductionCost(
+                Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+          } else {
+            RedCost = Ctx.TTI.getArithmeticReductionCost(
+                Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+          }
+          return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
         },
         Range);
   };
@@ -9312,16 +9324,32 @@ static VPMulAccRecipe *tryToMatchAndCreateMulAcc(
   VPValue *A, *B;
   Type *RedTy = RdxDesc.getRecurrenceType();
 
-  // Test if the cost of MulAcc is valid and clamp the range.
-  // Note that mul-acc is not always valid for all VF and types.
-  auto IsMulAccValidAndClampRange = [&](bool isZExt, Type *SrcTy) -> bool {
+  // Test if using mul-acc-reduction is beneficial and clamp the range.
+  auto IsMulAccValidAndClampRange =
+      [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+          VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
     return LoopVectorizationPlanner::getDecisionAndClampRange(
         [&](ElementCount VF) {
+          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+          Type *SrcTy =
+              Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
           VectorType *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
-          return Ctx.TTI
-              .getMulAccReductionCost(isZExt, RedTy, SrcVecTy,
-                                      TTI::TCK_RecipThroughput)
-              .isValid();
+          VectorType *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
+          InstructionCost MulAccCost =
+              Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
+          InstructionCost MulCost = Mul->computeCost(VF, Ctx);
+          InstructionCost RedCost = Ctx.TTI.getArithmeticReductionCost(
+              Instruction::Add, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+          InstructionCost ExtCost = 0;
+          if (Ext0)
+            ExtCost += Ext0->computeCost(VF, Ctx);
+          if (Ext1)
+            ExtCost += Ext1->computeCost(VF, Ctx);
+          if (OuterExt)
+            ExtCost += OuterExt->computeCost(VF, Ctx);
+
+          return MulAccCost.isValid() &&
+                 MulAccCost < ExtCost + MulCost + RedCost;
         },
         Range);
   };
@@ -9335,6 +9363,7 @@ static VPMulAccRecipe *tryToMatchAndCreateMulAcc(
         dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
     VPWidenCastRecipe *RecipeB =
         dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+    VPWidenRecipe *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
 
     // Matched reduce.add(mul(ext, ext))
     if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
@@ -9344,22 +9373,19 @@ static VPMulAccRecipe *tryToMatchAndCreateMulAcc(
       // Only create MulAccRecipe if the cost is valid.
       if (!IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
                                           Instruction::CastOps::ZExt,
-                                      Ctx.Types.inferScalarType(RecipeA)))
+                                      Mul, RecipeA, RecipeB, nullptr))
         return nullptr;
 
       return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-                                CM.useOrderedReductions(RdxDesc),
-                                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
-                                RecipeA, RecipeB);
+                                CM.useOrderedReductions(RdxDesc), Mul, RecipeA,
+                                RecipeB);
     } else {
       // Matched reduce.add(mul)
-      if (!IsMulAccValidAndClampRange(true, RedTy))
+      if (!IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
         return nullptr;
 
-      return new VPMulAccRecipe(
-          RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-          CM.useOrderedReductions(RdxDesc),
-          cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+      return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                                CM.useOrderedReductions(RdxDesc), Mul);
     }
     // Matched reduce.add(ext(mul(ext(A), ext(B))))
     // All extend instructions must have same opcode or A == B
@@ -9379,7 +9405,7 @@ static VPMulAccRecipe *tryToMatchAndCreateMulAcc(
       // Only create MulAcc recipe if the cost if valid.
       if (!IsMulAccValidAndClampRange(Ext0->getOpcode() ==
                                           Instruction::CastOps::ZExt,
-                                      Ctx.Types.inferScalarType(Ext0)))
+                                      Mul, Ext0, Ext1, Ext))
         return nullptr;
 
       return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 07d67af29aaba7..f6b1b4a8c24bb3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2218,121 +2218,24 @@ InstructionCost
 VPExtendedReductionRecipe::computeCost(ElementCount VF,
                                        VPCostContext &Ctx) const {
   const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
-  RecurKind RdxKind = RdxDesc.getRecurrenceKind();
-  Type *ElementTy = getResultType();
-  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  unsigned Opcode = RdxDesc.getOpcode();
+  Type *RedTy = Ctx.Types.inferScalarType(this);
   auto *SrcVecTy =
       cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
-  unsigned Opcode = RdxDesc.getOpcode();
-
-  // ExtendedReduction Cost
-  InstructionCost ExtendedRedCost =
-      Ctx.TTI.getExtendedReductionCost(Opcode, isZExt(), ElementTy, SrcVecTy,
-                                       RdxDesc.getFastMathFlags(), CostKind);
-
-  assert(ExtendedRedCost.isValid() && "VPExtendedReductionRecipe should not be "
-                                      "created if the cost is invalid.");
-
-  InstructionCost ReductionCost;
-  if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
-    Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    ReductionCost = Ctx.TTI.getMinMaxReductionCost(
-        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
-  } else {
-    ReductionCost = Ctx.TTI.getArithmeticReductionCost(
-        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
-  }
 
-  // Extended cost
-  TTI::CastContextHint CCH = computeCCH(getVecOp()->getDefiningRecipe(), VF);
-  // Arm TTI will use the underlying instruction to determine the cost.
-  InstructionCost ExtendedCost = Ctx.TTI.getCastInstrCost(
-      Opcode, VectorTy, SrcVecTy, CCH, TTI::TCK_RecipThroughput,
-      dyn_cast_if_present<Instruction>(getUnderlyingValue()));
-
-  // Check if folding ext into ExtendedReduction is profitable.
-  if (ExtendedRedCost.isValid() &&
-      ExtendedRedCost < ExtendedCost + ReductionCost) {
-    return ExtendedRedCost;
-  }
-
-  return ExtendedCost + ReductionCost;
+  return Ctx.TTI.getExtendedReductionCost(Opcode, isZExt(), RedTy, SrcVecTy,
+                                          RdxDesc.getFastMathFlags(), CostKind);
 }
 
 InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
-  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
-  Type *ElementTy = isExtended() ? RdxDesc.getRecurrenceType()
-                                 : Ctx.Types.inferScalarType(getVecOp0());
-  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
-  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
-  unsigned Opcode = RdxDesc.getOpcode();
-
-  assert(Opcode == Instruction::Add &&
-         "Reduction opcode must be add in the VPMulAccRecipe.");
-  // MulAccReduction Cost
+  Type *RedTy = Ctx.Types.inferScalarType(this);
   VectorType *SrcVecTy =
       cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
-  InstructionCost MulAccCost =
-      Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
-
-  assert(MulAccCost.isValid() && "VPMulAccRecipe should not be "
-                                 "created if the cost is invalid.");
-
-  // BaseCost = Reduction cost + BinOp cost
-  InstructionCost ReductionCost = Ctx.TTI.getArithmeticReductionCost(
-      Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
-
-  // Extended cost
-  InstructionCost ExtendedCost = 0;
-  if (isExtended()) {
-    TTI::CastContextHint CCH0 =
-        computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost = Ctx.TTI.getCastInstrCost(ExtOp, VectorTy, SrcVecTy, CCH0,
-                                            TTI::TCK_RecipThroughput);
-    TTI::CastContextHint CCH1 =
-        computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost += Ctx.TTI.getCastInstrCost(ExtOp, VectorTy, SrcVecTy, CCH1,
-                                             TTI::TCK_RecipThroughput);
-  }
-
-  // Mul cost
-  InstructionCost MulCost;
-  SmallVector<const Value *, 4> Operands;
-  if (isExtended())
-    MulCost = Ctx.TTI.getArithmeticInstrCost(
-        Instruction::Mul, VectorTy, CostKind,
-        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        Operands, nullptr, &Ctx.TLI);
-  else {
-    VPValue *RHS = getVecOp1();
-    // Certain instructions can be cheaper to vectorize if they have a constant
-    // second vector operand. One example of this are shifts on x86.
-    TargetTransformInfo::OperandValueInfo RHSInfo = {
-        TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None};
-    if (RHS->isLiveIn())
-      RHSInfo = Ctx.TTI.getOperandInfo(RHS->getLiveInIRValue());
-
-    if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
-        RHS->isDefinedOutsideLoopRegions())
-      RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
-    Operands.append(
-        {getVecOp0()->getUnderlyingValue(), RHS->getUnderlyingValue()});
-    MulCost = Ctx.TTI.getArithmeticInstrCost(
-        Instruction::Mul, VectorTy, CostKind,
-        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        RHSInfo, Operands, nullptr, &Ctx.TLI);
-  }
-
-  // Check if folding ext into ExtendedReduction is profitable.
-  if (MulAccCost.isValid() &&
-      MulAccCost < ExtendedCost + ReductionCost + MulCost) {
-    return MulAccCost;
-  }
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
 
-  return ExtendedCost + ReductionCost + MulCost;
+  return Ctx.TTI.getMulAccReductionCost(isZExt(), RedTy, SrcVecTy, CostKind);
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index eb47a855cba4cc..dd330b50240bb6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -1914,15 +1914,7 @@ void VPlanTransforms::prepareToExecute(VPlan &Plan) {
       Plan.getVectorLoopRegion());
   for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
            vp_depth_first_deep(Plan.getEntry()))) {
-    for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
-      if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
-        expandVPExtendedReduction(ExtRed);
-        continue;
-      }
-      if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
-        expandVPMulAcc(MulAcc);
-        continue;
-      }
+    for (VPRecipeBase &R : make_early_inc_range(VPBB->phis())) {
       if (!isa<VPCanonicalIVPHIRecipe, VPEVLBasedIVPHIRecipe>(&R))
         continue;
       auto *PhiR = cast<VPHeaderPHIRecipe>(&R);
@@ -1935,5 +1927,15 @@ void VPlanTransforms::prepareToExecute(VPlan &Plan) {
       PhiR->replaceAllUsesWith(ScalarR);
       PhiR->eraseFromParent();
     }
+    for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
+      if (!isa<VPExtendedReductionRecipe, VPMulAccRecipe>(&R))
+        continue;
+      if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
+        expandVPExtendedReduction(ExtRed);
+      }
+      if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+        expandVPMulAcc(MulAcc);
+      }
+    }
   }
 }

>From a9874560fd6567712146be0dca3fe8fb4779a203 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Wed, 11 Dec 2024 22:51:22 -0800
Subject: [PATCH 20/26] !fixup use `auto`

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 31 +++++++++----------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  2 +-
 2 files changed, 16 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index a34dc8e74d07fd..1e83cd5ef86356 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9349,8 +9349,8 @@ static VPExtendedReductionRecipe *tryToMatchAndCreateExtendedReduction(
                                              Type *SrcTy) -> bool {
     return LoopVectorizationPlanner::getDecisionAndClampRange(
         [&](ElementCount VF) {
-          VectorType *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
-          VectorType *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
+          auto *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+          auto *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
           TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
           InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
               Opcode, isZExt, RedTy, SrcVecTy, RdxDesc.getFastMathFlags(),
@@ -9412,8 +9412,8 @@ static VPMulAccRecipe *tryToMatchAndCreateMulAcc(
           TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
           Type *SrcTy =
               Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
-          VectorType *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
-          VectorType *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
+          auto *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+          auto *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
           InstructionCost MulAccCost =
               Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
           InstructionCost MulCost = Mul->computeCost(VF, Ctx);
@@ -9438,16 +9438,17 @@ static VPMulAccRecipe *tryToMatchAndCreateMulAcc(
 
   // Try to match reduce.add(mul(...))
   if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
-    VPWidenCastRecipe *RecipeA =
+    auto *RecipeA =
         dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
-    VPWidenCastRecipe *RecipeB =
+    auto *RecipeB =
         dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
-    VPWidenRecipe *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
+    auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
 
     // Matched reduce.add(mul(ext, ext))
-    if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
-        match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
-        (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
+    if (RecipeA && RecipeB &&
+        (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) &&
+        match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+        match(RecipeB, m_ZExtOrSExt(m_VPValue()))) {
 
       // Only create MulAccRecipe if the cost is valid.
       if (!IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
@@ -9471,13 +9472,11 @@ static VPMulAccRecipe *tryToMatchAndCreateMulAcc(
     // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
   } else if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
                                              m_ZExtOrSExt(m_VPValue()))))) {
-    VPWidenCastRecipe *Ext =
-        cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
-    VPWidenRecipe *Mul =
-        cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
-    VPWidenCastRecipe *Ext0 =
+    auto *Ext = cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+    auto *Mul = cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+    auto *Ext0 =
         cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
-    VPWidenCastRecipe *Ext1 =
+    auto *Ext1 =
         cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
     if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
         Ext0->getOpcode() == Ext1->getOpcode()) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 5bdaab04bd2c6a..ae268ea2f96f90 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2235,7 +2235,7 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
   Type *RedTy = Ctx.Types.inferScalarType(this);
-  VectorType *SrcVecTy =
+  auto *SrcVecTy =
       cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
 

>From 6c434c73f36a81212326b0c8cb5571892d920d39 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 12 Dec 2024 02:05:16 -0800
Subject: [PATCH 21/26] !fixup VPReductionRecipe unit tests.

Update unit tests to contruct VPReductionRecipe without null underlying
instruction. With null instruction, it will fail to in the ctor of
VPRecipeWithIRFlags.

Note that this commit will drop the underlying instruction when transform
VPReductionRecipe to VPReductionEVLRecipe but preserve the debug
location.
---
 llvm/lib/Transforms/Vectorize/VPlan.h            |  5 ++---
 .../unittests/Transforms/Vectorize/VPlanTest.cpp | 16 ++++++++--------
 2 files changed, 10 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index db602d7a3adaa4..bdcdafd7181d68 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2677,7 +2677,7 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
 
   VPReductionRecipe(const RecurrenceDescriptor &R, VPValue *ChainOp,
                     VPValue *VecOp, VPValue *CondOp, bool IsOrdered,
-                    DebugLoc DL)
+                    DebugLoc DL = {})
       : VPReductionRecipe(VPDef::VPReductionSC, R,
                           ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
                           IsOrdered, DL) {}
@@ -2741,9 +2741,8 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
   VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp)
       : VPReductionRecipe(
             VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
-            cast_or_null<Instruction>(R.getUnderlyingValue()),
             ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
-            R.isOrdered()) {}
+            R.isOrdered(), R.getDebugLoc()) {}
 
   ~VPReductionEVLRecipe() override = default;
 
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
index 3179cfc676ab67..31f97f85a44896 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
@@ -1168,8 +1168,8 @@ TEST(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
     VPValue ChainOp;
     VPValue VecOp;
     VPValue CondOp;
-    VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, &ChainOp, &CondOp,
-                             &VecOp, false);
+    VPReductionRecipe Recipe(RecurrenceDescriptor(), &ChainOp, &CondOp, &VecOp,
+                             false);
     EXPECT_FALSE(Recipe.mayHaveSideEffects());
     EXPECT_FALSE(Recipe.mayReadFromMemory());
     EXPECT_FALSE(Recipe.mayWriteToMemory());
@@ -1180,8 +1180,8 @@ TEST(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
     VPValue ChainOp;
     VPValue VecOp;
     VPValue CondOp;
-    VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, &ChainOp, &CondOp,
-                             &VecOp, false);
+    VPReductionRecipe Recipe(RecurrenceDescriptor(), &ChainOp, &CondOp, &VecOp,
+                             false);
     VPValue EVL;
     VPReductionEVLRecipe EVLRecipe(Recipe, EVL, &CondOp);
     EXPECT_FALSE(EVLRecipe.mayHaveSideEffects());
@@ -1544,8 +1544,8 @@ TEST(VPRecipeTest, CastVPReductionRecipeToVPUser) {
   VPValue ChainOp;
   VPValue VecOp;
   VPValue CondOp;
-  VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, &ChainOp, &CondOp,
-                           &VecOp, false);
+  VPReductionRecipe Recipe(RecurrenceDescriptor(), &ChainOp, &CondOp, &VecOp,
+                           false);
   EXPECT_TRUE(isa<VPUser>(&Recipe));
   VPRecipeBase *BaseR = &Recipe;
   EXPECT_TRUE(isa<VPUser>(BaseR));
@@ -1557,8 +1557,8 @@ TEST(VPRecipeTest, CastVPReductionEVLRecipeToVPUser) {
   VPValue ChainOp;
   VPValue VecOp;
   VPValue CondOp;
-  VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, &ChainOp, &CondOp,
-                           &VecOp, false);
+  VPReductionRecipe Recipe(RecurrenceDescriptor(), &ChainOp, &CondOp, &VecOp,
+                           false);
   VPValue EVL;
   VPReductionEVLRecipe EVLRecipe(Recipe, EVL, &CondOp);
   EXPECT_TRUE(isa<VPUser>(&EVLRecipe));

>From f4b1b78fdb385236f52b04e2e47692d3f1685457 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 22 Dec 2024 17:09:20 -0800
Subject: [PATCH 22/26] !fixup migrate tryTo* to VPlanTransforms

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 182 +-----------------
 llvm/lib/Transforms/Vectorize/VPlan.h         |  85 ++++----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |   7 +-
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 171 +++++++++++++++-
 .../Transforms/Vectorize/VPlanTransforms.h    |  25 +++
 llvm/lib/Transforms/Vectorize/VPlanValue.h    |   2 +-
 6 files changed, 252 insertions(+), 220 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1e83cd5ef86356..f64d2ebaedcfb2 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9330,170 +9330,6 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
   return Plan;
 }
 
-/// Try to match the extended-reduction and create VPExtendedReductionRecipe.
-///
-/// This function try to match following pattern which will generate
-/// extended-reduction instruction.
-///    reduce(ext(...)).
-static VPExtendedReductionRecipe *tryToMatchAndCreateExtendedReduction(
-    const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
-    VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp,
-    LoopVectorizationCostModel &CM, VPCostContext &Ctx, VFRange &Range) {
-  using namespace VPlanPatternMatch;
-
-  VPValue *A;
-  Type *RedTy = RdxDesc.getRecurrenceType();
-
-  // Test if using extended-reduction is beneficial and clamp the range.
-  auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
-                                             Type *SrcTy) -> bool {
-    return LoopVectorizationPlanner::getDecisionAndClampRange(
-        [&](ElementCount VF) {
-          auto *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
-          auto *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
-          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
-          InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
-              Opcode, isZExt, RedTy, SrcVecTy, RdxDesc.getFastMathFlags(),
-              CostKind);
-          InstructionCost ExtCost =
-              cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
-          RecurKind RdxKind = RdxDesc.getRecurrenceKind();
-          InstructionCost RedCost;
-          if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
-            Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-            RedCost = Ctx.TTI.getMinMaxReductionCost(
-                Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
-          } else {
-            RedCost = Ctx.TTI.getArithmeticReductionCost(
-                Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
-          }
-          return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
-        },
-        Range);
-  };
-
-  // Matched reduce(ext)).
-  if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
-    if (!IsExtendedRedValidAndClampRange(
-            RdxDesc.getOpcode(),
-            cast<VPWidenCastRecipe>(VecOp)->getOpcode() ==
-                Instruction::CastOps::ZExt,
-            Ctx.Types.inferScalarType(A)))
-      return nullptr;
-    return new VPExtendedReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
-                                         cast<VPWidenCastRecipe>(VecOp), CondOp,
-                                         CM.useOrderedReductions(RdxDesc));
-  }
-  return nullptr;
-}
-
-/// Try to match the mul-acc-reduction and create VPMulAccRecipe.
-///
-/// This function try to match following patterns which will generate mul-acc
-/// instructions.
-///    reduce.add(mul(...)),
-///    reduce.add(mul(ext(A), ext(B))),
-///    reduce.add(ext(mul(ext(A), ext(B)))).
-static VPMulAccRecipe *tryToMatchAndCreateMulAcc(
-    const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
-    VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp,
-    LoopVectorizationCostModel &CM, VPCostContext &Ctx, VFRange &Range) {
-  using namespace VPlanPatternMatch;
-
-  VPValue *A, *B;
-  Type *RedTy = RdxDesc.getRecurrenceType();
-
-  // Test if using mul-acc-reduction is beneficial and clamp the range.
-  auto IsMulAccValidAndClampRange =
-      [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
-          VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
-    return LoopVectorizationPlanner::getDecisionAndClampRange(
-        [&](ElementCount VF) {
-          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
-          Type *SrcTy =
-              Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
-          auto *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
-          auto *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
-          InstructionCost MulAccCost =
-              Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
-          InstructionCost MulCost = Mul->computeCost(VF, Ctx);
-          InstructionCost RedCost = Ctx.TTI.getArithmeticReductionCost(
-              Instruction::Add, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
-          InstructionCost ExtCost = 0;
-          if (Ext0)
-            ExtCost += Ext0->computeCost(VF, Ctx);
-          if (Ext1)
-            ExtCost += Ext1->computeCost(VF, Ctx);
-          if (OuterExt)
-            ExtCost += OuterExt->computeCost(VF, Ctx);
-
-          return MulAccCost.isValid() &&
-                 MulAccCost < ExtCost + MulCost + RedCost;
-        },
-        Range);
-  };
-
-  if (RdxDesc.getOpcode() != Instruction::Add)
-    return nullptr;
-
-  // Try to match reduce.add(mul(...))
-  if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
-    auto *RecipeA =
-        dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
-    auto *RecipeB =
-        dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
-    auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
-
-    // Matched reduce.add(mul(ext, ext))
-    if (RecipeA && RecipeB &&
-        (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) &&
-        match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
-        match(RecipeB, m_ZExtOrSExt(m_VPValue()))) {
-
-      // Only create MulAccRecipe if the cost is valid.
-      if (!IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
-                                          Instruction::CastOps::ZExt,
-                                      Mul, RecipeA, RecipeB, nullptr))
-        return nullptr;
-
-      return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-                                CM.useOrderedReductions(RdxDesc), Mul, RecipeA,
-                                RecipeB);
-    } else {
-      // Matched reduce.add(mul)
-      if (!IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
-        return nullptr;
-
-      return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-                                CM.useOrderedReductions(RdxDesc), Mul);
-    }
-    // Matched reduce.add(ext(mul(ext(A), ext(B))))
-    // All extend instructions must have same opcode or A == B
-    // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
-  } else if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
-                                             m_ZExtOrSExt(m_VPValue()))))) {
-    auto *Ext = cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
-    auto *Mul = cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
-    auto *Ext0 =
-        cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
-    auto *Ext1 =
-        cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
-    if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
-        Ext0->getOpcode() == Ext1->getOpcode()) {
-      // Only create MulAcc recipe if the cost if valid.
-      if (!IsMulAccValidAndClampRange(Ext0->getOpcode() ==
-                                          Instruction::CastOps::ZExt,
-                                      Mul, Ext0, Ext1, Ext))
-        return nullptr;
-
-      return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-                                CM.useOrderedReductions(RdxDesc), Mul, Ext0,
-                                Ext1);
-    }
-  }
-  return nullptr;
-}
-
 // Adjust the recipes for reductions. For in-loop reductions the chain of
 // instructions leading from the loop exit instr to the phi need to be converted
 // to reductions, with one operand being vector and the other being the scalar
@@ -9629,18 +9465,20 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       VPReductionRecipe *RedRecipe;
       VPCostContext CostCtx(CM.TTI, *CM.TLI, Legal->getWidestInductionType(),
                             CM);
+      bool IsOrderedRed = CM.useOrderedReductions(RdxDesc);
       if (auto *MulAcc =
-              tryToMatchAndCreateMulAcc(RdxDesc, CurrentLinkI, PreviousLink,
-                                        VecOp, CondOp, CM, CostCtx, Range))
+              VPlanTransforms::tryToMatchAndCreateMulAccumulateReduction(
+                  RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp,
+                  IsOrderedRed, CostCtx, Range))
         RedRecipe = MulAcc;
-      else if (auto *ExtRed = tryToMatchAndCreateExtendedReduction(
-                   RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp, CM,
-                   CostCtx, Range))
+      else if (auto *ExtRed =
+                   VPlanTransforms::tryToMatchAndCreateExtendedReduction(
+                       RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp,
+                       IsOrderedRed, CostCtx, Range))
         RedRecipe = ExtRed;
       else
-        RedRecipe =
-            new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
-                                  CondOp, CM.useOrderedReductions(RdxDesc));
+        RedRecipe = new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
+                                          VecOp, CondOp, IsOrderedRed);
       // Append the recipe to the end of the VPBasicBlock because we need to
       // ensure that it comes after all of it's inputs, including CondOp.
       // Note that this transformation may leave over dead recipes (including
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index bdcdafd7181d68..0db2b7344ba91b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -859,7 +859,7 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPInstructionSC:
     case VPRecipeBase::VPReductionEVLSC:
     case VPRecipeBase::VPReductionSC:
-    case VPRecipeBase::VPMulAccSC:
+    case VPRecipeBase::VPMulAccumulateReductionSC:
     case VPRecipeBase::VPExtendedReductionSC:
     case VPRecipeBase::VPReplicateSC:
     case VPRecipeBase::VPScalarIVStepsSC:
@@ -1077,7 +1077,7 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
            R->getVPDefID() == VPRecipeBase::VPReverseVectorPointerSC ||
            R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
            R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
-           R->getVPDefID() == VPRecipeBase::VPMulAccSC;
+           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
   }
 
   static inline bool classof(const VPUser *U) {
@@ -2693,7 +2693,7 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
     return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
            R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
            R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
-           R->getVPDefID() == VPRecipeBase::VPMulAccSC;
+           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
   }
 
   static inline bool classof(const VPUser *U) {
@@ -2845,7 +2845,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// multiplication into a scalar value, and adding the result to a chain. This
 /// recipe is abstract and needs to be lowered to concrete recipes before
 /// codegen. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
-class VPMulAccRecipe : public VPReductionRecipe {
+class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
   Instruction::CastOps ExtOp;
 
   // Non-neg flag for the extend instruction.
@@ -2862,11 +2862,11 @@ class VPMulAccRecipe : public VPReductionRecipe {
   bool IsExtended = false;
 
 protected:
-  VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, bool NUW, bool NSW, DebugLoc MulDL,
-                 Instruction::CastOps ExtOp, bool IsNonNeg, DebugLoc Ext0DL,
-                 DebugLoc Ext1DL, VPValue *ChainOp, VPValue *VecOp0,
-                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered)
+  VPMulAccumulateReductionRecipe(
+      const unsigned char SC, const RecurrenceDescriptor &R, Instruction *RedI,
+      bool NUW, bool NSW, DebugLoc MulDL, Instruction::CastOps ExtOp,
+      bool IsNonNeg, DebugLoc Ext0DL, DebugLoc Ext1DL, VPValue *ChainOp,
+      VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered, NUW, NSW, MulDL),
         ExtOp(ExtOp), IsNonNeg(IsNonNeg), Ext0DL(Ext0DL), Ext1DL(Ext1DL),
@@ -2876,10 +2876,12 @@ class VPMulAccRecipe : public VPReductionRecipe {
     IsExtended = true;
   }
 
-  VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, bool NUW, bool NSW, DebugLoc MulDL,
-                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
-                 VPValue *CondOp, bool IsOrdered)
+  VPMulAccumulateReductionRecipe(const unsigned char SC,
+                                 const RecurrenceDescriptor &R,
+                                 Instruction *RedI, bool NUW, bool NSW,
+                                 DebugLoc MulDL, VPValue *ChainOp,
+                                 VPValue *VecOp0, VPValue *VecOp1,
+                                 VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered, NUW, NSW, MulDL),
         RedDL(RedI->getDebugLoc()) {
@@ -2888,36 +2890,43 @@ class VPMulAccRecipe : public VPReductionRecipe {
   }
 
 public:
-  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
-                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
-                 VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
-                 VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->hasNoUnsignedWrap(),
-                       Mul->hasNoSignedWrap(), Mul->getDebugLoc(),
-                       Ext0->getOpcode(), Ext0->isNonNeg(), Ext0->getDebugLoc(),
-                       Ext1->getDebugLoc(), ChainOp, Ext0->getOperand(0),
-                       Ext1->getOperand(0), CondOp, IsOrdered) {}
-
-  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
-                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
-                 VPWidenRecipe *Mul)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->hasNoUnsignedWrap(),
-                       Mul->hasNoSignedWrap(), Mul->getDebugLoc(), ChainOp,
-                       Mul->getOperand(0), Mul->getOperand(1), CondOp,
-                       IsOrdered) {}
-
-  ~VPMulAccRecipe() override = default;
-
-  VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
-
-  VP_CLASSOF_IMPL(VPDef::VPMulAccSC);
+  VPMulAccumulateReductionRecipe(const RecurrenceDescriptor &R,
+                                 Instruction *RedI, VPValue *ChainOp,
+                                 VPValue *CondOp, bool IsOrdered,
+                                 VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+                                 VPWidenCastRecipe *Ext1)
+      : VPMulAccumulateReductionRecipe(
+            VPDef::VPMulAccumulateReductionSC, R, RedI,
+            Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap(),
+            Mul->getDebugLoc(), Ext0->getOpcode(), Ext0->isNonNeg(),
+            Ext0->getDebugLoc(), Ext1->getDebugLoc(), ChainOp,
+            Ext0->getOperand(0), Ext1->getOperand(0), CondOp, IsOrdered) {}
+
+  VPMulAccumulateReductionRecipe(const RecurrenceDescriptor &R,
+                                 Instruction *RedI, VPValue *ChainOp,
+                                 VPValue *CondOp, bool IsOrdered,
+                                 VPWidenRecipe *Mul)
+      : VPMulAccumulateReductionRecipe(
+            VPDef::VPMulAccumulateReductionSC, R, RedI,
+            Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap(),
+            Mul->getDebugLoc(), ChainOp, Mul->getOperand(0), Mul->getOperand(1),
+            CondOp, IsOrdered) {}
+
+  ~VPMulAccumulateReductionRecipe() override = default;
+
+  VPMulAccumulateReductionRecipe *clone() override {
+    llvm_unreachable("Not implement yet");
+  }
+
+  VP_CLASSOF_IMPL(VPDef::VPMulAccumulateReductionSC);
 
   void execute(VPTransformState &State) override {
-    llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
+    llvm_unreachable("VPMulAccumulateReductionRecipe should transform to "
+                     "VPWidenCastRecipe + "
                      "VPWidenRecipe + VPReductionRecipe before execution");
   }
 
-  /// Return the cost of VPMulAccRecipe.
+  /// Return the cost of VPMulAccumulateReductionRecipe.
   InstructionCost computeCost(ElementCount VF,
                               VPCostContext &Ctx) const override;
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ae268ea2f96f90..950378ddabed7f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2232,7 +2232,8 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
                                           RdxDesc.getFastMathFlags(), CostKind);
 }
 
-InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
+InstructionCost
+VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
   Type *RedTy = Ctx.Types.inferScalarType(this);
   auto *SrcVecTy =
@@ -2309,8 +2310,8 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
          "outside of loop)";
 }
 
-void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
-                           VPSlotTracker &SlotTracker) const {
+void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+                                           VPSlotTracker &SlotTracker) const {
   const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
   Type *RedTy = RdxDesc.getRecurrenceType();
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 4e15ec09cd4de0..f9c02b338efb55 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -27,6 +27,7 @@
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
 
 using namespace llvm;
 
@@ -1483,7 +1484,9 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) {
                   return nullptr;
                 return new VPWidenEVLRecipe(*W, EVL);
               })
-              .Case<VPExtendedReductionRecipe, VPMulAccRecipe>(
+              // TODO: Implement EVL verstion of following abstract reduction
+              // recipes.
+              .Case<VPExtendedReductionRecipe, VPMulAccumulateReductionRecipe>(
                   [](const auto *R) { return nullptr; })
               .Case<VPReductionRecipe>([&](VPReductionRecipe *Red) {
                 VPValue *NewMask = GetNewMask(Red->getCondOp());
@@ -1873,9 +1876,11 @@ static void expandVPExtendedReduction(VPExtendedReductionRecipe *ExtRed) {
   ExtRed->eraseFromParent();
 }
 
-// Expand VPMulAccRecipe to VPWidenRecipe (mul) + VPReductionRecipe (reduce.add)
+// Expand VPMulAccumulateReductionRecipe to VPWidenRecipe (mul) +
+// VPReductionRecipe (reduce.add)
 // + VPWidenCastRecipe (optional).
-static void expandVPMulAcc(VPMulAccRecipe *MulAcc) {
+static void
+expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
   Type *RedTy = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
 
   // Generate inner VPWidenCastRecipes if necessary.
@@ -1946,14 +1951,168 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan) {
       PhiR->eraseFromParent();
     }
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
-      if (!isa<VPExtendedReductionRecipe, VPMulAccRecipe>(&R))
+      if (!isa<VPExtendedReductionRecipe, VPMulAccumulateReductionRecipe>(&R))
         continue;
       if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
         expandVPExtendedReduction(ExtRed);
       }
-      if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
-        expandVPMulAcc(MulAcc);
+      if (auto *MulAcc = dyn_cast<VPMulAccumulateReductionRecipe>(&R)) {
+        expandVPMulAccumulateReduction(MulAcc);
       }
     }
   }
 }
+
+VPExtendedReductionRecipe *
+VPlanTransforms::tryToMatchAndCreateExtendedReduction(
+    const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+    VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp, bool IsOrderedRed,
+    VPCostContext &Ctx, VFRange &Range) {
+  using namespace VPlanPatternMatch;
+
+  VPValue *A;
+  Type *RedTy = RdxDesc.getRecurrenceType();
+
+  // Test if using extended-reduction is beneficial and clamp the range.
+  auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
+                                             Type *SrcTy) -> bool {
+    return LoopVectorizationPlanner::getDecisionAndClampRange(
+        [&](ElementCount VF) {
+          auto *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+          auto *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
+          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+          InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
+              Opcode, isZExt, RedTy, SrcVecTy, RdxDesc.getFastMathFlags(),
+              CostKind);
+          InstructionCost ExtCost =
+              cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
+          RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+          InstructionCost RedCost;
+          if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+            Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+            RedCost = Ctx.TTI.getMinMaxReductionCost(
+                Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+          } else {
+            RedCost = Ctx.TTI.getArithmeticReductionCost(
+                Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+          }
+          return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
+        },
+        Range);
+  };
+
+  // Matched reduce(ext)).
+  if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
+    if (!IsExtendedRedValidAndClampRange(
+            RdxDesc.getOpcode(),
+            cast<VPWidenCastRecipe>(VecOp)->getOpcode() ==
+                Instruction::CastOps::ZExt,
+            Ctx.Types.inferScalarType(A)))
+      return nullptr;
+    return new VPExtendedReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
+                                         cast<VPWidenCastRecipe>(VecOp), CondOp,
+                                         IsOrderedRed);
+  }
+  return nullptr;
+}
+
+VPMulAccumulateReductionRecipe *
+VPlanTransforms::tryToMatchAndCreateMulAccumulateReduction(
+    const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+    VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp, bool IsOrderedRed,
+    VPCostContext &Ctx, VFRange &Range) {
+  using namespace VPlanPatternMatch;
+
+  VPValue *A, *B;
+  Type *RedTy = RdxDesc.getRecurrenceType();
+
+  // Test if using mulutiply-accumulate-reduction is beneficial and clamp the
+  // range.
+  auto IsMulAccValidAndClampRange =
+      [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+          VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
+    return LoopVectorizationPlanner::getDecisionAndClampRange(
+        [&](ElementCount VF) {
+          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+          Type *SrcTy =
+              Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
+          auto *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+          auto *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
+          InstructionCost MulAccCost =
+              Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
+          InstructionCost MulCost = Mul->computeCost(VF, Ctx);
+          InstructionCost RedCost = Ctx.TTI.getArithmeticReductionCost(
+              Instruction::Add, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+          InstructionCost ExtCost = 0;
+          if (Ext0)
+            ExtCost += Ext0->computeCost(VF, Ctx);
+          if (Ext1)
+            ExtCost += Ext1->computeCost(VF, Ctx);
+          if (OuterExt)
+            ExtCost += OuterExt->computeCost(VF, Ctx);
+
+          return MulAccCost.isValid() &&
+                 MulAccCost < ExtCost + MulCost + RedCost;
+        },
+        Range);
+  };
+
+  if (RdxDesc.getOpcode() != Instruction::Add)
+    return nullptr;
+
+  // Try to match reduce.add(mul(...))
+  if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+    auto *RecipeA =
+        dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+    auto *RecipeB =
+        dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+    auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
+
+    // Matched reduce.add(mul(ext, ext))
+    if (RecipeA && RecipeB &&
+        (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) &&
+        match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+        match(RecipeB, m_ZExtOrSExt(m_VPValue()))) {
+
+      // Only create MulAccRecipe if the cost is valid.
+      if (!IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
+                                          Instruction::CastOps::ZExt,
+                                      Mul, RecipeA, RecipeB, nullptr))
+        return nullptr;
+
+      return new VPMulAccumulateReductionRecipe(
+          RdxDesc, CurrentLinkI, PreviousLink, CondOp, IsOrderedRed, Mul,
+          RecipeA, RecipeB);
+    } else {
+      // Matched reduce.add(mul)
+      if (!IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
+        return nullptr;
+
+      return new VPMulAccumulateReductionRecipe(
+          RdxDesc, CurrentLinkI, PreviousLink, CondOp, IsOrderedRed, Mul);
+    }
+  } else if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
+                                             m_ZExtOrSExt(m_VPValue()))))) {
+    // Matched reduce.add(ext(mul(ext(A), ext(B))))
+    // All extend instructions must have same opcode or A == B
+    // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
+    auto *Ext = cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+    auto *Mul = cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+    auto *Ext0 =
+        cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
+    auto *Ext1 =
+        cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
+    if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
+        Ext0->getOpcode() == Ext1->getOpcode()) {
+      if (!IsMulAccValidAndClampRange(Ext0->getOpcode() ==
+                                          Instruction::CastOps::ZExt,
+                                      Mul, Ext0, Ext1, Ext))
+        return nullptr;
+
+      return new VPMulAccumulateReductionRecipe(RdxDesc, CurrentLinkI,
+                                                PreviousLink, CondOp,
+                                                IsOrderedRed, Mul, Ext0, Ext1);
+    }
+  }
+  return nullptr;
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 9cf314a6a9f447..376b4e7a61cb76 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -126,6 +126,31 @@ struct VPlanTransforms {
 
   /// Lower abstract recipes to concrete ones, that can be codegen'd.
   static void convertToConcreteRecipes(VPlan &Plan);
+
+  /// This function try to match following pattern to create
+  /// VPExtendedReductionRecipe and clamp the \p Range if it is beneficial and
+  /// valid. The created VPExtendedReductionRecipe will lower to concrete before
+  /// executeion.
+  ///   reduce(ext(...)).
+  static VPExtendedReductionRecipe *tryToMatchAndCreateExtendedReduction(
+      const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+      VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp, bool IsOrderedRed,
+      VPCostContext &Ctx, VFRange &Range);
+
+  /// This function try to match following pattern to create
+  /// VPMulAccumulateReductionRecipe and clamp the \p Range if it is beneficial
+  /// and valid. The created VPMulAccumulateReduction will lower to concrete
+  /// before executeion.
+  ///   reduce.add(mul(...)),
+  ///   reduce.add(mul(ext(A), ext(B))),
+  ///   reduce.add(ext(mul(ext(A), ext(B)))).
+  static VPMulAccumulateReductionRecipe *
+  tryToMatchAndCreateMulAccumulateReduction(const RecurrenceDescriptor &RdxDesc,
+                                            Instruction *CurrentLinkI,
+                                            VPValue *PreviousLink,
+                                            VPValue *VecOp, VPValue *CondOp,
+                                            bool IsOrderedRed,
+                                            VPCostContext &Ctx, VFRange &Range);
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 2c922d8d39c927..e5ae7666e4df35 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -329,7 +329,7 @@ class VPDef {
     VPInterleaveSC,
     VPReductionEVLSC,
     VPReductionSC,
-    VPMulAccSC,
+    VPMulAccumulateReductionSC,
     VPExtendedReductionSC,
     VPReplicateSC,
     VPScalarCastSC,

>From bffcac5d9977a362ee77c21ab3ebc06de954da1a Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 22 Dec 2024 20:29:54 -0800
Subject: [PATCH 23/26] Implement clone() and add some docs.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  29 +++--
 llvm/lib/Transforms/Vectorize/VPlan.h         | 106 +++++++++---------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |   2 +-
 3 files changed, 71 insertions(+), 66 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f64d2ebaedcfb2..04068a88ef9d02 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9466,19 +9466,26 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       VPCostContext CostCtx(CM.TTI, *CM.TLI, Legal->getWidestInductionType(),
                             CM);
       bool IsOrderedRed = CM.useOrderedReductions(RdxDesc);
-      if (auto *MulAcc =
-              VPlanTransforms::tryToMatchAndCreateMulAccumulateReduction(
-                  RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp,
-                  IsOrderedRed, CostCtx, Range))
-        RedRecipe = MulAcc;
-      else if (auto *ExtRed =
-                   VPlanTransforms::tryToMatchAndCreateExtendedReduction(
-                       RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp,
-                       IsOrderedRed, CostCtx, Range))
-        RedRecipe = ExtRed;
-      else
+      // TODO: Remove EVL check when we support EVL version of
+      // VPExtendedReductionRecipe and VPMulAccumulateReductionRecipe.
+      if (CM.foldTailWithEVL()) {
         RedRecipe = new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
                                           VecOp, CondOp, IsOrderedRed);
+      } else {
+        if (auto *MulAcc =
+                VPlanTransforms::tryToMatchAndCreateMulAccumulateReduction(
+                    RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp,
+                    IsOrderedRed, CostCtx, Range))
+          RedRecipe = MulAcc;
+        else if (auto *ExtRed =
+                     VPlanTransforms::tryToMatchAndCreateExtendedReduction(
+                         RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp,
+                         IsOrderedRed, CostCtx, Range))
+          RedRecipe = ExtRed;
+        else
+          RedRecipe = new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
+                                            VecOp, CondOp, IsOrderedRed);
+      }
       // Append the recipe to the end of the VPBasicBlock because we need to
       // ensure that it comes after all of it's inputs, including CondOp.
       // Note that this transformation may leave over dead recipes (including
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 0db2b7344ba91b..7e7a3464d37b73 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2778,32 +2778,34 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 /// concrete recipes before codegen. The Operands are {ChainOp, VecOp,
 /// [Condition]}.
 class VPExtendedReductionRecipe : public VPReductionRecipe {
+  /// Opcode for the extend recipe.
   Instruction::CastOps ExtOp;
+  /// Debug location of reduction instruction.
   DebugLoc RedDL;
 
 protected:
-  VPExtendedReductionRecipe(const unsigned char SC,
-                            const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, DebugLoc ExtDL,
-                            bool IsNonNeg, VPValue *ChainOp, VPValue *VecOp,
-                            VPValue *CondOp, bool IsOrdered)
-      : VPReductionRecipe(SC, R, ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
-                          IsOrdered, IsNonNeg, ExtDL),
-        ExtOp(ExtOp), RedDL(RedI->getDebugLoc()) {}
-
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                             VPValue *ChainOp, VPWidenCastRecipe *Ext,
                             VPValue *CondOp, bool IsOrdered)
-      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
-                                  Ext->getOpcode(), Ext->getDebugLoc(),
-                                  Ext->isNonNeg(), ChainOp, Ext->getOperand(0),
-                                  CondOp, IsOrdered) {}
+      : VPReductionRecipe(VPDef::VPExtendedReductionSC, R,
+                          ArrayRef<VPValue *>({ChainOp, Ext->getOperand(0)}),
+                          CondOp, IsOrdered, Ext->isNonNeg(),
+                          Ext->getDebugLoc()),
+        ExtOp(Ext->getOpcode()), RedDL(RedI->getDebugLoc()) {}
+
+  VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed)
+      : VPReductionRecipe(
+            VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceDescriptor(),
+            ArrayRef<VPValue *>({ExtRed->getChainOp(), ExtRed->getVecOp()}),
+            ExtRed->getCondOp(), ExtRed->isOrdered(), ExtRed->isNonNeg(),
+            ExtRed->getExtDebugLoc()),
+        ExtOp(ExtRed->getExtOpcode()), RedDL(ExtRed->getRedDebugLoc()) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
   VPExtendedReductionRecipe *clone() override {
-    llvm_unreachable("Not implement yet");
+    return new VPExtendedReductionRecipe(this);
   }
 
   VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC);
@@ -2861,61 +2863,57 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
   // Is this mul-acc recipe contains extend recipes?
   bool IsExtended = false;
 
-protected:
-  VPMulAccumulateReductionRecipe(
-      const unsigned char SC, const RecurrenceDescriptor &R, Instruction *RedI,
-      bool NUW, bool NSW, DebugLoc MulDL, Instruction::CastOps ExtOp,
-      bool IsNonNeg, DebugLoc Ext0DL, DebugLoc Ext1DL, VPValue *ChainOp,
-      VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp, bool IsOrdered)
-      : VPReductionRecipe(SC, R, ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
-                          CondOp, IsOrdered, NUW, NSW, MulDL),
-        ExtOp(ExtOp), IsNonNeg(IsNonNeg), Ext0DL(Ext0DL), Ext1DL(Ext1DL),
-        RedDL(RedI->getDebugLoc()) {
-    assert(R.getOpcode() == Instruction::Add &&
-           "The reduction instruction in MulAccRecipe must be Add");
-    IsExtended = true;
-  }
-
-  VPMulAccumulateReductionRecipe(const unsigned char SC,
-                                 const RecurrenceDescriptor &R,
-                                 Instruction *RedI, bool NUW, bool NSW,
-                                 DebugLoc MulDL, VPValue *ChainOp,
-                                 VPValue *VecOp0, VPValue *VecOp1,
-                                 VPValue *CondOp, bool IsOrdered)
-      : VPReductionRecipe(SC, R, ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
-                          CondOp, IsOrdered, NUW, NSW, MulDL),
-        RedDL(RedI->getDebugLoc()) {
-    assert(R.getOpcode() == Instruction::Add &&
-           "The reduction instruction in MulAccRecipe must be Add");
-  }
-
 public:
   VPMulAccumulateReductionRecipe(const RecurrenceDescriptor &R,
                                  Instruction *RedI, VPValue *ChainOp,
                                  VPValue *CondOp, bool IsOrdered,
                                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                                  VPWidenCastRecipe *Ext1)
-      : VPMulAccumulateReductionRecipe(
-            VPDef::VPMulAccumulateReductionSC, R, RedI,
-            Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap(),
-            Mul->getDebugLoc(), Ext0->getOpcode(), Ext0->isNonNeg(),
-            Ext0->getDebugLoc(), Ext1->getDebugLoc(), ChainOp,
-            Ext0->getOperand(0), Ext1->getOperand(0), CondOp, IsOrdered) {}
+      : VPReductionRecipe(VPDef::VPMulAccumulateReductionSC, R,
+                          ArrayRef<VPValue *>({ChainOp, Ext0->getOperand(0),
+                                               Ext1->getOperand(0)}),
+                          CondOp, IsOrdered, Mul->hasNoUnsignedWrap(),
+                          Mul->hasNoSignedWrap(), Mul->getDebugLoc()),
+        ExtOp(Ext0->getOpcode()), IsNonNeg(Ext0->isNonNeg()),
+        Ext0DL(Ext0->getDebugLoc()), Ext1DL(Ext1->getDebugLoc()),
+        RedDL(RedI->getDebugLoc()) {
+    assert(R.getOpcode() == Instruction::Add &&
+           "The reduction instruction in MulAccRecipe must be Add");
+    IsExtended = true;
+  }
 
   VPMulAccumulateReductionRecipe(const RecurrenceDescriptor &R,
                                  Instruction *RedI, VPValue *ChainOp,
                                  VPValue *CondOp, bool IsOrdered,
                                  VPWidenRecipe *Mul)
-      : VPMulAccumulateReductionRecipe(
-            VPDef::VPMulAccumulateReductionSC, R, RedI,
-            Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap(),
-            Mul->getDebugLoc(), ChainOp, Mul->getOperand(0), Mul->getOperand(1),
-            CondOp, IsOrdered) {}
+      : VPReductionRecipe(VPDef::VPMulAccumulateReductionSC, R,
+                          ArrayRef<VPValue *>({ChainOp, Mul->getOperand(0),
+                                               Mul->getOperand(1)}),
+                          CondOp, IsOrdered, Mul->hasNoUnsignedWrap(),
+                          Mul->hasNoSignedWrap(), Mul->getDebugLoc()),
+        RedDL(RedI->getDebugLoc()) {
+    assert(R.getOpcode() == Instruction::Add &&
+           "The reduction instruction in MulAccRecipe must be Add");
+  }
+
+  /// Constructor for cloning VPMulAccumulateReductionRecipe.
+  VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
+      : VPReductionRecipe(
+            VPDef::VPMulAccumulateReductionSC,
+            MulAcc->getRecurrenceDescriptor(),
+            ArrayRef<VPValue *>({MulAcc->getChainOp(), MulAcc->getOperand(1),
+                                 MulAcc->getOperand(2)}),
+            MulAcc->getCondOp(), MulAcc->isOrdered(),
+            MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap(),
+            MulAcc->getMulDebugLoc()),
+        ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
+        Ext0DL(MulAcc->getExt0DebugLoc()), Ext1DL(MulAcc->getExt1DebugLoc()),
+        RedDL(MulAcc->getRedDebugLoc()), IsExtended(MulAcc->isExtended()) {}
 
   ~VPMulAccumulateReductionRecipe() override = default;
 
   VPMulAccumulateReductionRecipe *clone() override {
-    llvm_unreachable("Not implement yet");
+    return new VPMulAccumulateReductionRecipe(this);
   }
 
   VP_CLASSOF_IMPL(VPDef::VPMulAccumulateReductionSC);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 950378ddabed7f..5e8d2da1f7e822 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1504,7 +1504,7 @@ void VPWidenCastRecipe::execute(VPTransformState &State) {
     setFlags(CastOp);
 }
 
-// Computes the CastContextHint from a recipes that may access memory.
+/// Computes the CastContextHint for a recipe.
 static TTI::CastContextHint computeCCH(const VPRecipeBase *R, ElementCount VF) {
   if (VF.isScalar())
     return TTI::CastContextHint::Normal;

>From da705f1b9d3b1e41c31ca850ef6e255dd92ae3ab Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 22 Dec 2024 21:03:59 -0800
Subject: [PATCH 24/26] Update comments.

---
 llvm/lib/Transforms/Vectorize/VPlan.h | 28 +++++++++++++++------------
 1 file changed, 16 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 7e7a3464d37b73..746007103c573c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2778,12 +2778,11 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 /// concrete recipes before codegen. The Operands are {ChainOp, VecOp,
 /// [Condition]}.
 class VPExtendedReductionRecipe : public VPReductionRecipe {
-  /// Opcode for the extend recipe.
+  /// Opcode of the extend recipe will be lowered to.
   Instruction::CastOps ExtOp;
-  /// Debug location of reduction instruction.
+  /// Debug location of reduction recipe will be lowered to.
   DebugLoc RedDL;
 
-protected:
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                             VPValue *ChainOp, VPWidenCastRecipe *Ext,
@@ -2794,6 +2793,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
                           Ext->getDebugLoc()),
         ExtOp(Ext->getOpcode()), RedDL(RedI->getDebugLoc()) {}
 
+  /// Contructor for cloning VPExtendedReductionRecipe.
   VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed)
       : VPReductionRecipe(
             VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceDescriptor(),
@@ -2842,25 +2842,26 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
   DebugLoc getRedDebugLoc() const { return RedDL; }
 };
 
-/// A recipe to represent inloop MulAccReduction operations, performing a
+/// A recipe to represent inloop MulAccumulateReduction operations, performing a
 /// reduction.add on the result of vector operands (might be extended)
 /// multiplication into a scalar value, and adding the result to a chain. This
 /// recipe is abstract and needs to be lowered to concrete recipes before
 /// codegen. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
+  /// Opcode of the extend recipe.
   Instruction::CastOps ExtOp;
 
-  // Non-neg flag for the extend instruction.
+  /// Non-neg flag of the extend recipe.
   bool IsNonNeg;
 
-  // Debug location of extend instruction.
+  /// Debug location of extend recipes will be lowered to.
   DebugLoc Ext0DL;
   DebugLoc Ext1DL;
 
-  // Debug location of reduction instruction.
+  /// Debug location of the reduction recipe will be lowered to.
   DebugLoc RedDL;
 
-  // Is this mul-acc recipe contains extend recipes?
+  /// Is this multiply-accumulate-reduction recipe contains extend?
   bool IsExtended = false;
 
 public:
@@ -2878,7 +2879,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
         Ext0DL(Ext0->getDebugLoc()), Ext1DL(Ext1->getDebugLoc()),
         RedDL(RedI->getDebugLoc()) {
     assert(R.getOpcode() == Instruction::Add &&
-           "The reduction instruction in MulAccRecipe must be Add");
+           "The reduction instruction in MulAccumulateteReductionRecipe must "
+           "be Add");
     IsExtended = true;
   }
 
@@ -2893,7 +2895,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
                           Mul->hasNoSignedWrap(), Mul->getDebugLoc()),
         RedDL(RedI->getDebugLoc()) {
     assert(R.getOpcode() == Instruction::Add &&
-           "The reduction instruction in MulAccRecipe must be Add");
+           "The reduction instruction in MulAccumulateReductionRecipe must be "
+           "Add");
   }
 
   /// Constructor for cloning VPMulAccumulateReductionRecipe.
@@ -2901,8 +2904,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
       : VPReductionRecipe(
             VPDef::VPMulAccumulateReductionSC,
             MulAcc->getRecurrenceDescriptor(),
-            ArrayRef<VPValue *>({MulAcc->getChainOp(), MulAcc->getOperand(1),
-                                 MulAcc->getOperand(2)}),
+            ArrayRef<VPValue *>({MulAcc->getChainOp(), MulAcc->getVecOp0(),
+                                 MulAcc->getVecOp1()}),
             MulAcc->getCondOp(), MulAcc->isOrdered(),
             MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap(),
             MulAcc->getMulDebugLoc()),
@@ -2944,6 +2947,7 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
   /// Return if the operands of mul instruction come from same extend.
   bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
 
+  /// Return the opcode of the underlying extend.
   Instruction::CastOps getExtOpcode() const { return ExtOp; }
 
   /// Return if the extend opcode is ZExt.

>From 1dc279efea6f86986791fae807b7d4b048d4b9d9 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 22 Dec 2024 21:32:30 -0800
Subject: [PATCH 25/26] fix-ReductionEVLRecipe query underlyingInstr().

---
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  2 +-
 .../LoopVectorize/RISCV/inloop-reduction.ll   | 34 +++++++------------
 .../RISCV/vplan-vp-intrinsics-reduction.ll    |  2 +-
 3 files changed, 14 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 5e8d2da1f7e822..0fad6d0ff04ee0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2273,7 +2273,7 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " = ";
   getChainOp()->printAsOperand(O, SlotTracker);
   O << " +";
-  if (isa<FPMathOperator>(getUnderlyingInstr()))
+  if (isa_and_nonnull<FPMathOperator>(getUnderlyingValue()))
     O << getUnderlyingInstr()->getFastMathFlags();
   O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
   getVecOp()->printAsOperand(O, SlotTracker);
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
index 5e8e2111903522..14818199072c28 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
@@ -187,36 +187,26 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
 ; IF-EVL-INLOOP-NEXT:    [[N_RND_UP:%.*]] = add i32 [[N]], [[TMP2]]
 ; IF-EVL-INLOOP-NEXT:    [[N_MOD_VF:%.*]] = urem i32 [[N_RND_UP]], [[TMP1]]
 ; IF-EVL-INLOOP-NEXT:    [[N_VEC:%.*]] = sub i32 [[N_RND_UP]], [[N_MOD_VF]]
-; IF-EVL-INLOOP-NEXT:    [[TRIP_COUNT_MINUS_1:%.*]] = sub i32 [[N]], 1
 ; IF-EVL-INLOOP-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vscale.i32()
 ; IF-EVL-INLOOP-NEXT:    [[TMP4:%.*]] = mul i32 [[TMP3]], 8
-; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[TRIP_COUNT_MINUS_1]], i64 0
-; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLAT2:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT1]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
 ; IF-EVL-INLOOP-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; IF-EVL-INLOOP:       vector.body:
 ; IF-EVL-INLOOP-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; IF-EVL-INLOOP-NEXT:    [[EVL_BASED_IV:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_EVL_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; IF-EVL-INLOOP-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP11:%.*]], [[VECTOR_BODY]] ]
-; IF-EVL-INLOOP-NEXT:    [[AVL:%.*]] = sub i32 [[N]], [[EVL_BASED_IV]]
-; IF-EVL-INLOOP-NEXT:    [[TMP5:%.*]] = call i32 @llvm.experimental.get.vector.length.i32(i32 [[AVL]], i32 8, i1 true)
-; IF-EVL-INLOOP-NEXT:    [[TMP6:%.*]] = add i32 [[EVL_BASED_IV]], 0
-; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[EVL_BASED_IV]], i64 0
-; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
-; IF-EVL-INLOOP-NEXT:    [[TMP7:%.*]] = call <vscale x 8 x i32> @llvm.stepvector.nxv8i32()
-; IF-EVL-INLOOP-NEXT:    [[TMP8:%.*]] = add <vscale x 8 x i32> zeroinitializer, [[TMP7]]
-; IF-EVL-INLOOP-NEXT:    [[VEC_IV:%.*]] = add <vscale x 8 x i32> [[BROADCAST_SPLAT]], [[TMP8]]
-; IF-EVL-INLOOP-NEXT:    [[TMP9:%.*]] = icmp ule <vscale x 8 x i32> [[VEC_IV]], [[BROADCAST_SPLAT2]]
-; IF-EVL-INLOOP-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP6]]
-; IF-EVL-INLOOP-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i16, ptr [[TMP10]], i32 0
-; IF-EVL-INLOOP-NEXT:    [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP15]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP5]])
-; IF-EVL-INLOOP-NEXT:    [[TMP12:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
-; IF-EVL-INLOOP-NEXT:    [[TMP17:%.*]] = select <vscale x 8 x i1> [[TMP9]], <vscale x 8 x i32> [[TMP12]], <vscale x 8 x i32> zeroinitializer
-; IF-EVL-INLOOP-NEXT:    [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP17]])
-; IF-EVL-INLOOP-NEXT:    [[TMP11]] = add i32 [[TMP14]], [[VEC_PHI]]
-; IF-EVL-INLOOP-NEXT:    [[INDEX_EVL_NEXT]] = add nuw i32 [[TMP5]], [[EVL_BASED_IV]]
+; IF-EVL-INLOOP-NEXT:    [[TMP5:%.*]] = sub i32 [[N]], [[EVL_BASED_IV]]
+; IF-EVL-INLOOP-NEXT:    [[TMP6:%.*]] = call i32 @llvm.experimental.get.vector.length.i32(i32 [[TMP5]], i32 8, i1 true)
+; IF-EVL-INLOOP-NEXT:    [[TMP7:%.*]] = add i32 [[EVL_BASED_IV]], 0
+; IF-EVL-INLOOP-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP7]]
+; IF-EVL-INLOOP-NEXT:    [[TMP9:%.*]] = getelementptr inbounds i16, ptr [[TMP8]], i32 0
+; IF-EVL-INLOOP-NEXT:    [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP9]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP6]])
+; IF-EVL-INLOOP-NEXT:    [[TMP14:%.*]] = call <vscale x 8 x i32> @llvm.vp.sext.nxv8i32.nxv8i16(<vscale x 8 x i16> [[VP_OP_LOAD]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP6]])
+; IF-EVL-INLOOP-NEXT:    [[TMP10:%.*]] = call i32 @llvm.vp.reduce.add.nxv8i32(i32 0, <vscale x 8 x i32> [[TMP14]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP6]])
+; IF-EVL-INLOOP-NEXT:    [[TMP11]] = add i32 [[TMP10]], [[VEC_PHI]]
+; IF-EVL-INLOOP-NEXT:    [[INDEX_EVL_NEXT]] = add nuw i32 [[TMP6]], [[EVL_BASED_IV]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], [[TMP4]]
-; IF-EVL-INLOOP-NEXT:    [[TMP16:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; IF-EVL-INLOOP-NEXT:    br i1 [[TMP16]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; IF-EVL-INLOOP-NEXT:    [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
+; IF-EVL-INLOOP-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; IF-EVL-INLOOP:       middle.block:
 ; IF-EVL-INLOOP-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ; IF-EVL-INLOOP:       scalar.ph:
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-intrinsics-reduction.ll b/llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-intrinsics-reduction.ll
index a474d926a6303f..443b52173fdc85 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-intrinsics-reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-intrinsics-reduction.ll
@@ -29,7 +29,7 @@ define i32 @reduction(ptr %a, i64 %n, i32 %start) {
 ; IF-EVL-OUTLOOP-NEXT: Live-in vp<[[VTC:%[0-9]+]]> = vector-trip-count
 ; IF-EVL-OUTLOOP-NEXT: Live-in ir<%n> = original trip-count
 ; IF-EVL-OUTLOOP-EMPTY:
-; IF-EVL-OUTLOOP:      vector.ph:
+; IF-EVL-OUTLOOP-NEXT: vector.ph:
 ; IF-EVL-OUTLOOP-NEXT: Successor(s): vector loop
 ; IF-EVL-OUTLOOP-EMPTY:
 ; IF-EVL-OUTLOOP-NEXT: <x1> vector loop: {

>From 90f9ffab9a0cae703b81181b51994b02894fe090 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 23 Dec 2024 00:33:24 -0800
Subject: [PATCH 26/26] Update after merge.

---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp      | 2 +-
 llvm/test/Transforms/LoopVectorize/vplan-printing.ll | 9 +++++++++
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e107a536a38e8d..a5042753e22d1e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9561,7 +9561,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       bool IsOrderedRed = CM.useOrderedReductions(RdxDesc);
       // TODO: Remove EVL check when we support EVL version of
       // VPExtendedReductionRecipe and VPMulAccumulateReductionRecipe.
-      if (CM.foldTailWithEVL()) {
+      if (ForceTailFoldingStyle == TailFoldingStyle::DataWithEVL) {
         RedRecipe = new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
                                           VecOp, CondOp, IsOrderedRed);
       } else {
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index bb20702eef9620..e24e19bc6ecb12 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -1190,6 +1190,9 @@ define i64 @print_extended_reduction(ptr nocapture readonly %x, ptr nocapture re
 ; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
 ; CHECK-NEXT: Live-in ir<%n> = original trip-count
 ; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body.preheader>:
+; CHECK-NEXT: Successor(s): vector.ph
+; CHECK-EMPTY:
 ; CHECK-NEXT: vector.ph:
 ; CHECK-NEXT: Successor(s): vector loop
 ; CHECK-EMPTY:
@@ -1262,6 +1265,9 @@ define i64 @print_mulacc(ptr nocapture readonly %x, ptr nocapture readonly %y, i
 ; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
 ; CHECK-NEXT: Live-in ir<%n> = original trip-count
 ; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body.preheader>:
+; CHECK-NEXT: Successor(s): vector.ph
+; CHECK-EMPTY:
 ; CHECK-NEXT: vector.ph:
 ; CHECK-NEXT: Successor(s): vector loop
 ; CHECK-EMPTY:
@@ -1341,6 +1347,9 @@ define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture reado
 ; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
 ; CHECK-NEXT: Live-in ir<%n> = original trip-count
 ; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body.preheader>:
+; CHECK-NEXT: Successor(s): vector.ph
+; CHECK-EMPTY:
 ; CHECK-NEXT: vector.ph:
 ; CHECK-NEXT: Successor(s): vector loop
 ; CHECK-EMPTY:



More information about the llvm-commits mailing list