[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)

Elvis Wang via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 6 21:30:00 PST 2024


https://github.com/ElvisWang123 updated https://github.com/llvm/llvm-project/pull/113903

>From bed08b201fbc5604d9775d6446bae1e85c56422b Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 28 Oct 2024 05:39:35 -0700
Subject: [PATCH 1/6] [VPlan] Impl VPlan-based pattern match for ExtendedRed
 and MulAccRed. NFCI

This patch implement the VPlan-based pattern match for extendedReduction
and MulAccReduction. In above reduction patterns, extened instructions
and mul instruction can fold into reduction instruction and the cost is
free.

We add `FoldedRecipes` in the `VPCostContext` to put recipes that can be
folded into other recipes.

ExtendedReductionPatterns:
    reduce(ext(...))
MulAccReductionPatterns:
    reduce.add(mul(...))
    reduce.add(mul(ext(...), ext(...)))
    reduce.add(ext(mul(...)))
    reduce.add(ext(mul(ext(...), ext(...))))

Ref: Original instruction based implementation:
https://reviews.llvm.org/D93476
---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  45 ------
 llvm/lib/Transforms/Vectorize/VPlan.h         |   2 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 139 ++++++++++++++++--
 3 files changed, 129 insertions(+), 57 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index c07af8519049c4..865dda8f5952cf 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7306,51 +7306,6 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
       Cost += ReductionCost;
       continue;
     }
-
-    const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop);
-    SetVector<Instruction *> ChainOpsAndOperands(ChainOps.begin(),
-                                                 ChainOps.end());
-    auto IsZExtOrSExt = [](const unsigned Opcode) -> bool {
-      return Opcode == Instruction::ZExt || Opcode == Instruction::SExt;
-    };
-    // Also include the operands of instructions in the chain, as the cost-model
-    // may mark extends as free.
-    //
-    // For ARM, some of the instruction can folded into the reducion
-    // instruction. So we need to mark all folded instructions free.
-    // For example: We can fold reduce(mul(ext(A), ext(B))) into one
-    // instruction.
-    for (auto *ChainOp : ChainOps) {
-      for (Value *Op : ChainOp->operands()) {
-        if (auto *I = dyn_cast<Instruction>(Op)) {
-          ChainOpsAndOperands.insert(I);
-          if (I->getOpcode() == Instruction::Mul) {
-            auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
-            auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
-            if (Ext0 && IsZExtOrSExt(Ext0->getOpcode()) && Ext1 &&
-                Ext0->getOpcode() == Ext1->getOpcode()) {
-              ChainOpsAndOperands.insert(Ext0);
-              ChainOpsAndOperands.insert(Ext1);
-            }
-          }
-        }
-      }
-    }
-
-    // Pre-compute the cost for I, if it has a reduction pattern cost.
-    for (Instruction *I : ChainOpsAndOperands) {
-      auto ReductionCost = CM.getReductionPatternCost(
-          I, VF, ToVectorTy(I->getType(), VF), TTI::TCK_RecipThroughput);
-      if (!ReductionCost)
-        continue;
-
-      assert(!CostCtx.SkipCostComputation.contains(I) &&
-             "reduction op visited multiple times");
-      CostCtx.SkipCostComputation.insert(I);
-      LLVM_DEBUG(dbgs() << "Cost of " << ReductionCost << " for VF " << VF
-                        << ":\n in-loop reduction " << *I << "\n");
-      Cost += *ReductionCost;
-    }
   }
 
   // Pre-compute the costs for branches except for the backedge, as the number
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 18f5f13073aa63..b57ea3b1bd143b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -682,6 +682,8 @@ struct VPCostContext {
   LLVMContext &LLVMCtx;
   LoopVectorizationCostModel &CM;
   SmallPtrSet<Instruction *, 8> SkipCostComputation;
+  /// Contains recipes that are folded into other recipes.
+  SmallDenseMap<ElementCount, SmallPtrSet<VPRecipeBase *, 4>, 4> FoldedRecipes;
 
   VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,
                 Type *CanIVTy, LoopVectorizationCostModel &CM)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 6254ea15191819..e3ec72d0aca6ce 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,7 +270,9 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = &WidenMem->getIngredient();
 
   InstructionCost RecipeCost;
-  if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
+  if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
+      (Ctx.FoldedRecipes.contains(VF) &&
+       Ctx.FoldedRecipes.at(VF).contains(this))) {
     RecipeCost = 0;
   } else {
     RecipeCost = computeCost(VF, Ctx);
@@ -2185,30 +2187,143 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  // TODO: Support any-of and in-loop reductions.
+  // TODO: Support any-of reductions.
   assert(
       (!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
        ForceTargetInstructionCost.getNumOccurrences() > 0) &&
       "Any-of reduction not implemented in VPlan-based cost model currently.");
-  assert(
-      (!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
-       ForceTargetInstructionCost.getNumOccurrences() > 0) &&
-      "In-loop reduction not implemented in VPlan-based cost model currently.");
 
   assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
          "Inferred type and recurrence type mismatch.");
 
-  // Cost = Reduction cost + BinOp cost
-  InstructionCost Cost =
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost BaseCost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    return Cost + Ctx.TTI.getMinMaxReductionCost(
-                      Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+    BaseCost += Ctx.TTI.getMinMaxReductionCost(
+        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    BaseCost += Ctx.TTI.getArithmeticReductionCost(
+        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
-  return Cost + Ctx.TTI.getArithmeticReductionCost(
-                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  using namespace llvm::VPlanPatternMatch;
+  auto GetMulAccReductionCost =
+      [&](const VPReductionRecipe *Red) -> InstructionCost {
+    VPValue *A, *B;
+    InstructionCost InnerExt0Cost = 0;
+    InstructionCost InnerExt1Cost = 0;
+    InstructionCost ExtCost = 0;
+    InstructionCost MulCost = 0;
+
+    VectorType *SrcVecTy = VectorTy;
+    Type *InnerExt0Ty;
+    Type *InnerExt1Ty;
+    Type *MaxInnerExtTy;
+    bool IsUnsigned = true;
+    bool HasOuterExt = false;
+
+    auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
+        Red->getVecOp()->getDefiningRecipe());
+    VPRecipeBase *Mul;
+    // Try to match outer extend reduce.add(ext(...))
+    if (Ext && match(Ext, m_ZExtOrSExt(m_VPValue())) &&
+        cast<VPWidenCastRecipe>(Ext)->getNumUsers() == 1) {
+      IsUnsigned =
+          Ext->getOpcode() == Instruction::CastOps::ZExt ? true : false;
+      ExtCost = Ext->computeCost(VF, Ctx);
+      Mul = Ext->getOperand(0)->getDefiningRecipe();
+      HasOuterExt = true;
+    } else {
+      Mul = Red->getVecOp()->getDefiningRecipe();
+    }
+
+    // Match reduce.add(mul())
+    if (Mul && match(Mul, m_Mul(m_VPValue(A), m_VPValue(B))) &&
+        cast<VPWidenRecipe>(Mul)->getNumUsers() == 1) {
+      MulCost = cast<VPWidenRecipe>(Mul)->computeCost(VF, Ctx);
+      auto *InnerExt0 =
+          dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+      auto *InnerExt1 =
+          dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+      bool HasInnerExt = false;
+      // Try to match inner extends.
+      if (InnerExt0 && InnerExt1 &&
+          match(InnerExt0, m_ZExtOrSExt(m_VPValue())) &&
+          match(InnerExt1, m_ZExtOrSExt(m_VPValue())) &&
+          InnerExt0->getOpcode() == InnerExt1->getOpcode() &&
+          (InnerExt0->getNumUsers() > 0 &&
+           !InnerExt0->hasMoreThanOneUniqueUser()) &&
+          (InnerExt1->getNumUsers() > 0 &&
+           !InnerExt1->hasMoreThanOneUniqueUser())) {
+        InnerExt0Cost = InnerExt0->computeCost(VF, Ctx);
+        InnerExt1Cost = InnerExt1->computeCost(VF, Ctx);
+        Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
+        Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
+        Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth() >
+                                      InnerExt1Ty->getIntegerBitWidth()
+                                  ? InnerExt0Ty
+                                  : InnerExt1Ty;
+        SrcVecTy = cast<VectorType>(ToVectorTy(MaxInnerExtTy, VF));
+        IsUnsigned = true;
+        HasInnerExt = true;
+      }
+      InstructionCost MulAccRedCost = Ctx.TTI.getMulAccReductionCost(
+          IsUnsigned, ElementTy, SrcVecTy, CostKind);
+      // Check if folding ext/mul into MulAccReduction is profitable.
+      if (MulAccRedCost.isValid() &&
+          MulAccRedCost <
+              ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
+        if (HasInnerExt) {
+          Ctx.FoldedRecipes[VF].insert(InnerExt0);
+          Ctx.FoldedRecipes[VF].insert(InnerExt1);
+        }
+        Ctx.FoldedRecipes[VF].insert(Mul);
+        if (HasOuterExt)
+          Ctx.FoldedRecipes[VF].insert(Ext);
+        return MulAccRedCost;
+      }
+    }
+    return InstructionCost::getInvalid();
+  };
+
+  // Match reduce(ext(...))
+  auto GetExtendedReductionCost =
+      [&](const VPReductionRecipe *Red) -> InstructionCost {
+    VPValue *VecOp = Red->getVecOp();
+    VPValue *A;
+    if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && VecOp->getNumUsers() == 1) {
+      VPWidenCastRecipe *Ext =
+          cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+      bool IsUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
+      InstructionCost ExtCost = Ext->computeCost(VF, Ctx);
+      auto *ExtVecTy =
+          cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(A), VF));
+      InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
+          Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags(),
+          CostKind);
+      // Check if folding ext into ExtendedReduction is profitable.
+      if (ExtendedRedCost.isValid() && ExtendedRedCost < ExtCost + BaseCost) {
+        Ctx.FoldedRecipes[VF].insert(Ext);
+        return ExtendedRedCost;
+      }
+    }
+    return InstructionCost::getInvalid();
+  };
+
+  // Match MulAccReduction patterns.
+  InstructionCost MulAccCost = GetMulAccReductionCost(this);
+  if (MulAccCost.isValid())
+    return MulAccCost;
+
+  // Match ExtendedReduction patterns.
+  InstructionCost ExtendedCost = GetExtendedReductionCost(this);
+  if (ExtendedCost.isValid())
+    return ExtendedCost;
+
+  // Default cost.
+  return BaseCost;
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)

>From 80ab0a6e2d0a0a16635d7a9c39c2c8bb068df5b3 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 3 Nov 2024 18:55:55 -0800
Subject: [PATCH 2/6] Partially support Extended-reduction.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  24 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         | 217 ++++++++++++++++++
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 114 +++++++--
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  24 ++
 .../Transforms/Vectorize/VPlanTransforms.h    |   3 +
 llvm/lib/Transforms/Vectorize/VPlanValue.h    |   2 +
 6 files changed, 359 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 865dda8f5952cf..e44975cffbb124 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7660,6 +7660,10 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
                              ILV.getOrCreateVectorTripCount(nullptr),
                              CanonicalIVStartValue, State);
 
+  // TODO: Rebase to fhahn's implementation.
+  VPlanTransforms::prepareExecute(BestVPlan);
+  dbgs() << "\n\n print plan\n";
+  BestVPlan.print(dbgs());
   BestVPlan.execute(&State);
 
   // 2.5 Collect reduction resume values.
@@ -9254,6 +9258,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
 // Adjust AnyOf reductions; replace the reduction phi for the selected value
 // with a boolean reduction phi node to check if the condition is true in any
 // iteration. The final value is selected by the final ComputeReductionResult.
+// TODO: Implement VPMulAccHere.
 void LoopVectorizationPlanner::adjustRecipesForReductions(
     VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) {
   using namespace VPlanPatternMatch;
@@ -9372,9 +9377,22 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      VPReductionRecipe *RedRecipe =
-          new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
-                                CondOp, CM.useOrderedReductions(RdxDesc));
+      // VPWidenCastRecipes can folded into VPReductionRecipe
+      VPValue *A;
+      VPSingleDefRecipe *RedRecipe;
+      if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+          !VecOp->hasMoreThanOneUniqueUser()) {
+        RedRecipe = new VPExtendedReductionRecipe(
+            RdxDesc, CurrentLinkI,
+            cast<CastInst>(
+                cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
+            PreviousLink, A, CondOp, CM.useOrderedReductions(RdxDesc),
+            cast<VPWidenCastRecipe>(VecOp)->getResultType());
+      } else {
+        RedRecipe =
+            new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
+                                  CondOp, CM.useOrderedReductions(RdxDesc));
+      }
       // Append the recipe to the end of the VPBasicBlock because we need to
       // ensure that it comes after all of it's inputs, including CondOp.
       // Note that this transformation may leave over dead recipes (including
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index b57ea3b1bd143b..176318ed0c5b1d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -859,6 +859,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPInstructionSC:
     case VPRecipeBase::VPReductionEVLSC:
     case VPRecipeBase::VPReductionSC:
+    case VPRecipeBase::VPMulAccSC:
+    case VPRecipeBase::VPExtendedReductionSC:
     case VPRecipeBase::VPReplicateSC:
     case VPRecipeBase::VPScalarIVStepsSC:
     case VPRecipeBase::VPVectorPointerSC:
@@ -2655,6 +2657,221 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
   }
 };
 
+/// A recipe to represent inloop extended reduction operations, performing a
+/// reduction on a vector operand into a scalar value, and adding the result to
+/// a chain. This recipe is high level abstract which will generate
+/// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
+/// {ChainOp, VecOp, [Condition]}.
+class VPExtendedReductionRecipe : public VPSingleDefRecipe {
+  /// The recurrence decriptor for the reduction in question.
+  const RecurrenceDescriptor &RdxDesc;
+  bool IsOrdered;
+  /// Whether the reduction is conditional.
+  bool IsConditional = false;
+  /// Type after extend.
+  Type *ResultTy;
+  Instruction::CastOps ExtOp;
+  CastInst *CastInstr;
+  bool IsZExt;
+
+protected:
+  VPExtendedReductionRecipe(const unsigned char SC,
+                            const RecurrenceDescriptor &R, Instruction *RedI,
+                            Instruction::CastOps ExtOp, CastInst *CastI,
+                            ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                            bool IsOrdered, Type *ResultTy)
+      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+        ResultTy(ResultTy), ExtOp(ExtOp), CastInstr(CastI) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+    IsZExt = ExtOp == Instruction::CastOps::ZExt;
+  }
+
+public:
+  VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                            CastInst *CastI, VPValue *ChainOp, VPValue *VecOp,
+                            VPValue *CondOp, bool IsOrdered, Type *ResultTy)
+      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
+                                  CastI->getOpcode(), CastI,
+                                  ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
+                                  IsOrdered, ResultTy) {}
+
+  VPExtendedReductionRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *Ext)
+      : VPExtendedReductionRecipe(
+            VPDef::VPExtendedReductionSC, Red->getRecurrenceDescriptor(),
+            Red->getUnderlyingInstr(), Ext->getOpcode(),
+            cast<CastInst>(Ext->getUnderlyingInstr()),
+            ArrayRef<VPValue *>({Red->getChainOp(), Ext->getOperand(0)}),
+            Red->getCondOp(), Red->isOrdered(), Ext->getResultType()) {}
+
+  ~VPExtendedReductionRecipe() override = default;
+
+  VPExtendedReductionRecipe *clone() override {
+    llvm_unreachable("Not implement yet");
+  }
+
+  static inline bool classof(const VPRecipeBase *R) {
+    return R->getVPDefID() == VPDef::VPExtendedReductionSC;
+  }
+
+  static inline bool classof(const VPUser *U) {
+    auto *R = dyn_cast<VPRecipeBase>(U);
+    return R && classof(R);
+  }
+
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPExtendedReductionRecipe should be transform to "
+                     "VPExtendedRecipe + VPReductionRecipe before execution.");
+  };
+
+  /// Return the cost of VPExtendedReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// Return the recurrence decriptor for the in-loop reduction.
+  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
+    return RdxDesc;
+  }
+  /// Return true if the in-loop reduction is ordered.
+  bool isOrdered() const { return IsOrdered; };
+  /// Return true if the in-loop reduction is conditional.
+  bool isConditional() const { return IsConditional; };
+  /// The VPValue of the scalar Chain being accumulated.
+  VPValue *getChainOp() const { return getOperand(0); }
+  /// The VPValue of the vector value to be extended and reduced.
+  VPValue *getVecOp() const { return getOperand(1); }
+  /// The VPValue of the condition for the block.
+  VPValue *getCondOp() const {
+    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
+  }
+  Type *getResultType() const { return ResultTy; };
+  Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  CastInst *getExtInstr() const { return CastInstr; };
+};
+
+/// A recipe to represent inloop MulAccreduction operations, performing a
+/// reduction on a vector operand into a scalar value, and adding the result to
+/// a chain. This recipe is high level abstract which will generate
+/// VPReductionRecipe VPWidenRecipe(mul)and VPWidenCastRecipe before execution.
+/// The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+class VPMulAccRecipe : public VPSingleDefRecipe {
+  /// The recurrence decriptor for the reduction in question.
+  const RecurrenceDescriptor &RdxDesc;
+  bool IsOrdered;
+  /// Whether the reduction is conditional.
+  bool IsConditional = false;
+  /// Type after extend.
+  Type *ResultTy;
+  /// Type for mul.
+  Type *MulTy;
+  /// reduce.add(OuterExt(mul(InnerExt(), InnerExt())))
+  Instruction::CastOps OuterExtOp;
+  Instruction::CastOps InnerExtOp;
+
+  Instruction *MulI;
+  Instruction *OuterExtI;
+  Instruction *InnerExt0I;
+  Instruction *InnerExt1I;
+
+protected:
+  VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+                 Instruction *RedI, Instruction::CastOps OuterExtOp,
+                 Instruction *OuterExtI, Instruction *MulI,
+                 Instruction::CastOps InnerExtOp, Instruction *InnerExt0I,
+                 Instruction *InnerExt1I, ArrayRef<VPValue *> Operands,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+        ResultTy(ResultTy), MulTy(MulTy), OuterExtOp(OuterExtOp),
+        InnerExtOp(InnerExtOp), MulI(MulI), OuterExtI(OuterExtI),
+        InnerExt0I(InnerExt0I), InnerExt1I(InnerExt1I) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+  }
+
+public:
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 Instruction *OuterExt, Instruction *Mul,
+                 Instruction *InnerExt0, Instruction *InnerExt1,
+                 VPValue *ChainOp, VPValue *InnerExt0Op, VPValue *InnerExt1Op,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+      : VPMulAccRecipe(
+            VPDef::VPMulAccSC, R, RedI, cast<CastInst>(OuterExt)->getOpcode(),
+            OuterExt, Mul, cast<CastInst>(InnerExt0)->getOpcode(), InnerExt0,
+            InnerExt1, ArrayRef<VPValue *>({ChainOp, InnerExt0Op, InnerExt1Op}),
+            CondOp, IsOrdered, ResultTy, MulTy) {}
+
+  VPMulAccRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *OuterExt,
+                 VPWidenRecipe *Mul, VPWidenCastRecipe *InnerExt0,
+                 VPWidenCastRecipe *InnerExt1)
+      : VPMulAccRecipe(
+            VPDef::VPMulAccSC, Red->getRecurrenceDescriptor(),
+            Red->getUnderlyingInstr(), OuterExt->getOpcode(),
+            OuterExt->getUnderlyingInstr(), Mul->getUnderlyingInstr(),
+            InnerExt0->getOpcode(), InnerExt0->getUnderlyingInstr(),
+            InnerExt1->getUnderlyingInstr(),
+            ArrayRef<VPValue *>({Red->getChainOp(), InnerExt0->getOperand(0),
+                                 InnerExt1->getOperand(0)}),
+            Red->getCondOp(), Red->isOrdered(), OuterExt->getResultType(),
+            InnerExt0->getResultType()) {}
+
+  ~VPMulAccRecipe() override = default;
+
+  VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
+
+  static inline bool classof(const VPRecipeBase *R) {
+    return R->getVPDefID() == VPRecipeBase::VPMulAccSC;
+  }
+
+  static inline bool classof(const VPUser *U) {
+    auto *R = dyn_cast<VPRecipeBase>(U);
+    return R && classof(R);
+  }
+
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override;
+
+  /// Return the cost of VPExtendedReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// Return the recurrence decriptor for the in-loop reduction.
+  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
+    return RdxDesc;
+  }
+  /// Return true if the in-loop reduction is ordered.
+  bool isOrdered() const { return IsOrdered; };
+  /// Return true if the in-loop reduction is conditional.
+  bool isConditional() const { return IsConditional; };
+  /// The VPValue of the scalar Chain being accumulated.
+  VPValue *getChainOp() const { return getOperand(0); }
+  /// The VPValue of the vector value to be extended and reduced.
+  VPValue *getVecOp() const { return getOperand(1); }
+  /// The VPValue of the condition for the block.
+  VPValue *getCondOp() const {
+    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
+  }
+  Type *getResultTy() const { return ResultTy; };
+  Instruction::CastOps getOuterExtOpcode() const { return OuterExtOp; };
+  Instruction::CastOps getInnerExtOpcode() const { return InnerExtOp; };
+};
+
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
 /// copies of the original scalar type, one per lane, instead of producing a
 /// single copy of widened type for all lanes. If the instruction is known to be
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index e3ec72d0aca6ce..dc4fe31d10ab0e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1488,6 +1488,27 @@ void VPWidenCastRecipe::execute(VPTransformState &State) {
   State.addMetadata(Cast, cast_or_null<Instruction>(getUnderlyingValue()));
 }
 
+// Computes the CastContextHint from a recipes that may access memory.
+static TTI::CastContextHint computeCCH(const VPRecipeBase *R, ElementCount VF) {
+  if (VF.isScalar())
+    return TTI::CastContextHint::Normal;
+  if (isa<VPInterleaveRecipe>(R))
+    return TTI::CastContextHint::Interleave;
+  if (const auto *ReplicateRecipe = dyn_cast<VPReplicateRecipe>(R))
+    return ReplicateRecipe->isPredicated() ? TTI::CastContextHint::Masked
+                                           : TTI::CastContextHint::Normal;
+  const auto *WidenMemoryRecipe = dyn_cast<VPWidenMemoryRecipe>(R);
+  if (WidenMemoryRecipe == nullptr)
+    return TTI::CastContextHint::None;
+  if (!WidenMemoryRecipe->isConsecutive())
+    return TTI::CastContextHint::GatherScatter;
+  if (WidenMemoryRecipe->isReverse())
+    return TTI::CastContextHint::Reversed;
+  if (WidenMemoryRecipe->isMasked())
+    return TTI::CastContextHint::Masked;
+  return TTI::CastContextHint::Normal;
+}
+
 InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
                                                VPCostContext &Ctx) const {
   // TODO: In some cases, VPWidenCastRecipes are created but not considered in
@@ -1495,26 +1516,6 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
   // reduction in a smaller type.
   if (!getUnderlyingValue())
     return 0;
-  // Computes the CastContextHint from a recipes that may access memory.
-  auto ComputeCCH = [&](const VPRecipeBase *R) -> TTI::CastContextHint {
-    if (VF.isScalar())
-      return TTI::CastContextHint::Normal;
-    if (isa<VPInterleaveRecipe>(R))
-      return TTI::CastContextHint::Interleave;
-    if (const auto *ReplicateRecipe = dyn_cast<VPReplicateRecipe>(R))
-      return ReplicateRecipe->isPredicated() ? TTI::CastContextHint::Masked
-                                             : TTI::CastContextHint::Normal;
-    const auto *WidenMemoryRecipe = dyn_cast<VPWidenMemoryRecipe>(R);
-    if (WidenMemoryRecipe == nullptr)
-      return TTI::CastContextHint::None;
-    if (!WidenMemoryRecipe->isConsecutive())
-      return TTI::CastContextHint::GatherScatter;
-    if (WidenMemoryRecipe->isReverse())
-      return TTI::CastContextHint::Reversed;
-    if (WidenMemoryRecipe->isMasked())
-      return TTI::CastContextHint::Masked;
-    return TTI::CastContextHint::Normal;
-  };
 
   VPValue *Operand = getOperand(0);
   TTI::CastContextHint CCH = TTI::CastContextHint::None;
@@ -1522,7 +1523,7 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
   if ((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) &&
       !hasMoreThanOneUniqueUser() && getNumUsers() > 0) {
     if (auto *StoreRecipe = dyn_cast<VPRecipeBase>(*user_begin()))
-      CCH = ComputeCCH(StoreRecipe);
+      CCH = computeCCH(StoreRecipe, VF);
   }
   // For Z/Sext, get the context from the operand.
   else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt ||
@@ -1530,7 +1531,7 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
     if (Operand->isLiveIn())
       CCH = TTI::CastContextHint::Normal;
     else if (Operand->getDefiningRecipe())
-      CCH = ComputeCCH(Operand->getDefiningRecipe());
+      CCH = computeCCH(Operand->getDefiningRecipe(), VF);
   }
 
   auto *SrcTy =
