[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)

Elvis Wang via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 11 23:18:44 PST 2024


https://github.com/ElvisWang123 updated https://github.com/llvm/llvm-project/pull/113903

>From ea25db2ce086df7dd850d3ed992cee13f2295ad0 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 28 Oct 2024 05:39:35 -0700
Subject: [PATCH 1/9] [VPlan] Impl VPlan-based pattern match for ExtendedRed
 and MulAccRed. NFCI

This patch implement the VPlan-based pattern match for extendedReduction
and MulAccReduction. In above reduction patterns, extened instructions
and mul instruction can fold into reduction instruction and the cost is
free.

We add `FoldedRecipes` in the `VPCostContext` to put recipes that can be
folded into other recipes.

ExtendedReductionPatterns:
    reduce(ext(...))
MulAccReductionPatterns:
    reduce.add(mul(...))
    reduce.add(mul(ext(...), ext(...)))
    reduce.add(ext(mul(...)))
    reduce.add(ext(mul(ext(...), ext(...))))

Ref: Original instruction based implementation:
https://reviews.llvm.org/D93476
---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  45 ------
 llvm/lib/Transforms/Vectorize/VPlan.h         |   2 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 139 ++++++++++++++++--
 3 files changed, 129 insertions(+), 57 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1ebc62f9843905..32fbde97bb8283 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7306,51 +7306,6 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
       Cost += ReductionCost;
       continue;
     }
-
-    const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop);
-    SetVector<Instruction *> ChainOpsAndOperands(ChainOps.begin(),
-                                                 ChainOps.end());
-    auto IsZExtOrSExt = [](const unsigned Opcode) -> bool {
-      return Opcode == Instruction::ZExt || Opcode == Instruction::SExt;
-    };
-    // Also include the operands of instructions in the chain, as the cost-model
-    // may mark extends as free.
-    //
-    // For ARM, some of the instruction can folded into the reducion
-    // instruction. So we need to mark all folded instructions free.
-    // For example: We can fold reduce(mul(ext(A), ext(B))) into one
-    // instruction.
-    for (auto *ChainOp : ChainOps) {
-      for (Value *Op : ChainOp->operands()) {
-        if (auto *I = dyn_cast<Instruction>(Op)) {
-          ChainOpsAndOperands.insert(I);
-          if (I->getOpcode() == Instruction::Mul) {
-            auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
-            auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
-            if (Ext0 && IsZExtOrSExt(Ext0->getOpcode()) && Ext1 &&
-                Ext0->getOpcode() == Ext1->getOpcode()) {
-              ChainOpsAndOperands.insert(Ext0);
-              ChainOpsAndOperands.insert(Ext1);
-            }
-          }
-        }
-      }
-    }
-
-    // Pre-compute the cost for I, if it has a reduction pattern cost.
-    for (Instruction *I : ChainOpsAndOperands) {
-      auto ReductionCost = CM.getReductionPatternCost(
-          I, VF, ToVectorTy(I->getType(), VF), TTI::TCK_RecipThroughput);
-      if (!ReductionCost)
-        continue;
-
-      assert(!CostCtx.SkipCostComputation.contains(I) &&
-             "reduction op visited multiple times");
-      CostCtx.SkipCostComputation.insert(I);
-      LLVM_DEBUG(dbgs() << "Cost of " << ReductionCost << " for VF " << VF
-                        << ":\n in-loop reduction " << *I << "\n");
-      Cost += *ReductionCost;
-    }
   }
 
   // Pre-compute the costs for branches except for the backedge, as the number
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index abfe97b4ab55b6..dd763de1e7a1ef 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -682,6 +682,8 @@ struct VPCostContext {
   LLVMContext &LLVMCtx;
   LoopVectorizationCostModel &CM;
   SmallPtrSet<Instruction *, 8> SkipCostComputation;
+  /// Contains recipes that are folded into other recipes.
+  SmallDenseMap<ElementCount, SmallPtrSet<VPRecipeBase *, 4>, 4> FoldedRecipes;
 
   VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,
                 Type *CanIVTy, LoopVectorizationCostModel &CM)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ef2ca9af7268d1..36898220c8e48a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,7 +270,9 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = &WidenMem->getIngredient();
 
   InstructionCost RecipeCost;
-  if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
+  if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
+      (Ctx.FoldedRecipes.contains(VF) &&
+       Ctx.FoldedRecipes.at(VF).contains(this))) {
     RecipeCost = 0;
   } else {
     RecipeCost = computeCost(VF, Ctx);
@@ -2187,30 +2189,143 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  // TODO: Support any-of and in-loop reductions.
+  // TODO: Support any-of reductions.
   assert(
       (!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
        ForceTargetInstructionCost.getNumOccurrences() > 0) &&
       "Any-of reduction not implemented in VPlan-based cost model currently.");
-  assert(
-      (!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
-       ForceTargetInstructionCost.getNumOccurrences() > 0) &&
-      "In-loop reduction not implemented in VPlan-based cost model currently.");
 
   assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
          "Inferred type and recurrence type mismatch.");
 
-  // Cost = Reduction cost + BinOp cost
-  InstructionCost Cost =
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost BaseCost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    return Cost + Ctx.TTI.getMinMaxReductionCost(
-                      Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+    BaseCost += Ctx.TTI.getMinMaxReductionCost(
+        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    BaseCost += Ctx.TTI.getArithmeticReductionCost(
+        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
-  return Cost + Ctx.TTI.getArithmeticReductionCost(
-                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  using namespace llvm::VPlanPatternMatch;
+  auto GetMulAccReductionCost =
+      [&](const VPReductionRecipe *Red) -> InstructionCost {
+    VPValue *A, *B;
+    InstructionCost InnerExt0Cost = 0;
+    InstructionCost InnerExt1Cost = 0;
+    InstructionCost ExtCost = 0;
+    InstructionCost MulCost = 0;
+
+    VectorType *SrcVecTy = VectorTy;
+    Type *InnerExt0Ty;
+    Type *InnerExt1Ty;
+    Type *MaxInnerExtTy;
+    bool IsUnsigned = true;
+    bool HasOuterExt = false;
+
+    auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
+        Red->getVecOp()->getDefiningRecipe());
+    VPRecipeBase *Mul;
+    // Try to match outer extend reduce.add(ext(...))
+    if (Ext && match(Ext, m_ZExtOrSExt(m_VPValue())) &&
+        cast<VPWidenCastRecipe>(Ext)->getNumUsers() == 1) {
+      IsUnsigned =
+          Ext->getOpcode() == Instruction::CastOps::ZExt ? true : false;
+      ExtCost = Ext->computeCost(VF, Ctx);
+      Mul = Ext->getOperand(0)->getDefiningRecipe();
+      HasOuterExt = true;
+    } else {
+      Mul = Red->getVecOp()->getDefiningRecipe();
+    }
+
+    // Match reduce.add(mul())
+    if (Mul && match(Mul, m_Mul(m_VPValue(A), m_VPValue(B))) &&
+        cast<VPWidenRecipe>(Mul)->getNumUsers() == 1) {
+      MulCost = cast<VPWidenRecipe>(Mul)->computeCost(VF, Ctx);
+      auto *InnerExt0 =
+          dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+      auto *InnerExt1 =
+          dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+      bool HasInnerExt = false;
+      // Try to match inner extends.
+      if (InnerExt0 && InnerExt1 &&
+          match(InnerExt0, m_ZExtOrSExt(m_VPValue())) &&
+          match(InnerExt1, m_ZExtOrSExt(m_VPValue())) &&
+          InnerExt0->getOpcode() == InnerExt1->getOpcode() &&
+          (InnerExt0->getNumUsers() > 0 &&
+           !InnerExt0->hasMoreThanOneUniqueUser()) &&
+          (InnerExt1->getNumUsers() > 0 &&
+           !InnerExt1->hasMoreThanOneUniqueUser())) {
+        InnerExt0Cost = InnerExt0->computeCost(VF, Ctx);
+        InnerExt1Cost = InnerExt1->computeCost(VF, Ctx);
+        Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
+        Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
+        Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth() >
+                                      InnerExt1Ty->getIntegerBitWidth()
+                                  ? InnerExt0Ty
+                                  : InnerExt1Ty;
+        SrcVecTy = cast<VectorType>(ToVectorTy(MaxInnerExtTy, VF));
+        IsUnsigned = true;
+        HasInnerExt = true;
+      }
+      InstructionCost MulAccRedCost = Ctx.TTI.getMulAccReductionCost(
+          IsUnsigned, ElementTy, SrcVecTy, CostKind);
+      // Check if folding ext/mul into MulAccReduction is profitable.
+      if (MulAccRedCost.isValid() &&
+          MulAccRedCost <
+              ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
+        if (HasInnerExt) {
+          Ctx.FoldedRecipes[VF].insert(InnerExt0);
+          Ctx.FoldedRecipes[VF].insert(InnerExt1);
+        }
+        Ctx.FoldedRecipes[VF].insert(Mul);
+        if (HasOuterExt)
+          Ctx.FoldedRecipes[VF].insert(Ext);
+        return MulAccRedCost;
+      }
+    }
+    return InstructionCost::getInvalid();
+  };
+
+  // Match reduce(ext(...))
+  auto GetExtendedReductionCost =
+      [&](const VPReductionRecipe *Red) -> InstructionCost {
+    VPValue *VecOp = Red->getVecOp();
+    VPValue *A;
+    if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && VecOp->getNumUsers() == 1) {
+      VPWidenCastRecipe *Ext =
+          cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+      bool IsUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
+      InstructionCost ExtCost = Ext->computeCost(VF, Ctx);
+      auto *ExtVecTy =
+          cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(A), VF));
+      InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
+          Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags(),
+          CostKind);
+      // Check if folding ext into ExtendedReduction is profitable.
+      if (ExtendedRedCost.isValid() && ExtendedRedCost < ExtCost + BaseCost) {
+        Ctx.FoldedRecipes[VF].insert(Ext);
+        return ExtendedRedCost;
+      }
+    }
+    return InstructionCost::getInvalid();
+  };
+
+  // Match MulAccReduction patterns.
+  InstructionCost MulAccCost = GetMulAccReductionCost(this);
+  if (MulAccCost.isValid())
+    return MulAccCost;
+
+  // Match ExtendedReduction patterns.
+  InstructionCost ExtendedCost = GetExtendedReductionCost(this);
+  if (ExtendedCost.isValid())
+    return ExtendedCost;
+
+  // Default cost.
+  return BaseCost;
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)

>From bf92b2d9d9e87ae4e646f8b224fd60bc3931c883 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 3 Nov 2024 18:55:55 -0800
Subject: [PATCH 2/9] Partially support Extended-reduction.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  24 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         | 217 ++++++++++++++++++
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 114 +++++++--
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  24 ++
 .../Transforms/Vectorize/VPlanTransforms.h    |   3 +
 llvm/lib/Transforms/Vectorize/VPlanValue.h    |   2 +
 6 files changed, 359 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 32fbde97bb8283..29df6a52fa98e7 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7662,6 +7662,10 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
                              ILV.getOrCreateVectorTripCount(nullptr),
                              CanonicalIVStartValue, State);
 
+  // TODO: Rebase to fhahn's implementation.
+  VPlanTransforms::prepareExecute(BestVPlan);
+  dbgs() << "\n\n print plan\n";
+  BestVPlan.print(dbgs());
   BestVPlan.execute(&State);
 
   // 2.5 Collect reduction resume values.
@@ -9256,6 +9260,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
 // Adjust AnyOf reductions; replace the reduction phi for the selected value
 // with a boolean reduction phi node to check if the condition is true in any
 // iteration. The final value is selected by the final ComputeReductionResult.
+// TODO: Implement VPMulAccHere.
 void LoopVectorizationPlanner::adjustRecipesForReductions(
     VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) {
   using namespace VPlanPatternMatch;
@@ -9374,9 +9379,22 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      VPReductionRecipe *RedRecipe =
-          new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
-                                CondOp, CM.useOrderedReductions(RdxDesc));
+      // VPWidenCastRecipes can folded into VPReductionRecipe
+      VPValue *A;
+      VPSingleDefRecipe *RedRecipe;
+      if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+          !VecOp->hasMoreThanOneUniqueUser()) {
+        RedRecipe = new VPExtendedReductionRecipe(
+            RdxDesc, CurrentLinkI,
+            cast<CastInst>(
+                cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
+            PreviousLink, A, CondOp, CM.useOrderedReductions(RdxDesc),
+            cast<VPWidenCastRecipe>(VecOp)->getResultType());
+      } else {
+        RedRecipe =
+            new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
+                                  CondOp, CM.useOrderedReductions(RdxDesc));
+      }
       // Append the recipe to the end of the VPBasicBlock because we need to
       // ensure that it comes after all of it's inputs, including CondOp.
       // Note that this transformation may leave over dead recipes (including
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index dd763de1e7a1ef..7a05145514d0e7 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -859,6 +859,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPInstructionSC:
     case VPRecipeBase::VPReductionEVLSC:
     case VPRecipeBase::VPReductionSC:
+    case VPRecipeBase::VPMulAccSC:
+    case VPRecipeBase::VPExtendedReductionSC:
     case VPRecipeBase::VPReplicateSC:
     case VPRecipeBase::VPScalarIVStepsSC:
     case VPRecipeBase::VPVectorPointerSC:
@@ -2655,6 +2657,221 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
   }
 };
 
+/// A recipe to represent inloop extended reduction operations, performing a
+/// reduction on a vector operand into a scalar value, and adding the result to
+/// a chain. This recipe is high level abstract which will generate
+/// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
+/// {ChainOp, VecOp, [Condition]}.
+class VPExtendedReductionRecipe : public VPSingleDefRecipe {
+  /// The recurrence decriptor for the reduction in question.
+  const RecurrenceDescriptor &RdxDesc;
+  bool IsOrdered;
+  /// Whether the reduction is conditional.
+  bool IsConditional = false;
+  /// Type after extend.
+  Type *ResultTy;
+  Instruction::CastOps ExtOp;
+  CastInst *CastInstr;
+  bool IsZExt;
+
+protected:
+  VPExtendedReductionRecipe(const unsigned char SC,
+                            const RecurrenceDescriptor &R, Instruction *RedI,
+                            Instruction::CastOps ExtOp, CastInst *CastI,
+                            ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                            bool IsOrdered, Type *ResultTy)
+      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+        ResultTy(ResultTy), ExtOp(ExtOp), CastInstr(CastI) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+    IsZExt = ExtOp == Instruction::CastOps::ZExt;
+  }
+
+public:
+  VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                            CastInst *CastI, VPValue *ChainOp, VPValue *VecOp,
+                            VPValue *CondOp, bool IsOrdered, Type *ResultTy)
+      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
+                                  CastI->getOpcode(), CastI,
+                                  ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
+                                  IsOrdered, ResultTy) {}
+
+  VPExtendedReductionRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *Ext)
+      : VPExtendedReductionRecipe(
+            VPDef::VPExtendedReductionSC, Red->getRecurrenceDescriptor(),
+            Red->getUnderlyingInstr(), Ext->getOpcode(),
+            cast<CastInst>(Ext->getUnderlyingInstr()),
+            ArrayRef<VPValue *>({Red->getChainOp(), Ext->getOperand(0)}),
+            Red->getCondOp(), Red->isOrdered(), Ext->getResultType()) {}
+
+  ~VPExtendedReductionRecipe() override = default;
+
+  VPExtendedReductionRecipe *clone() override {
+    llvm_unreachable("Not implement yet");
+  }
+
+  static inline bool classof(const VPRecipeBase *R) {
+    return R->getVPDefID() == VPDef::VPExtendedReductionSC;
+  }
+
+  static inline bool classof(const VPUser *U) {
+    auto *R = dyn_cast<VPRecipeBase>(U);
+    return R && classof(R);
+  }
+
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPExtendedReductionRecipe should be transform to "
+                     "VPExtendedRecipe + VPReductionRecipe before execution.");
+  };
+
+  /// Return the cost of VPExtendedReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// Return the recurrence decriptor for the in-loop reduction.
+  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
+    return RdxDesc;
+  }
+  /// Return true if the in-loop reduction is ordered.
+  bool isOrdered() const { return IsOrdered; };
+  /// Return true if the in-loop reduction is conditional.
+  bool isConditional() const { return IsConditional; };
+  /// The VPValue of the scalar Chain being accumulated.
+  VPValue *getChainOp() const { return getOperand(0); }
+  /// The VPValue of the vector value to be extended and reduced.
+  VPValue *getVecOp() const { return getOperand(1); }
+  /// The VPValue of the condition for the block.
+  VPValue *getCondOp() const {
+    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
+  }
+  Type *getResultType() const { return ResultTy; };
+  Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  CastInst *getExtInstr() const { return CastInstr; };
+};
+
+/// A recipe to represent inloop MulAccreduction operations, performing a
+/// reduction on a vector operand into a scalar value, and adding the result to
+/// a chain. This recipe is high level abstract which will generate
+/// VPReductionRecipe VPWidenRecipe(mul)and VPWidenCastRecipe before execution.
+/// The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+class VPMulAccRecipe : public VPSingleDefRecipe {
+  /// The recurrence decriptor for the reduction in question.
+  const RecurrenceDescriptor &RdxDesc;
+  bool IsOrdered;
+  /// Whether the reduction is conditional.
+  bool IsConditional = false;
+  /// Type after extend.
+  Type *ResultTy;
+  /// Type for mul.
+  Type *MulTy;
+  /// reduce.add(OuterExt(mul(InnerExt(), InnerExt())))
+  Instruction::CastOps OuterExtOp;
+  Instruction::CastOps InnerExtOp;
+
+  Instruction *MulI;
+  Instruction *OuterExtI;
+  Instruction *InnerExt0I;
+  Instruction *InnerExt1I;
+
+protected:
+  VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+                 Instruction *RedI, Instruction::CastOps OuterExtOp,
+                 Instruction *OuterExtI, Instruction *MulI,
+                 Instruction::CastOps InnerExtOp, Instruction *InnerExt0I,
+                 Instruction *InnerExt1I, ArrayRef<VPValue *> Operands,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+        ResultTy(ResultTy), MulTy(MulTy), OuterExtOp(OuterExtOp),
+        InnerExtOp(InnerExtOp), MulI(MulI), OuterExtI(OuterExtI),
+        InnerExt0I(InnerExt0I), InnerExt1I(InnerExt1I) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+  }
+
+public:
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 Instruction *OuterExt, Instruction *Mul,
+                 Instruction *InnerExt0, Instruction *InnerExt1,
+                 VPValue *ChainOp, VPValue *InnerExt0Op, VPValue *InnerExt1Op,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+      : VPMulAccRecipe(
+            VPDef::VPMulAccSC, R, RedI, cast<CastInst>(OuterExt)->getOpcode(),
+            OuterExt, Mul, cast<CastInst>(InnerExt0)->getOpcode(), InnerExt0,
+            InnerExt1, ArrayRef<VPValue *>({ChainOp, InnerExt0Op, InnerExt1Op}),
+            CondOp, IsOrdered, ResultTy, MulTy) {}
+
+  VPMulAccRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *OuterExt,
+                 VPWidenRecipe *Mul, VPWidenCastRecipe *InnerExt0,
+                 VPWidenCastRecipe *InnerExt1)
+      : VPMulAccRecipe(
+            VPDef::VPMulAccSC, Red->getRecurrenceDescriptor(),
+            Red->getUnderlyingInstr(), OuterExt->getOpcode(),
+            OuterExt->getUnderlyingInstr(), Mul->getUnderlyingInstr(),
+            InnerExt0->getOpcode(), InnerExt0->getUnderlyingInstr(),
+            InnerExt1->getUnderlyingInstr(),
+            ArrayRef<VPValue *>({Red->getChainOp(), InnerExt0->getOperand(0),
+                                 InnerExt1->getOperand(0)}),
+            Red->getCondOp(), Red->isOrdered(), OuterExt->getResultType(),
+            InnerExt0->getResultType()) {}
+
+  ~VPMulAccRecipe() override = default;
+
+  VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
+
+  static inline bool classof(const VPRecipeBase *R) {
+    return R->getVPDefID() == VPRecipeBase::VPMulAccSC;
+  }
+
+  static inline bool classof(const VPUser *U) {
+    auto *R = dyn_cast<VPRecipeBase>(U);
+    return R && classof(R);
+  }
+
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override;
+
+  /// Return the cost of VPExtendedReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// Return the recurrence decriptor for the in-loop reduction.
+  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
+    return RdxDesc;
+  }
+  /// Return true if the in-loop reduction is ordered.
+  bool isOrdered() const { return IsOrdered; };
+  /// Return true if the in-loop reduction is conditional.
+  bool isConditional() const { return IsConditional; };
+  /// The VPValue of the scalar Chain being accumulated.
+  VPValue *getChainOp() const { return getOperand(0); }
+  /// The VPValue of the vector value to be extended and reduced.
+  VPValue *getVecOp() const { return getOperand(1); }
+  /// The VPValue of the condition for the block.
+  VPValue *getCondOp() const {
+    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
+  }
+  Type *getResultTy() const { return ResultTy; };
+  Instruction::CastOps getOuterExtOpcode() const { return OuterExtOp; };
+  Instruction::CastOps getInnerExtOpcode() const { return InnerExtOp; };
+};
+
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
 /// copies of the original scalar type, one per lane, instead of producing a
 /// single copy of widened type for all lanes. If the instruction is known to be
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 36898220c8e48a..a7858a729c2c92 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1490,6 +1490,27 @@ void VPWidenCastRecipe::execute(VPTransformState &State) {
     setFlags(CastOp);
 }
 
+// Computes the CastContextHint from a recipes that may access memory.
+static TTI::CastContextHint computeCCH(const VPRecipeBase *R, ElementCount VF) {
+  if (VF.isScalar())
+    return TTI::CastContextHint::Normal;
+  if (isa<VPInterleaveRecipe>(R))
+    return TTI::CastContextHint::Interleave;
+  if (const auto *ReplicateRecipe = dyn_cast<VPReplicateRecipe>(R))
+    return ReplicateRecipe->isPredicated() ? TTI::CastContextHint::Masked
+                                           : TTI::CastContextHint::Normal;
+  const auto *WidenMemoryRecipe = dyn_cast<VPWidenMemoryRecipe>(R);
+  if (WidenMemoryRecipe == nullptr)
+    return TTI::CastContextHint::None;
+  if (!WidenMemoryRecipe->isConsecutive())
+    return TTI::CastContextHint::GatherScatter;
+  if (WidenMemoryRecipe->isReverse())
+    return TTI::CastContextHint::Reversed;
+  if (WidenMemoryRecipe->isMasked())
+    return TTI::CastContextHint::Masked;
+  return TTI::CastContextHint::Normal;
+}
+
 InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
                                                VPCostContext &Ctx) const {
   // TODO: In some cases, VPWidenCastRecipes are created but not considered in
@@ -1497,26 +1518,6 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
   // reduction in a smaller type.
   if (!getUnderlyingValue())
     return 0;
-  // Computes the CastContextHint from a recipes that may access memory.
-  auto ComputeCCH = [&](const VPRecipeBase *R) -> TTI::CastContextHint {
-    if (VF.isScalar())
-      return TTI::CastContextHint::Normal;
-    if (isa<VPInterleaveRecipe>(R))
-      return TTI::CastContextHint::Interleave;
-    if (const auto *ReplicateRecipe = dyn_cast<VPReplicateRecipe>(R))
-      return ReplicateRecipe->isPredicated() ? TTI::CastContextHint::Masked
-                                             : TTI::CastContextHint::Normal;
-    const auto *WidenMemoryRecipe = dyn_cast<VPWidenMemoryRecipe>(R);
-    if (WidenMemoryRecipe == nullptr)
-      return TTI::CastContextHint::None;
-    if (!WidenMemoryRecipe->isConsecutive())
-      return TTI::CastContextHint::GatherScatter;
-    if (WidenMemoryRecipe->isReverse())
-      return TTI::CastContextHint::Reversed;
-    if (WidenMemoryRecipe->isMasked())
-      return TTI::CastContextHint::Masked;
-    return TTI::CastContextHint::Normal;
-  };
 
   VPValue *Operand = getOperand(0);
   TTI::CastContextHint CCH = TTI::CastContextHint::None;
@@ -1524,7 +1525,7 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
   if ((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) &&
       !hasMoreThanOneUniqueUser() && getNumUsers() > 0) {
     if (auto *StoreRecipe = dyn_cast<VPRecipeBase>(*user_begin()))
-      CCH = ComputeCCH(StoreRecipe);
+      CCH = computeCCH(StoreRecipe, VF);
   }
   // For Z/Sext, get the context from the operand.
   else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt ||
@@ -1532,7 +1533,7 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
     if (Operand->isLiveIn())
       CCH = TTI::CastContextHint::Normal;
     else if (Operand->getDefiningRecipe())
-      CCH = ComputeCCH(Operand->getDefiningRecipe());
+      CCH = computeCCH(Operand->getDefiningRecipe(), VF);
   }
 
   auto *SrcTy =
@@ -2210,6 +2211,7 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
         Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
+  /*
   using namespace llvm::VPlanPatternMatch;
   auto GetMulAccReductionCost =
       [&](const VPReductionRecipe *Red) -> InstructionCost {
@@ -2323,11 +2325,57 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   InstructionCost ExtendedCost = GetExtendedReductionCost(this);
   if (ExtendedCost.isValid())
     return ExtendedCost;
+  */
 
   // Default cost.
   return BaseCost;
 }
 
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+                                       VPCostContext &Ctx) const {
+  RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+  Type *ElementTy = getResultType();
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
+         "Inferred type and recurrence type mismatch.");
+
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost ReductionCost =
+      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+    Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+    ReductionCost += Ctx.TTI.getMinMaxReductionCost(
+        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  }
+
+  // Extended cost
+  auto *SrcTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
+  auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+  TTI::CastContextHint CCH = computeCCH(getVecOp()->getDefiningRecipe(), VF);
+  // Arm TTI will use the underlying instruction to determine the cost.
+  InstructionCost ExtendedCost = Ctx.TTI.getCastInstrCost(
+      Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
+      dyn_cast_if_present<Instruction>(getUnderlyingValue()));
+
+  // ExtendedReduction Cost
+  InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
+      Opcode, IsZExt, ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
+  // Check if folding ext into ExtendedReduction is profitable.
+  if (ExtendedRedCost.isValid() &&
+      ExtendedRedCost < ExtendedCost + ReductionCost) {
+    return ExtendedRedCost;
+  }
+  return ExtendedCost + ReductionCost;
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {
@@ -2373,6 +2421,28 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
 }
+
+void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+                                      VPSlotTracker &SlotTracker) const {
+  O << Indent << "EXTENDED-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = ";
+  getChainOp()->printAsOperand(O, SlotTracker);
+  O << " +";
+  if (isa<FPMathOperator>(getUnderlyingInstr()))
+    O << getUnderlyingInstr()->getFastMathFlags();
+  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  getVecOp()->printAsOperand(O, SlotTracker);
+  O << " extended to " << *getResultType();
+  if (isConditional()) {
+    O << ", ";
+    getCondOp()->printAsOperand(O, SlotTracker);
+  }
+  O << ")";
+  if (RdxDesc.IntermediateStore)
+    O << " (with final reduction value stored in invariant address sank "
+         "outside of loop)";
+}
 #endif
 
 bool VPReplicateRecipe::shouldPack() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index b9ab8a8fe60107..22ac98751bbd86 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -519,6 +519,30 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
   }
 }
 
+void VPlanTransforms::prepareExecute(VPlan &Plan) {
+  errs() << "\n\n\n!!Prepare to execute\n";
+  ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
+      Plan.getVectorLoopRegion());
+  for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
+           vp_depth_first_deep(Plan.getEntry()))) {
+    for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
+      if (!isa<VPExtendedReductionRecipe>(&R))
+        continue;
+      auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+      auto *Ext = new VPWidenCastRecipe(
+          ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+          *ExtRed->getExtInstr());
+      auto *Red = new VPReductionRecipe(
+          ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
+          ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
+      Ext->insertBefore(ExtRed);
+      Red->insertBefore(ExtRed);
+      ExtRed->replaceAllUsesWith(Red);
+      ExtRed->eraseFromParent();
+    }
+  }
+}
+
 static VPScalarIVStepsRecipe *
 createScalarIVSteps(VPlan &Plan, InductionDescriptor::InductionKind Kind,
                     Instruction::BinaryOps InductionOpcode,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 11e094db6294f6..6310c23b605da3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -123,6 +123,9 @@ struct VPlanTransforms {
 
   /// Remove dead recipes from \p Plan.
   static void removeDeadRecipes(VPlan &Plan);
+
+  /// TODO: Rebase to fhahn's implementation.
+  static void prepareExecute(VPlan &Plan);
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 691b0d40823cfb..09defa6406c078 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -329,6 +329,8 @@ class VPDef {
     VPInterleaveSC,
     VPReductionEVLSC,
     VPReductionSC,
+    VPMulAccSC,
+    VPExtendedReductionSC,
     VPReplicateSC,
     VPScalarCastSC,
     VPScalarIVStepsSC,

>From d5be1e3f33d4fced02511654f843f9d5893ced9c Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 4 Nov 2024 22:02:22 -0800
Subject: [PATCH 3/9] Support MulAccRecipe

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  33 ++++-
 llvm/lib/Transforms/Vectorize/VPlan.h         | 103 +++++++++-------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 114 +++++++++++++++++-
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  56 ++++++---
 4 files changed, 237 insertions(+), 69 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 29df6a52fa98e7..7147c35c807491 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7664,8 +7664,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
 
   // TODO: Rebase to fhahn's implementation.
   VPlanTransforms::prepareExecute(BestVPlan);
-  dbgs() << "\n\n print plan\n";
-  BestVPlan.print(dbgs());
   BestVPlan.execute(&State);
 
   // 2.5 Collect reduction resume values.
@@ -9379,11 +9377,34 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      // VPWidenCastRecipes can folded into VPReductionRecipe
-      VPValue *A;
+      VPValue *A, *B;
       VPSingleDefRecipe *RedRecipe;
-      if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
-          !VecOp->hasMoreThanOneUniqueUser()) {
+      // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+      if (RdxDesc.getOpcode() == Instruction::Add &&
+          match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+        VPRecipeBase *RecipeA = A->getDefiningRecipe();
+        VPRecipeBase *RecipeB = B->getDefiningRecipe();
+        if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+            match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+            cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
+                cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
+            !A->hasMoreThanOneUniqueUser() && !B->hasMoreThanOneUniqueUser()) {
+          RedRecipe = new VPMulAccRecipe(
+              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+              CM.useOrderedReductions(RdxDesc),
+              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
+              cast<VPWidenCastRecipe>(RecipeA),
+              cast<VPWidenCastRecipe>(RecipeB));
+        } else {
+          RedRecipe = new VPMulAccRecipe(
+              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+              CM.useOrderedReductions(RdxDesc),
+              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+        }
+      }
+      // VPWidenCastRecipes can folded into VPReductionRecipe
+      else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+               !VecOp->hasMoreThanOneUniqueUser()) {
         RedRecipe = new VPExtendedReductionRecipe(
             RdxDesc, CurrentLinkI,
             cast<CastInst>(
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 7a05145514d0e7..3a49962e8b465c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2770,60 +2770,64 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   /// Whether the reduction is conditional.
   bool IsConditional = false;
   /// Type after extend.
-  Type *ResultTy;
-  /// Type for mul.
-  Type *MulTy;
-  /// reduce.add(OuterExt(mul(InnerExt(), InnerExt())))
-  Instruction::CastOps OuterExtOp;
-  Instruction::CastOps InnerExtOp;
+  Type *ResultType;
+  /// reduce.add(mul(Ext(), Ext()))
+  Instruction::CastOps ExtOp;
+
+  Instruction *MulInstr;
+  CastInst *Ext0Instr;
+  CastInst *Ext1Instr;
 
-  Instruction *MulI;
-  Instruction *OuterExtI;
-  Instruction *InnerExt0I;
-  Instruction *InnerExt1I;
+  bool IsExtended;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction::CastOps OuterExtOp,
-                 Instruction *OuterExtI, Instruction *MulI,
-                 Instruction::CastOps InnerExtOp, Instruction *InnerExt0I,
-                 Instruction *InnerExt1I, ArrayRef<VPValue *> Operands,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+                 Instruction *RedI, Instruction *MulInstr,
+                 Instruction::CastOps ExtOp, Instruction *Ext0Instr,
+                 Instruction *Ext1Instr, ArrayRef<VPValue *> Operands,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+        ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+        Ext0Instr(cast<CastInst>(Ext0Instr)),
+        Ext1Instr(cast<CastInst>(Ext1Instr)) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+    IsExtended = true;
+  }
+
+  VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+                 Instruction *RedI, Instruction *MulInstr,
+                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
-        ResultTy(ResultTy), MulTy(MulTy), OuterExtOp(OuterExtOp),
-        InnerExtOp(InnerExtOp), MulI(MulI), OuterExtI(OuterExtI),
-        InnerExt0I(InnerExt0I), InnerExt1I(InnerExt1I) {
+        MulInstr(MulInstr) {
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
     }
+    IsExtended = false;
   }
 
 public:
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
-                 Instruction *OuterExt, Instruction *Mul,
-                 Instruction *InnerExt0, Instruction *InnerExt1,
-                 VPValue *ChainOp, VPValue *InnerExt0Op, VPValue *InnerExt1Op,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
-      : VPMulAccRecipe(
-            VPDef::VPMulAccSC, R, RedI, cast<CastInst>(OuterExt)->getOpcode(),
-            OuterExt, Mul, cast<CastInst>(InnerExt0)->getOpcode(), InnerExt0,
-            InnerExt1, ArrayRef<VPValue *>({ChainOp, InnerExt0Op, InnerExt1Op}),
-            CondOp, IsOrdered, ResultTy, MulTy) {}
-
-  VPMulAccRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *OuterExt,
-                 VPWidenRecipe *Mul, VPWidenCastRecipe *InnerExt0,
-                 VPWidenCastRecipe *InnerExt1)
-      : VPMulAccRecipe(
-            VPDef::VPMulAccSC, Red->getRecurrenceDescriptor(),
-            Red->getUnderlyingInstr(), OuterExt->getOpcode(),
-            OuterExt->getUnderlyingInstr(), Mul->getUnderlyingInstr(),
-            InnerExt0->getOpcode(), InnerExt0->getUnderlyingInstr(),
-            InnerExt1->getUnderlyingInstr(),
-            ArrayRef<VPValue *>({Red->getChainOp(), InnerExt0->getOperand(0),
-                                 InnerExt1->getOperand(0)}),
-            Red->getCondOp(), Red->isOrdered(), OuterExt->getResultType(),
-            InnerExt0->getResultType()) {}
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+                 VPWidenCastRecipe *Ext1)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       Ext0->getOpcode(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(),
+                       ArrayRef<VPValue *>(
+                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       CondOp, IsOrdered, Ext0->getResultType()) {}
+
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenRecipe *Mul)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       ArrayRef<VPValue *>(
+                           {ChainOp, Mul->getOperand(0), Mul->getOperand(0)}),
+                       CondOp, IsOrdered) {}
 
   ~VPMulAccRecipe() override = default;
 