@@ -2208,6 +2209,7 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
         Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
+  /*
   using namespace llvm::VPlanPatternMatch;
   auto GetMulAccReductionCost =
       [&](const VPReductionRecipe *Red) -> InstructionCost {
@@ -2321,11 +2323,57 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   InstructionCost ExtendedCost = GetExtendedReductionCost(this);
   if (ExtendedCost.isValid())
     return ExtendedCost;
+  */
 
   // Default cost.
   return BaseCost;
 }
 
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+                                       VPCostContext &Ctx) const {
+  RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+  Type *ElementTy = getResultType();
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
+         "Inferred type and recurrence type mismatch.");
+
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost ReductionCost =
+      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+    Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+    ReductionCost += Ctx.TTI.getMinMaxReductionCost(
+        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  }
+
+  // Extended cost
+  auto *SrcTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
+  auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+  TTI::CastContextHint CCH = computeCCH(getVecOp()->getDefiningRecipe(), VF);
+  // Arm TTI will use the underlying instruction to determine the cost.
+  InstructionCost ExtendedCost = Ctx.TTI.getCastInstrCost(
+      Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
+      dyn_cast_if_present<Instruction>(getUnderlyingValue()));
+
+  // ExtendedReduction Cost
+  InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
+      Opcode, IsZExt, ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
+  // Check if folding ext into ExtendedReduction is profitable.
+  if (ExtendedRedCost.isValid() &&
+      ExtendedRedCost < ExtendedCost + ReductionCost) {
+    return ExtendedRedCost;
+  }
+  return ExtendedCost + ReductionCost;
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {
@@ -2371,6 +2419,28 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
 }
+
+void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+                                      VPSlotTracker &SlotTracker) const {
+  O << Indent << "EXTENDED-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = ";
+  getChainOp()->printAsOperand(O, SlotTracker);
+  O << " +";
+  if (isa<FPMathOperator>(getUnderlyingInstr()))
+    O << getUnderlyingInstr()->getFastMathFlags();
+  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  getVecOp()->printAsOperand(O, SlotTracker);
+  O << " extended to " << *getResultType();
+  if (isConditional()) {
+    O << ", ";
+    getCondOp()->printAsOperand(O, SlotTracker);
+  }
+  O << ")";
+  if (RdxDesc.IntermediateStore)
+    O << " (with final reduction value stored in invariant address sank "
+         "outside of loop)";
+}
 #endif
 
 bool VPReplicateRecipe::shouldPack() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index ea8845eaa75d4d..0cfe1ce7998830 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -521,6 +521,30 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
   }
 }
 
+void VPlanTransforms::prepareExecute(VPlan &Plan) {
+  errs() << "\n\n\n!!Prepare to execute\n";
+  ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
+      Plan.getVectorLoopRegion());
+  for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
+           vp_depth_first_deep(Plan.getEntry()))) {
+    for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
+      if (!isa<VPExtendedReductionRecipe>(&R))
+        continue;
+      auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+      auto *Ext = new VPWidenCastRecipe(
+          ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+          *ExtRed->getExtInstr());
+      auto *Red = new VPReductionRecipe(
+          ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
+          ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
+      Ext->insertBefore(ExtRed);
+      Red->insertBefore(ExtRed);
+      ExtRed->replaceAllUsesWith(Red);
+      ExtRed->eraseFromParent();
+    }
+  }
+}
+
 static VPScalarIVStepsRecipe *
 createScalarIVSteps(VPlan &Plan, InductionDescriptor::InductionKind Kind,
                     Instruction::BinaryOps InductionOpcode,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 11e094db6294f6..6310c23b605da3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -123,6 +123,9 @@ struct VPlanTransforms {
 
   /// Remove dead recipes from \p Plan.
   static void removeDeadRecipes(VPlan &Plan);
+
+  /// TODO: Rebase to fhahn's implementation.
+  static void prepareExecute(VPlan &Plan);
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 691b0d40823cfb..09defa6406c078 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -329,6 +329,8 @@ class VPDef {
     VPInterleaveSC,
     VPReductionEVLSC,
     VPReductionSC,
+    VPMulAccSC,
+    VPExtendedReductionSC,
     VPReplicateSC,
     VPScalarCastSC,
     VPScalarIVStepsSC,

>From a025b91d5a4ebd2e6aa4b84c3196bd33f4eaeb8e Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 4 Nov 2024 22:02:22 -0800
Subject: [PATCH 3/6] Support MulAccRecipe

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  33 ++++-
 llvm/lib/Transforms/Vectorize/VPlan.h         | 103 +++++++++-------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 114 +++++++++++++++++-
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  56 ++++++---
 4 files changed, 237 insertions(+), 69 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e44975cffbb124..0035442e545c90 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7662,8 +7662,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
 
   // TODO: Rebase to fhahn's implementation.
   VPlanTransforms::prepareExecute(BestVPlan);
-  dbgs() << "\n\n print plan\n";
-  BestVPlan.print(dbgs());
   BestVPlan.execute(&State);
 
   // 2.5 Collect reduction resume values.
@@ -9377,11 +9375,34 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      // VPWidenCastRecipes can folded into VPReductionRecipe
-      VPValue *A;
+      VPValue *A, *B;
       VPSingleDefRecipe *RedRecipe;
-      if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
-          !VecOp->hasMoreThanOneUniqueUser()) {
+      // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+      if (RdxDesc.getOpcode() == Instruction::Add &&
+          match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+        VPRecipeBase *RecipeA = A->getDefiningRecipe();
+        VPRecipeBase *RecipeB = B->getDefiningRecipe();
+        if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+            match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+            cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
+                cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
+            !A->hasMoreThanOneUniqueUser() && !B->hasMoreThanOneUniqueUser()) {
+          RedRecipe = new VPMulAccRecipe(
+              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+              CM.useOrderedReductions(RdxDesc),
+              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
+              cast<VPWidenCastRecipe>(RecipeA),
+              cast<VPWidenCastRecipe>(RecipeB));
+        } else {
+          RedRecipe = new VPMulAccRecipe(
+              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+              CM.useOrderedReductions(RdxDesc),
+              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+        }
+      }
+      // VPWidenCastRecipes can folded into VPReductionRecipe
+      else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+               !VecOp->hasMoreThanOneUniqueUser()) {
         RedRecipe = new VPExtendedReductionRecipe(
             RdxDesc, CurrentLinkI,
             cast<CastInst>(
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 176318ed0c5b1d..e91549e11812aa 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2770,60 +2770,64 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   /// Whether the reduction is conditional.
   bool IsConditional = false;
   /// Type after extend.
-  Type *ResultTy;
-  /// Type for mul.
-  Type *MulTy;
-  /// reduce.add(OuterExt(mul(InnerExt(), InnerExt())))
-  Instruction::CastOps OuterExtOp;
-  Instruction::CastOps InnerExtOp;
+  Type *ResultType;
+  /// reduce.add(mul(Ext(), Ext()))
+  Instruction::CastOps ExtOp;
+
+  Instruction *MulInstr;
+  CastInst *Ext0Instr;
+  CastInst *Ext1Instr;
 
-  Instruction *MulI;
-  Instruction *OuterExtI;
-  Instruction *InnerExt0I;
-  Instruction *InnerExt1I;
+  bool IsExtended;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction::CastOps OuterExtOp,
-                 Instruction *OuterExtI, Instruction *MulI,
-                 Instruction::CastOps InnerExtOp, Instruction *InnerExt0I,
-                 Instruction *InnerExt1I, ArrayRef<VPValue *> Operands,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+                 Instruction *RedI, Instruction *MulInstr,
+                 Instruction::CastOps ExtOp, Instruction *Ext0Instr,
+                 Instruction *Ext1Instr, ArrayRef<VPValue *> Operands,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+        ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+        Ext0Instr(cast<CastInst>(Ext0Instr)),
+        Ext1Instr(cast<CastInst>(Ext1Instr)) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+    IsExtended = true;
+  }
+
+  VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+                 Instruction *RedI, Instruction *MulInstr,
+                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
-        ResultTy(ResultTy), MulTy(MulTy), OuterExtOp(OuterExtOp),
-        InnerExtOp(InnerExtOp), MulI(MulI), OuterExtI(OuterExtI),
-        InnerExt0I(InnerExt0I), InnerExt1I(InnerExt1I) {
+        MulInstr(MulInstr) {
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
     }
+    IsExtended = false;
   }
 
 public:
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
-                 Instruction *OuterExt, Instruction *Mul,
-                 Instruction *InnerExt0, Instruction *InnerExt1,
-                 VPValue *ChainOp, VPValue *InnerExt0Op, VPValue *InnerExt1Op,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
-      : VPMulAccRecipe(
-            VPDef::VPMulAccSC, R, RedI, cast<CastInst>(OuterExt)->getOpcode(),
-            OuterExt, Mul, cast<CastInst>(InnerExt0)->getOpcode(), InnerExt0,
-            InnerExt1, ArrayRef<VPValue *>({ChainOp, InnerExt0Op, InnerExt1Op}),
-            CondOp, IsOrdered, ResultTy, MulTy) {}
-
-  VPMulAccRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *OuterExt,
-                 VPWidenRecipe *Mul, VPWidenCastRecipe *InnerExt0,
-                 VPWidenCastRecipe *InnerExt1)
-      : VPMulAccRecipe(
-            VPDef::VPMulAccSC, Red->getRecurrenceDescriptor(),
-            Red->getUnderlyingInstr(), OuterExt->getOpcode(),
-            OuterExt->getUnderlyingInstr(), Mul->getUnderlyingInstr(),
-            InnerExt0->getOpcode(), InnerExt0->getUnderlyingInstr(),
-            InnerExt1->getUnderlyingInstr(),
-            ArrayRef<VPValue *>({Red->getChainOp(), InnerExt0->getOperand(0),
-                                 InnerExt1->getOperand(0)}),
-            Red->getCondOp(), Red->isOrdered(), OuterExt->getResultType(),
-            InnerExt0->getResultType()) {}
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+                 VPWidenCastRecipe *Ext1)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       Ext0->getOpcode(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(),
+                       ArrayRef<VPValue *>(
+                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       CondOp, IsOrdered, Ext0->getResultType()) {}
+
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenRecipe *Mul)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       ArrayRef<VPValue *>(
+                           {ChainOp, Mul->getOperand(0), Mul->getOperand(0)}),
+                       CondOp, IsOrdered) {}
 
   ~VPMulAccRecipe() override = default;
 
@@ -2839,7 +2843,10 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   }
 
   /// Generate the reduction in the loop
-  void execute(VPTransformState &State) override;
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
+                     "VPWidenRecipe + VPReductionRecipe before execution");
+  }
 
   /// Return the cost of VPExtendedReductionRecipe.
   InstructionCost computeCost(ElementCount VF,
@@ -2862,14 +2869,18 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   /// The VPValue of the scalar Chain being accumulated.
   VPValue *getChainOp() const { return getOperand(0); }
   /// The VPValue of the vector value to be extended and reduced.
-  VPValue *getVecOp() const { return getOperand(1); }
+  VPValue *getVecOp0() const { return getOperand(1); }
+  VPValue *getVecOp1() const { return getOperand(2); }
   /// The VPValue of the condition for the block.
   VPValue *getCondOp() const {
     return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
   }
-  Type *getResultTy() const { return ResultTy; };
-  Instruction::CastOps getOuterExtOpcode() const { return OuterExtOp; };
-  Instruction::CastOps getInnerExtOpcode() const { return InnerExtOp; };
+  Type *getResultType() const { return ResultType; };
+  Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  Instruction *getMulInstr() const { return MulInstr; };
+  CastInst *getExt0Instr() const { return Ext0Instr; };
+  CastInst *getExt1Instr() const { return Ext1Instr; };
+  bool isExtended() const { return IsExtended; };
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index dc4fe31d10ab0e..15d39b7eabcdf6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,9 +270,7 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = &WidenMem->getIngredient();
 
   InstructionCost RecipeCost;
-  if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
-      (Ctx.FoldedRecipes.contains(VF) &&
-       Ctx.FoldedRecipes.at(VF).contains(this))) {
+  if ((UI && Ctx.skipCostComputation(UI, VF.isVector()))) {
     RecipeCost = 0;
   } else {
     RecipeCost = computeCost(VF, Ctx);
@@ -2374,6 +2372,85 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   return ExtendedCost + ReductionCost;
 }
 
+InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
+                                            VPCostContext &Ctx) const {
+  Type *ElementTy =
+      IsExtended ? getResultType() : Ctx.Types.inferScalarType(getVecOp0());
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
+         "Inferred type and recurrence type mismatch.");
+
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost ReductionCost =
+      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+      Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+
+  // Extended cost
+  InstructionCost ExtendedCost = 0;
+  if (IsExtended) {
+    auto *SrcTy = cast<VectorType>(
+        ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+    auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+    TTI::CastContextHint CCH0 =
+        computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+    // Arm TTI will use the underlying instruction to determine the cost.
+    ExtendedCost = Ctx.TTI.getCastInstrCost(
+        ExtOp, DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt0Instr()));
+    TTI::CastContextHint CCH1 =
+        computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+    ExtendedCost += Ctx.TTI.getCastInstrCost(
+        ExtOp, DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt1Instr()));
+  }
+
+  // Mul cost
+  InstructionCost MulCost;
+  SmallVector<const Value *, 4> Operands;
+  Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
+  if (IsExtended)
+    MulCost = Ctx.TTI.getArithmeticInstrCost(
+        Instruction::Mul, VectorTy, CostKind,
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        Operands, MulInstr, &Ctx.TLI);
+  else {
+    VPValue *RHS = getVecOp1();
+    // Certain instructions can be cheaper to vectorize if they have a constant
+    // second vector operand. One example of this are shifts on x86.
+    TargetTransformInfo::OperandValueInfo RHSInfo = {
+        TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None};
+    if (RHS->isLiveIn())
+      RHSInfo = Ctx.TTI.getOperandInfo(RHS->getLiveInIRValue());
+
+    if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
+        RHS->isDefinedOutsideLoopRegions())
+      RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
+    MulCost = Ctx.TTI.getArithmeticInstrCost(
+        Instruction::Mul, VectorTy, CostKind,
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        RHSInfo, Operands, MulInstr, &Ctx.TLI);
+  }
+
+  // ExtendedReduction Cost
+  VectorType *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+  InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+      getExtOpcode() == Instruction::CastOps::ZExt, ElementTy, SrcVecTy,
+      CostKind);
+
+  // Check if folding ext into ExtendedReduction is profitable.
+  if (MulAccCost.isValid() &&
+      MulAccCost < ExtendedCost + ReductionCost + MulCost) {
+    return MulAccCost;
+  }
+  return ExtendedCost + ReductionCost + MulCost;
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {
@@ -2441,6 +2518,37 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
 }
+
+void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
+                           VPSlotTracker &SlotTracker) const {
+  O << Indent << "MULACC-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = ";
+  getChainOp()->printAsOperand(O, SlotTracker);
+  O << " +";
+  if (isa<FPMathOperator>(getUnderlyingInstr()))
+    O << getUnderlyingInstr()->getFastMathFlags();
+  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  O << " mul ";
+  if (IsExtended)
+    O << "(";
+  getVecOp0()->printAsOperand(O, SlotTracker);
+  if (IsExtended)
+    O << " extended to " << *getResultType() << ")";
+  if (IsExtended)
+    O << "(";
+  getVecOp1()->printAsOperand(O, SlotTracker);
+  if (IsExtended)
+    O << " extended to " << *getResultType() << ")";
+  if (isConditional()) {
+    O << ", ";
+    getCondOp()->printAsOperand(O, SlotTracker);
+  }
+  O << ")";
+  if (RdxDesc.IntermediateStore)
+    O << " (with final reduction value stored in invariant address sank "
+         "outside of loop)";
+}
 #endif
 
 bool VPReplicateRecipe::shouldPack() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 0cfe1ce7998830..d65954cb8dc14d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -522,25 +522,53 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
 }
 
 void VPlanTransforms::prepareExecute(VPlan &Plan) {
-  errs() << "\n\n\n!!Prepare to execute\n";
   ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
       Plan.getVectorLoopRegion());
   for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
            vp_depth_first_deep(Plan.getEntry()))) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
-      if (!isa<VPExtendedReductionRecipe>(&R))
-        continue;
-      auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
-      auto *Ext = new VPWidenCastRecipe(
-          ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
-          *ExtRed->getExtInstr());
-      auto *Red = new VPReductionRecipe(
-          ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
-          ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
-      Ext->insertBefore(ExtRed);
-      Red->insertBefore(ExtRed);
-      ExtRed->replaceAllUsesWith(Red);
-      ExtRed->eraseFromParent();
+      if (isa<VPExtendedReductionRecipe>(&R)) {
+        auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+        auto *Ext = new VPWidenCastRecipe(
+            ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+            *ExtRed->getExtInstr());
+        auto *Red = new VPReductionRecipe(
+            ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
+            ExtRed->getChainOp(), Ext, ExtRed->getCondOp(),
+            ExtRed->isOrdered());
+        Ext->insertBefore(ExtRed);
+        Red->insertBefore(ExtRed);
+        ExtRed->replaceAllUsesWith(Red);
+        ExtRed->eraseFromParent();
+      } else if (isa<VPMulAccRecipe>(&R)) {
+        auto *MulAcc = cast<VPMulAccRecipe>(&R);
+        VPValue *Op0, *Op1;
+        if (MulAcc->isExtended()) {
+          Op0 = new VPWidenCastRecipe(
+              MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+              MulAcc->getResultType(), *MulAcc->getExt0Instr());
+          Op1 = new VPWidenCastRecipe(
+              MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+              MulAcc->getResultType(), *MulAcc->getExt1Instr());
+          Op0->getDefiningRecipe()->insertBefore(MulAcc);
+          Op1->getDefiningRecipe()->insertBefore(MulAcc);
+        } else {
+          Op0 = MulAcc->getVecOp0();
+          Op1 = MulAcc->getVecOp1();
+        }
+        Instruction *MulInstr = MulAcc->getMulInstr();
+        SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
+        auto *Mul = new VPWidenRecipe(*MulInstr,
+                                      make_range(MulOps.begin(), MulOps.end()));
+        auto *Red = new VPReductionRecipe(
+            MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
+            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+            MulAcc->isOrdered());
+        Mul->insertBefore(MulAcc);
+        Red->insertBefore(MulAcc);
+        MulAcc->replaceAllUsesWith(Red);
+        MulAcc->eraseFromParent();
+      }
     }
   }
 }

>From 4319f06f9d6220c938f7dda48a268e8914ba36b3 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 5 Nov 2024 16:52:31 -0800
Subject: [PATCH 4/6] Fix servel errors and update tests.

We need to update tests since the generated vector IR will be reordered.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 45 ++++++++++++++
 llvm/lib/Transforms/Vectorize/VPlan.h         | 34 ++++++++---
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  6 +-
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 14 ++++-
 .../LoopVectorize/ARM/mve-reduction-types.ll  |  4 +-
 .../LoopVectorize/ARM/mve-reductions.ll       | 61 ++++++++++---------
 .../LoopVectorize/RISCV/inloop-reduction.ll   | 32 ++++++----
 .../LoopVectorize/reduction-inloop-pred.ll    | 34 +++++------
 .../LoopVectorize/reduction-inloop.ll         | 12 ++--
 9 files changed, 163 insertions(+), 79 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0035442e545c90..8bdcdae09a62a6 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7394,6 +7394,19 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       }
       if (Instruction *UI = GetInstructionForCost(&R))
         SeenInstrs.insert(UI);
+      // VPExtendedReductionRecipe contains a folded extend instruction.
+      if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
+        SeenInstrs.insert(ExtendedRed->getExtInstr());
+      // VPMulAccRecupe constians a mul and otional extend instructions.
+      else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+        SeenInstrs.insert(MulAcc->getMulInstr());
+        if (MulAcc->isExtended()) {
+          SeenInstrs.insert(MulAcc->getExt0Instr());
+          SeenInstrs.insert(MulAcc->getExt1Instr());
+          if (auto *Ext = MulAcc->getExtInstr())
+            SeenInstrs.insert(Ext);
+        }
+      }
     }
   }
 
@@ -9399,6 +9412,38 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
               CM.useOrderedReductions(RdxDesc),
               cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
         }
+      } else if (RdxDesc.getOpcode() == Instruction::Add &&
+                 match(VecOp,
+                       m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue(A)),
+                                          m_ZExtOrSExt(m_VPValue(B)))))) {
+        VPWidenCastRecipe *Ext =
+            dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+        VPWidenRecipe *Mul =
+            dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+        if (Mul && match(Mul, m_Mul(m_ZExtOrSExt(m_VPValue()),
+                                    m_ZExtOrSExt(m_VPValue())))) {
+          VPWidenRecipe *Mul =
+              cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+          VPWidenCastRecipe *Ext0 =
+              cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
+          VPWidenCastRecipe *Ext1 =
+              cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
+          if (Ext->getOpcode() == Ext0->getOpcode() &&
+              Ext0->getOpcode() == Ext1->getOpcode()) {
+            RedRecipe = new VPMulAccRecipe(
+                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
+                cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
+          } else
+            RedRecipe = new VPExtendedReductionRecipe(
+                RdxDesc, CurrentLinkI,
+                cast<CastInst>(
+                    cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
+                PreviousLink, cast<VPWidenCastRecipe>(VecOp)->getOperand(0),
+                CondOp, CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenCastRecipe>(VecOp)->getResultType());
+        }
       }
       // VPWidenCastRecipes can folded into VPReductionRecipe
       else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e91549e11812aa..6102399e351bdd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2771,23 +2771,27 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   bool IsConditional = false;
   /// Type after extend.
   Type *ResultType;
-  /// reduce.add(mul(Ext(), Ext()))
+  /// reduce.add(ext((mul(Ext(), Ext())))
   Instruction::CastOps ExtOp;
 
   Instruction *MulInstr;
+  CastInst *ExtInstr = nullptr;
   CastInst *Ext0Instr;
   CastInst *Ext1Instr;
 
   bool IsExtended;
+  bool IsOuterExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *MulInstr,
-                 Instruction::CastOps ExtOp, Instruction *Ext0Instr,
-                 Instruction *Ext1Instr, ArrayRef<VPValue *> Operands,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+                 Instruction *RedI, Instruction *ExtInstr,
+                 Instruction *MulInstr, Instruction::CastOps ExtOp,
+                 Instruction *Ext0Instr, Instruction *Ext1Instr,
+                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered,
+                 Type *ResultType)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
         ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+        ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
     if (CondOp) {
@@ -2814,9 +2818,9 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       Ext0->getOpcode(), Ext0->getUnderlyingInstr(),
-                       Ext1->getUnderlyingInstr(),
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
+                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
+                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
                        ArrayRef<VPValue *>(
                            {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
                        CondOp, IsOrdered, Ext0->getResultType()) {}
@@ -2826,9 +2830,20 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  VPWidenRecipe *Mul)
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
                        ArrayRef<VPValue *>(
-                           {ChainOp, Mul->getOperand(0), Mul->getOperand(0)}),
+                           {ChainOp, Mul->getOperand(0), Mul->getOperand(1)}),
                        CondOp, IsOrdered) {}
 
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
+                 VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
+                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
+                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
+                       ArrayRef<VPValue *>(
+                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       CondOp, IsOrdered, Ext0->getResultType()) {}
+
   ~VPMulAccRecipe() override = default;
 
   VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
@@ -2878,6 +2893,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   Type *getResultType() const { return ResultType; };
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
   Instruction *getMulInstr() const { return MulInstr; };
+  CastInst *getExtInstr() const { return ExtInstr; };
   CastInst *getExt0Instr() const { return Ext0Instr; };
   CastInst *getExt1Instr() const { return Ext1Instr; };
   bool isExtended() const { return IsExtended; };
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 15d39b7eabcdf6..f9cba35837d2ea 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2374,8 +2374,8 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 
 InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
-  Type *ElementTy =
-      IsExtended ? getResultType() : Ctx.Types.inferScalarType(getVecOp0());
+  Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
+                               : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
@@ -2436,7 +2436,7 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
         RHSInfo, Operands, MulInstr, &Ctx.TLI);
   }
 
-  // ExtendedReduction Cost
+  // MulAccReduction Cost
   VectorType *SrcVecTy =
       cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
   InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index d65954cb8dc14d..ddb3b0ae001132 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -556,15 +556,27 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
           Op0 = MulAcc->getVecOp0();
           Op1 = MulAcc->getVecOp1();
         }