@@ -2839,7 +2843,10 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   }
 
   /// Generate the reduction in the loop
-  void execute(VPTransformState &State) override;
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
+                     "VPWidenRecipe + VPReductionRecipe before execution");
+  }
 
   /// Return the cost of VPExtendedReductionRecipe.
   InstructionCost computeCost(ElementCount VF,
@@ -2862,14 +2869,18 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   /// The VPValue of the scalar Chain being accumulated.
   VPValue *getChainOp() const { return getOperand(0); }
   /// The VPValue of the vector value to be extended and reduced.
-  VPValue *getVecOp() const { return getOperand(1); }
+  VPValue *getVecOp0() const { return getOperand(1); }
+  VPValue *getVecOp1() const { return getOperand(2); }
   /// The VPValue of the condition for the block.
   VPValue *getCondOp() const {
     return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
   }
-  Type *getResultTy() const { return ResultTy; };
-  Instruction::CastOps getOuterExtOpcode() const { return OuterExtOp; };
-  Instruction::CastOps getInnerExtOpcode() const { return InnerExtOp; };
+  Type *getResultType() const { return ResultType; };
+  Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  Instruction *getMulInstr() const { return MulInstr; };
+  CastInst *getExt0Instr() const { return Ext0Instr; };
+  CastInst *getExt1Instr() const { return Ext1Instr; };
+  bool isExtended() const { return IsExtended; };
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index a7858a729c2c92..195a4d676f5fa0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,9 +270,7 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = &WidenMem->getIngredient();
 
   InstructionCost RecipeCost;
-  if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
-      (Ctx.FoldedRecipes.contains(VF) &&
-       Ctx.FoldedRecipes.at(VF).contains(this))) {
+  if ((UI && Ctx.skipCostComputation(UI, VF.isVector()))) {
     RecipeCost = 0;
   } else {
     RecipeCost = computeCost(VF, Ctx);
@@ -2376,6 +2374,85 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   return ExtendedCost + ReductionCost;
 }
 
+InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
+                                            VPCostContext &Ctx) const {
+  Type *ElementTy =
+      IsExtended ? getResultType() : Ctx.Types.inferScalarType(getVecOp0());
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
+         "Inferred type and recurrence type mismatch.");
+
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost ReductionCost =
+      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+      Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+
+  // Extended cost
+  InstructionCost ExtendedCost = 0;
+  if (IsExtended) {
+    auto *SrcTy = cast<VectorType>(
+        ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+    auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+    TTI::CastContextHint CCH0 =
+        computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+    // Arm TTI will use the underlying instruction to determine the cost.
+    ExtendedCost = Ctx.TTI.getCastInstrCost(
+        ExtOp, DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt0Instr()));
+    TTI::CastContextHint CCH1 =
+        computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+    ExtendedCost += Ctx.TTI.getCastInstrCost(
+        ExtOp, DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt1Instr()));
+  }
+
+  // Mul cost
+  InstructionCost MulCost;
+  SmallVector<const Value *, 4> Operands;
+  Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
+  if (IsExtended)
+    MulCost = Ctx.TTI.getArithmeticInstrCost(
+        Instruction::Mul, VectorTy, CostKind,
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        Operands, MulInstr, &Ctx.TLI);
+  else {
+    VPValue *RHS = getVecOp1();
+    // Certain instructions can be cheaper to vectorize if they have a constant
+    // second vector operand. One example of this are shifts on x86.
+    TargetTransformInfo::OperandValueInfo RHSInfo = {
+        TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None};
+    if (RHS->isLiveIn())
+      RHSInfo = Ctx.TTI.getOperandInfo(RHS->getLiveInIRValue());
+
+    if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
+        RHS->isDefinedOutsideLoopRegions())
+      RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
+    MulCost = Ctx.TTI.getArithmeticInstrCost(
+        Instruction::Mul, VectorTy, CostKind,
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        RHSInfo, Operands, MulInstr, &Ctx.TLI);
+  }
+
+  // ExtendedReduction Cost
+  VectorType *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+  InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+      getExtOpcode() == Instruction::CastOps::ZExt, ElementTy, SrcVecTy,
+      CostKind);
+
+  // Check if folding ext into ExtendedReduction is profitable.
+  if (MulAccCost.isValid() &&
+      MulAccCost < ExtendedCost + ReductionCost + MulCost) {
+    return MulAccCost;
+  }
+  return ExtendedCost + ReductionCost + MulCost;
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {
@@ -2443,6 +2520,37 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
 }
+
+void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
+                           VPSlotTracker &SlotTracker) const {
+  O << Indent << "MULACC-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = ";
+  getChainOp()->printAsOperand(O, SlotTracker);
+  O << " +";
+  if (isa<FPMathOperator>(getUnderlyingInstr()))
+    O << getUnderlyingInstr()->getFastMathFlags();
+  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  O << " mul ";
+  if (IsExtended)
+    O << "(";
+  getVecOp0()->printAsOperand(O, SlotTracker);
+  if (IsExtended)
+    O << " extended to " << *getResultType() << ")";
+  if (IsExtended)
+    O << "(";
+  getVecOp1()->printAsOperand(O, SlotTracker);
+  if (IsExtended)
+    O << " extended to " << *getResultType() << ")";
+  if (isConditional()) {
+    O << ", ";
+    getCondOp()->printAsOperand(O, SlotTracker);
+  }
+  O << ")";
+  if (RdxDesc.IntermediateStore)
+    O << " (with final reduction value stored in invariant address sank "
+         "outside of loop)";
+}
 #endif
 
 bool VPReplicateRecipe::shouldPack() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 22ac98751bbd86..6c9c157d7a9071 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -520,25 +520,53 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
 }
 
 void VPlanTransforms::prepareExecute(VPlan &Plan) {
-  errs() << "\n\n\n!!Prepare to execute\n";
   ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
       Plan.getVectorLoopRegion());
   for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
            vp_depth_first_deep(Plan.getEntry()))) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
-      if (!isa<VPExtendedReductionRecipe>(&R))
-        continue;
-      auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
-      auto *Ext = new VPWidenCastRecipe(
-          ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
-          *ExtRed->getExtInstr());
-      auto *Red = new VPReductionRecipe(
-          ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
-          ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
-      Ext->insertBefore(ExtRed);
-      Red->insertBefore(ExtRed);
-      ExtRed->replaceAllUsesWith(Red);
-      ExtRed->eraseFromParent();
+      if (isa<VPExtendedReductionRecipe>(&R)) {
+        auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+        auto *Ext = new VPWidenCastRecipe(
+            ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+            *ExtRed->getExtInstr());
+        auto *Red = new VPReductionRecipe(
+            ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
+            ExtRed->getChainOp(), Ext, ExtRed->getCondOp(),
+            ExtRed->isOrdered());
+        Ext->insertBefore(ExtRed);
+        Red->insertBefore(ExtRed);
+        ExtRed->replaceAllUsesWith(Red);
+        ExtRed->eraseFromParent();
+      } else if (isa<VPMulAccRecipe>(&R)) {
+        auto *MulAcc = cast<VPMulAccRecipe>(&R);
+        VPValue *Op0, *Op1;
+        if (MulAcc->isExtended()) {
+          Op0 = new VPWidenCastRecipe(
+              MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+              MulAcc->getResultType(), *MulAcc->getExt0Instr());
+          Op1 = new VPWidenCastRecipe(
+              MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+              MulAcc->getResultType(), *MulAcc->getExt1Instr());
+          Op0->getDefiningRecipe()->insertBefore(MulAcc);
+          Op1->getDefiningRecipe()->insertBefore(MulAcc);
+        } else {
+          Op0 = MulAcc->getVecOp0();
+          Op1 = MulAcc->getVecOp1();
+        }
+        Instruction *MulInstr = MulAcc->getMulInstr();
+        SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
+        auto *Mul = new VPWidenRecipe(*MulInstr,
+                                      make_range(MulOps.begin(), MulOps.end()));
+        auto *Red = new VPReductionRecipe(
+            MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
+            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+            MulAcc->isOrdered());
+        Mul->insertBefore(MulAcc);
+        Red->insertBefore(MulAcc);
+        MulAcc->replaceAllUsesWith(Red);
+        MulAcc->eraseFromParent();
+      }
     }
   }
 }

>From e537818b73c3fbc82c6152e6439d93cea0cb0592 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 5 Nov 2024 16:52:31 -0800
Subject: [PATCH 4/9] Fix servel errors and update tests.

We need to update tests since the generated vector IR will be reordered.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 45 ++++++++++++++
 llvm/lib/Transforms/Vectorize/VPlan.h         | 34 ++++++++---
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  6 +-
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 14 ++++-
 .../LoopVectorize/ARM/mve-reduction-types.ll  |  4 +-
 .../LoopVectorize/ARM/mve-reductions.ll       | 61 ++++++++++---------
 .../LoopVectorize/RISCV/inloop-reduction.ll   | 32 ++++++----
 .../LoopVectorize/reduction-inloop-pred.ll    | 34 +++++------
 .../LoopVectorize/reduction-inloop.ll         | 12 ++--
 9 files changed, 163 insertions(+), 79 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 7147c35c807491..ea2e6b96f7711c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7394,6 +7394,19 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       }
       if (Instruction *UI = GetInstructionForCost(&R))
         SeenInstrs.insert(UI);
+      // VPExtendedReductionRecipe contains a folded extend instruction.
+      if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
+        SeenInstrs.insert(ExtendedRed->getExtInstr());
+      // VPMulAccRecupe constians a mul and otional extend instructions.
+      else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+        SeenInstrs.insert(MulAcc->getMulInstr());
+        if (MulAcc->isExtended()) {
+          SeenInstrs.insert(MulAcc->getExt0Instr());
+          SeenInstrs.insert(MulAcc->getExt1Instr());
+          if (auto *Ext = MulAcc->getExtInstr())
+            SeenInstrs.insert(Ext);
+        }
+      }
     }
   }
 
@@ -9401,6 +9414,38 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
               CM.useOrderedReductions(RdxDesc),
               cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
         }
+      } else if (RdxDesc.getOpcode() == Instruction::Add &&
+                 match(VecOp,
+                       m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue(A)),
+                                          m_ZExtOrSExt(m_VPValue(B)))))) {
+        VPWidenCastRecipe *Ext =
+            dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+        VPWidenRecipe *Mul =
+            dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+        if (Mul && match(Mul, m_Mul(m_ZExtOrSExt(m_VPValue()),
+                                    m_ZExtOrSExt(m_VPValue())))) {
+          VPWidenRecipe *Mul =
+              cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+          VPWidenCastRecipe *Ext0 =
+              cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
+          VPWidenCastRecipe *Ext1 =
+              cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
+          if (Ext->getOpcode() == Ext0->getOpcode() &&
+              Ext0->getOpcode() == Ext1->getOpcode()) {
+            RedRecipe = new VPMulAccRecipe(
+                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
+                cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
+          } else
+            RedRecipe = new VPExtendedReductionRecipe(
+                RdxDesc, CurrentLinkI,
+                cast<CastInst>(
+                    cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
+                PreviousLink, cast<VPWidenCastRecipe>(VecOp)->getOperand(0),
+                CondOp, CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenCastRecipe>(VecOp)->getResultType());
+        }
       }
       // VPWidenCastRecipes can folded into VPReductionRecipe
       else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 3a49962e8b465c..0103686be422d6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2771,23 +2771,27 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   bool IsConditional = false;
   /// Type after extend.
   Type *ResultType;
-  /// reduce.add(mul(Ext(), Ext()))
+  /// reduce.add(ext((mul(Ext(), Ext())))
   Instruction::CastOps ExtOp;
 
   Instruction *MulInstr;
+  CastInst *ExtInstr = nullptr;
   CastInst *Ext0Instr;
   CastInst *Ext1Instr;
 
   bool IsExtended;
+  bool IsOuterExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *MulInstr,
-                 Instruction::CastOps ExtOp, Instruction *Ext0Instr,
-                 Instruction *Ext1Instr, ArrayRef<VPValue *> Operands,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+                 Instruction *RedI, Instruction *ExtInstr,
+                 Instruction *MulInstr, Instruction::CastOps ExtOp,
+                 Instruction *Ext0Instr, Instruction *Ext1Instr,
+                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered,
+                 Type *ResultType)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
         ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+        ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
     if (CondOp) {
@@ -2814,9 +2818,9 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       Ext0->getOpcode(), Ext0->getUnderlyingInstr(),
-                       Ext1->getUnderlyingInstr(),
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
+                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
+                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
                        ArrayRef<VPValue *>(
                            {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
                        CondOp, IsOrdered, Ext0->getResultType()) {}
@@ -2826,9 +2830,20 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  VPWidenRecipe *Mul)
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
                        ArrayRef<VPValue *>(
-                           {ChainOp, Mul->getOperand(0), Mul->getOperand(0)}),
+                           {ChainOp, Mul->getOperand(0), Mul->getOperand(1)}),
                        CondOp, IsOrdered) {}
 
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
+                 VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
+                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
+                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
+                       ArrayRef<VPValue *>(
+                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       CondOp, IsOrdered, Ext0->getResultType()) {}
+
   ~VPMulAccRecipe() override = default;
 
   VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
@@ -2878,6 +2893,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   Type *getResultType() const { return ResultType; };
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
   Instruction *getMulInstr() const { return MulInstr; };
+  CastInst *getExtInstr() const { return ExtInstr; };
   CastInst *getExt0Instr() const { return Ext0Instr; };
   CastInst *getExt1Instr() const { return Ext1Instr; };
   bool isExtended() const { return IsExtended; };
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 195a4d676f5fa0..d28b2d93ee5630 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2376,8 +2376,8 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 
 InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
-  Type *ElementTy =
-      IsExtended ? getResultType() : Ctx.Types.inferScalarType(getVecOp0());
+  Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
+                               : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
@@ -2438,7 +2438,7 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
         RHSInfo, Operands, MulInstr, &Ctx.TLI);
   }
 
-  // ExtendedReduction Cost
+  // MulAccReduction Cost
   VectorType *SrcVecTy =
       cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
   InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 6c9c157d7a9071..b7cc945747df13 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -554,15 +554,27 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
           Op0 = MulAcc->getVecOp0();
           Op1 = MulAcc->getVecOp1();
         }
+        VPSingleDefRecipe *VecOp;
         Instruction *MulInstr = MulAcc->getMulInstr();
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
         auto *Mul = new VPWidenRecipe(*MulInstr,
                                       make_range(MulOps.begin(), MulOps.end()));
+        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
+          // dbgs() <<"\n!!!"<< *OuterExtInstr << " " << MulAcc->getExtOpcode()
+          // << "\n";
+          VecOp = new VPWidenCastRecipe(
+              MulAcc->getExtOpcode(), Mul,
+              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
+              *OuterExtInstr);
+        } else
+          VecOp = Mul;
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
-            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+            MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
             MulAcc->isOrdered());
         Mul->insertBefore(MulAcc);
+        if (VecOp != Mul)
+          VecOp->insertBefore(MulAcc);
         Red->insertBefore(MulAcc);
         MulAcc->replaceAllUsesWith(Red);
         MulAcc->eraseFromParent();
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
index 832d4db53036fb..9d4b83edd55ccb 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
@@ -24,11 +24,11 @@ define i32 @mla_i32(ptr noalias nocapture readonly %A, ptr noalias nocapture rea
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[TMP1]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP5]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP6:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <16 x i32> [[TMP6]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP8:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP7]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
@@ -107,11 +107,11 @@ define i32 @mla_i8(ptr noalias nocapture readonly %A, ptr noalias nocapture read
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[TMP1]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP5]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP6:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <16 x i32> [[TMP6]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP8:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP7]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index 3dae408feeed7f..1e4134a8fdce97 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -646,11 +646,11 @@ define i64 @mla_i16_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i16>, ptr [[TMP2]], align 2
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
+; CHECK-NEXT:    [[TMP11:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i32> [[TMP4]] to <8 x i64>
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
 ; CHECK-NEXT:    [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
@@ -726,11 +726,11 @@ define i64 @mla_i8_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i32
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i8>, ptr [[TMP0]], align 1
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i8>, ptr [[TMP2]], align 1
-; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <8 x i32> [[TMP3]], [[TMP1]]
+; CHECK-NEXT:    [[TMP11:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <8 x i32> [[TMP11]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = zext nneg <8 x i32> [[TMP4]] to <8 x i64>
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
 ; CHECK-NEXT:    [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
@@ -802,11 +802,11 @@ define i32 @mla_i32_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP0]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP1]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i32, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP7]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP2]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
@@ -855,10 +855,10 @@ define i32 @mla_i16_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP0]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP2]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP4]], <8 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP5]])
@@ -910,10 +910,10 @@ define i32 @mla_i8_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i32
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <16 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP4]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
@@ -963,11 +963,11 @@ define signext i16 @mla_i16_i16(ptr nocapture readonly %x, ptr nocapture readonl
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i16 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP0]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP1]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = mul <8 x i16> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP7]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = mul <8 x i16> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> [[TMP2]], <8 x i16> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i16 [[TMP4]], [[VEC_PHI]]
@@ -1016,10 +1016,10 @@ define signext i16 @mla_i8_i16(ptr nocapture readonly %x, ptr nocapture readonly
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i16>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i16>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i16>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw <16 x i16> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i16> [[TMP4]], <16 x i16> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[TMP5]])
@@ -1069,11 +1069,11 @@ define zeroext i8 @mla_i8_i8(ptr nocapture readonly %x, ptr nocapture readonly %
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i8 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP1]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = mul <16 x i8> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP7]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = mul <16 x i8> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> [[TMP2]], <16 x i8> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i8 @llvm.vector.reduce.add.v16i8(<16 x i8> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i8 [[TMP4]], [[VEC_PHI]]
@@ -1122,10 +1122,10 @@ define i32 @red_mla_ext_s8_s16_s32(ptr noalias nocapture readonly %A, ptr noalia
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i8> @llvm.masked.load.v8i8.p0(ptr [[TMP0]], i32 1, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[B:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP2]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP4]], <8 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP5]])
@@ -1183,11 +1183,11 @@ define i64 @red_mla_ext_s16_u16_s64(ptr noalias nocapture readonly %A, ptr noali
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, ptr [[TMP0]], align 1
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <4 x i16> [[WIDE_LOAD]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[B:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i16>, ptr [[TMP2]], align 2
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, ptr [[TMP2]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <4 x i16> [[WIDE_LOAD]] to <4 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds i16, ptr [[B1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i16>, ptr [[TMP11]], align 2
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext <4 x i16> [[WIDE_LOAD1]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <4 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = zext <4 x i32> [[TMP4]] to <4 x i64>
@@ -1206,10 +1206,10 @@ define i64 @red_mla_ext_s16_u16_s64(ptr noalias nocapture readonly %A, ptr noali
 ; CHECK:       for.body:
 ; CHECK-NEXT:    [[I_011:%.*]] = phi i32 [ [[INC:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
 ; CHECK-NEXT:    [[S_010:%.*]] = phi i64 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i16, ptr [[A]], i32 [[I_011]]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i16, ptr [[B]], i32 [[I_011]]
 ; CHECK-NEXT:    [[TMP9:%.*]] = load i16, ptr [[ARRAYIDX]], align 1
 ; CHECK-NEXT:    [[CONV:%.*]] = sext i16 [[TMP9]] to i32
-; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds i16, ptr [[B]], i32 [[I_011]]
+; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds i16, ptr [[B1]], i32 [[I_011]]
 ; CHECK-NEXT:    [[TMP10:%.*]] = load i16, ptr [[ARRAYIDX1]], align 2
 ; CHECK-NEXT:    [[CONV2:%.*]] = zext i16 [[TMP10]] to i32
 ; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[CONV2]], [[CONV]]
@@ -1268,12 +1268,12 @@ define i32 @red_mla_u8_s8_u32(ptr noalias nocapture readonly %A, ptr noalias noc
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP0]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i8> [[WIDE_MASKED_LOAD]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP2]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i8> [[WIDE_MASKED_LOAD1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds i8, ptr [[B1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP9]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD2]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <4 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP4]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP5]])
@@ -1413,7 +1413,8 @@ define i32 @mla_i8_i32_multiuse(ptr nocapture readonly %x, ptr nocapture readonl
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP7]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP2]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
index 8ca2bd1f286ae3..9f1a61ebb5efef 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
@@ -187,23 +187,33 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
 ; IF-EVL-INLOOP-NEXT:    [[N_RND_UP:%.*]] = add i32 [[N]], [[TMP2]]
 ; IF-EVL-INLOOP-NEXT:    [[N_MOD_VF:%.*]] = urem i32 [[N_RND_UP]], [[TMP1]]
 ; IF-EVL-INLOOP-NEXT:    [[N_VEC:%.*]] = sub i32 [[N_RND_UP]], [[N_MOD_VF]]
+; IF-EVL-INLOOP-NEXT:    [[TRIP_COUNT_MINUS_1:%.*]] = sub i32 [[N]], 1
 ; IF-EVL-INLOOP-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vscale.i32()
 ; IF-EVL-INLOOP-NEXT:    [[TMP4:%.*]] = mul i32 [[TMP3]], 8
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[TRIP_COUNT_MINUS_1]], i64 0
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLAT2:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT1]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
 ; IF-EVL-INLOOP-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; IF-EVL-INLOOP:       vector.body:
 ; IF-EVL-INLOOP-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; IF-EVL-INLOOP-NEXT:    [[EVL_BASED_IV:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_EVL_NEXT:%.*]], [[VECTOR_BODY]] ]
-; IF-EVL-INLOOP-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP11:%.*]], [[VECTOR_BODY]] ]
-; IF-EVL-INLOOP-NEXT:    [[AVL:%.*]] = sub i32 [[N]], [[EVL_BASED_IV]]
-; IF-EVL-INLOOP-NEXT:    [[TMP5:%.*]] = call i32 @llvm.experimental.get.vector.length.i32(i32 [[AVL]], i32 8, i1 true)
-; IF-EVL-INLOOP-NEXT:    [[TMP6:%.*]] = add i32 [[EVL_BASED_IV]], 0
-; IF-EVL-INLOOP-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP6]]
-; IF-EVL-INLOOP-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i16, ptr [[TMP7]], i32 0
-; IF-EVL-INLOOP-NEXT:    [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP8]], <vscale x 8 x i1> shufflevector (<vscale x 8 x i1> insertelement (<vscale x 8 x i1> poison, i1 true, i64 0), <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer), i32 [[TMP5]])
-; IF-EVL-INLOOP-NEXT:    [[TMP9:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
-; IF-EVL-INLOOP-NEXT:    [[TMP10:%.*]] = call i32 @llvm.vp.reduce.add.nxv8i32(i32 0, <vscale x 8 x i32> [[TMP9]], <vscale x 8 x i1> shufflevector (<vscale x 8 x i1> insertelement (<vscale x 8 x i1> poison, i1 true, i64 0), <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer), i32 [[TMP5]])
-; IF-EVL-INLOOP-NEXT:    [[TMP11]] = add i32 [[TMP10]], [[VEC_PHI]]
-; IF-EVL-INLOOP-NEXT:    [[INDEX_EVL_NEXT]] = add i32 [[TMP5]], [[EVL_BASED_IV]]
+; IF-EVL-INLOOP-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ]
+; IF-EVL-INLOOP-NEXT:    [[TMP5:%.*]] = sub i32 [[N]], [[EVL_BASED_IV]]
+; IF-EVL-INLOOP-NEXT:    [[TMP6:%.*]] = call i32 @llvm.experimental.get.vector.length.i32(i32 [[TMP5]], i32 8, i1 true)
+; IF-EVL-INLOOP-NEXT:    [[TMP7:%.*]] = add i32 [[EVL_BASED_IV]], 0
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[EVL_BASED_IV]], i64 0
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
+; IF-EVL-INLOOP-NEXT:    [[TMP15:%.*]] = call <vscale x 8 x i32> @llvm.stepvector.nxv8i32()
+; IF-EVL-INLOOP-NEXT:    [[TMP16:%.*]] = add <vscale x 8 x i32> zeroinitializer, [[TMP15]]
+; IF-EVL-INLOOP-NEXT:    [[VEC_IV:%.*]] = add <vscale x 8 x i32> [[BROADCAST_SPLAT]], [[TMP16]]
+; IF-EVL-INLOOP-NEXT:    [[TMP17:%.*]] = icmp ule <vscale x 8 x i32> [[VEC_IV]], [[BROADCAST_SPLAT2]]
+; IF-EVL-INLOOP-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP7]]
+; IF-EVL-INLOOP-NEXT:    [[TMP9:%.*]] = getelementptr inbounds i16, ptr [[TMP8]], i32 0
+; IF-EVL-INLOOP-NEXT:    [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP9]], <vscale x 8 x i1> shufflevector (<vscale x 8 x i1> insertelement (<vscale x 8 x i1> poison, i1 true, i64 0), <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer), i32 [[TMP6]])
+; IF-EVL-INLOOP-NEXT:    [[TMP10:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
+; IF-EVL-INLOOP-NEXT:    [[TMP18:%.*]] = select <vscale x 8 x i1> [[TMP17]], <vscale x 8 x i32> [[TMP10]], <vscale x 8 x i32> zeroinitializer
+; IF-EVL-INLOOP-NEXT:    [[TMP11:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP18]])
+; IF-EVL-INLOOP-NEXT:    [[TMP12]] = add i32 [[TMP11]], [[VEC_PHI]]
+; IF-EVL-INLOOP-NEXT:    [[INDEX_EVL_NEXT]] = add i32 [[TMP6]], [[EVL_BASED_IV]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_NEXT]] = add i32 [[INDEX]], [[TMP4]]
 ; IF-EVL-INLOOP-NEXT:    [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; IF-EVL-INLOOP-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
index 6771f561913130..cb025608882e6a 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
@@ -424,62 +424,62 @@ define i32 @reduction_mix(ptr noalias nocapture %A, ptr noalias nocapture %B) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x i1> [[TMP0]], i64 0
 ; CHECK-NEXT:    br i1 [[TMP1]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
 ; CHECK:       pred.load.if:
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr [[TMP2]], align 4
-; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <4 x i32> poison, i32 [[TMP3]], i64 0
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = load i32, ptr [[TMP5]], align 4
 ; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x i32> poison, i32 [[TMP6]], i64 0
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[B1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[TMP12:%.*]] = load i32, ptr [[TMP8]], align 4
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <4 x i32> poison, i32 [[TMP12]], i64 0
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE]]
 ; CHECK:       pred.load.continue:
-; CHECK-NEXT:    [[TMP8:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP4]], [[PRED_LOAD_IF]] ]
 ; CHECK-NEXT:    [[TMP9:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP7]], [[PRED_LOAD_IF]] ]
+; CHECK-NEXT:    [[TMP14:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP13]], [[PRED_LOAD_IF]] ]
 ; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x i1> [[TMP0]], i64 1
 ; CHECK-NEXT:    br i1 [[TMP10]], label [[PRED_LOAD_IF3:%.*]], label [[PRED_LOAD_CONTINUE4:%.*]]
 ; CHECK:       pred.load.if3:
 ; CHECK-NEXT:    [[TMP11:%.*]] = or disjoint i64 [[INDEX]], 1
-; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP11]]
-; CHECK-NEXT:    [[TMP13:%.*]] = load i32, ptr [[TMP12]], align 4
-; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <4 x i32> [[TMP8]], i32 [[TMP13]], i64 1
 ; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP11]]
 ; CHECK-NEXT:    [[TMP16:%.*]] = load i32, ptr [[TMP15]], align 4
 ; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <4 x i32> [[TMP9]], i32 [[TMP16]], i64 1
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP22:%.*]] = load i32, ptr [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP23:%.*]] = insertelement <4 x i32> [[TMP14]], i32 [[TMP22]], i64 1
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE4]]
 ; CHECK:       pred.load.continue4:
-; CHECK-NEXT:    [[TMP18:%.*]] = phi <4 x i32> [ [[TMP8]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP14]], [[PRED_LOAD_IF3]] ]
 ; CHECK-NEXT:    [[TMP19:%.*]] = phi <4 x i32> [ [[TMP9]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP17]], [[PRED_LOAD_IF3]] ]