+        VPSingleDefRecipe *VecOp;
         Instruction *MulInstr = MulAcc->getMulInstr();
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
         auto *Mul = new VPWidenRecipe(*MulInstr,
                                       make_range(MulOps.begin(), MulOps.end()));
+        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
+          // dbgs() <<"\n!!!"<< *OuterExtInstr << " " << MulAcc->getExtOpcode()
+          // << "\n";
+          VecOp = new VPWidenCastRecipe(
+              MulAcc->getExtOpcode(), Mul,
+              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
+              *OuterExtInstr);
+        } else
+          VecOp = Mul;
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
-            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+            MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
             MulAcc->isOrdered());
         Mul->insertBefore(MulAcc);
+        if (VecOp != Mul)
+          VecOp->insertBefore(MulAcc);
         Red->insertBefore(MulAcc);
         MulAcc->replaceAllUsesWith(Red);
         MulAcc->eraseFromParent();
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
index 832d4db53036fb..9d4b83edd55ccb 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
@@ -24,11 +24,11 @@ define i32 @mla_i32(ptr noalias nocapture readonly %A, ptr noalias nocapture rea
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[TMP1]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP5]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP6:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <16 x i32> [[TMP6]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP8:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP7]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
@@ -107,11 +107,11 @@ define i32 @mla_i8(ptr noalias nocapture readonly %A, ptr noalias nocapture read
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[TMP1]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP5]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP6:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <16 x i32> [[TMP6]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP8:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP7]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index a7cb5c61ca5502..ea6519c8c13b93 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -646,11 +646,11 @@ define i64 @mla_i16_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i16>, ptr [[TMP2]], align 2
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
+; CHECK-NEXT:    [[TMP11:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i32> [[TMP4]] to <8 x i64>
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
 ; CHECK-NEXT:    [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
@@ -726,11 +726,11 @@ define i64 @mla_i8_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i32
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i8>, ptr [[TMP0]], align 1
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i8>, ptr [[TMP2]], align 1
-; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <8 x i32> [[TMP3]], [[TMP1]]
+; CHECK-NEXT:    [[TMP11:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <8 x i32> [[TMP11]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = zext nneg <8 x i32> [[TMP4]] to <8 x i64>
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
 ; CHECK-NEXT:    [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
@@ -802,11 +802,11 @@ define i32 @mla_i32_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP0]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP1]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i32, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP7]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP2]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
@@ -855,10 +855,10 @@ define i32 @mla_i16_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP0]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP2]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP4]], <8 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP5]])
@@ -910,10 +910,10 @@ define i32 @mla_i8_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i32
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <16 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP4]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
@@ -963,11 +963,11 @@ define signext i16 @mla_i16_i16(ptr nocapture readonly %x, ptr nocapture readonl
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i16 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP0]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP1]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = mul <8 x i16> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP7]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = mul <8 x i16> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> [[TMP2]], <8 x i16> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i16 [[TMP4]], [[VEC_PHI]]
@@ -1016,10 +1016,10 @@ define signext i16 @mla_i8_i16(ptr nocapture readonly %x, ptr nocapture readonly
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i16>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i16>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i16>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw <16 x i16> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i16> [[TMP4]], <16 x i16> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[TMP5]])
@@ -1069,11 +1069,11 @@ define zeroext i8 @mla_i8_i8(ptr nocapture readonly %x, ptr nocapture readonly %
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i8 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP1]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = mul <16 x i8> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP7]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = mul <16 x i8> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> [[TMP2]], <16 x i8> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i8 @llvm.vector.reduce.add.v16i8(<16 x i8> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i8 [[TMP4]], [[VEC_PHI]]
@@ -1122,10 +1122,10 @@ define i32 @red_mla_ext_s8_s16_s32(ptr noalias nocapture readonly %A, ptr noalia
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i8> @llvm.masked.load.v8i8.p0(ptr [[TMP0]], i32 1, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[B:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP2]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP4]], <8 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP5]])
@@ -1183,11 +1183,11 @@ define i64 @red_mla_ext_s16_u16_s64(ptr noalias nocapture readonly %A, ptr noali
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, ptr [[TMP0]], align 1
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <4 x i16> [[WIDE_LOAD]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[B:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i16>, ptr [[TMP2]], align 2
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, ptr [[TMP2]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <4 x i16> [[WIDE_LOAD]] to <4 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds i16, ptr [[B1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i16>, ptr [[TMP11]], align 2
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext <4 x i16> [[WIDE_LOAD1]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <4 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = zext <4 x i32> [[TMP4]] to <4 x i64>
@@ -1206,10 +1206,10 @@ define i64 @red_mla_ext_s16_u16_s64(ptr noalias nocapture readonly %A, ptr noali
 ; CHECK:       for.body:
 ; CHECK-NEXT:    [[I_011:%.*]] = phi i32 [ [[INC:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
 ; CHECK-NEXT:    [[S_010:%.*]] = phi i64 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i16, ptr [[A]], i32 [[I_011]]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i16, ptr [[B]], i32 [[I_011]]
 ; CHECK-NEXT:    [[TMP9:%.*]] = load i16, ptr [[ARRAYIDX]], align 1
 ; CHECK-NEXT:    [[CONV:%.*]] = sext i16 [[TMP9]] to i32
-; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds i16, ptr [[B]], i32 [[I_011]]
+; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds i16, ptr [[B1]], i32 [[I_011]]
 ; CHECK-NEXT:    [[TMP10:%.*]] = load i16, ptr [[ARRAYIDX1]], align 2
 ; CHECK-NEXT:    [[CONV2:%.*]] = zext i16 [[TMP10]] to i32
 ; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[CONV2]], [[CONV]]
@@ -1268,12 +1268,12 @@ define i32 @red_mla_u8_s8_u32(ptr noalias nocapture readonly %A, ptr noalias noc
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP0]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i8> [[WIDE_MASKED_LOAD]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP2]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i8> [[WIDE_MASKED_LOAD1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds i8, ptr [[B1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP9]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD2]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <4 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP4]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP5]])
@@ -1413,7 +1413,8 @@ define i32 @mla_i8_i32_multiuse(ptr nocapture readonly %x, ptr nocapture readonl
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP7]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP2]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
index 8ca2bd1f286ae3..9f1a61ebb5efef 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
@@ -187,23 +187,33 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
 ; IF-EVL-INLOOP-NEXT:    [[N_RND_UP:%.*]] = add i32 [[N]], [[TMP2]]
 ; IF-EVL-INLOOP-NEXT:    [[N_MOD_VF:%.*]] = urem i32 [[N_RND_UP]], [[TMP1]]
 ; IF-EVL-INLOOP-NEXT:    [[N_VEC:%.*]] = sub i32 [[N_RND_UP]], [[N_MOD_VF]]
+; IF-EVL-INLOOP-NEXT:    [[TRIP_COUNT_MINUS_1:%.*]] = sub i32 [[N]], 1
 ; IF-EVL-INLOOP-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vscale.i32()
 ; IF-EVL-INLOOP-NEXT:    [[TMP4:%.*]] = mul i32 [[TMP3]], 8
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[TRIP_COUNT_MINUS_1]], i64 0
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLAT2:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT1]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
 ; IF-EVL-INLOOP-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; IF-EVL-INLOOP:       vector.body:
 ; IF-EVL-INLOOP-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; IF-EVL-INLOOP-NEXT:    [[EVL_BASED_IV:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_EVL_NEXT:%.*]], [[VECTOR_BODY]] ]
-; IF-EVL-INLOOP-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP11:%.*]], [[VECTOR_BODY]] ]
-; IF-EVL-INLOOP-NEXT:    [[AVL:%.*]] = sub i32 [[N]], [[EVL_BASED_IV]]
-; IF-EVL-INLOOP-NEXT:    [[TMP5:%.*]] = call i32 @llvm.experimental.get.vector.length.i32(i32 [[AVL]], i32 8, i1 true)
-; IF-EVL-INLOOP-NEXT:    [[TMP6:%.*]] = add i32 [[EVL_BASED_IV]], 0
-; IF-EVL-INLOOP-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP6]]
-; IF-EVL-INLOOP-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i16, ptr [[TMP7]], i32 0
-; IF-EVL-INLOOP-NEXT:    [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP8]], <vscale x 8 x i1> shufflevector (<vscale x 8 x i1> insertelement (<vscale x 8 x i1> poison, i1 true, i64 0), <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer), i32 [[TMP5]])
-; IF-EVL-INLOOP-NEXT:    [[TMP9:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
-; IF-EVL-INLOOP-NEXT:    [[TMP10:%.*]] = call i32 @llvm.vp.reduce.add.nxv8i32(i32 0, <vscale x 8 x i32> [[TMP9]], <vscale x 8 x i1> shufflevector (<vscale x 8 x i1> insertelement (<vscale x 8 x i1> poison, i1 true, i64 0), <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer), i32 [[TMP5]])
-; IF-EVL-INLOOP-NEXT:    [[TMP11]] = add i32 [[TMP10]], [[VEC_PHI]]
-; IF-EVL-INLOOP-NEXT:    [[INDEX_EVL_NEXT]] = add i32 [[TMP5]], [[EVL_BASED_IV]]
+; IF-EVL-INLOOP-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ]
+; IF-EVL-INLOOP-NEXT:    [[TMP5:%.*]] = sub i32 [[N]], [[EVL_BASED_IV]]
+; IF-EVL-INLOOP-NEXT:    [[TMP6:%.*]] = call i32 @llvm.experimental.get.vector.length.i32(i32 [[TMP5]], i32 8, i1 true)
+; IF-EVL-INLOOP-NEXT:    [[TMP7:%.*]] = add i32 [[EVL_BASED_IV]], 0
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[EVL_BASED_IV]], i64 0
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
+; IF-EVL-INLOOP-NEXT:    [[TMP15:%.*]] = call <vscale x 8 x i32> @llvm.stepvector.nxv8i32()
+; IF-EVL-INLOOP-NEXT:    [[TMP16:%.*]] = add <vscale x 8 x i32> zeroinitializer, [[TMP15]]
+; IF-EVL-INLOOP-NEXT:    [[VEC_IV:%.*]] = add <vscale x 8 x i32> [[BROADCAST_SPLAT]], [[TMP16]]
+; IF-EVL-INLOOP-NEXT:    [[TMP17:%.*]] = icmp ule <vscale x 8 x i32> [[VEC_IV]], [[BROADCAST_SPLAT2]]
+; IF-EVL-INLOOP-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP7]]
+; IF-EVL-INLOOP-NEXT:    [[TMP9:%.*]] = getelementptr inbounds i16, ptr [[TMP8]], i32 0
+; IF-EVL-INLOOP-NEXT:    [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP9]], <vscale x 8 x i1> shufflevector (<vscale x 8 x i1> insertelement (<vscale x 8 x i1> poison, i1 true, i64 0), <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer), i32 [[TMP6]])
+; IF-EVL-INLOOP-NEXT:    [[TMP10:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
+; IF-EVL-INLOOP-NEXT:    [[TMP18:%.*]] = select <vscale x 8 x i1> [[TMP17]], <vscale x 8 x i32> [[TMP10]], <vscale x 8 x i32> zeroinitializer
+; IF-EVL-INLOOP-NEXT:    [[TMP11:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP18]])
+; IF-EVL-INLOOP-NEXT:    [[TMP12]] = add i32 [[TMP11]], [[VEC_PHI]]
+; IF-EVL-INLOOP-NEXT:    [[INDEX_EVL_NEXT]] = add i32 [[TMP6]], [[EVL_BASED_IV]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_NEXT]] = add i32 [[INDEX]], [[TMP4]]
 ; IF-EVL-INLOOP-NEXT:    [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; IF-EVL-INLOOP-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
index 6771f561913130..cb025608882e6a 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
@@ -424,62 +424,62 @@ define i32 @reduction_mix(ptr noalias nocapture %A, ptr noalias nocapture %B) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x i1> [[TMP0]], i64 0
 ; CHECK-NEXT:    br i1 [[TMP1]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
 ; CHECK:       pred.load.if:
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr [[TMP2]], align 4
-; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <4 x i32> poison, i32 [[TMP3]], i64 0
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = load i32, ptr [[TMP5]], align 4
 ; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x i32> poison, i32 [[TMP6]], i64 0
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[B1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[TMP12:%.*]] = load i32, ptr [[TMP8]], align 4
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <4 x i32> poison, i32 [[TMP12]], i64 0
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE]]
 ; CHECK:       pred.load.continue:
-; CHECK-NEXT:    [[TMP8:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP4]], [[PRED_LOAD_IF]] ]
 ; CHECK-NEXT:    [[TMP9:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP7]], [[PRED_LOAD_IF]] ]
+; CHECK-NEXT:    [[TMP14:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP13]], [[PRED_LOAD_IF]] ]
 ; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x i1> [[TMP0]], i64 1
 ; CHECK-NEXT:    br i1 [[TMP10]], label [[PRED_LOAD_IF3:%.*]], label [[PRED_LOAD_CONTINUE4:%.*]]
 ; CHECK:       pred.load.if3:
 ; CHECK-NEXT:    [[TMP11:%.*]] = or disjoint i64 [[INDEX]], 1
-; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP11]]
-; CHECK-NEXT:    [[TMP13:%.*]] = load i32, ptr [[TMP12]], align 4
-; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <4 x i32> [[TMP8]], i32 [[TMP13]], i64 1
 ; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP11]]
 ; CHECK-NEXT:    [[TMP16:%.*]] = load i32, ptr [[TMP15]], align 4
 ; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <4 x i32> [[TMP9]], i32 [[TMP16]], i64 1
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP22:%.*]] = load i32, ptr [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP23:%.*]] = insertelement <4 x i32> [[TMP14]], i32 [[TMP22]], i64 1
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE4]]
 ; CHECK:       pred.load.continue4:
-; CHECK-NEXT:    [[TMP18:%.*]] = phi <4 x i32> [ [[TMP8]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP14]], [[PRED_LOAD_IF3]] ]
 ; CHECK-NEXT:    [[TMP19:%.*]] = phi <4 x i32> [ [[TMP9]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP17]], [[PRED_LOAD_IF3]] ]
+; CHECK-NEXT:    [[TMP24:%.*]] = phi <4 x i32> [ [[TMP14]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP23]], [[PRED_LOAD_IF3]] ]
 ; CHECK-NEXT:    [[TMP20:%.*]] = extractelement <4 x i1> [[TMP0]], i64 2
 ; CHECK-NEXT:    br i1 [[TMP20]], label [[PRED_LOAD_IF5:%.*]], label [[PRED_LOAD_CONTINUE6:%.*]]
 ; CHECK:       pred.load.if5:
 ; CHECK-NEXT:    [[TMP21:%.*]] = or disjoint i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP22:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP21]]
-; CHECK-NEXT:    [[TMP23:%.*]] = load i32, ptr [[TMP22]], align 4
-; CHECK-NEXT:    [[TMP24:%.*]] = insertelement <4 x i32> [[TMP18]], i32 [[TMP23]], i64 2
 ; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP21]]
 ; CHECK-NEXT:    [[TMP26:%.*]] = load i32, ptr [[TMP25]], align 4
 ; CHECK-NEXT:    [[TMP27:%.*]] = insertelement <4 x i32> [[TMP19]], i32 [[TMP26]], i64 2
+; CHECK-NEXT:    [[TMP28:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP21]]
+; CHECK-NEXT:    [[TMP32:%.*]] = load i32, ptr [[TMP28]], align 4
+; CHECK-NEXT:    [[TMP33:%.*]] = insertelement <4 x i32> [[TMP24]], i32 [[TMP32]], i64 2
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE6]]
 ; CHECK:       pred.load.continue6:
-; CHECK-NEXT:    [[TMP28:%.*]] = phi <4 x i32> [ [[TMP18]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP24]], [[PRED_LOAD_IF5]] ]
 ; CHECK-NEXT:    [[TMP29:%.*]] = phi <4 x i32> [ [[TMP19]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP27]], [[PRED_LOAD_IF5]] ]
+; CHECK-NEXT:    [[TMP34:%.*]] = phi <4 x i32> [ [[TMP24]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP33]], [[PRED_LOAD_IF5]] ]
 ; CHECK-NEXT:    [[TMP30:%.*]] = extractelement <4 x i1> [[TMP0]], i64 3
 ; CHECK-NEXT:    br i1 [[TMP30]], label [[PRED_LOAD_IF7:%.*]], label [[PRED_LOAD_CONTINUE8]]
 ; CHECK:       pred.load.if7:
 ; CHECK-NEXT:    [[TMP31:%.*]] = or disjoint i64 [[INDEX]], 3
-; CHECK-NEXT:    [[TMP32:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP31]]
-; CHECK-NEXT:    [[TMP33:%.*]] = load i32, ptr [[TMP32]], align 4
-; CHECK-NEXT:    [[TMP34:%.*]] = insertelement <4 x i32> [[TMP28]], i32 [[TMP33]], i64 3
 ; CHECK-NEXT:    [[TMP35:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP31]]
 ; CHECK-NEXT:    [[TMP36:%.*]] = load i32, ptr [[TMP35]], align 4
 ; CHECK-NEXT:    [[TMP37:%.*]] = insertelement <4 x i32> [[TMP29]], i32 [[TMP36]], i64 3
+; CHECK-NEXT:    [[TMP38:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP31]]
+; CHECK-NEXT:    [[TMP48:%.*]] = load i32, ptr [[TMP38]], align 4
+; CHECK-NEXT:    [[TMP49:%.*]] = insertelement <4 x i32> [[TMP34]], i32 [[TMP48]], i64 3
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE8]]
 ; CHECK:       pred.load.continue8:
-; CHECK-NEXT:    [[TMP38:%.*]] = phi <4 x i32> [ [[TMP28]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP34]], [[PRED_LOAD_IF7]] ]
 ; CHECK-NEXT:    [[TMP39:%.*]] = phi <4 x i32> [ [[TMP29]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP37]], [[PRED_LOAD_IF7]] ]
-; CHECK-NEXT:    [[TMP40:%.*]] = mul nsw <4 x i32> [[TMP39]], [[TMP38]]
+; CHECK-NEXT:    [[TMP50:%.*]] = phi <4 x i32> [ [[TMP34]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP49]], [[PRED_LOAD_IF7]] ]
 ; CHECK-NEXT:    [[TMP41:%.*]] = select <4 x i1> [[TMP0]], <4 x i32> [[VEC_IND1]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP42:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP41]])
 ; CHECK-NEXT:    [[TMP43:%.*]] = add i32 [[TMP42]], [[VEC_PHI]]
+; CHECK-NEXT:    [[TMP40:%.*]] = mul nsw <4 x i32> [[TMP50]], [[TMP39]]
 ; CHECK-NEXT:    [[TMP44:%.*]] = select <4 x i1> [[TMP0]], <4 x i32> [[TMP40]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP45:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP44]])
 ; CHECK-NEXT:    [[TMP46]] = add i32 [[TMP45]], [[TMP43]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index fe74a7c3a9b27c..b578e61d85dfa1 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -221,13 +221,13 @@ define i32 @reduction_mix(ptr noalias nocapture %A, ptr noalias nocapture %B) {
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP6:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_IND:%.*]] = phi <4 x i32> [ <i32 0, i32 1, i32 2, i32 3>, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i32>, ptr [[TMP1]], align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[B1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <4 x i32>, ptr [[TMP8]], align 4
 ; CHECK-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[VEC_IND]])
 ; CHECK-NEXT:    [[TMP4:%.*]] = add i32 [[TMP3]], [[VEC_PHI]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD2]], [[WIDE_LOAD1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP2]])
 ; CHECK-NEXT:    [[TMP6]] = add i32 [[TMP5]], [[TMP4]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
@@ -329,11 +329,11 @@ define i32 @start_at_non_zero(ptr nocapture %in, ptr nocapture %coeff, ptr nocap
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 120, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[IN:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[COEFF:%.*]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i32>, ptr [[TMP1]], align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds i32, ptr [[COEFF1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <4 x i32>, ptr [[TMP6]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD2]], [[WIDE_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP2]])
 ; CHECK-NEXT:    [[TMP4]] = add i32 [[TMP3]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4

>From 8231ac83a9a218e7f0a8b161dd9c334453fe6139 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 5 Nov 2024 17:49:26 -0800
Subject: [PATCH 5/6] Refactors

Using lamda function to early return when pattern matched.
Leave some assertions.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 121 ++++++++---------
 llvm/lib/Transforms/Vectorize/VPlan.h         |  55 ++++----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 123 +-----------------
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  19 +--
 .../LoopVectorize/ARM/mve-reductions.ll       |   3 +-
 .../LoopVectorize/RISCV/inloop-reduction.ll   |   8 +-
 6 files changed, 113 insertions(+), 216 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 8bdcdae09a62a6..de482bb32bcc71 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7397,7 +7397,7 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       // VPExtendedReductionRecipe contains a folded extend instruction.
       if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
         SeenInstrs.insert(ExtendedRed->getExtInstr());
-      // VPMulAccRecupe constians a mul and otional extend instructions.
+      // VPMulAccRecipe constians a mul and otional extend instructions.
       else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
         SeenInstrs.insert(MulAcc->getMulInstr());
         if (MulAcc->isExtended()) {
@@ -9388,77 +9388,82 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      VPValue *A, *B;
-      VPSingleDefRecipe *RedRecipe;
-      // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
-      if (RdxDesc.getOpcode() == Instruction::Add &&
-          match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
-        VPRecipeBase *RecipeA = A->getDefiningRecipe();
-        VPRecipeBase *RecipeB = B->getDefiningRecipe();
-        if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
-            match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
-            cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
-                cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
-            !A->hasMoreThanOneUniqueUser() && !B->hasMoreThanOneUniqueUser()) {
-          RedRecipe = new VPMulAccRecipe(
-              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-              CM.useOrderedReductions(RdxDesc),
-              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
-              cast<VPWidenCastRecipe>(RecipeA),
-              cast<VPWidenCastRecipe>(RecipeB));
-        } else {
-          RedRecipe = new VPMulAccRecipe(
-              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-              CM.useOrderedReductions(RdxDesc),
-              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
-        }
-      } else if (RdxDesc.getOpcode() == Instruction::Add &&
-                 match(VecOp,
-                       m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue(A)),
-                                          m_ZExtOrSExt(m_VPValue(B)))))) {
-        VPWidenCastRecipe *Ext =
-            dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
-        VPWidenRecipe *Mul =
-            dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
-        if (Mul && match(Mul, m_Mul(m_ZExtOrSExt(m_VPValue()),
-                                    m_ZExtOrSExt(m_VPValue())))) {
+      auto TryToMatchMulAcc = [&]() -> VPSingleDefRecipe * {
+        VPValue *A, *B;
+        if (RdxDesc.getOpcode() != Instruction::Add)
+          return nullptr;
+        // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+        if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
+            !VecOp->hasMoreThanOneUniqueUser()) {
+          VPRecipeBase *RecipeA = A->getDefiningRecipe();
+          VPRecipeBase *RecipeB = B->getDefiningRecipe();
+          if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+              match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+              cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
+                  cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
+              !A->hasMoreThanOneUniqueUser() &&
+              !B->hasMoreThanOneUniqueUser()) {
+            return new VPMulAccRecipe(
+                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
+                cast<VPWidenCastRecipe>(RecipeA),
+                cast<VPWidenCastRecipe>(RecipeB));
+          } else {
+            // Matched reduce.add(mul(...))
+            return new VPMulAccRecipe(
+                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+          }
+          // Matched reduce.add(ext(mul(ext, ext)))
+          // Note that 3 extend instructions must have same opcode.
+        } else if (match(VecOp,
+                         m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
+                                            m_ZExtOrSExt(m_VPValue())))) &&
+                   !VecOp->hasMoreThanOneUniqueUser()) {
+          VPWidenCastRecipe *Ext =
+              dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
           VPWidenRecipe *Mul =
-              cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+              dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext0 =
               cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext1 =
               cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
           if (Ext->getOpcode() == Ext0->getOpcode() &&
-              Ext0->getOpcode() == Ext1->getOpcode()) {
-            RedRecipe = new VPMulAccRecipe(
+              Ext0->getOpcode() == Ext1->getOpcode() &&
+              !Mul->hasMoreThanOneUniqueUser() &&
+              !Ext0->hasMoreThanOneUniqueUser() &&
+              !Ext1->hasMoreThanOneUniqueUser()) {
+            return new VPMulAccRecipe(
                 RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                 CM.useOrderedReductions(RdxDesc),
                 cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
                 cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
-          } else
-            RedRecipe = new VPExtendedReductionRecipe(
-                RdxDesc, CurrentLinkI,
-                cast<CastInst>(
-                    cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
-                PreviousLink, cast<VPWidenCastRecipe>(VecOp)->getOperand(0),
-                CondOp, CM.useOrderedReductions(RdxDesc),
-                cast<VPWidenCastRecipe>(VecOp)->getResultType());
+          }
         }
-      }
-      // VPWidenCastRecipes can folded into VPReductionRecipe
-      else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
-               !VecOp->hasMoreThanOneUniqueUser()) {
-        RedRecipe = new VPExtendedReductionRecipe(
-            RdxDesc, CurrentLinkI,
-            cast<CastInst>(
-                cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
-            PreviousLink, A, CondOp, CM.useOrderedReductions(RdxDesc),
-            cast<VPWidenCastRecipe>(VecOp)->getResultType());
-      } else {
+        return nullptr;
+      };
+      auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
+        VPValue *A;
+        if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+            !VecOp->hasMoreThanOneUniqueUser()) {
+          return new VPExtendedReductionRecipe(
+              RdxDesc, CurrentLinkI, PreviousLink,
+              cast<VPWidenCastRecipe>(VecOp), CondOp,
+              CM.useOrderedReductions(RdxDesc));
+        }
+        return nullptr;
+      };
+      VPSingleDefRecipe *RedRecipe;
+      if (auto *MulAcc = TryToMatchMulAcc())
+        RedRecipe = MulAcc;
+      else if (auto *ExtendedRed = TryToMatchExtendedReduction())
+        RedRecipe = ExtendedRed;
+      else
         RedRecipe =
             new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
                                   CondOp, CM.useOrderedReductions(RdxDesc));