+; CHECK-NEXT:    [[TMP24:%.*]] = phi <4 x i32> [ [[TMP14]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP23]], [[PRED_LOAD_IF3]] ]
 ; CHECK-NEXT:    [[TMP20:%.*]] = extractelement <4 x i1> [[TMP0]], i64 2
 ; CHECK-NEXT:    br i1 [[TMP20]], label [[PRED_LOAD_IF5:%.*]], label [[PRED_LOAD_CONTINUE6:%.*]]
 ; CHECK:       pred.load.if5:
 ; CHECK-NEXT:    [[TMP21:%.*]] = or disjoint i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP22:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP21]]
-; CHECK-NEXT:    [[TMP23:%.*]] = load i32, ptr [[TMP22]], align 4
-; CHECK-NEXT:    [[TMP24:%.*]] = insertelement <4 x i32> [[TMP18]], i32 [[TMP23]], i64 2
 ; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP21]]
 ; CHECK-NEXT:    [[TMP26:%.*]] = load i32, ptr [[TMP25]], align 4
 ; CHECK-NEXT:    [[TMP27:%.*]] = insertelement <4 x i32> [[TMP19]], i32 [[TMP26]], i64 2
+; CHECK-NEXT:    [[TMP28:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP21]]
+; CHECK-NEXT:    [[TMP32:%.*]] = load i32, ptr [[TMP28]], align 4
+; CHECK-NEXT:    [[TMP33:%.*]] = insertelement <4 x i32> [[TMP24]], i32 [[TMP32]], i64 2
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE6]]
 ; CHECK:       pred.load.continue6:
-; CHECK-NEXT:    [[TMP28:%.*]] = phi <4 x i32> [ [[TMP18]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP24]], [[PRED_LOAD_IF5]] ]
 ; CHECK-NEXT:    [[TMP29:%.*]] = phi <4 x i32> [ [[TMP19]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP27]], [[PRED_LOAD_IF5]] ]
+; CHECK-NEXT:    [[TMP34:%.*]] = phi <4 x i32> [ [[TMP24]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP33]], [[PRED_LOAD_IF5]] ]
 ; CHECK-NEXT:    [[TMP30:%.*]] = extractelement <4 x i1> [[TMP0]], i64 3
 ; CHECK-NEXT:    br i1 [[TMP30]], label [[PRED_LOAD_IF7:%.*]], label [[PRED_LOAD_CONTINUE8]]
 ; CHECK:       pred.load.if7:
 ; CHECK-NEXT:    [[TMP31:%.*]] = or disjoint i64 [[INDEX]], 3
-; CHECK-NEXT:    [[TMP32:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP31]]
-; CHECK-NEXT:    [[TMP33:%.*]] = load i32, ptr [[TMP32]], align 4
-; CHECK-NEXT:    [[TMP34:%.*]] = insertelement <4 x i32> [[TMP28]], i32 [[TMP33]], i64 3
 ; CHECK-NEXT:    [[TMP35:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP31]]
 ; CHECK-NEXT:    [[TMP36:%.*]] = load i32, ptr [[TMP35]], align 4
 ; CHECK-NEXT:    [[TMP37:%.*]] = insertelement <4 x i32> [[TMP29]], i32 [[TMP36]], i64 3
+; CHECK-NEXT:    [[TMP38:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP31]]
+; CHECK-NEXT:    [[TMP48:%.*]] = load i32, ptr [[TMP38]], align 4
+; CHECK-NEXT:    [[TMP49:%.*]] = insertelement <4 x i32> [[TMP34]], i32 [[TMP48]], i64 3
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE8]]
 ; CHECK:       pred.load.continue8:
-; CHECK-NEXT:    [[TMP38:%.*]] = phi <4 x i32> [ [[TMP28]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP34]], [[PRED_LOAD_IF7]] ]
 ; CHECK-NEXT:    [[TMP39:%.*]] = phi <4 x i32> [ [[TMP29]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP37]], [[PRED_LOAD_IF7]] ]
-; CHECK-NEXT:    [[TMP40:%.*]] = mul nsw <4 x i32> [[TMP39]], [[TMP38]]
+; CHECK-NEXT:    [[TMP50:%.*]] = phi <4 x i32> [ [[TMP34]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP49]], [[PRED_LOAD_IF7]] ]
 ; CHECK-NEXT:    [[TMP41:%.*]] = select <4 x i1> [[TMP0]], <4 x i32> [[VEC_IND1]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP42:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP41]])
 ; CHECK-NEXT:    [[TMP43:%.*]] = add i32 [[TMP42]], [[VEC_PHI]]
+; CHECK-NEXT:    [[TMP40:%.*]] = mul nsw <4 x i32> [[TMP50]], [[TMP39]]
 ; CHECK-NEXT:    [[TMP44:%.*]] = select <4 x i1> [[TMP0]], <4 x i32> [[TMP40]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP45:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP44]])
 ; CHECK-NEXT:    [[TMP46]] = add i32 [[TMP45]], [[TMP43]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index fe74a7c3a9b27c..b578e61d85dfa1 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -221,13 +221,13 @@ define i32 @reduction_mix(ptr noalias nocapture %A, ptr noalias nocapture %B) {
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP6:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_IND:%.*]] = phi <4 x i32> [ <i32 0, i32 1, i32 2, i32 3>, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i32>, ptr [[TMP1]], align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[B1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <4 x i32>, ptr [[TMP8]], align 4
 ; CHECK-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[VEC_IND]])
 ; CHECK-NEXT:    [[TMP4:%.*]] = add i32 [[TMP3]], [[VEC_PHI]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD2]], [[WIDE_LOAD1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP2]])
 ; CHECK-NEXT:    [[TMP6]] = add i32 [[TMP5]], [[TMP4]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
@@ -329,11 +329,11 @@ define i32 @start_at_non_zero(ptr nocapture %in, ptr nocapture %coeff, ptr nocap
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 120, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[IN:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[COEFF:%.*]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i32>, ptr [[TMP1]], align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds i32, ptr [[COEFF1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <4 x i32>, ptr [[TMP6]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD2]], [[WIDE_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP2]])
 ; CHECK-NEXT:    [[TMP4]] = add i32 [[TMP3]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4

>From fbfad8f11f3c5e54d0b2a14ddad8fb229e13ca8a Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 5 Nov 2024 17:49:26 -0800
Subject: [PATCH 5/9] Refactors

Using lamda function to early return when pattern matched.
Leave some assertions.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 121 ++++++++---------
 llvm/lib/Transforms/Vectorize/VPlan.h         |  55 ++++----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 123 +-----------------
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  19 +--
 .../LoopVectorize/ARM/mve-reductions.ll       |   3 +-
 .../LoopVectorize/RISCV/inloop-reduction.ll   |   8 +-
 6 files changed, 113 insertions(+), 216 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index ea2e6b96f7711c..ee429f9d3fb97d 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7397,7 +7397,7 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       // VPExtendedReductionRecipe contains a folded extend instruction.
       if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
         SeenInstrs.insert(ExtendedRed->getExtInstr());
-      // VPMulAccRecupe constians a mul and otional extend instructions.
+      // VPMulAccRecipe constians a mul and otional extend instructions.
       else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
         SeenInstrs.insert(MulAcc->getMulInstr());
         if (MulAcc->isExtended()) {
@@ -9390,77 +9390,82 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      VPValue *A, *B;
-      VPSingleDefRecipe *RedRecipe;
-      // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
-      if (RdxDesc.getOpcode() == Instruction::Add &&
-          match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
-        VPRecipeBase *RecipeA = A->getDefiningRecipe();
-        VPRecipeBase *RecipeB = B->getDefiningRecipe();
-        if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
-            match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
-            cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
-                cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
-            !A->hasMoreThanOneUniqueUser() && !B->hasMoreThanOneUniqueUser()) {
-          RedRecipe = new VPMulAccRecipe(
-              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-              CM.useOrderedReductions(RdxDesc),
-              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
-              cast<VPWidenCastRecipe>(RecipeA),
-              cast<VPWidenCastRecipe>(RecipeB));
-        } else {
-          RedRecipe = new VPMulAccRecipe(
-              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-              CM.useOrderedReductions(RdxDesc),
-              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
-        }
-      } else if (RdxDesc.getOpcode() == Instruction::Add &&
-                 match(VecOp,
-                       m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue(A)),
-                                          m_ZExtOrSExt(m_VPValue(B)))))) {
-        VPWidenCastRecipe *Ext =
-            dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
-        VPWidenRecipe *Mul =
-            dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
-        if (Mul && match(Mul, m_Mul(m_ZExtOrSExt(m_VPValue()),
-                                    m_ZExtOrSExt(m_VPValue())))) {
+      auto TryToMatchMulAcc = [&]() -> VPSingleDefRecipe * {
+        VPValue *A, *B;
+        if (RdxDesc.getOpcode() != Instruction::Add)
+          return nullptr;
+        // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+        if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
+            !VecOp->hasMoreThanOneUniqueUser()) {
+          VPRecipeBase *RecipeA = A->getDefiningRecipe();
+          VPRecipeBase *RecipeB = B->getDefiningRecipe();
+          if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+              match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+              cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
+                  cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
+              !A->hasMoreThanOneUniqueUser() &&
+              !B->hasMoreThanOneUniqueUser()) {
+            return new VPMulAccRecipe(
+                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
+                cast<VPWidenCastRecipe>(RecipeA),
+                cast<VPWidenCastRecipe>(RecipeB));
+          } else {
+            // Matched reduce.add(mul(...))
+            return new VPMulAccRecipe(
+                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+          }
+          // Matched reduce.add(ext(mul(ext, ext)))
+          // Note that 3 extend instructions must have same opcode.
+        } else if (match(VecOp,
+                         m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
+                                            m_ZExtOrSExt(m_VPValue())))) &&
+                   !VecOp->hasMoreThanOneUniqueUser()) {
+          VPWidenCastRecipe *Ext =
+              dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
           VPWidenRecipe *Mul =
-              cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+              dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext0 =
               cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext1 =
               cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
           if (Ext->getOpcode() == Ext0->getOpcode() &&
-              Ext0->getOpcode() == Ext1->getOpcode()) {
-            RedRecipe = new VPMulAccRecipe(
+              Ext0->getOpcode() == Ext1->getOpcode() &&
+              !Mul->hasMoreThanOneUniqueUser() &&
+              !Ext0->hasMoreThanOneUniqueUser() &&
+              !Ext1->hasMoreThanOneUniqueUser()) {
+            return new VPMulAccRecipe(
                 RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                 CM.useOrderedReductions(RdxDesc),
                 cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
                 cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
-          } else
-            RedRecipe = new VPExtendedReductionRecipe(
-                RdxDesc, CurrentLinkI,
-                cast<CastInst>(
-                    cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
-                PreviousLink, cast<VPWidenCastRecipe>(VecOp)->getOperand(0),
-                CondOp, CM.useOrderedReductions(RdxDesc),
-                cast<VPWidenCastRecipe>(VecOp)->getResultType());
+          }
         }
-      }
-      // VPWidenCastRecipes can folded into VPReductionRecipe
-      else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
-               !VecOp->hasMoreThanOneUniqueUser()) {
-        RedRecipe = new VPExtendedReductionRecipe(
-            RdxDesc, CurrentLinkI,
-            cast<CastInst>(
-                cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
-            PreviousLink, A, CondOp, CM.useOrderedReductions(RdxDesc),
-            cast<VPWidenCastRecipe>(VecOp)->getResultType());
-      } else {
+        return nullptr;
+      };
+      auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
+        VPValue *A;
+        if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+            !VecOp->hasMoreThanOneUniqueUser()) {
+          return new VPExtendedReductionRecipe(
+              RdxDesc, CurrentLinkI, PreviousLink,
+              cast<VPWidenCastRecipe>(VecOp), CondOp,
+              CM.useOrderedReductions(RdxDesc));
+        }
+        return nullptr;
+      };
+      VPSingleDefRecipe *RedRecipe;
+      if (auto *MulAcc = TryToMatchMulAcc())
+        RedRecipe = MulAcc;
+      else if (auto *ExtendedRed = TryToMatchExtendedReduction())
+        RedRecipe = ExtendedRed;
+      else
         RedRecipe =
             new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
                                   CondOp, CM.useOrderedReductions(RdxDesc));
-      }
       // Append the recipe to the end of the VPBasicBlock because we need to
       // ensure that it comes after all of it's inputs, including CondOp.
       // Note that this transformation may leave over dead recipes (including
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 0103686be422d6..c7117551bf5dda 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2670,18 +2670,19 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
   bool IsConditional = false;
   /// Type after extend.
   Type *ResultTy;
+  /// Opcode for the extend instruction.
   Instruction::CastOps ExtOp;
-  CastInst *CastInstr;
+  CastInst *ExtInstr;
   bool IsZExt;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
                             const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, CastInst *CastI,
+                            Instruction::CastOps ExtOp, CastInst *ExtI,
                             ArrayRef<VPValue *> Operands, VPValue *CondOp,
                             bool IsOrdered, Type *ResultTy)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
-        ResultTy(ResultTy), ExtOp(ExtOp), CastInstr(CastI) {
+        ResultTy(ResultTy), ExtOp(ExtOp), ExtInstr(ExtI) {
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2691,20 +2692,13 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
-                            CastInst *CastI, VPValue *ChainOp, VPValue *VecOp,
-                            VPValue *CondOp, bool IsOrdered, Type *ResultTy)
-      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
-                                  CastI->getOpcode(), CastI,
-                                  ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
-                                  IsOrdered, ResultTy) {}
-
-  VPExtendedReductionRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *Ext)
+                            VPValue *ChainOp, VPWidenCastRecipe *Ext,
+                            VPValue *CondOp, bool IsOrdered)
       : VPExtendedReductionRecipe(
-            VPDef::VPExtendedReductionSC, Red->getRecurrenceDescriptor(),
-            Red->getUnderlyingInstr(), Ext->getOpcode(),
+            VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
             cast<CastInst>(Ext->getUnderlyingInstr()),
-            ArrayRef<VPValue *>({Red->getChainOp(), Ext->getOperand(0)}),
-            Red->getCondOp(), Red->isOrdered(), Ext->getResultType()) {}
+            ArrayRef<VPValue *>({ChainOp, Ext->getOperand(0)}), CondOp,
+            IsOrdered, Ext->getResultType()) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2721,7 +2715,6 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
     return R && classof(R);
   }
 
-  /// Generate the reduction in the loop
   void execute(VPTransformState &State) override {
     llvm_unreachable("VPExtendedReductionRecipe should be transform to "
                      "VPExtendedRecipe + VPReductionRecipe before execution.");
@@ -2753,9 +2746,12 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
   VPValue *getCondOp() const {
     return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
   }
+  /// The Type after extended.
   Type *getResultType() const { return ResultTy; };
+  /// The Opcode of extend instruction.
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
-  CastInst *getExtInstr() const { return CastInstr; };
+  /// The CastInst of the extend instruction.
+  CastInst *getExtInstr() const { return ExtInstr; };
 };
 
 /// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2771,16 +2767,17 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   bool IsConditional = false;
   /// Type after extend.
   Type *ResultType;
-  /// reduce.add(ext((mul(Ext(), Ext())))
+  // Note that all extend instruction must have the same opcode in MulAcc.
   Instruction::CastOps ExtOp;
 
+  /// reduce.add(ext(mul(ext0(), ext1())))
   Instruction *MulInstr;
   CastInst *ExtInstr = nullptr;
-  CastInst *Ext0Instr;
-  CastInst *Ext1Instr;
+  CastInst *Ext0Instr = nullptr;
+  CastInst *Ext1Instr = nullptr;
 
+  /// Is this MulAcc recipe contains extend recipes?
   bool IsExtended;
-  bool IsOuterExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2794,6 +2791,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
         ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2806,6 +2804,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
         MulInstr(MulInstr) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2857,13 +2856,12 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
     return R && classof(R);
   }
 
-  /// Generate the reduction in the loop
   void execute(VPTransformState &State) override {
     llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
                      "VPWidenRecipe + VPReductionRecipe before execution");
   }
 