-      }
       // Append the recipe to the end of the VPBasicBlock because we need to
       // ensure that it comes after all of it's inputs, including CondOp.
       // Note that this transformation may leave over dead recipes (including
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 6102399e351bdd..5cf9c0385fb485 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2670,18 +2670,19 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
   bool IsConditional = false;
   /// Type after extend.
   Type *ResultTy;
+  /// Opcode for the extend instruction.
   Instruction::CastOps ExtOp;
-  CastInst *CastInstr;
+  CastInst *ExtInstr;
   bool IsZExt;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
                             const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, CastInst *CastI,
+                            Instruction::CastOps ExtOp, CastInst *ExtI,
                             ArrayRef<VPValue *> Operands, VPValue *CondOp,
                             bool IsOrdered, Type *ResultTy)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
-        ResultTy(ResultTy), ExtOp(ExtOp), CastInstr(CastI) {
+        ResultTy(ResultTy), ExtOp(ExtOp), ExtInstr(ExtI) {
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2691,20 +2692,13 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
-                            CastInst *CastI, VPValue *ChainOp, VPValue *VecOp,
-                            VPValue *CondOp, bool IsOrdered, Type *ResultTy)
-      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
-                                  CastI->getOpcode(), CastI,
-                                  ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
-                                  IsOrdered, ResultTy) {}
-
-  VPExtendedReductionRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *Ext)
+                            VPValue *ChainOp, VPWidenCastRecipe *Ext,
+                            VPValue *CondOp, bool IsOrdered)
       : VPExtendedReductionRecipe(
-            VPDef::VPExtendedReductionSC, Red->getRecurrenceDescriptor(),
-            Red->getUnderlyingInstr(), Ext->getOpcode(),
+            VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
             cast<CastInst>(Ext->getUnderlyingInstr()),
-            ArrayRef<VPValue *>({Red->getChainOp(), Ext->getOperand(0)}),
-            Red->getCondOp(), Red->isOrdered(), Ext->getResultType()) {}
+            ArrayRef<VPValue *>({ChainOp, Ext->getOperand(0)}), CondOp,
+            IsOrdered, Ext->getResultType()) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2721,7 +2715,6 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
     return R && classof(R);
   }
 
-  /// Generate the reduction in the loop
   void execute(VPTransformState &State) override {
     llvm_unreachable("VPExtendedReductionRecipe should be transform to "
                      "VPExtendedRecipe + VPReductionRecipe before execution.");
@@ -2753,9 +2746,12 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
   VPValue *getCondOp() const {
     return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
   }
+  /// The Type after extended.
   Type *getResultType() const { return ResultTy; };
+  /// The Opcode of extend instruction.
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
-  CastInst *getExtInstr() const { return CastInstr; };
+  /// The CastInst of the extend instruction.
+  CastInst *getExtInstr() const { return ExtInstr; };
 };
 
 /// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2771,16 +2767,17 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   bool IsConditional = false;
   /// Type after extend.
   Type *ResultType;
-  /// reduce.add(ext((mul(Ext(), Ext())))
+  // Note that all extend instruction must have the same opcode in MulAcc.
   Instruction::CastOps ExtOp;
 
+  /// reduce.add(ext(mul(ext0(), ext1())))
   Instruction *MulInstr;
   CastInst *ExtInstr = nullptr;
-  CastInst *Ext0Instr;
-  CastInst *Ext1Instr;
+  CastInst *Ext0Instr = nullptr;
+  CastInst *Ext1Instr = nullptr;
 
+  /// Is this MulAcc recipe contains extend recipes?
   bool IsExtended;
-  bool IsOuterExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2794,6 +2791,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
         ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2806,6 +2804,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
         MulInstr(MulInstr) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2857,13 +2856,12 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
     return R && classof(R);
   }
 
-  /// Generate the reduction in the loop
   void execute(VPTransformState &State) override {
     llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
                      "VPWidenRecipe + VPReductionRecipe before execution");
   }
 
-  /// Return the cost of VPExtendedReductionRecipe.
+  /// Return the cost of VPMulAccRecipe.
   InstructionCost computeCost(ElementCount VF,
                               VPCostContext &Ctx) const override;
 
@@ -2890,13 +2888,24 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   VPValue *getCondOp() const {
     return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
   }
+  /// Return the type after inner extended, which must equal to the type of mul
+  /// instruction. If the ResultType != recurrenceType, than it must have a
+  /// extend recipe after mul recipe.
   Type *getResultType() const { return ResultType; };
+  /// The opcode of the extend instructions.
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  /// The underlying instruction for VPWidenRecipe.
   Instruction *getMulInstr() const { return MulInstr; };
+  /// The underlying Instruction for outer VPWidenCastRecipe.
   CastInst *getExtInstr() const { return ExtInstr; };
+  /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt0Instr() const { return Ext0Instr; };
+  /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt1Instr() const { return Ext1Instr; };
+  /// Return if this MulAcc recipe contains extend instructions.
   bool isExtended() const { return IsExtended; };
+  /// Return if the operands of mul instruction come from same extend.
+  bool isSameExtend() const { return Ext0Instr == Ext1Instr; };
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index f9cba35837d2ea..da3f15b996a42f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2207,122 +2207,6 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
         Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
-  /*
-  using namespace llvm::VPlanPatternMatch;
-  auto GetMulAccReductionCost =
-      [&](const VPReductionRecipe *Red) -> InstructionCost {
-    VPValue *A, *B;
-    InstructionCost InnerExt0Cost = 0;
-    InstructionCost InnerExt1Cost = 0;
-    InstructionCost ExtCost = 0;
-    InstructionCost MulCost = 0;
-
-    VectorType *SrcVecTy = VectorTy;
-    Type *InnerExt0Ty;
-    Type *InnerExt1Ty;
-    Type *MaxInnerExtTy;
-    bool IsUnsigned = true;
-    bool HasOuterExt = false;
-
-    auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
-        Red->getVecOp()->getDefiningRecipe());
-    VPRecipeBase *Mul;
-    // Try to match outer extend reduce.add(ext(...))
-    if (Ext && match(Ext, m_ZExtOrSExt(m_VPValue())) &&
-        cast<VPWidenCastRecipe>(Ext)->getNumUsers() == 1) {
-      IsUnsigned =
-          Ext->getOpcode() == Instruction::CastOps::ZExt ? true : false;
-      ExtCost = Ext->computeCost(VF, Ctx);
-      Mul = Ext->getOperand(0)->getDefiningRecipe();
-      HasOuterExt = true;
-    } else {
-      Mul = Red->getVecOp()->getDefiningRecipe();
-    }
-
-    // Match reduce.add(mul())
-    if (Mul && match(Mul, m_Mul(m_VPValue(A), m_VPValue(B))) &&
-        cast<VPWidenRecipe>(Mul)->getNumUsers() == 1) {
-      MulCost = cast<VPWidenRecipe>(Mul)->computeCost(VF, Ctx);
-      auto *InnerExt0 =
-          dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
-      auto *InnerExt1 =
-          dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
-      bool HasInnerExt = false;
-      // Try to match inner extends.
-      if (InnerExt0 && InnerExt1 &&
-          match(InnerExt0, m_ZExtOrSExt(m_VPValue())) &&
-          match(InnerExt1, m_ZExtOrSExt(m_VPValue())) &&
-          InnerExt0->getOpcode() == InnerExt1->getOpcode() &&
-          (InnerExt0->getNumUsers() > 0 &&
-           !InnerExt0->hasMoreThanOneUniqueUser()) &&
-          (InnerExt1->getNumUsers() > 0 &&
-           !InnerExt1->hasMoreThanOneUniqueUser())) {
-        InnerExt0Cost = InnerExt0->computeCost(VF, Ctx);
-        InnerExt1Cost = InnerExt1->computeCost(VF, Ctx);
-        Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
-        Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
-        Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth() >
-                                      InnerExt1Ty->getIntegerBitWidth()
-                                  ? InnerExt0Ty
-                                  : InnerExt1Ty;
-        SrcVecTy = cast<VectorType>(ToVectorTy(MaxInnerExtTy, VF));
-        IsUnsigned = true;
-        HasInnerExt = true;
-      }
-      InstructionCost MulAccRedCost = Ctx.TTI.getMulAccReductionCost(
-          IsUnsigned, ElementTy, SrcVecTy, CostKind);
-      // Check if folding ext/mul into MulAccReduction is profitable.
-      if (MulAccRedCost.isValid() &&
-          MulAccRedCost <
-              ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
-        if (HasInnerExt) {
-          Ctx.FoldedRecipes[VF].insert(InnerExt0);
-          Ctx.FoldedRecipes[VF].insert(InnerExt1);
-        }
-        Ctx.FoldedRecipes[VF].insert(Mul);
-        if (HasOuterExt)
-          Ctx.FoldedRecipes[VF].insert(Ext);
-        return MulAccRedCost;
-      }
-    }
-    return InstructionCost::getInvalid();
-  };
-
-  // Match reduce(ext(...))
-  auto GetExtendedReductionCost =
-      [&](const VPReductionRecipe *Red) -> InstructionCost {
-    VPValue *VecOp = Red->getVecOp();
-    VPValue *A;
-    if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && VecOp->getNumUsers() == 1) {
-      VPWidenCastRecipe *Ext =
-          cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
-      bool IsUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
-      InstructionCost ExtCost = Ext->computeCost(VF, Ctx);
-      auto *ExtVecTy =
-          cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(A), VF));
-      InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
-          Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags(),
-          CostKind);
-      // Check if folding ext into ExtendedReduction is profitable.
-      if (ExtendedRedCost.isValid() && ExtendedRedCost < ExtCost + BaseCost) {
-        Ctx.FoldedRecipes[VF].insert(Ext);
-        return ExtendedRedCost;
-      }
-    }
-    return InstructionCost::getInvalid();
-  };
-
-  // Match MulAccReduction patterns.
-  InstructionCost MulAccCost = GetMulAccReductionCost(this);
-  if (MulAccCost.isValid())
-    return MulAccCost;
-
-  // Match ExtendedReduction patterns.
-  InstructionCost ExtendedCost = GetExtendedReductionCost(this);
-  if (ExtendedCost.isValid())
-    return ExtendedCost;
-  */
-
   // Default cost.
   return BaseCost;
 }
@@ -2336,9 +2220,6 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
-         "Inferred type and recurrence type mismatch.");
-
   // BaseCost = Reduction cost + BinOp cost
   InstructionCost ReductionCost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
@@ -2380,8 +2261,8 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
-         "Inferred type and recurrence type mismatch.");
+  assert(Opcode == Instruction::Add &&
+         "Reduction opcode must be add in the VPMulAccRecipe.");
 
   // BaseCost = Reduction cost + BinOp cost
   InstructionCost ReductionCost =
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index ddb3b0ae001132..8bb70f44ab6845 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -547,11 +547,15 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
           Op0 = new VPWidenCastRecipe(
               MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
               MulAcc->getResultType(), *MulAcc->getExt0Instr());