-  /// Return the cost of VPExtendedReductionRecipe.
+  /// Return the cost of VPMulAccRecipe.
   InstructionCost computeCost(ElementCount VF,
                               VPCostContext &Ctx) const override;
 
@@ -2890,13 +2888,24 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   VPValue *getCondOp() const {
     return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
   }
+  /// Return the type after inner extended, which must equal to the type of mul
+  /// instruction. If the ResultType != recurrenceType, than it must have a
+  /// extend recipe after mul recipe.
   Type *getResultType() const { return ResultType; };
+  /// The opcode of the extend instructions.
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  /// The underlying instruction for VPWidenRecipe.
   Instruction *getMulInstr() const { return MulInstr; };
+  /// The underlying Instruction for outer VPWidenCastRecipe.
   CastInst *getExtInstr() const { return ExtInstr; };
+  /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt0Instr() const { return Ext0Instr; };
+  /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt1Instr() const { return Ext1Instr; };
+  /// Return if this MulAcc recipe contains extend instructions.
   bool isExtended() const { return IsExtended; };
+  /// Return if the operands of mul instruction come from same extend.
+  bool isSameExtend() const { return Ext0Instr == Ext1Instr; };
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index d28b2d93ee5630..ab40bed209167d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2209,122 +2209,6 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
         Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
-  /*
-  using namespace llvm::VPlanPatternMatch;
-  auto GetMulAccReductionCost =
-      [&](const VPReductionRecipe *Red) -> InstructionCost {
-    VPValue *A, *B;
-    InstructionCost InnerExt0Cost = 0;
-    InstructionCost InnerExt1Cost = 0;
-    InstructionCost ExtCost = 0;
-    InstructionCost MulCost = 0;
-
-    VectorType *SrcVecTy = VectorTy;
-    Type *InnerExt0Ty;
-    Type *InnerExt1Ty;
-    Type *MaxInnerExtTy;
-    bool IsUnsigned = true;
-    bool HasOuterExt = false;
-
-    auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
-        Red->getVecOp()->getDefiningRecipe());
-    VPRecipeBase *Mul;
-    // Try to match outer extend reduce.add(ext(...))
-    if (Ext && match(Ext, m_ZExtOrSExt(m_VPValue())) &&
-        cast<VPWidenCastRecipe>(Ext)->getNumUsers() == 1) {
-      IsUnsigned =
-          Ext->getOpcode() == Instruction::CastOps::ZExt ? true : false;
-      ExtCost = Ext->computeCost(VF, Ctx);
-      Mul = Ext->getOperand(0)->getDefiningRecipe();
-      HasOuterExt = true;
-    } else {
-      Mul = Red->getVecOp()->getDefiningRecipe();
-    }
-
-    // Match reduce.add(mul())
-    if (Mul && match(Mul, m_Mul(m_VPValue(A), m_VPValue(B))) &&
-        cast<VPWidenRecipe>(Mul)->getNumUsers() == 1) {
-      MulCost = cast<VPWidenRecipe>(Mul)->computeCost(VF, Ctx);
-      auto *InnerExt0 =
-          dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
-      auto *InnerExt1 =
-          dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
-      bool HasInnerExt = false;
-      // Try to match inner extends.
-      if (InnerExt0 && InnerExt1 &&
-          match(InnerExt0, m_ZExtOrSExt(m_VPValue())) &&
-          match(InnerExt1, m_ZExtOrSExt(m_VPValue())) &&
-          InnerExt0->getOpcode() == InnerExt1->getOpcode() &&
-          (InnerExt0->getNumUsers() > 0 &&
-           !InnerExt0->hasMoreThanOneUniqueUser()) &&
-          (InnerExt1->getNumUsers() > 0 &&
-           !InnerExt1->hasMoreThanOneUniqueUser())) {
-        InnerExt0Cost = InnerExt0->computeCost(VF, Ctx);
-        InnerExt1Cost = InnerExt1->computeCost(VF, Ctx);
-        Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
-        Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
-        Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth() >
-                                      InnerExt1Ty->getIntegerBitWidth()
-                                  ? InnerExt0Ty
-                                  : InnerExt1Ty;
-        SrcVecTy = cast<VectorType>(ToVectorTy(MaxInnerExtTy, VF));
-        IsUnsigned = true;
-        HasInnerExt = true;
-      }
-      InstructionCost MulAccRedCost = Ctx.TTI.getMulAccReductionCost(
-          IsUnsigned, ElementTy, SrcVecTy, CostKind);
-      // Check if folding ext/mul into MulAccReduction is profitable.
-      if (MulAccRedCost.isValid() &&
-          MulAccRedCost <
-              ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
-        if (HasInnerExt) {
-          Ctx.FoldedRecipes[VF].insert(InnerExt0);
-          Ctx.FoldedRecipes[VF].insert(InnerExt1);
-        }
-        Ctx.FoldedRecipes[VF].insert(Mul);
-        if (HasOuterExt)
-          Ctx.FoldedRecipes[VF].insert(Ext);
-        return MulAccRedCost;
-      }
-    }
-    return InstructionCost::getInvalid();
-  };
-
-  // Match reduce(ext(...))
-  auto GetExtendedReductionCost =
-      [&](const VPReductionRecipe *Red) -> InstructionCost {
-    VPValue *VecOp = Red->getVecOp();
-    VPValue *A;
-    if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && VecOp->getNumUsers() == 1) {
-      VPWidenCastRecipe *Ext =
-          cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
-      bool IsUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
-      InstructionCost ExtCost = Ext->computeCost(VF, Ctx);
-      auto *ExtVecTy =
-          cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(A), VF));
-      InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
-          Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags(),
-          CostKind);
-      // Check if folding ext into ExtendedReduction is profitable.
-      if (ExtendedRedCost.isValid() && ExtendedRedCost < ExtCost + BaseCost) {
-        Ctx.FoldedRecipes[VF].insert(Ext);
-        return ExtendedRedCost;
-      }
-    }
-    return InstructionCost::getInvalid();
-  };
-
-  // Match MulAccReduction patterns.
-  InstructionCost MulAccCost = GetMulAccReductionCost(this);
-  if (MulAccCost.isValid())
-    return MulAccCost;
-
-  // Match ExtendedReduction patterns.
-  InstructionCost ExtendedCost = GetExtendedReductionCost(this);
-  if (ExtendedCost.isValid())
-    return ExtendedCost;
-  */
-
   // Default cost.
   return BaseCost;
 }
@@ -2338,9 +2222,6 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
-         "Inferred type and recurrence type mismatch.");
-
   // BaseCost = Reduction cost + BinOp cost
   InstructionCost ReductionCost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
@@ -2382,8 +2263,8 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
-         "Inferred type and recurrence type mismatch.");
+  assert(Opcode == Instruction::Add &&
+         "Reduction opcode must be add in the VPMulAccRecipe.");
 
   // BaseCost = Reduction cost + BinOp cost
   InstructionCost ReductionCost =
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index b7cc945747df13..e52991acd20562 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -545,11 +545,15 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
           Op0 = new VPWidenCastRecipe(
               MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
               MulAcc->getResultType(), *MulAcc->getExt0Instr());
-          Op1 = new VPWidenCastRecipe(
-              MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
-              MulAcc->getResultType(), *MulAcc->getExt1Instr());
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
-          Op1->getDefiningRecipe()->insertBefore(MulAcc);
+          if (!MulAcc->isSameExtend()) {
+            Op1 = new VPWidenCastRecipe(
+                MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+                MulAcc->getResultType(), *MulAcc->getExt1Instr());
+            Op1->getDefiningRecipe()->insertBefore(MulAcc);
+          } else {
+            Op1 = Op0;
+          }
         } else {
           Op0 = MulAcc->getVecOp0();
           Op1 = MulAcc->getVecOp1();
@@ -559,14 +563,13 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
         auto *Mul = new VPWidenRecipe(*MulInstr,
                                       make_range(MulOps.begin(), MulOps.end()));
-        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
-          // dbgs() <<"\n!!!"<< *OuterExtInstr << " " << MulAcc->getExtOpcode()
-          // << "\n";
+        // Outer extend.
+        if (auto *OuterExtInstr = MulAcc->getExtInstr())
           VecOp = new VPWidenCastRecipe(
               MulAcc->getExtOpcode(), Mul,
               MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
               *OuterExtInstr);
-        } else
+        else
           VecOp = Mul;
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index 1e4134a8fdce97..1cbca4c8eaf1d9 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1412,9 +1412,8 @@ define i32 @mla_i8_i32_multiuse(ptr nocapture readonly %x, ptr nocapture readonl
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP7]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP7]], [[TMP7]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP2]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
index 9f1a61ebb5efef..9af31c9a4762b3 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
@@ -215,13 +215,13 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
 ; IF-EVL-INLOOP-NEXT:    [[TMP12]] = add i32 [[TMP11]], [[VEC_PHI]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_EVL_NEXT]] = add i32 [[TMP6]], [[EVL_BASED_IV]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_NEXT]] = add i32 [[INDEX]], [[TMP4]]
-; IF-EVL-INLOOP-NEXT:    [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; IF-EVL-INLOOP-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; IF-EVL-INLOOP-NEXT:    [[TMP19:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
+; IF-EVL-INLOOP-NEXT:    br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; IF-EVL-INLOOP:       middle.block:
 ; IF-EVL-INLOOP-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ; IF-EVL-INLOOP:       scalar.ph:
 ; IF-EVL-INLOOP-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[FOR_BODY_PREHEADER]] ]
-; IF-EVL-INLOOP-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP11]], [[MIDDLE_BLOCK]] ], [ 0, [[FOR_BODY_PREHEADER]] ]
+; IF-EVL-INLOOP-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP12]], [[MIDDLE_BLOCK]] ], [ 0, [[FOR_BODY_PREHEADER]] ]
 ; IF-EVL-INLOOP-NEXT:    br label [[FOR_BODY:%.*]]
 ; IF-EVL-INLOOP:       for.body:
 ; IF-EVL-INLOOP-NEXT:    [[I_08:%.*]] = phi i32 [ [[INC:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
@@ -234,7 +234,7 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
 ; IF-EVL-INLOOP-NEXT:    [[EXITCOND:%.*]] = icmp eq i32 [[INC]], [[N]]
 ; IF-EVL-INLOOP-NEXT:    br i1 [[EXITCOND]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
 ; IF-EVL-INLOOP:       for.cond.cleanup.loopexit:
-; IF-EVL-INLOOP-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD]], [[FOR_BODY]] ], [ [[TMP11]], [[MIDDLE_BLOCK]] ]
+; IF-EVL-INLOOP-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD]], [[FOR_BODY]] ], [ [[TMP12]], [[MIDDLE_BLOCK]] ]
 ; IF-EVL-INLOOP-NEXT:    br label [[FOR_COND_CLEANUP]]
 ; IF-EVL-INLOOP:       for.cond.cleanup:
 ; IF-EVL-INLOOP-NEXT:    [[R_0_LCSSA:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[ADD_LCSSA]], [[FOR_COND_CLEANUP_LOOPEXIT]] ]

>From efd223615b23b83339ae7249e5647eb166a4745f Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Wed, 6 Nov 2024 21:07:04 -0800
Subject: [PATCH 6/9] Fix typos and update printing test

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |   3 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |   7 +-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  18 +-
 .../Transforms/Vectorize/VPlanTransforms.h    |   2 +-
 .../LoopVectorize/vplan-printing.ll           | 236 ++++++++++++++++++
 5 files changed, 254 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index ee429f9d3fb97d..b50f8d9c4cf438 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7675,7 +7675,7 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
                              ILV.getOrCreateVectorTripCount(nullptr),
                              CanonicalIVStartValue, State);
 
-  // TODO: Rebase to fhahn's implementation.
+  // TODO: Replace with upstream implementation.
   VPlanTransforms::prepareExecute(BestVPlan);
   BestVPlan.execute(&State);
 
@@ -9271,7 +9271,6 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
 // Adjust AnyOf reductions; replace the reduction phi for the selected value
 // with a boolean reduction phi node to check if the condition is true in any
 // iteration. The final value is selected by the final ComputeReductionResult.
-// TODO: Implement VPMulAccHere.
 void LoopVectorizationPlanner::adjustRecipesForReductions(
     VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) {
   using namespace VPlanPatternMatch;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c7117551bf5dda..d48e75fe9fb64c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2757,8 +2757,8 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
 /// A recipe to represent inloop MulAccreduction operations, performing a
 /// reduction on a vector operand into a scalar value, and adding the result to
 /// a chain. This recipe is high level abstract which will generate
-/// VPReductionRecipe VPWidenRecipe(mul)and VPWidenCastRecipe before execution.
-/// The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+/// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
+/// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPSingleDefRecipe {
   /// The recurrence decriptor for the reduction in question.
   const RecurrenceDescriptor &RdxDesc;
@@ -2778,6 +2778,8 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
 
   /// Is this MulAcc recipe contains extend recipes?
   bool IsExtended;
+  /// Is this reciep contains outer extend instuction?
+  bool IsOuterExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2797,6 +2799,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
       addOperand(CondOp);
     }
     IsExtended = true;
+    IsOuterExtended = ExtInstr != nullptr;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ab40bed209167d..8a2ecd46ab672d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,7 +270,7 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = &WidenMem->getIngredient();
 
   InstructionCost RecipeCost;
-  if ((UI && Ctx.skipCostComputation(UI, VF.isVector()))) {
+  if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
     RecipeCost = 0;
   } else {
     RecipeCost = computeCost(VF, Ctx);
@@ -2408,18 +2408,20 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   printAsOperand(O, SlotTracker);
   O << " = ";
   getChainOp()->printAsOperand(O, SlotTracker);
-  O << " +";
+  O << " + ";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
-  O << " mul ";
+  if (IsOuterExtended)
+    O << " (";
+  O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  O << "mul ";
   if (IsExtended)
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
   if (IsExtended)
-    O << " extended to " << *getResultType() << ")";
-  if (IsExtended)
-    O << "(";
+    O << " extended to " << *getResultType() << "), (";
+  else
+    O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
   if (IsExtended)
     O << " extended to " << *getResultType() << ")";
@@ -2428,6 +2430,8 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
+  if (IsOuterExtended)
+    O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
   if (RdxDesc.IntermediateStore)
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 6310c23b605da3..40ed91104d566e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -124,7 +124,7 @@ struct VPlanTransforms {
   /// Remove dead recipes from \p Plan.
   static void removeDeadRecipes(VPlan &Plan);
 
-  /// TODO: Rebase to fhahn's implementation.
+  /// TODO: Rebase to upstream implementation.
   static void prepareExecute(VPlan &Plan);
 };
 
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 6bb20a301e0ade..a8fb374a4b162d 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -1143,6 +1143,242 @@ exit:
   ret i16 %for.1
 }
 
+define i64 @print_extended_reduction(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_extended_reduction'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:     WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT:     EXTENDED-REDUCE ir<%add> = ir<%r.09> + reduce.add (ir<%load0> extended to i64)
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%6> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT:   EMIT vp<%7> = extract-from-end vp<%6>, ir<1>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%7>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%6>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i32, ptr %x, i32 %i.010
+; CHECK-NEXT:   IR   %load0 = load i32, ptr %arrayidx, align 4
+; CHECK-NEXT:   IR   %conv0 = zext i32 %load0 to i64
+; CHECK-NEXT:   IR   %add = add nsw i64 %r.09, %conv0
+; CHECK-NEXT:   IR   %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:   IR   %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+  %cmp8 = icmp sgt i32 %n, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i32, ptr %x, i32 %i.010
+  %load0 = load i32, ptr %arrayidx, align 4
+  %conv0 = zext i32 %load0 to i64
+  %add = add nsw i64 %r.09, %conv0
+  %inc = add nuw nsw i32 %i.010, 1
+  %exitcond = icmp eq i32 %inc, %n
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %r.0.lcssa
+}
+
+define i64 @print_mulacc(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_mulacc'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:     WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT:     CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
+; CHECK-NEXT:     vp<%5> = vector-pointer ir<%arrayidx1>
+; CHECK-NEXT:     WIDEN ir<%load1> = load vp<%5>
+; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> + reduce.add (mul ir<%load0>, ir<%load1>)
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%7> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT:   EMIT vp<%8> = extract-from-end vp<%7>, ir<1>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%7>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i64, ptr %x, i32 %i.010
+; CHECK-NEXT:   IR   %load0 = load i64, ptr %arrayidx, align 4
+; CHECK-NEXT:   IR   %arrayidx1 = getelementptr inbounds i64, ptr %y, i32 %i.010
+; CHECK-NEXT:   IR   %load1 = load i64, ptr %arrayidx1, align 4
+; CHECK-NEXT:   IR   %mul = mul nsw i64 %load0, %load1
+; CHECK-NEXT:   IR   %add = add nsw i64 %r.09, %mul
+; CHECK-NEXT:   IR   %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:   IR   %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+  %cmp8 = icmp sgt i32 %n, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i64, ptr %x, i32 %i.010
+  %load0 = load i64, ptr %arrayidx, align 4
+  %arrayidx1 = getelementptr inbounds i64, ptr %y, i32 %i.010
+  %load1 = load i64, ptr %arrayidx1, align 4
+  %mul = mul nsw i64 %load0, %load1
+  %add = add nsw i64 %r.09, %mul
+  %inc = add nuw nsw i32 %i.010, 1
+  %exitcond = icmp eq i32 %inc, %n
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %r.0.lcssa
+}
+
+define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_mulacc_extended'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:     WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT:     CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
+; CHECK-NEXT:     vp<%5> = vector-pointer ir<%arrayidx1>
+; CHECK-NEXT:     WIDEN ir<%load1> = load vp<%5>
+; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> +  (reduce.add (mul (ir<%load0> extended to i32), (ir<%load1> extended to i32)) extended to i64)
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%7> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT:   EMIT vp<%8> = extract-from-end vp<%7>, ir<1>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%7>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i16, ptr %x, i32 %i.010
+; CHECK-NEXT:   IR   %load0 = load i16, ptr %arrayidx, align 4
+; CHECK-NEXT:   IR   %arrayidx1 = getelementptr inbounds i16, ptr %y, i32 %i.010
+; CHECK-NEXT:   IR   %load1 = load i16, ptr %arrayidx1, align 4
+; CHECK-NEXT:   IR   %conv0 = sext i16 %load0 to i32
+; CHECK-NEXT:   IR   %conv1 = sext i16 %load1 to i32
+; CHECK-NEXT:   IR   %mul = mul nsw i32 %conv0, %conv1
+; CHECK-NEXT:   IR   %conv = sext i32 %mul to i64
+; CHECK-NEXT:   IR   %add = add nsw i64 %r.09, %conv
+; CHECK-NEXT:   IR   %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:   IR   %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+  %cmp8 = icmp sgt i32 %n, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i16, ptr %x, i32 %i.010
+  %load0 = load i16, ptr %arrayidx, align 4
+  %arrayidx1 = getelementptr inbounds i16, ptr %y, i32 %i.010
+  %load1 = load i16, ptr %arrayidx1, align 4
+  %conv0 = sext i16 %load0 to i32
+  %conv1 = sext i16 %load1 to i32
+  %mul = mul nsw i32 %conv0, %conv1
+  %conv = sext i32 %mul to i64
+  %add = add nsw i64 %r.09, %conv
+  %inc = add nuw nsw i32 %i.010, 1
+  %exitcond = icmp eq i32 %inc, %n
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %r.0.lcssa
+}
+
 !llvm.dbg.cu = !{!0}
 !llvm.module.flags = !{!3, !4}
 

>From 234e81ee87184b277b4f4b17ac222737fbf114eb Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 10 Nov 2024 18:00:52 -0800
Subject: [PATCH 7/9] Fold reduce.add(zext(mul(sext(A), sext(B)))) into
 MulAccRecipe when A == B

For the future refactor of avoiding reference underlying instructions
and mismatched opcode and the entend instruction in the new added
pattern, removed passing UI when creating VPWidenCastRecipe.
This removed will lead to dupicate extend instruction created after loop
vectorizer when there are two reduction patterns exist in the same loop.
This redundant instruction might be removed after LV.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 31 ++++++++-----------
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 15 +++++----
 .../LoopVectorize/ARM/mve-reductions.ll       |  3 +-
 .../LoopVectorize/reduction-inloop.ll         |  6 ++--
 4 files changed, 24 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index b50f8d9c4cf438..999b4157142a0e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9396,20 +9396,18 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
         if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
             !VecOp->hasMoreThanOneUniqueUser()) {
-          VPRecipeBase *RecipeA = A->getDefiningRecipe();
-          VPRecipeBase *RecipeB = B->getDefiningRecipe();
+          VPWidenCastRecipe *RecipeA =
+              dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+          VPWidenCastRecipe *RecipeB =
+              dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
           if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
               match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
-              cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
-                  cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
-              !A->hasMoreThanOneUniqueUser() &&
-              !B->hasMoreThanOneUniqueUser()) {
+              (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
             return new VPMulAccRecipe(
                 RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                 CM.useOrderedReductions(RdxDesc),
-                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
-                cast<VPWidenCastRecipe>(RecipeA),
-                cast<VPWidenCastRecipe>(RecipeB));
+                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()), RecipeA,
+                RecipeB);
           } else {
             // Matched reduce.add(mul(...))
             return new VPMulAccRecipe(
@@ -9417,8 +9415,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
                 CM.useOrderedReductions(RdxDesc),
                 cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
           }
-          // Matched reduce.add(ext(mul(ext, ext)))
-          // Note that 3 extend instructions must have same opcode.
+          // Matched reduce.add(ext(mul(ext(A), ext(B))))
+          // Note that 3 extend instructions must have same opcode or A == B
+          // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
         } else if (match(VecOp,
                          m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
                                             m_ZExtOrSExt(m_VPValue())))) &&
@@ -9431,11 +9430,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
               cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext1 =
               cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
-          if (Ext->getOpcode() == Ext0->getOpcode() &&
-              Ext0->getOpcode() == Ext1->getOpcode() &&
-              !Mul->hasMoreThanOneUniqueUser() &&
-              !Ext0->hasMoreThanOneUniqueUser() &&
-              !Ext1->hasMoreThanOneUniqueUser()) {
+          if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
+              Ext0->getOpcode() == Ext1->getOpcode()) {
             return new VPMulAccRecipe(
                 RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                 CM.useOrderedReductions(RdxDesc),
@@ -9447,8 +9443,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       };
       auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
         VPValue *A;
-        if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
-            !VecOp->hasMoreThanOneUniqueUser()) {
+        if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
           return new VPExtendedReductionRecipe(
               RdxDesc, CurrentLinkI, PreviousLink,
               cast<VPWidenCastRecipe>(VecOp), CondOp,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index e52991acd20562..af287577fcb879 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -542,14 +542,14 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
-          Op0 = new VPWidenCastRecipe(
-              MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
-              MulAcc->getResultType(), *MulAcc->getExt0Instr());
+          Op0 =
+              new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+                                    MulAcc->getResultType());
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           if (!MulAcc->isSameExtend()) {
-            Op1 = new VPWidenCastRecipe(
-                MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
-                MulAcc->getResultType(), *MulAcc->getExt1Instr());
+            Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+                                        MulAcc->getVecOp1(),
+                                        MulAcc->getResultType());
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           } else {
             Op1 = Op0;
@@ -567,8 +567,7 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         if (auto *OuterExtInstr = MulAcc->getExtInstr())
           VecOp = new VPWidenCastRecipe(
               MulAcc->getExtOpcode(), Mul,
-              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
-              *OuterExtInstr);
+              MulAcc->getRecurrenceDescriptor().getRecurrenceType());
         else
           VecOp = Mul;
         auto *Red = new VPReductionRecipe(
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index 1cbca4c8eaf1d9..1942fd14213c13 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1535,7 +1535,8 @@ define i64 @mla_and_add_together_16_64(ptr nocapture noundef readonly %x, i32 no
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext nneg <8 x i32> [[TMP2]] to <8 x i64>
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i64 [[TMP4]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP1]])
+; CHECK-NEXT:    [[TMP10:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP10]])
 ; CHECK-NEXT:    [[TMP7]] = add i32 [[TMP6]], [[VEC_PHI1]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index b578e61d85dfa1..d6dbb74f26d4ab 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -1206,15 +1206,13 @@ define i32 @predicated_not_dominates_reduction_twoadd(ptr nocapture noundef read
 ; CHECK-NEXT:    [[TMP0:%.*]] = sext i32 [[INDEX]] to i64
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[H:%.*]], i64 [[TMP0]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
-; CHECK-NEXT:    [[DOTNOT:%.*]] = icmp eq <4 x i8> [[WIDE_LOAD]], zeroinitializer
 ; CHECK-NEXT:    [[TMP2:%.*]] = udiv <4 x i8> [[WIDE_LOAD]], splat (i8 31)
 ; CHECK-NEXT:    [[TMP3:%.*]] = shl nuw nsw <4 x i8> [[TMP2]], splat (i8 3)
 ; CHECK-NEXT:    [[TMP4:%.*]] = udiv <4 x i8> [[TMP3]], splat (i8 31)
 ; CHECK-NEXT:    [[TMP5:%.*]] = zext nneg <4 x i8> [[TMP4]] to <4 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i32> zeroinitializer, <4 x i32> [[TMP5]]
-; CHECK-NEXT:    [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
+; CHECK-NEXT:    [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP5]])
 ; CHECK-NEXT:    [[TMP8:%.*]] = add i32 [[TMP7]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP9:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i32> zeroinitializer, <4 x i32> [[TMP5]]
+; CHECK-NEXT:    [[TMP9:%.*]] = zext nneg <4 x i8> [[TMP4]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP9]])
 ; CHECK-NEXT:    [[TMP11]] = add i32 [[TMP10]], [[TMP8]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4

>From 9237279613a6abb5bec1bbf9bbcb7bdd8e5022db Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 10 Nov 2024 23:34:05 -0800
Subject: [PATCH 8/9] Refactor! Reuse functions from VPReductionRecipe.

---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 114 +++++-------------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |   6 +-
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  20 +--
 3 files changed, 47 insertions(+), 93 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index d48e75fe9fb64c..bb42d840ded90f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -682,8 +682,6 @@ struct VPCostContext {
   LLVMContext &LLVMCtx;
   LoopVectorizationCostModel &CM;
   SmallPtrSet<Instruction *, 8> SkipCostComputation;
-  /// Contains recipes that are folded into other recipes.
-  SmallDenseMap<ElementCount, SmallPtrSet<VPRecipeBase *, 4>, 4> FoldedRecipes;
 
   VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,
                 Type *CanIVTy, LoopVectorizationCostModel &CM)
@@ -2575,6 +2573,8 @@ class VPReductionRecipe : public VPSingleDefRecipe {
                                  getVecOp(), getCondOp(), IsOrdered);
   }
 
+  // TODO: Support VPExtendedReductionRecipe and VPMulAccRecipe after EVL
+  // support.
   static inline bool classof(const VPRecipeBase *R) {
     return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
            R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
@@ -2662,33 +2662,20 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 /// a chain. This recipe is high level abstract which will generate
 /// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
 /// {ChainOp, VecOp, [Condition]}.
-class VPExtendedReductionRecipe : public VPSingleDefRecipe {
-  /// The recurrence decriptor for the reduction in question.
-  const RecurrenceDescriptor &RdxDesc;
-  bool IsOrdered;
-  /// Whether the reduction is conditional.
-  bool IsConditional = false;
+class VPExtendedReductionRecipe : public VPReductionRecipe {
   /// Type after extend.
   Type *ResultTy;
-  /// Opcode for the extend instruction.
-  Instruction::CastOps ExtOp;
   CastInst *ExtInstr;
-  bool IsZExt;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
                             const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, CastInst *ExtI,
-                            ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                            Instruction::CastOps ExtOp, CastInst *ExtInstr,
+                            VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
                             bool IsOrdered, Type *ResultTy)
-      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
-        ResultTy(ResultTy), ExtOp(ExtOp), ExtInstr(ExtI) {
-    if (CondOp) {
-      IsConditional = true;
-      addOperand(CondOp);
-    }
-    IsZExt = ExtOp == Instruction::CastOps::ZExt;
-  }
+      : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
+                          CondOp, IsOrdered),
+        ResultTy(ResultTy), ExtInstr(ExtInstr) {}
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
@@ -2696,9 +2683,8 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
                             VPValue *CondOp, bool IsOrdered)
       : VPExtendedReductionRecipe(
             VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
-            cast<CastInst>(Ext->getUnderlyingInstr()),
-            ArrayRef<VPValue *>({ChainOp, Ext->getOperand(0)}), CondOp,
-            IsOrdered, Ext->getResultType()) {}
+            cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
+            Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2730,26 +2716,11 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
              VPSlotTracker &SlotTracker) const override;
 #endif
 
-  /// Return the recurrence decriptor for the in-loop reduction.
-  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
-    return RdxDesc;
-  }
-  /// Return true if the in-loop reduction is ordered.
-  bool isOrdered() const { return IsOrdered; };
-  /// Return true if the in-loop reduction is conditional.
-  bool isConditional() const { return IsConditional; };
-  /// The VPValue of the scalar Chain being accumulated.
-  VPValue *getChainOp() const { return getOperand(0); }
-  /// The VPValue of the vector value to be extended and reduced.
-  VPValue *getVecOp() const { return getOperand(1); }
-  /// The VPValue of the condition for the block.
-  VPValue *getCondOp() const {
-    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
-  }
   /// The Type after extended.
   Type *getResultType() const { return ResultTy; };
+  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; };
   /// The Opcode of extend instruction.
-  Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); };
   /// The CastInst of the extend instruction.
   CastInst *getExtInstr() const { return ExtInstr; };
 };
@@ -2759,12 +2730,7 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
 /// a chain. This recipe is high level abstract which will generate
 /// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
 /// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
-class VPMulAccRecipe : public VPSingleDefRecipe {
-  /// The recurrence decriptor for the reduction in question.
-  const RecurrenceDescriptor &RdxDesc;
-  bool IsOrdered;
-  /// Whether the reduction is conditional.
-  bool IsConditional = false;
+class VPMulAccRecipe : public VPReductionRecipe {
   /// Type after extend.
   Type *ResultType;
   // Note that all extend instruction must have the same opcode in MulAcc.
@@ -2786,32 +2752,29 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  Instruction *RedI, Instruction *ExtInstr,
                  Instruction *MulInstr, Instruction::CastOps ExtOp,
                  Instruction *Ext0Instr, Instruction *Ext1Instr,
-                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered,
-                 Type *ResultType)
-      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+      : VPReductionRecipe(SC, R, RedI,
+                          ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
+                          CondOp, IsOrdered),
         ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
         ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
-    if (CondOp) {
-      IsConditional = true;
-      addOperand(CondOp);
-    }
     IsExtended = true;
     IsOuterExtended = ExtInstr != nullptr;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *MulInstr,
-                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
-      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+                 Instruction *RedI, Instruction *MulInstr, VPValue *ChainOp,
+                 VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp,
+                 bool IsOrdered)
+      : VPReductionRecipe(SC, R, RedI,
+                          ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
+                          CondOp, IsOrdered),
         MulInstr(MulInstr) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
-    if (CondOp) {
-      IsConditional = true;
-      addOperand(CondOp);
-    }
     IsExtended = false;
   }
 
@@ -2823,17 +2786,15 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
                        Mul->getUnderlyingInstr(), Ext0->getOpcode(),
                        Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ArrayRef<VPValue *>(
-                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
                        CondOp, IsOrdered, Ext0->getResultType()) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul)
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       ArrayRef<VPValue *>(
-                           {ChainOp, Mul->getOperand(0), Mul->getOperand(1)}),
-                       CondOp, IsOrdered) {}
+                       ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
+                       IsOrdered) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
@@ -2842,8 +2803,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
                        Mul->getUnderlyingInstr(), Ext0->getOpcode(),
                        Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ArrayRef<VPValue *>(
-                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
                        CondOp, IsOrdered, Ext0->getResultType()) {}
 
   ~VPMulAccRecipe() override = default;