-          Op1 = new VPWidenCastRecipe(
-              MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
-              MulAcc->getResultType(), *MulAcc->getExt1Instr());
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
-          Op1->getDefiningRecipe()->insertBefore(MulAcc);
+          if (!MulAcc->isSameExtend()) {
+            Op1 = new VPWidenCastRecipe(
+                MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+                MulAcc->getResultType(), *MulAcc->getExt1Instr());
+            Op1->getDefiningRecipe()->insertBefore(MulAcc);
+          } else {
+            Op1 = Op0;
+          }
         } else {
           Op0 = MulAcc->getVecOp0();
           Op1 = MulAcc->getVecOp1();
@@ -561,14 +565,13 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
         auto *Mul = new VPWidenRecipe(*MulInstr,
                                       make_range(MulOps.begin(), MulOps.end()));
-        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
-          // dbgs() <<"\n!!!"<< *OuterExtInstr << " " << MulAcc->getExtOpcode()
-          // << "\n";
+        // Outer extend.
+        if (auto *OuterExtInstr = MulAcc->getExtInstr())
           VecOp = new VPWidenCastRecipe(
               MulAcc->getExtOpcode(), Mul,
               MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
               *OuterExtInstr);
-        } else
+        else
           VecOp = Mul;
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index ea6519c8c13b93..1bee855511d128 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1412,9 +1412,8 @@ define i32 @mla_i8_i32_multiuse(ptr nocapture readonly %x, ptr nocapture readonl
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP7]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP7]], [[TMP7]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP2]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
index 9f1a61ebb5efef..9af31c9a4762b3 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
@@ -215,13 +215,13 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
 ; IF-EVL-INLOOP-NEXT:    [[TMP12]] = add i32 [[TMP11]], [[VEC_PHI]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_EVL_NEXT]] = add i32 [[TMP6]], [[EVL_BASED_IV]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_NEXT]] = add i32 [[INDEX]], [[TMP4]]
-; IF-EVL-INLOOP-NEXT:    [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; IF-EVL-INLOOP-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; IF-EVL-INLOOP-NEXT:    [[TMP19:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
+; IF-EVL-INLOOP-NEXT:    br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; IF-EVL-INLOOP:       middle.block:
 ; IF-EVL-INLOOP-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ; IF-EVL-INLOOP:       scalar.ph:
 ; IF-EVL-INLOOP-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[FOR_BODY_PREHEADER]] ]
-; IF-EVL-INLOOP-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP11]], [[MIDDLE_BLOCK]] ], [ 0, [[FOR_BODY_PREHEADER]] ]
+; IF-EVL-INLOOP-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP12]], [[MIDDLE_BLOCK]] ], [ 0, [[FOR_BODY_PREHEADER]] ]
 ; IF-EVL-INLOOP-NEXT:    br label [[FOR_BODY:%.*]]
 ; IF-EVL-INLOOP:       for.body:
 ; IF-EVL-INLOOP-NEXT:    [[I_08:%.*]] = phi i32 [ [[INC:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
@@ -234,7 +234,7 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
 ; IF-EVL-INLOOP-NEXT:    [[EXITCOND:%.*]] = icmp eq i32 [[INC]], [[N]]
 ; IF-EVL-INLOOP-NEXT:    br i1 [[EXITCOND]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
 ; IF-EVL-INLOOP:       for.cond.cleanup.loopexit:
-; IF-EVL-INLOOP-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD]], [[FOR_BODY]] ], [ [[TMP11]], [[MIDDLE_BLOCK]] ]
+; IF-EVL-INLOOP-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD]], [[FOR_BODY]] ], [ [[TMP12]], [[MIDDLE_BLOCK]] ]
 ; IF-EVL-INLOOP-NEXT:    br label [[FOR_COND_CLEANUP]]
 ; IF-EVL-INLOOP:       for.cond.cleanup:
 ; IF-EVL-INLOOP-NEXT:    [[R_0_LCSSA:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[ADD_LCSSA]], [[FOR_COND_CLEANUP_LOOPEXIT]] ]

>From bcccb133b476b4162fc1e26b6929673b77fa0c5f Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Wed, 6 Nov 2024 21:07:04 -0800
Subject: [PATCH 6/6] Fix typos and update printing test

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |   3 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |   7 +-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  18 +-
 .../Transforms/Vectorize/VPlanTransforms.h    |   2 +-
 .../LoopVectorize/vplan-printing.ll           | 236 ++++++++++++++++++
 5 files changed, 254 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index de482bb32bcc71..118adf61ef19a6 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7673,7 +7673,7 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
                              ILV.getOrCreateVectorTripCount(nullptr),
                              CanonicalIVStartValue, State);
 
-  // TODO: Rebase to fhahn's implementation.
+  // TODO: Replace with upstream implementation.
   VPlanTransforms::prepareExecute(BestVPlan);
   BestVPlan.execute(&State);
 
@@ -9269,7 +9269,6 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
 // Adjust AnyOf reductions; replace the reduction phi for the selected value
 // with a boolean reduction phi node to check if the condition is true in any
 // iteration. The final value is selected by the final ComputeReductionResult.
-// TODO: Implement VPMulAccHere.
 void LoopVectorizationPlanner::adjustRecipesForReductions(
     VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) {
   using namespace VPlanPatternMatch;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 5cf9c0385fb485..fb02817423e9de 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2757,8 +2757,8 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
 /// A recipe to represent inloop MulAccreduction operations, performing a
 /// reduction on a vector operand into a scalar value, and adding the result to
 /// a chain. This recipe is high level abstract which will generate
-/// VPReductionRecipe VPWidenRecipe(mul)and VPWidenCastRecipe before execution.
-/// The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+/// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
+/// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPSingleDefRecipe {
   /// The recurrence decriptor for the reduction in question.
   const RecurrenceDescriptor &RdxDesc;
@@ -2778,6 +2778,8 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
 
   /// Is this MulAcc recipe contains extend recipes?
   bool IsExtended;
+  /// Is this reciep contains outer extend instuction?
+  bool IsOuterExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2797,6 +2799,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
       addOperand(CondOp);
     }
     IsExtended = true;
+    IsOuterExtended = ExtInstr != nullptr;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index da3f15b996a42f..c1acec5ee5fc53 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,7 +270,7 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = &WidenMem->getIngredient();
 
   InstructionCost RecipeCost;
-  if ((UI && Ctx.skipCostComputation(UI, VF.isVector()))) {
+  if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
     RecipeCost = 0;
   } else {
     RecipeCost = computeCost(VF, Ctx);
@@ -2406,18 +2406,20 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   printAsOperand(O, SlotTracker);
   O << " = ";
   getChainOp()->printAsOperand(O, SlotTracker);
-  O << " +";
+  O << " + ";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
-  O << " mul ";
+  if (IsOuterExtended)
+    O << " (";
+  O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  O << "mul ";
   if (IsExtended)
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
   if (IsExtended)
-    O << " extended to " << *getResultType() << ")";
-  if (IsExtended)
-    O << "(";
+    O << " extended to " << *getResultType() << "), (";
+  else
+    O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
   if (IsExtended)
     O << " extended to " << *getResultType() << ")";
@@ -2426,6 +2428,8 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
+  if (IsOuterExtended)
+    O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
   if (RdxDesc.IntermediateStore)
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 6310c23b605da3..40ed91104d566e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -124,7 +124,7 @@ struct VPlanTransforms {
   /// Remove dead recipes from \p Plan.
   static void removeDeadRecipes(VPlan &Plan);
 
-  /// TODO: Rebase to fhahn's implementation.
+  /// TODO: Rebase to upstream implementation.
   static void prepareExecute(VPlan &Plan);
 };
 
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 6bb20a301e0ade..a8fb374a4b162d 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -1143,6 +1143,242 @@ exit:
   ret i16 %for.1
 }
 
+define i64 @print_extended_reduction(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_extended_reduction'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:     WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT:     EXTENDED-REDUCE ir<%add> = ir<%r.09> + reduce.add (ir<%load0> extended to i64)
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%6> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT:   EMIT vp<%7> = extract-from-end vp<%6>, ir<1>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%7>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%6>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i32, ptr %x, i32 %i.010
+; CHECK-NEXT:   IR   %load0 = load i32, ptr %arrayidx, align 4
+; CHECK-NEXT:   IR   %conv0 = zext i32 %load0 to i64
+; CHECK-NEXT:   IR   %add = add nsw i64 %r.09, %conv0
+; CHECK-NEXT:   IR   %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:   IR   %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+  %cmp8 = icmp sgt i32 %n, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i32, ptr %x, i32 %i.010
+  %load0 = load i32, ptr %arrayidx, align 4
+  %conv0 = zext i32 %load0 to i64
+  %add = add nsw i64 %r.09, %conv0
+  %inc = add nuw nsw i32 %i.010, 1
+  %exitcond = icmp eq i32 %inc, %n
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %r.0.lcssa
+}
+
+define i64 @print_mulacc(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_mulacc'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:     WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT:     CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
+; CHECK-NEXT:     vp<%5> = vector-pointer ir<%arrayidx1>
+; CHECK-NEXT:     WIDEN ir<%load1> = load vp<%5>
+; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> + reduce.add (mul ir<%load0>, ir<%load1>)
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%7> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT:   EMIT vp<%8> = extract-from-end vp<%7>, ir<1>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%7>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i64, ptr %x, i32 %i.010
+; CHECK-NEXT:   IR   %load0 = load i64, ptr %arrayidx, align 4
+; CHECK-NEXT:   IR   %arrayidx1 = getelementptr inbounds i64, ptr %y, i32 %i.010
+; CHECK-NEXT:   IR   %load1 = load i64, ptr %arrayidx1, align 4
+; CHECK-NEXT:   IR   %mul = mul nsw i64 %load0, %load1
+; CHECK-NEXT:   IR   %add = add nsw i64 %r.09, %mul
+; CHECK-NEXT:   IR   %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:   IR   %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+  %cmp8 = icmp sgt i32 %n, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i64, ptr %x, i32 %i.010
+  %load0 = load i64, ptr %arrayidx, align 4
+  %arrayidx1 = getelementptr inbounds i64, ptr %y, i32 %i.010
+  %load1 = load i64, ptr %arrayidx1, align 4
+  %mul = mul nsw i64 %load0, %load1
+  %add = add nsw i64 %r.09, %mul
+  %inc = add nuw nsw i32 %i.010, 1
+  %exitcond = icmp eq i32 %inc, %n
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %r.0.lcssa
+}
+
+define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_mulacc_extended'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:     WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT:     CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
+; CHECK-NEXT:     vp<%5> = vector-pointer ir<%arrayidx1>
+; CHECK-NEXT:     WIDEN ir<%load1> = load vp<%5>
+; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> +  (reduce.add (mul (ir<%load0> extended to i32), (ir<%load1> extended to i32)) extended to i64)
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%7> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT:   EMIT vp<%8> = extract-from-end vp<%7>, ir<1>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%7>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i16, ptr %x, i32 %i.010
+; CHECK-NEXT:   IR   %load0 = load i16, ptr %arrayidx, align 4
+; CHECK-NEXT:   IR   %arrayidx1 = getelementptr inbounds i16, ptr %y, i32 %i.010
+; CHECK-NEXT:   IR   %load1 = load i16, ptr %arrayidx1, align 4
+; CHECK-NEXT:   IR   %conv0 = sext i16 %load0 to i32
+; CHECK-NEXT:   IR   %conv1 = sext i16 %load1 to i32
+; CHECK-NEXT:   IR   %mul = mul nsw i32 %conv0, %conv1
+; CHECK-NEXT:   IR   %conv = sext i32 %mul to i64
+; CHECK-NEXT:   IR   %add = add nsw i64 %r.09, %conv
+; CHECK-NEXT:   IR   %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:   IR   %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+  %cmp8 = icmp sgt i32 %n, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i16, ptr %x, i32 %i.010
+  %load0 = load i16, ptr %arrayidx, align 4
+  %arrayidx1 = getelementptr inbounds i16, ptr %y, i32 %i.010
+  %load1 = load i16, ptr %arrayidx1, align 4
+  %conv0 = sext i16 %load0 to i32
+  %conv1 = sext i16 %load1 to i32
+  %mul = mul nsw i32 %conv0, %conv1
+  %conv = sext i32 %mul to i64
+  %add = add nsw i64 %r.09, %conv
+  %inc = add nuw nsw i32 %i.010, 1
+  %exitcond = icmp eq i32 %inc, %n
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %r.0.lcssa
+}
+
 !llvm.dbg.cu = !{!0}
 !llvm.module.flags = !{!3, !4}
 



More information about the llvm-commits mailing list