@@ -2874,37 +2834,27 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
              VPSlotTracker &SlotTracker) const override;
 #endif
 
-  /// Return the recurrence decriptor for the in-loop reduction.
-  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
-    return RdxDesc;
-  }
-  /// Return true if the in-loop reduction is ordered.
-  bool isOrdered() const { return IsOrdered; };
-  /// Return true if the in-loop reduction is conditional.
-  bool isConditional() const { return IsConditional; };
-  /// The VPValue of the scalar Chain being accumulated.
-  VPValue *getChainOp() const { return getOperand(0); }
   /// The VPValue of the vector value to be extended and reduced.
   VPValue *getVecOp0() const { return getOperand(1); }
   VPValue *getVecOp1() const { return getOperand(2); }
-  /// The VPValue of the condition for the block.
-  VPValue *getCondOp() const {
-    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
-  }
+
   /// Return the type after inner extended, which must equal to the type of mul
   /// instruction. If the ResultType != recurrenceType, than it must have a
   /// extend recipe after mul recipe.
   Type *getResultType() const { return ResultType; };
+
   /// The opcode of the extend instructions.
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
   /// The underlying instruction for VPWidenRecipe.
   Instruction *getMulInstr() const { return MulInstr; };
+
   /// The underlying Instruction for outer VPWidenCastRecipe.
   CastInst *getExtInstr() const { return ExtInstr; };
   /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt0Instr() const { return Ext0Instr; };
   /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt1Instr() const { return Ext1Instr; };
+
   /// Return if this MulAcc recipe contains extend instructions.
   bool isExtended() const { return IsExtended; };
   /// Return if the operands of mul instruction come from same extend.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 8a2ecd46ab672d..2ecf4e763c39b4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2216,6 +2216,7 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
 InstructionCost
 VPExtendedReductionRecipe::computeCost(ElementCount VF,
                                        VPCostContext &Ctx) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
   RecurKind RdxKind = RdxDesc.getRecurrenceKind();
   Type *ElementTy = getResultType();
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
@@ -2246,7 +2247,7 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 
   // ExtendedReduction Cost
   InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
-      Opcode, IsZExt, ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
+      Opcode, isZExt(), ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
   // Check if folding ext into ExtendedReduction is profitable.
   if (ExtendedRedCost.isValid() &&
       ExtendedRedCost < ExtendedCost + ReductionCost) {
@@ -2257,6 +2258,7 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 
 InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
   Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
                                : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
@@ -2382,6 +2384,7 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
 
 void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                                       VPSlotTracker &SlotTracker) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
   O << Indent << "EXTENDED-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
@@ -2404,6 +2407,7 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 
 void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
                            VPSlotTracker &SlotTracker) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
   O << Indent << "MULACC-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index af287577fcb879..55e5788e8f9131 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -525,8 +525,7 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
   for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
            vp_depth_first_deep(Plan.getEntry()))) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
-      if (isa<VPExtendedReductionRecipe>(&R)) {
-        auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+      if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
         auto *Ext = new VPWidenCastRecipe(
             ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
             *ExtRed->getExtInstr());
@@ -542,14 +541,14 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
-          Op0 =
-              new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
-                                    MulAcc->getResultType());
+          CastInst *Ext0 = MulAcc->getExt0Instr();
+          Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
+                                      MulAcc->getResultType(), *Ext0);
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           if (!MulAcc->isSameExtend()) {
-            Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
-                                        MulAcc->getVecOp1(),
-                                        MulAcc->getResultType());
+            CastInst *Ext1 = MulAcc->getExt1Instr();
+            Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
+                                        MulAcc->getResultType(), *Ext1);
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           } else {
             Op1 = Op0;
@@ -566,8 +565,9 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         // Outer extend.
         if (auto *OuterExtInstr = MulAcc->getExtInstr())
           VecOp = new VPWidenCastRecipe(
-              MulAcc->getExtOpcode(), Mul,
-              MulAcc->getRecurrenceDescriptor().getRecurrenceType());
+              OuterExtInstr->getOpcode(), Mul,
+              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
+              *OuterExtInstr);
         else
           VecOp = Mul;
         auto *Red = new VPReductionRecipe(

>From d62e3c7ce63b51244c68c91a65af4e29287a5b61 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 11 Nov 2024 23:17:46 -0800
Subject: [PATCH 9/9] Refactor! Add comments and refine new recipes.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 20 +++---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 72 +++++++++----------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 28 ++++----
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 35 ++++++---
 4 files changed, 85 insertions(+), 70 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 999b4157142a0e..fef9956068f1fd 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9389,17 +9389,18 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      auto TryToMatchMulAcc = [&]() -> VPSingleDefRecipe * {
+      auto TryToMatchMulAcc = [&]() -> VPReductionRecipe * {
         VPValue *A, *B;
         if (RdxDesc.getOpcode() != Instruction::Add)
           return nullptr;
-        // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+        // Try to match reduce.add(mul(...))
         if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
             !VecOp->hasMoreThanOneUniqueUser()) {
           VPWidenCastRecipe *RecipeA =
               dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
           VPWidenCastRecipe *RecipeB =
               dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+          // Matched reduce.add(mul(ext, ext))
           if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
               match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
               (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
@@ -9409,23 +9410,23 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
                 cast<VPWidenRecipe>(VecOp->getDefiningRecipe()), RecipeA,
                 RecipeB);
           } else {
-            // Matched reduce.add(mul(...))
+            // Matched reduce.add(mul)
             return new VPMulAccRecipe(
                 RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                 CM.useOrderedReductions(RdxDesc),
                 cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
           }
           // Matched reduce.add(ext(mul(ext(A), ext(B))))
-          // Note that 3 extend instructions must have same opcode or A == B
+          // Note that all extend instructions must have same opcode or A == B
           // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
         } else if (match(VecOp,
                          m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
                                             m_ZExtOrSExt(m_VPValue())))) &&
                    !VecOp->hasMoreThanOneUniqueUser()) {
           VPWidenCastRecipe *Ext =
-              dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+              cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
           VPWidenRecipe *Mul =
-              dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+              cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext0 =
               cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext1 =
@@ -9441,8 +9442,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         }
         return nullptr;
       };
-      auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
+
+      auto TryToMatchExtendedReduction = [&]() -> VPReductionRecipe * {
         VPValue *A;
+        // Matched reduce(ext)).
         if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
           return new VPExtendedReductionRecipe(
               RdxDesc, CurrentLinkI, PreviousLink,
@@ -9451,7 +9454,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         }
         return nullptr;
       };
-      VPSingleDefRecipe *RedRecipe;
+
+      VPReductionRecipe *RedRecipe;
       if (auto *MulAcc = TryToMatchMulAcc())
         RedRecipe = MulAcc;
       else if (auto *ExtendedRed = TryToMatchExtendedReduction())
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index bb42d840ded90f..cd7d0efe8dda45 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2717,12 +2717,12 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 #endif
 
   /// The Type after extended.
-  Type *getResultType() const { return ResultTy; };
-  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; };
+  Type *getResultType() const { return ResultTy; }
+  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
   /// The Opcode of extend instruction.
-  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); };
+  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
   /// The CastInst of the extend instruction.
-  CastInst *getExtInstr() const { return ExtInstr; };
+  CastInst *getExtInstr() const { return ExtInstr; }
 };
 
 /// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2733,8 +2733,6 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 class VPMulAccRecipe : public VPReductionRecipe {
   /// Type after extend.
   Type *ResultType;
-  // Note that all extend instruction must have the same opcode in MulAcc.
-  Instruction::CastOps ExtOp;
 
   /// reduce.add(ext(mul(ext0(), ext1())))
   Instruction *MulInstr;
@@ -2742,28 +2740,21 @@ class VPMulAccRecipe : public VPReductionRecipe {
   CastInst *Ext0Instr = nullptr;
   CastInst *Ext1Instr = nullptr;
 
-  /// Is this MulAcc recipe contains extend recipes?
-  bool IsExtended;
-  /// Is this reciep contains outer extend instuction?
-  bool IsOuterExtended = false;
-
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
                  Instruction *RedI, Instruction *ExtInstr,
-                 Instruction *MulInstr, Instruction::CastOps ExtOp,
-                 Instruction *Ext0Instr, Instruction *Ext1Instr,
-                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+                 Instruction *MulInstr, Instruction *Ext0Instr,
+                 Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
+                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
+                 Type *ResultType)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered),
-        ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+        ResultType(ResultType), MulInstr(MulInstr),
         ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
-    IsExtended = true;
-    IsOuterExtended = ExtInstr != nullptr;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2775,7 +2766,6 @@ class VPMulAccRecipe : public VPReductionRecipe {
                           CondOp, IsOrdered),
         MulInstr(MulInstr) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
-    IsExtended = false;
   }
 
 public:
@@ -2784,10 +2774,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
-                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
-                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered, Ext0->getResultType()) {}
+                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered,
+                       Ext0->getResultType()) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
@@ -2801,10 +2791,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
                  VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
-                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
-                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered, Ext0->getResultType()) {}
+                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered,
+                       Ext0->getResultType()) {}
 
   ~VPMulAccRecipe() override = default;
 
@@ -2841,24 +2831,34 @@ class VPMulAccRecipe : public VPReductionRecipe {
   /// Return the type after inner extended, which must equal to the type of mul
   /// instruction. If the ResultType != recurrenceType, than it must have a
   /// extend recipe after mul recipe.
-  Type *getResultType() const { return ResultType; };
+  Type *getResultType() const { return ResultType; }
 
-  /// The opcode of the extend instructions.
-  Instruction::CastOps getExtOpcode() const { return ExtOp; };
   /// The underlying instruction for VPWidenRecipe.
-  Instruction *getMulInstr() const { return MulInstr; };
+  Instruction *getMulInstr() const { return MulInstr; }
 
   /// The underlying Instruction for outer VPWidenCastRecipe.
-  CastInst *getExtInstr() const { return ExtInstr; };
+  CastInst *getExtInstr() const { return ExtInstr; }
   /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt0Instr() const { return Ext0Instr; };
+  CastInst *getExt0Instr() const { return Ext0Instr; }
   /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt1Instr() const { return Ext1Instr; };
+  CastInst *getExt1Instr() const { return Ext1Instr; }
 
   /// Return if this MulAcc recipe contains extend instructions.
-  bool isExtended() const { return IsExtended; };
+  bool isExtended() const { return Ext0Instr && Ext1Instr; }
   /// Return if the operands of mul instruction come from same extend.
-  bool isSameExtend() const { return Ext0Instr == Ext1Instr; };
+  bool isSameExtend() const { return Ext0Instr == Ext1Instr; }
+  /// Return if the MulAcc recipes contains extend after mul.
+  bool isOuterExtended() const { return ExtInstr != nullptr; }
+  /// Return if the extend opcode is ZExt.
+  bool isZExt() const {
+    if (!isExtended())
+      return true;
+    // reduce.add(sext(mul(zext(A), zext(A)))) can be transform to
+    // reduce.add(zext(mul(sext(A), sext(A))))
+    if (ExtInstr && ExtInstr->getOpcode() != Ext0Instr->getOpcode())
+      return true;
+    return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
+  }
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 2ecf4e763c39b4..fc6b90191efa43 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2259,8 +2259,8 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
   const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
-  Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
-                               : Ctx.Types.inferScalarType(getVecOp0());
+  Type *ElementTy = isExtended() ? RdxDesc.getRecurrenceType()
+                                 : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
@@ -2276,20 +2276,19 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
 
   // Extended cost
   InstructionCost ExtendedCost = 0;
-  if (IsExtended) {
+  if (isExtended()) {
     auto *SrcTy = cast<VectorType>(
         ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
     auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
     TTI::CastContextHint CCH0 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    // Arm TTI will use the underlying instruction to determine the cost.
     ExtendedCost = Ctx.TTI.getCastInstrCost(
-        ExtOp, DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        Ext0Instr->getOpcode(), DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt0Instr()));
     TTI::CastContextHint CCH1 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
     ExtendedCost += Ctx.TTI.getCastInstrCost(
-        ExtOp, DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        Ext1Instr->getOpcode(), DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt1Instr()));
   }
 
@@ -2297,7 +2296,7 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   InstructionCost MulCost;
   SmallVector<const Value *, 4> Operands;
   Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
-  if (IsExtended)
+  if (isExtended())
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
@@ -2324,9 +2323,8 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   // MulAccReduction Cost
   VectorType *SrcVecTy =
       cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
-  InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
-      getExtOpcode() == Instruction::CastOps::ZExt, ElementTy, SrcVecTy,
-      CostKind);
+  InstructionCost MulAccCost =
+      Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
 
   // Check if folding ext into ExtendedReduction is profitable.
   if (MulAccCost.isValid() &&
@@ -2415,26 +2413,26 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " + ";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  if (IsOuterExtended)
+  if (isOuterExtended())
     O << " (";
   O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
   O << "mul ";
-  if (IsExtended)
+  if (isExtended())
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
-  if (IsExtended)
+  if (isExtended())
     O << " extended to " << *getResultType() << "), (";
   else
     O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
-  if (IsExtended)
+  if (isExtended())
     O << " extended to " << *getResultType() << ")";
   if (isConditional()) {
     O << ", ";
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
-  if (IsOuterExtended)
+  if (isOuterExtended())
     O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
   if (RdxDesc.IntermediateStore)
     O << " (with final reduction value stored in invariant address sank "
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 55e5788e8f9131..bbfb2540cc6e3d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -526,9 +526,12 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
            vp_depth_first_deep(Plan.getEntry()))) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
       if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
+        // Genearte VPWidenCastRecipe.
         auto *Ext = new VPWidenCastRecipe(
             ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
             *ExtRed->getExtInstr());
+
+        // Generate VPreductionRecipe.
         auto *Red = new VPReductionRecipe(
             ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
             ExtRed->getChainOp(), Ext, ExtRed->getCondOp(),
@@ -539,45 +542,55 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         ExtRed->eraseFromParent();
       } else if (isa<VPMulAccRecipe>(&R)) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
+
+        // Generate inner VPWidenCastRecipes if necessary.
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
           CastInst *Ext0 = MulAcc->getExt0Instr();
           Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
                                       MulAcc->getResultType(), *Ext0);
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
-          if (!MulAcc->isSameExtend()) {
+          // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
+          // VPWidenCastRecipe.
+          if (MulAcc->isSameExtend()) {
+            Op1 = Op0;
+          } else {
             CastInst *Ext1 = MulAcc->getExt1Instr();
             Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
                                         MulAcc->getResultType(), *Ext1);
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
-          } else {
-            Op1 = Op0;
           }
+          // Not contains extend instruction in this MulAccRecipe.
         } else {
           Op0 = MulAcc->getVecOp0();
           Op1 = MulAcc->getVecOp1();
         }
+
+        // Generate VPWidenRecipe.
         VPSingleDefRecipe *VecOp;
-        Instruction *MulInstr = MulAcc->getMulInstr();
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
-        auto *Mul = new VPWidenRecipe(*MulInstr,
+        auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
                                       make_range(MulOps.begin(), MulOps.end()));
-        // Outer extend.
-        if (auto *OuterExtInstr = MulAcc->getExtInstr())
+        Mul->insertBefore(MulAcc);
+
+        // Generate outer VPWidenCastRecipe if necessary.
+        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
           VecOp = new VPWidenCastRecipe(
               OuterExtInstr->getOpcode(), Mul,
               MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
               *OuterExtInstr);
-        else
+          VecOp->insertBefore(MulAcc);
+        } else {
           VecOp = Mul;
+        }
+
+        // Generate VPReductionRecipe.
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
             MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
             MulAcc->isOrdered());
-        Mul->insertBefore(MulAcc);
-        if (VecOp != Mul)
-          VecOp->insertBefore(MulAcc);
         Red->insertBefore(MulAcc);
+
         MulAcc->replaceAllUsesWith(Red);
         MulAcc->eraseFromParent();
       }



More information about the llvm-commits mailing list