[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)

Elvis Wang via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 4 17:40:05 PST 2024


https://github.com/ElvisWang123 updated https://github.com/llvm/llvm-project/pull/113903

>From 33b1f601f5677b8159b0c2ef8cfa8e6c1fa9d541 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 28 Oct 2024 05:39:35 -0700
Subject: [PATCH 01/17] [VPlan] Impl VPlan-based pattern match for ExtendedRed
 and MulAccRed. NFCI

This patch implement the VPlan-based pattern match for extendedReduction
and MulAccReduction. In above reduction patterns, extened instructions
and mul instruction can fold into reduction instruction and the cost is
free.

We add `FoldedRecipes` in the `VPCostContext` to put recipes that can be
folded into other recipes.

ExtendedReductionPatterns:
    reduce(ext(...))
MulAccReductionPatterns:
    reduce.add(mul(...))
    reduce.add(mul(ext(...), ext(...)))
    reduce.add(ext(mul(...)))
    reduce.add(ext(mul(ext(...), ext(...))))

Ref: Original instruction based implementation:
https://reviews.llvm.org/D93476
---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  57 -------
 llvm/lib/Transforms/Vectorize/VPlan.h         |   2 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 139 ++++++++++++++++--
 3 files changed, 129 insertions(+), 69 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 3c7c044a042719..74f463821addf9 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7300,63 +7300,6 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
     }
   }
 
-  // The legacy cost model has special logic to compute the cost of in-loop
-  // reductions, which may be smaller than the sum of all instructions involved
-  // in the reduction.
-  // TODO: Switch to costing based on VPlan once the logic has been ported.
-  for (const auto &[RedPhi, RdxDesc] : Legal->getReductionVars()) {
-    if (ForceTargetInstructionCost.getNumOccurrences())
-      continue;
-
-    if (!CM.isInLoopReduction(RedPhi))
-      continue;
-
-    const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop);
-    SetVector<Instruction *> ChainOpsAndOperands(ChainOps.begin(),
-                                                 ChainOps.end());
-    auto IsZExtOrSExt = [](const unsigned Opcode) -> bool {
-      return Opcode == Instruction::ZExt || Opcode == Instruction::SExt;
-    };
-    // Also include the operands of instructions in the chain, as the cost-model
-    // may mark extends as free.
-    //
-    // For ARM, some of the instruction can folded into the reducion
-    // instruction. So we need to mark all folded instructions free.
-    // For example: We can fold reduce(mul(ext(A), ext(B))) into one
-    // instruction.
-    for (auto *ChainOp : ChainOps) {
-      for (Value *Op : ChainOp->operands()) {
-        if (auto *I = dyn_cast<Instruction>(Op)) {
-          ChainOpsAndOperands.insert(I);
-          if (I->getOpcode() == Instruction::Mul) {
-            auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
-            auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
-            if (Ext0 && IsZExtOrSExt(Ext0->getOpcode()) && Ext1 &&
-                Ext0->getOpcode() == Ext1->getOpcode()) {
-              ChainOpsAndOperands.insert(Ext0);
-              ChainOpsAndOperands.insert(Ext1);
-            }
-          }
-        }
-      }
-    }
-
-    // Pre-compute the cost for I, if it has a reduction pattern cost.
-    for (Instruction *I : ChainOpsAndOperands) {
-      auto ReductionCost = CM.getReductionPatternCost(
-          I, VF, ToVectorTy(I->getType(), VF), TTI::TCK_RecipThroughput);
-      if (!ReductionCost)
-        continue;
-
-      assert(!CostCtx.SkipCostComputation.contains(I) &&
-             "reduction op visited multiple times");
-      CostCtx.SkipCostComputation.insert(I);
-      LLVM_DEBUG(dbgs() << "Cost of " << ReductionCost << " for VF " << VF
-                        << ":\n in-loop reduction " << *I << "\n");
-      Cost += *ReductionCost;
-    }
-  }
-
   // Pre-compute the costs for branches except for the backedge, as the number
   // of replicate regions in a VPlan may not directly match the number of
   // branches, which would lead to different decisions.
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e1d828f038f9a2..b87e042a0d1c50 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -684,6 +684,8 @@ struct VPCostContext {
   LLVMContext &LLVMCtx;
   LoopVectorizationCostModel &CM;
   SmallPtrSet<Instruction *, 8> SkipCostComputation;
+  /// Contains recipes that are folded into other recipes.
+  SmallDenseMap<ElementCount, SmallPtrSet<VPRecipeBase *, 4>, 4> FoldedRecipes;
 
   VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,
                 Type *CanIVTy, LoopVectorizationCostModel &CM)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ef5f6e22f82206..ac78164b09e17b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,7 +270,9 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = &WidenMem->getIngredient();
 
   InstructionCost RecipeCost;
-  if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
+  if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
+      (Ctx.FoldedRecipes.contains(VF) &&
+       Ctx.FoldedRecipes.at(VF).contains(this))) {
     RecipeCost = 0;
   } else {
     RecipeCost = computeCost(VF, Ctx);
@@ -2192,30 +2194,143 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  // TODO: Support any-of and in-loop reductions.
+  // TODO: Support any-of reductions.
   assert(
       (!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
        ForceTargetInstructionCost.getNumOccurrences() > 0) &&
       "Any-of reduction not implemented in VPlan-based cost model currently.");
-  assert(
-      (!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
-       ForceTargetInstructionCost.getNumOccurrences() > 0) &&
-      "In-loop reduction not implemented in VPlan-based cost model currently.");
 
   assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
          "Inferred type and recurrence type mismatch.");
 
-  // Cost = Reduction cost + BinOp cost
-  InstructionCost Cost =
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost BaseCost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    return Cost + Ctx.TTI.getMinMaxReductionCost(
-                      Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+    BaseCost += Ctx.TTI.getMinMaxReductionCost(
+        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    BaseCost += Ctx.TTI.getArithmeticReductionCost(
+        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
-  return Cost + Ctx.TTI.getArithmeticReductionCost(
-                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  using namespace llvm::VPlanPatternMatch;
+  auto GetMulAccReductionCost =
+      [&](const VPReductionRecipe *Red) -> InstructionCost {
+    VPValue *A, *B;
+    InstructionCost InnerExt0Cost = 0;
+    InstructionCost InnerExt1Cost = 0;
+    InstructionCost ExtCost = 0;
+    InstructionCost MulCost = 0;
+
+    VectorType *SrcVecTy = VectorTy;
+    Type *InnerExt0Ty;
+    Type *InnerExt1Ty;
+    Type *MaxInnerExtTy;
+    bool IsUnsigned = true;
+    bool HasOuterExt = false;
+
+    auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
+        Red->getVecOp()->getDefiningRecipe());
+    VPRecipeBase *Mul;
+    // Try to match outer extend reduce.add(ext(...))
+    if (Ext && match(Ext, m_ZExtOrSExt(m_VPValue())) &&
+        cast<VPWidenCastRecipe>(Ext)->getNumUsers() == 1) {
+      IsUnsigned =
+          Ext->getOpcode() == Instruction::CastOps::ZExt ? true : false;
+      ExtCost = Ext->computeCost(VF, Ctx);
+      Mul = Ext->getOperand(0)->getDefiningRecipe();
+      HasOuterExt = true;
+    } else {
+      Mul = Red->getVecOp()->getDefiningRecipe();
+    }
+
+    // Match reduce.add(mul())
+    if (Mul && match(Mul, m_Mul(m_VPValue(A), m_VPValue(B))) &&
+        cast<VPWidenRecipe>(Mul)->getNumUsers() == 1) {
+      MulCost = cast<VPWidenRecipe>(Mul)->computeCost(VF, Ctx);
+      auto *InnerExt0 =
+          dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+      auto *InnerExt1 =
+          dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+      bool HasInnerExt = false;
+      // Try to match inner extends.
+      if (InnerExt0 && InnerExt1 &&
+          match(InnerExt0, m_ZExtOrSExt(m_VPValue())) &&
+          match(InnerExt1, m_ZExtOrSExt(m_VPValue())) &&
+          InnerExt0->getOpcode() == InnerExt1->getOpcode() &&
+          (InnerExt0->getNumUsers() > 0 &&
+           !InnerExt0->hasMoreThanOneUniqueUser()) &&
+          (InnerExt1->getNumUsers() > 0 &&
+           !InnerExt1->hasMoreThanOneUniqueUser())) {
+        InnerExt0Cost = InnerExt0->computeCost(VF, Ctx);
+        InnerExt1Cost = InnerExt1->computeCost(VF, Ctx);
+        Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
+        Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
+        Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth() >
+                                      InnerExt1Ty->getIntegerBitWidth()
+                                  ? InnerExt0Ty
+                                  : InnerExt1Ty;
+        SrcVecTy = cast<VectorType>(ToVectorTy(MaxInnerExtTy, VF));
+        IsUnsigned = true;
+        HasInnerExt = true;
+      }
+      InstructionCost MulAccRedCost = Ctx.TTI.getMulAccReductionCost(
+          IsUnsigned, ElementTy, SrcVecTy, CostKind);
+      // Check if folding ext/mul into MulAccReduction is profitable.
+      if (MulAccRedCost.isValid() &&
+          MulAccRedCost <
+              ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
+        if (HasInnerExt) {
+          Ctx.FoldedRecipes[VF].insert(InnerExt0);
+          Ctx.FoldedRecipes[VF].insert(InnerExt1);
+        }
+        Ctx.FoldedRecipes[VF].insert(Mul);
+        if (HasOuterExt)
+          Ctx.FoldedRecipes[VF].insert(Ext);
+        return MulAccRedCost;
+      }
+    }
+    return InstructionCost::getInvalid();
+  };
+
+  // Match reduce(ext(...))
+  auto GetExtendedReductionCost =
+      [&](const VPReductionRecipe *Red) -> InstructionCost {
+    VPValue *VecOp = Red->getVecOp();
+    VPValue *A;
+    if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && VecOp->getNumUsers() == 1) {
+      VPWidenCastRecipe *Ext =
+          cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+      bool IsUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
+      InstructionCost ExtCost = Ext->computeCost(VF, Ctx);
+      auto *ExtVecTy =
+          cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(A), VF));
+      InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
+          Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags(),
+          CostKind);
+      // Check if folding ext into ExtendedReduction is profitable.
+      if (ExtendedRedCost.isValid() && ExtendedRedCost < ExtCost + BaseCost) {
+        Ctx.FoldedRecipes[VF].insert(Ext);
+        return ExtendedRedCost;
+      }
+    }
+    return InstructionCost::getInvalid();
+  };
+
+  // Match MulAccReduction patterns.
+  InstructionCost MulAccCost = GetMulAccReductionCost(this);
+  if (MulAccCost.isValid())
+    return MulAccCost;
+
+  // Match ExtendedReduction patterns.
+  InstructionCost ExtendedCost = GetExtendedReductionCost(this);
+  if (ExtendedCost.isValid())
+    return ExtendedCost;
+
+  // Default cost.
+  return BaseCost;
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)

>From 68fbd7047ba6b0aa5bf42a1755d217970a74b0ec Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 3 Nov 2024 18:55:55 -0800
Subject: [PATCH 02/17] Partially support Extended-reduction.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  24 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         | 217 ++++++++++++++++++
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 114 +++++++--
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  24 ++
 llvm/lib/Transforms/Vectorize/VPlanValue.h    |   2 +
 5 files changed, 356 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 74f463821addf9..c41c2a84afa6cd 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7666,6 +7666,10 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
                              CanonicalIVStartValue, State);
   VPlanTransforms::prepareToExecute(BestVPlan);
 
+  // TODO: Rebase to fhahn's implementation.
+  VPlanTransforms::prepareExecute(BestVPlan);
+  dbgs() << "\n\n print plan\n";
+  BestVPlan.print(dbgs());
   BestVPlan.execute(&State);
 
   // 2.5 Collect reduction resume values.
@@ -9264,6 +9268,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
 // Adjust AnyOf reductions; replace the reduction phi for the selected value
 // with a boolean reduction phi node to check if the condition is true in any
 // iteration. The final value is selected by the final ComputeReductionResult.
+// TODO: Implement VPMulAccHere.
 void LoopVectorizationPlanner::adjustRecipesForReductions(
     VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) {
   using namespace VPlanPatternMatch;
@@ -9382,9 +9387,22 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      VPReductionRecipe *RedRecipe =
-          new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
-                                CondOp, CM.useOrderedReductions(RdxDesc));
+      // VPWidenCastRecipes can folded into VPReductionRecipe
+      VPValue *A;
+      VPSingleDefRecipe *RedRecipe;
+      if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+          !VecOp->hasMoreThanOneUniqueUser()) {
+        RedRecipe = new VPExtendedReductionRecipe(
+            RdxDesc, CurrentLinkI,
+            cast<CastInst>(
+                cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
+            PreviousLink, A, CondOp, CM.useOrderedReductions(RdxDesc),
+            cast<VPWidenCastRecipe>(VecOp)->getResultType());
+      } else {
+        RedRecipe =
+            new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
+                                  CondOp, CM.useOrderedReductions(RdxDesc));
+      }
       // Append the recipe to the end of the VPBasicBlock because we need to
       // ensure that it comes after all of it's inputs, including CondOp.
       // Note that this transformation may leave over dead recipes (including
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index b87e042a0d1c50..d0c2fb2e48d466 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -861,6 +861,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPInstructionSC:
     case VPRecipeBase::VPReductionEVLSC:
     case VPRecipeBase::VPReductionSC:
+    case VPRecipeBase::VPMulAccSC:
+    case VPRecipeBase::VPExtendedReductionSC:
     case VPRecipeBase::VPReplicateSC:
     case VPRecipeBase::VPScalarIVStepsSC:
     case VPRecipeBase::VPVectorPointerSC:
@@ -2696,6 +2698,221 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
   }
 };
 
+/// A recipe to represent inloop extended reduction operations, performing a
+/// reduction on a vector operand into a scalar value, and adding the result to
+/// a chain. This recipe is high level abstract which will generate
+/// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
+/// {ChainOp, VecOp, [Condition]}.
+class VPExtendedReductionRecipe : public VPSingleDefRecipe {
+  /// The recurrence decriptor for the reduction in question.
+  const RecurrenceDescriptor &RdxDesc;
+  bool IsOrdered;
+  /// Whether the reduction is conditional.
+  bool IsConditional = false;
+  /// Type after extend.
+  Type *ResultTy;
+  Instruction::CastOps ExtOp;
+  CastInst *CastInstr;
+  bool IsZExt;
+
+protected:
+  VPExtendedReductionRecipe(const unsigned char SC,
+                            const RecurrenceDescriptor &R, Instruction *RedI,
+                            Instruction::CastOps ExtOp, CastInst *CastI,
+                            ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                            bool IsOrdered, Type *ResultTy)
+      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+        ResultTy(ResultTy), ExtOp(ExtOp), CastInstr(CastI) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+    IsZExt = ExtOp == Instruction::CastOps::ZExt;
+  }
+
+public:
+  VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                            CastInst *CastI, VPValue *ChainOp, VPValue *VecOp,
+                            VPValue *CondOp, bool IsOrdered, Type *ResultTy)
+      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
+                                  CastI->getOpcode(), CastI,
+                                  ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
+                                  IsOrdered, ResultTy) {}
+
+  VPExtendedReductionRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *Ext)
+      : VPExtendedReductionRecipe(
+            VPDef::VPExtendedReductionSC, Red->getRecurrenceDescriptor(),
+            Red->getUnderlyingInstr(), Ext->getOpcode(),
+            cast<CastInst>(Ext->getUnderlyingInstr()),
+            ArrayRef<VPValue *>({Red->getChainOp(), Ext->getOperand(0)}),
+            Red->getCondOp(), Red->isOrdered(), Ext->getResultType()) {}
+
+  ~VPExtendedReductionRecipe() override = default;
+
+  VPExtendedReductionRecipe *clone() override {
+    llvm_unreachable("Not implement yet");
+  }
+
+  static inline bool classof(const VPRecipeBase *R) {
+    return R->getVPDefID() == VPDef::VPExtendedReductionSC;
+  }
+
+  static inline bool classof(const VPUser *U) {
+    auto *R = dyn_cast<VPRecipeBase>(U);
+    return R && classof(R);
+  }
+
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPExtendedReductionRecipe should be transform to "
+                     "VPExtendedRecipe + VPReductionRecipe before execution.");
+  };
+
+  /// Return the cost of VPExtendedReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// Return the recurrence decriptor for the in-loop reduction.
+  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
+    return RdxDesc;
+  }
+  /// Return true if the in-loop reduction is ordered.
+  bool isOrdered() const { return IsOrdered; };
+  /// Return true if the in-loop reduction is conditional.
+  bool isConditional() const { return IsConditional; };
+  /// The VPValue of the scalar Chain being accumulated.
+  VPValue *getChainOp() const { return getOperand(0); }
+  /// The VPValue of the vector value to be extended and reduced.
+  VPValue *getVecOp() const { return getOperand(1); }
+  /// The VPValue of the condition for the block.
+  VPValue *getCondOp() const {
+    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
+  }
+  Type *getResultType() const { return ResultTy; };
+  Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  CastInst *getExtInstr() const { return CastInstr; };
+};
+
+/// A recipe to represent inloop MulAccreduction operations, performing a
+/// reduction on a vector operand into a scalar value, and adding the result to
+/// a chain. This recipe is high level abstract which will generate
+/// VPReductionRecipe VPWidenRecipe(mul)and VPWidenCastRecipe before execution.
+/// The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+class VPMulAccRecipe : public VPSingleDefRecipe {
+  /// The recurrence decriptor for the reduction in question.
+  const RecurrenceDescriptor &RdxDesc;
+  bool IsOrdered;
+  /// Whether the reduction is conditional.
+  bool IsConditional = false;
+  /// Type after extend.
+  Type *ResultTy;
+  /// Type for mul.
+  Type *MulTy;
+  /// reduce.add(OuterExt(mul(InnerExt(), InnerExt())))
+  Instruction::CastOps OuterExtOp;
+  Instruction::CastOps InnerExtOp;
+
+  Instruction *MulI;
+  Instruction *OuterExtI;
+  Instruction *InnerExt0I;
+  Instruction *InnerExt1I;
+
+protected:
+  VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+                 Instruction *RedI, Instruction::CastOps OuterExtOp,
+                 Instruction *OuterExtI, Instruction *MulI,
+                 Instruction::CastOps InnerExtOp, Instruction *InnerExt0I,
+                 Instruction *InnerExt1I, ArrayRef<VPValue *> Operands,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+        ResultTy(ResultTy), MulTy(MulTy), OuterExtOp(OuterExtOp),
+        InnerExtOp(InnerExtOp), MulI(MulI), OuterExtI(OuterExtI),
+        InnerExt0I(InnerExt0I), InnerExt1I(InnerExt1I) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+  }
+
+public:
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 Instruction *OuterExt, Instruction *Mul,
+                 Instruction *InnerExt0, Instruction *InnerExt1,
+                 VPValue *ChainOp, VPValue *InnerExt0Op, VPValue *InnerExt1Op,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+      : VPMulAccRecipe(
+            VPDef::VPMulAccSC, R, RedI, cast<CastInst>(OuterExt)->getOpcode(),
+            OuterExt, Mul, cast<CastInst>(InnerExt0)->getOpcode(), InnerExt0,
+            InnerExt1, ArrayRef<VPValue *>({ChainOp, InnerExt0Op, InnerExt1Op}),
+            CondOp, IsOrdered, ResultTy, MulTy) {}
+
+  VPMulAccRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *OuterExt,
+                 VPWidenRecipe *Mul, VPWidenCastRecipe *InnerExt0,
+                 VPWidenCastRecipe *InnerExt1)
+      : VPMulAccRecipe(
+            VPDef::VPMulAccSC, Red->getRecurrenceDescriptor(),
+            Red->getUnderlyingInstr(), OuterExt->getOpcode(),
+            OuterExt->getUnderlyingInstr(), Mul->getUnderlyingInstr(),
+            InnerExt0->getOpcode(), InnerExt0->getUnderlyingInstr(),
+            InnerExt1->getUnderlyingInstr(),
+            ArrayRef<VPValue *>({Red->getChainOp(), InnerExt0->getOperand(0),
+                                 InnerExt1->getOperand(0)}),
+            Red->getCondOp(), Red->isOrdered(), OuterExt->getResultType(),
+            InnerExt0->getResultType()) {}
+
+  ~VPMulAccRecipe() override = default;
+
+  VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
+
+  static inline bool classof(const VPRecipeBase *R) {
+    return R->getVPDefID() == VPRecipeBase::VPMulAccSC;
+  }
+
+  static inline bool classof(const VPUser *U) {
+    auto *R = dyn_cast<VPRecipeBase>(U);
+    return R && classof(R);
+  }
+
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override;
+
+  /// Return the cost of VPExtendedReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// Return the recurrence decriptor for the in-loop reduction.
+  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
+    return RdxDesc;
+  }
+  /// Return true if the in-loop reduction is ordered.
+  bool isOrdered() const { return IsOrdered; };
+  /// Return true if the in-loop reduction is conditional.
+  bool isConditional() const { return IsConditional; };
+  /// The VPValue of the scalar Chain being accumulated.
+  VPValue *getChainOp() const { return getOperand(0); }
+  /// The VPValue of the vector value to be extended and reduced.
+  VPValue *getVecOp() const { return getOperand(1); }
+  /// The VPValue of the condition for the block.
+  VPValue *getCondOp() const {
+    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
+  }
+  Type *getResultTy() const { return ResultTy; };
+  Instruction::CastOps getOuterExtOpcode() const { return OuterExtOp; };
+  Instruction::CastOps getInnerExtOpcode() const { return InnerExtOp; };
+};
+
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
 /// copies of the original scalar type, one per lane, instead of producing a
 /// single copy of widened type for all lanes. If the instruction is known to be
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ac78164b09e17b..69577a5b78000a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1502,6 +1502,27 @@ void VPWidenCastRecipe::execute(VPTransformState &State) {
     setFlags(CastOp);
 }
 
+// Computes the CastContextHint from a recipes that may access memory.
+static TTI::CastContextHint computeCCH(const VPRecipeBase *R, ElementCount VF) {
+  if (VF.isScalar())
+    return TTI::CastContextHint::Normal;
+  if (isa<VPInterleaveRecipe>(R))
+    return TTI::CastContextHint::Interleave;
+  if (const auto *ReplicateRecipe = dyn_cast<VPReplicateRecipe>(R))
+    return ReplicateRecipe->isPredicated() ? TTI::CastContextHint::Masked
+                                           : TTI::CastContextHint::Normal;
+  const auto *WidenMemoryRecipe = dyn_cast<VPWidenMemoryRecipe>(R);
+  if (WidenMemoryRecipe == nullptr)
+    return TTI::CastContextHint::None;
+  if (!WidenMemoryRecipe->isConsecutive())
+    return TTI::CastContextHint::GatherScatter;
+  if (WidenMemoryRecipe->isReverse())
+    return TTI::CastContextHint::Reversed;
+  if (WidenMemoryRecipe->isMasked())
+    return TTI::CastContextHint::Masked;
+  return TTI::CastContextHint::Normal;
+}
+
 InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
                                                VPCostContext &Ctx) const {
   // TODO: In some cases, VPWidenCastRecipes are created but not considered in
@@ -1509,26 +1530,6 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
   // reduction in a smaller type.
   if (!getUnderlyingValue())
     return 0;
-  // Computes the CastContextHint from a recipes that may access memory.
-  auto ComputeCCH = [&](const VPRecipeBase *R) -> TTI::CastContextHint {
-    if (VF.isScalar())
-      return TTI::CastContextHint::Normal;
-    if (isa<VPInterleaveRecipe>(R))
-      return TTI::CastContextHint::Interleave;
-    if (const auto *ReplicateRecipe = dyn_cast<VPReplicateRecipe>(R))
-      return ReplicateRecipe->isPredicated() ? TTI::CastContextHint::Masked
-                                             : TTI::CastContextHint::Normal;
-    const auto *WidenMemoryRecipe = dyn_cast<VPWidenMemoryRecipe>(R);
-    if (WidenMemoryRecipe == nullptr)
-      return TTI::CastContextHint::None;
-    if (!WidenMemoryRecipe->isConsecutive())
-      return TTI::CastContextHint::GatherScatter;
-    if (WidenMemoryRecipe->isReverse())
-      return TTI::CastContextHint::Reversed;
-    if (WidenMemoryRecipe->isMasked())
-      return TTI::CastContextHint::Masked;
-    return TTI::CastContextHint::Normal;
-  };
 
   VPValue *Operand = getOperand(0);
   TTI::CastContextHint CCH = TTI::CastContextHint::None;
@@ -1536,7 +1537,7 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
   if ((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) &&
       !hasMoreThanOneUniqueUser() && getNumUsers() > 0) {
     if (auto *StoreRecipe = dyn_cast<VPRecipeBase>(*user_begin()))
-      CCH = ComputeCCH(StoreRecipe);
+      CCH = computeCCH(StoreRecipe, VF);
   }
   // For Z/Sext, get the context from the operand.
   else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt ||
@@ -1544,7 +1545,7 @@ InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF,
     if (Operand->isLiveIn())
       CCH = TTI::CastContextHint::Normal;
     else if (Operand->getDefiningRecipe())
-      CCH = ComputeCCH(Operand->getDefiningRecipe());
+      CCH = computeCCH(Operand->getDefiningRecipe(), VF);
   }
 
   auto *SrcTy =
@@ -2215,6 +2216,7 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
         Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
+  /*
   using namespace llvm::VPlanPatternMatch;
   auto GetMulAccReductionCost =
       [&](const VPReductionRecipe *Red) -> InstructionCost {
@@ -2328,11 +2330,57 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   InstructionCost ExtendedCost = GetExtendedReductionCost(this);
   if (ExtendedCost.isValid())
     return ExtendedCost;
+  */
 
   // Default cost.
   return BaseCost;
 }
 
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+                                       VPCostContext &Ctx) const {
+  RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+  Type *ElementTy = getResultType();
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
+         "Inferred type and recurrence type mismatch.");
+
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost ReductionCost =
+      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+    Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+    ReductionCost += Ctx.TTI.getMinMaxReductionCost(
+        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  }
+
+  // Extended cost
+  auto *SrcTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
+  auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+  TTI::CastContextHint CCH = computeCCH(getVecOp()->getDefiningRecipe(), VF);
+  // Arm TTI will use the underlying instruction to determine the cost.
+  InstructionCost ExtendedCost = Ctx.TTI.getCastInstrCost(
+      Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
+      dyn_cast_if_present<Instruction>(getUnderlyingValue()));
+
+  // ExtendedReduction Cost
+  InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
+      Opcode, IsZExt, ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
+  // Check if folding ext into ExtendedReduction is profitable.
+  if (ExtendedRedCost.isValid() &&
+      ExtendedRedCost < ExtendedCost + ReductionCost) {
+    return ExtendedRedCost;
+  }
+  return ExtendedCost + ReductionCost;
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {
@@ -2378,6 +2426,28 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
 }
+
+void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+                                      VPSlotTracker &SlotTracker) const {
+  O << Indent << "EXTENDED-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = ";
+  getChainOp()->printAsOperand(O, SlotTracker);
+  O << " +";
+  if (isa<FPMathOperator>(getUnderlyingInstr()))
+    O << getUnderlyingInstr()->getFastMathFlags();
+  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  getVecOp()->printAsOperand(O, SlotTracker);
+  O << " extended to " << *getResultType();
+  if (isConditional()) {
+    O << ", ";
+    getCondOp()->printAsOperand(O, SlotTracker);
+  }
+  O << ")";
+  if (RdxDesc.IntermediateStore)
+    O << " (with final reduction value stored in invariant address sank "
+         "outside of loop)";
+}
 #endif
 
 bool VPReplicateRecipe::shouldPack() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index cee83d1015b536..d56a1f74014759 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -519,6 +519,30 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
   }
 }
 
+void VPlanTransforms::prepareExecute(VPlan &Plan) {
+  errs() << "\n\n\n!!Prepare to execute\n";
+  ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
+      Plan.getVectorLoopRegion());
+  for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
+           vp_depth_first_deep(Plan.getEntry()))) {
+    for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
+      if (!isa<VPExtendedReductionRecipe>(&R))
+        continue;
+      auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+      auto *Ext = new VPWidenCastRecipe(
+          ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+          *ExtRed->getExtInstr());
+      auto *Red = new VPReductionRecipe(
+          ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
+          ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
+      Ext->insertBefore(ExtRed);
+      Red->insertBefore(ExtRed);
+      ExtRed->replaceAllUsesWith(Red);
+      ExtRed->eraseFromParent();
+    }
+  }
+}
+
 static VPScalarIVStepsRecipe *
 createScalarIVSteps(VPlan &Plan, InductionDescriptor::InductionKind Kind,
                     Instruction::BinaryOps InductionOpcode,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 957a602091c733..2c922d8d39c927 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -329,6 +329,8 @@ class VPDef {
     VPInterleaveSC,
     VPReductionEVLSC,
     VPReductionSC,
+    VPMulAccSC,
+    VPExtendedReductionSC,
     VPReplicateSC,
     VPScalarCastSC,
     VPScalarIVStepsSC,

>From c8c9d56419b763cfcf49be0826aedb79769c6e1b Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 4 Nov 2024 22:02:22 -0800
Subject: [PATCH 03/17] Support MulAccRecipe

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  33 ++++-
 llvm/lib/Transforms/Vectorize/VPlan.h         | 103 +++++++++-------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 114 +++++++++++++++++-
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  56 ++++++---
 4 files changed, 237 insertions(+), 69 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index c41c2a84afa6cd..e8c62fac5ee8c2 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7668,8 +7668,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
 
   // TODO: Rebase to fhahn's implementation.
   VPlanTransforms::prepareExecute(BestVPlan);
-  dbgs() << "\n\n print plan\n";
-  BestVPlan.print(dbgs());
   BestVPlan.execute(&State);
 
   // 2.5 Collect reduction resume values.
@@ -9387,11 +9385,34 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      // VPWidenCastRecipes can folded into VPReductionRecipe
-      VPValue *A;
+      VPValue *A, *B;
       VPSingleDefRecipe *RedRecipe;
-      if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
-          !VecOp->hasMoreThanOneUniqueUser()) {
+      // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+      if (RdxDesc.getOpcode() == Instruction::Add &&
+          match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+        VPRecipeBase *RecipeA = A->getDefiningRecipe();
+        VPRecipeBase *RecipeB = B->getDefiningRecipe();
+        if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+            match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+            cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
+                cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
+            !A->hasMoreThanOneUniqueUser() && !B->hasMoreThanOneUniqueUser()) {
+          RedRecipe = new VPMulAccRecipe(
+              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+              CM.useOrderedReductions(RdxDesc),
+              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
+              cast<VPWidenCastRecipe>(RecipeA),
+              cast<VPWidenCastRecipe>(RecipeB));
+        } else {
+          RedRecipe = new VPMulAccRecipe(
+              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+              CM.useOrderedReductions(RdxDesc),
+              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+        }
+      }
+      // VPWidenCastRecipes can folded into VPReductionRecipe
+      else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+               !VecOp->hasMoreThanOneUniqueUser()) {
         RedRecipe = new VPExtendedReductionRecipe(
             RdxDesc, CurrentLinkI,
             cast<CastInst>(
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index d0c2fb2e48d466..19d98674db9963 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2811,60 +2811,64 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   /// Whether the reduction is conditional.
   bool IsConditional = false;
   /// Type after extend.
-  Type *ResultTy;
-  /// Type for mul.
-  Type *MulTy;
-  /// reduce.add(OuterExt(mul(InnerExt(), InnerExt())))
-  Instruction::CastOps OuterExtOp;
-  Instruction::CastOps InnerExtOp;
+  Type *ResultType;
+  /// reduce.add(mul(Ext(), Ext()))
+  Instruction::CastOps ExtOp;
+
+  Instruction *MulInstr;
+  CastInst *Ext0Instr;
+  CastInst *Ext1Instr;
 
-  Instruction *MulI;
-  Instruction *OuterExtI;
-  Instruction *InnerExt0I;
-  Instruction *InnerExt1I;
+  bool IsExtended;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction::CastOps OuterExtOp,
-                 Instruction *OuterExtI, Instruction *MulI,
-                 Instruction::CastOps InnerExtOp, Instruction *InnerExt0I,
-                 Instruction *InnerExt1I, ArrayRef<VPValue *> Operands,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
+                 Instruction *RedI, Instruction *MulInstr,
+                 Instruction::CastOps ExtOp, Instruction *Ext0Instr,
+                 Instruction *Ext1Instr, ArrayRef<VPValue *> Operands,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+        ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+        Ext0Instr(cast<CastInst>(Ext0Instr)),
+        Ext1Instr(cast<CastInst>(Ext1Instr)) {
+    if (CondOp) {
+      IsConditional = true;
+      addOperand(CondOp);
+    }
+    IsExtended = true;
+  }
+
+  VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+                 Instruction *RedI, Instruction *MulInstr,
+                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
-        ResultTy(ResultTy), MulTy(MulTy), OuterExtOp(OuterExtOp),
-        InnerExtOp(InnerExtOp), MulI(MulI), OuterExtI(OuterExtI),
-        InnerExt0I(InnerExt0I), InnerExt1I(InnerExt1I) {
+        MulInstr(MulInstr) {
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
     }
+    IsExtended = false;
   }
 
 public:
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
-                 Instruction *OuterExt, Instruction *Mul,
-                 Instruction *InnerExt0, Instruction *InnerExt1,
-                 VPValue *ChainOp, VPValue *InnerExt0Op, VPValue *InnerExt1Op,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
-      : VPMulAccRecipe(
-            VPDef::VPMulAccSC, R, RedI, cast<CastInst>(OuterExt)->getOpcode(),
-            OuterExt, Mul, cast<CastInst>(InnerExt0)->getOpcode(), InnerExt0,
-            InnerExt1, ArrayRef<VPValue *>({ChainOp, InnerExt0Op, InnerExt1Op}),
-            CondOp, IsOrdered, ResultTy, MulTy) {}
-
-  VPMulAccRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *OuterExt,
-                 VPWidenRecipe *Mul, VPWidenCastRecipe *InnerExt0,
-                 VPWidenCastRecipe *InnerExt1)
-      : VPMulAccRecipe(
-            VPDef::VPMulAccSC, Red->getRecurrenceDescriptor(),
-            Red->getUnderlyingInstr(), OuterExt->getOpcode(),
-            OuterExt->getUnderlyingInstr(), Mul->getUnderlyingInstr(),
-            InnerExt0->getOpcode(), InnerExt0->getUnderlyingInstr(),
-            InnerExt1->getUnderlyingInstr(),
-            ArrayRef<VPValue *>({Red->getChainOp(), InnerExt0->getOperand(0),
-                                 InnerExt1->getOperand(0)}),
-            Red->getCondOp(), Red->isOrdered(), OuterExt->getResultType(),
-            InnerExt0->getResultType()) {}
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+                 VPWidenCastRecipe *Ext1)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       Ext0->getOpcode(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(),
+                       ArrayRef<VPValue *>(
+                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       CondOp, IsOrdered, Ext0->getResultType()) {}
+
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenRecipe *Mul)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       ArrayRef<VPValue *>(
+                           {ChainOp, Mul->getOperand(0), Mul->getOperand(0)}),
+                       CondOp, IsOrdered) {}
 
   ~VPMulAccRecipe() override = default;
 
@@ -2880,7 +2884,10 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   }
 
   /// Generate the reduction in the loop
-  void execute(VPTransformState &State) override;
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
+                     "VPWidenRecipe + VPReductionRecipe before execution");
+  }
 
   /// Return the cost of VPExtendedReductionRecipe.
   InstructionCost computeCost(ElementCount VF,
@@ -2903,14 +2910,18 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   /// The VPValue of the scalar Chain being accumulated.
   VPValue *getChainOp() const { return getOperand(0); }
   /// The VPValue of the vector value to be extended and reduced.
-  VPValue *getVecOp() const { return getOperand(1); }
+  VPValue *getVecOp0() const { return getOperand(1); }
+  VPValue *getVecOp1() const { return getOperand(2); }
   /// The VPValue of the condition for the block.
   VPValue *getCondOp() const {
     return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
   }
-  Type *getResultTy() const { return ResultTy; };
-  Instruction::CastOps getOuterExtOpcode() const { return OuterExtOp; };
-  Instruction::CastOps getInnerExtOpcode() const { return InnerExtOp; };
+  Type *getResultType() const { return ResultType; };
+  Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  Instruction *getMulInstr() const { return MulInstr; };
+  CastInst *getExt0Instr() const { return Ext0Instr; };
+  CastInst *getExt1Instr() const { return Ext1Instr; };
+  bool isExtended() const { return IsExtended; };
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 69577a5b78000a..a7d74a78e06317 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,9 +270,7 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = &WidenMem->getIngredient();
 
   InstructionCost RecipeCost;
-  if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
-      (Ctx.FoldedRecipes.contains(VF) &&
-       Ctx.FoldedRecipes.at(VF).contains(this))) {
+  if ((UI && Ctx.skipCostComputation(UI, VF.isVector()))) {
     RecipeCost = 0;
   } else {
     RecipeCost = computeCost(VF, Ctx);
@@ -2381,6 +2379,85 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   return ExtendedCost + ReductionCost;
 }
 
+InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
+                                            VPCostContext &Ctx) const {
+  Type *ElementTy =
+      IsExtended ? getResultType() : Ctx.Types.inferScalarType(getVecOp0());
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
+         "Inferred type and recurrence type mismatch.");
+
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost ReductionCost =
+      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+      Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+
+  // Extended cost
+  InstructionCost ExtendedCost = 0;
+  if (IsExtended) {
+    auto *SrcTy = cast<VectorType>(
+        ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+    auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+    TTI::CastContextHint CCH0 =
+        computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+    // Arm TTI will use the underlying instruction to determine the cost.
+    ExtendedCost = Ctx.TTI.getCastInstrCost(
+        ExtOp, DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt0Instr()));
+    TTI::CastContextHint CCH1 =
+        computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+    ExtendedCost += Ctx.TTI.getCastInstrCost(
+        ExtOp, DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt1Instr()));
+  }
+
+  // Mul cost
+  InstructionCost MulCost;
+  SmallVector<const Value *, 4> Operands;
+  Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
+  if (IsExtended)
+    MulCost = Ctx.TTI.getArithmeticInstrCost(
+        Instruction::Mul, VectorTy, CostKind,
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        Operands, MulInstr, &Ctx.TLI);
+  else {
+    VPValue *RHS = getVecOp1();
+    // Certain instructions can be cheaper to vectorize if they have a constant
+    // second vector operand. One example of this are shifts on x86.
+    TargetTransformInfo::OperandValueInfo RHSInfo = {
+        TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None};
+    if (RHS->isLiveIn())
+      RHSInfo = Ctx.TTI.getOperandInfo(RHS->getLiveInIRValue());
+
+    if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
+        RHS->isDefinedOutsideLoopRegions())
+      RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
+    MulCost = Ctx.TTI.getArithmeticInstrCost(
+        Instruction::Mul, VectorTy, CostKind,
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        RHSInfo, Operands, MulInstr, &Ctx.TLI);
+  }
+
+  // ExtendedReduction Cost
+  VectorType *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+  InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+      getExtOpcode() == Instruction::CastOps::ZExt, ElementTy, SrcVecTy,
+      CostKind);
+
+  // Check if folding ext into ExtendedReduction is profitable.
+  if (MulAccCost.isValid() &&
+      MulAccCost < ExtendedCost + ReductionCost + MulCost) {
+    return MulAccCost;
+  }
+  return ExtendedCost + ReductionCost + MulCost;
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {
@@ -2448,6 +2525,37 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
 }
+
+void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
+                           VPSlotTracker &SlotTracker) const {
+  O << Indent << "MULACC-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = ";
+  getChainOp()->printAsOperand(O, SlotTracker);
+  O << " +";
+  if (isa<FPMathOperator>(getUnderlyingInstr()))
+    O << getUnderlyingInstr()->getFastMathFlags();
+  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  O << " mul ";
+  if (IsExtended)
+    O << "(";
+  getVecOp0()->printAsOperand(O, SlotTracker);
+  if (IsExtended)
+    O << " extended to " << *getResultType() << ")";
+  if (IsExtended)
+    O << "(";
+  getVecOp1()->printAsOperand(O, SlotTracker);
+  if (IsExtended)
+    O << " extended to " << *getResultType() << ")";
+  if (isConditional()) {
+    O << ", ";
+    getCondOp()->printAsOperand(O, SlotTracker);
+  }
+  O << ")";
+  if (RdxDesc.IntermediateStore)
+    O << " (with final reduction value stored in invariant address sank "
+         "outside of loop)";
+}
 #endif
 
 bool VPReplicateRecipe::shouldPack() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index d56a1f74014759..1f82228d4d787e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -520,25 +520,53 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
 }
 
 void VPlanTransforms::prepareExecute(VPlan &Plan) {
-  errs() << "\n\n\n!!Prepare to execute\n";
   ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
       Plan.getVectorLoopRegion());
   for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
            vp_depth_first_deep(Plan.getEntry()))) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
-      if (!isa<VPExtendedReductionRecipe>(&R))
-        continue;
-      auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
-      auto *Ext = new VPWidenCastRecipe(
-          ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
-          *ExtRed->getExtInstr());
-      auto *Red = new VPReductionRecipe(
-          ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
-          ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
-      Ext->insertBefore(ExtRed);
-      Red->insertBefore(ExtRed);
-      ExtRed->replaceAllUsesWith(Red);
-      ExtRed->eraseFromParent();
+      if (isa<VPExtendedReductionRecipe>(&R)) {
+        auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+        auto *Ext = new VPWidenCastRecipe(
+            ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+            *ExtRed->getExtInstr());
+        auto *Red = new VPReductionRecipe(
+            ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
+            ExtRed->getChainOp(), Ext, ExtRed->getCondOp(),
+            ExtRed->isOrdered());
+        Ext->insertBefore(ExtRed);
+        Red->insertBefore(ExtRed);
+        ExtRed->replaceAllUsesWith(Red);
+        ExtRed->eraseFromParent();
+      } else if (isa<VPMulAccRecipe>(&R)) {
+        auto *MulAcc = cast<VPMulAccRecipe>(&R);
+        VPValue *Op0, *Op1;
+        if (MulAcc->isExtended()) {
+          Op0 = new VPWidenCastRecipe(
+              MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+              MulAcc->getResultType(), *MulAcc->getExt0Instr());
+          Op1 = new VPWidenCastRecipe(
+              MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+              MulAcc->getResultType(), *MulAcc->getExt1Instr());
+          Op0->getDefiningRecipe()->insertBefore(MulAcc);
+          Op1->getDefiningRecipe()->insertBefore(MulAcc);
+        } else {
+          Op0 = MulAcc->getVecOp0();
+          Op1 = MulAcc->getVecOp1();
+        }
+        Instruction *MulInstr = MulAcc->getMulInstr();
+        SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
+        auto *Mul = new VPWidenRecipe(*MulInstr,
+                                      make_range(MulOps.begin(), MulOps.end()));
+        auto *Red = new VPReductionRecipe(
+            MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
+            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+            MulAcc->isOrdered());
+        Mul->insertBefore(MulAcc);
+        Red->insertBefore(MulAcc);
+        MulAcc->replaceAllUsesWith(Red);
+        MulAcc->eraseFromParent();
+      }
     }
   }
 }

>From d29a1189194b9190ff6ae945587b34aff2ecfd49 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 5 Nov 2024 16:52:31 -0800
Subject: [PATCH 04/17] Fix servel errors and update tests.

We need to update tests since the generated vector IR will be reordered.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 45 ++++++++++++++
 llvm/lib/Transforms/Vectorize/VPlan.h         | 34 ++++++++---
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  6 +-
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 14 ++++-
 .../LoopVectorize/ARM/mve-reduction-types.ll  |  4 +-
 .../LoopVectorize/ARM/mve-reductions.ll       | 61 ++++++++++---------
 .../LoopVectorize/RISCV/inloop-reduction.ll   | 22 +++++--
 .../LoopVectorize/reduction-inloop-pred.ll    | 34 +++++------
 .../LoopVectorize/reduction-inloop.ll         | 12 ++--
 9 files changed, 158 insertions(+), 74 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e8c62fac5ee8c2..2bf3b63961791c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7396,6 +7396,19 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       }
       if (Instruction *UI = GetInstructionForCost(&R))
         SeenInstrs.insert(UI);
+      // VPExtendedReductionRecipe contains a folded extend instruction.
+      if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
+        SeenInstrs.insert(ExtendedRed->getExtInstr());
+      // VPMulAccRecupe constians a mul and otional extend instructions.
+      else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+        SeenInstrs.insert(MulAcc->getMulInstr());
+        if (MulAcc->isExtended()) {
+          SeenInstrs.insert(MulAcc->getExt0Instr());
+          SeenInstrs.insert(MulAcc->getExt1Instr());
+          if (auto *Ext = MulAcc->getExtInstr())
+            SeenInstrs.insert(Ext);
+        }
+      }
     }
   }
 
@@ -9409,6 +9422,38 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
               CM.useOrderedReductions(RdxDesc),
               cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
         }
+      } else if (RdxDesc.getOpcode() == Instruction::Add &&
+                 match(VecOp,
+                       m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue(A)),
+                                          m_ZExtOrSExt(m_VPValue(B)))))) {
+        VPWidenCastRecipe *Ext =
+            dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+        VPWidenRecipe *Mul =
+            dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+        if (Mul && match(Mul, m_Mul(m_ZExtOrSExt(m_VPValue()),
+                                    m_ZExtOrSExt(m_VPValue())))) {
+          VPWidenRecipe *Mul =
+              cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+          VPWidenCastRecipe *Ext0 =
+              cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
+          VPWidenCastRecipe *Ext1 =
+              cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
+          if (Ext->getOpcode() == Ext0->getOpcode() &&
+              Ext0->getOpcode() == Ext1->getOpcode()) {
+            RedRecipe = new VPMulAccRecipe(
+                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
+                cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
+          } else
+            RedRecipe = new VPExtendedReductionRecipe(
+                RdxDesc, CurrentLinkI,
+                cast<CastInst>(
+                    cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
+                PreviousLink, cast<VPWidenCastRecipe>(VecOp)->getOperand(0),
+                CondOp, CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenCastRecipe>(VecOp)->getResultType());
+        }
       }
       // VPWidenCastRecipes can folded into VPReductionRecipe
       else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 19d98674db9963..c33e6081669f8b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2812,23 +2812,27 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   bool IsConditional = false;
   /// Type after extend.
   Type *ResultType;
-  /// reduce.add(mul(Ext(), Ext()))
+  /// reduce.add(ext((mul(Ext(), Ext())))
   Instruction::CastOps ExtOp;
 
   Instruction *MulInstr;
+  CastInst *ExtInstr = nullptr;
   CastInst *Ext0Instr;
   CastInst *Ext1Instr;
 
   bool IsExtended;
+  bool IsOuterExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *MulInstr,
-                 Instruction::CastOps ExtOp, Instruction *Ext0Instr,
-                 Instruction *Ext1Instr, ArrayRef<VPValue *> Operands,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+                 Instruction *RedI, Instruction *ExtInstr,
+                 Instruction *MulInstr, Instruction::CastOps ExtOp,
+                 Instruction *Ext0Instr, Instruction *Ext1Instr,
+                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered,
+                 Type *ResultType)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
         ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+        ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
     if (CondOp) {
@@ -2855,9 +2859,9 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       Ext0->getOpcode(), Ext0->getUnderlyingInstr(),
-                       Ext1->getUnderlyingInstr(),
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
+                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
+                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
                        ArrayRef<VPValue *>(
                            {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
                        CondOp, IsOrdered, Ext0->getResultType()) {}
@@ -2867,9 +2871,20 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  VPWidenRecipe *Mul)
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
                        ArrayRef<VPValue *>(
-                           {ChainOp, Mul->getOperand(0), Mul->getOperand(0)}),
+                           {ChainOp, Mul->getOperand(0), Mul->getOperand(1)}),
                        CondOp, IsOrdered) {}
 
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
+                 VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
+                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
+                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
+                       ArrayRef<VPValue *>(
+                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       CondOp, IsOrdered, Ext0->getResultType()) {}
+
   ~VPMulAccRecipe() override = default;
 
   VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
@@ -2919,6 +2934,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   Type *getResultType() const { return ResultType; };
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
   Instruction *getMulInstr() const { return MulInstr; };
+  CastInst *getExtInstr() const { return ExtInstr; };
   CastInst *getExt0Instr() const { return Ext0Instr; };
   CastInst *getExt1Instr() const { return Ext1Instr; };
   bool isExtended() const { return IsExtended; };
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index a7d74a78e06317..ff1e31033fd9cf 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2381,8 +2381,8 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 
 InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
-  Type *ElementTy =
-      IsExtended ? getResultType() : Ctx.Types.inferScalarType(getVecOp0());
+  Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
+                               : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
@@ -2443,7 +2443,7 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
         RHSInfo, Operands, MulInstr, &Ctx.TLI);
   }
 
-  // ExtendedReduction Cost
+  // MulAccReduction Cost
   VectorType *SrcVecTy =
       cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
   InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 1f82228d4d787e..c28953e7314f57 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -554,15 +554,27 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
           Op0 = MulAcc->getVecOp0();
           Op1 = MulAcc->getVecOp1();
         }
+        VPSingleDefRecipe *VecOp;
         Instruction *MulInstr = MulAcc->getMulInstr();
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
         auto *Mul = new VPWidenRecipe(*MulInstr,
                                       make_range(MulOps.begin(), MulOps.end()));
+        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
+          // dbgs() <<"\n!!!"<< *OuterExtInstr << " " << MulAcc->getExtOpcode()
+          // << "\n";
+          VecOp = new VPWidenCastRecipe(
+              MulAcc->getExtOpcode(), Mul,
+              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
+              *OuterExtInstr);
+        } else
+          VecOp = Mul;
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
-            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+            MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
             MulAcc->isOrdered());
         Mul->insertBefore(MulAcc);
+        if (VecOp != Mul)
+          VecOp->insertBefore(MulAcc);
         Red->insertBefore(MulAcc);
         MulAcc->replaceAllUsesWith(Red);
         MulAcc->eraseFromParent();
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
index 6f78982d7ab029..0067b5ff5243cb 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll
@@ -24,11 +24,11 @@ define i32 @mla_i32(ptr noalias nocapture readonly %A, ptr noalias nocapture rea
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[TMP1]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP5]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP6:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <16 x i32> [[TMP6]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP8:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP7]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
@@ -107,11 +107,11 @@ define i32 @mla_i8(ptr noalias nocapture readonly %A, ptr noalias nocapture read
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[TMP1]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[TMP0]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP5]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP6:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <16 x i32> [[TMP6]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP8:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP7]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index baad7a84a891a6..ffdd8e66983d85 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -646,11 +646,11 @@ define i64 @mla_i16_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i16>, ptr [[TMP2]], align 2
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
+; CHECK-NEXT:    [[TMP11:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i32> [[TMP4]] to <8 x i64>
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
 ; CHECK-NEXT:    [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
@@ -726,11 +726,11 @@ define i64 @mla_i8_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i32
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i8>, ptr [[TMP0]], align 1
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i8>, ptr [[TMP2]], align 1
-; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <8 x i32> [[TMP3]], [[TMP1]]
+; CHECK-NEXT:    [[TMP11:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <8 x i32> [[TMP11]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = zext nneg <8 x i32> [[TMP4]] to <8 x i64>
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
 ; CHECK-NEXT:    [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
@@ -802,11 +802,11 @@ define i32 @mla_i32_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP0]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP1]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i32, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[TMP7]], i32 4, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP2]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
@@ -855,10 +855,10 @@ define i32 @mla_i16_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP0]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP2]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP4]], <8 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP5]])
@@ -910,10 +910,10 @@ define i32 @mla_i8_i32(ptr nocapture readonly %x, ptr nocapture readonly %y, i32
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <16 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP4]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
@@ -963,11 +963,11 @@ define signext i16 @mla_i16_i16(ptr nocapture readonly %x, ptr nocapture readonl
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i16 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP0]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP1]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = mul <8 x i16> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP7]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = mul <8 x i16> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> [[TMP2]], <8 x i16> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i16 [[TMP4]], [[VEC_PHI]]
@@ -1016,10 +1016,10 @@ define signext i16 @mla_i8_i16(ptr nocapture readonly %x, ptr nocapture readonly
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i16>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i16>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i16>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw <16 x i16> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i16> [[TMP4]], <16 x i16> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[TMP5]])
@@ -1069,11 +1069,11 @@ define zeroext i8 @mla_i8_i8(ptr nocapture readonly %x, ptr nocapture readonly %
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i8 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP1]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = mul <16 x i8> [[WIDE_MASKED_LOAD1]], [[WIDE_MASKED_LOAD]]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[Y1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP7]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = mul <16 x i8> [[WIDE_MASKED_LOAD2]], [[WIDE_MASKED_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> [[TMP2]], <16 x i8> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i8 @llvm.vector.reduce.add.v16i8(<16 x i8> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i8 [[TMP4]], [[VEC_PHI]]
@@ -1122,10 +1122,10 @@ define i32 @red_mla_ext_s8_s16_s32(ptr noalias nocapture readonly %A, ptr noalia
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i8> @llvm.masked.load.v8i8.p0(ptr [[TMP0]], i32 1, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[B:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[TMP2]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
 ; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP4]], <8 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP5]])
@@ -1183,11 +1183,11 @@ define i64 @red_mla_ext_s16_u16_s64(ptr noalias nocapture readonly %A, ptr noali
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, ptr [[TMP0]], align 1
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <4 x i16> [[WIDE_LOAD]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[B:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i16>, ptr [[TMP2]], align 2
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, ptr [[TMP2]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <4 x i16> [[WIDE_LOAD]] to <4 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds i16, ptr [[B1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i16>, ptr [[TMP11]], align 2
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext <4 x i16> [[WIDE_LOAD1]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <4 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = zext <4 x i32> [[TMP4]] to <4 x i64>
@@ -1206,10 +1206,10 @@ define i64 @red_mla_ext_s16_u16_s64(ptr noalias nocapture readonly %A, ptr noali
 ; CHECK:       for.body:
 ; CHECK-NEXT:    [[I_011:%.*]] = phi i32 [ [[INC:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
 ; CHECK-NEXT:    [[S_010:%.*]] = phi i64 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i16, ptr [[A]], i32 [[I_011]]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i16, ptr [[B]], i32 [[I_011]]
 ; CHECK-NEXT:    [[TMP9:%.*]] = load i16, ptr [[ARRAYIDX]], align 1
 ; CHECK-NEXT:    [[CONV:%.*]] = sext i16 [[TMP9]] to i32
-; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds i16, ptr [[B]], i32 [[I_011]]
+; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds i16, ptr [[B1]], i32 [[I_011]]
 ; CHECK-NEXT:    [[TMP10:%.*]] = load i16, ptr [[ARRAYIDX1]], align 2
 ; CHECK-NEXT:    [[CONV2:%.*]] = zext i16 [[TMP10]] to i32
 ; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[CONV2]], [[CONV]]
@@ -1268,12 +1268,12 @@ define i32 @red_mla_u8_s8_u32(ptr noalias nocapture readonly %A, ptr noalias noc
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]])
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP0]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i8> [[WIDE_MASKED_LOAD]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP2]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i8> [[WIDE_MASKED_LOAD1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds i8, ptr [[B1:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD2:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr [[TMP9]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD2]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <4 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP4]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP5]])
@@ -1413,7 +1413,8 @@ define i32 @mla_i8_i32_multiuse(ptr nocapture readonly %x, ptr nocapture readonl
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
 ; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP7]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP2]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
index c09b2de500d8cb..c6438857a7571e 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
@@ -187,8 +187,11 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
 ; IF-EVL-INLOOP-NEXT:    [[N_RND_UP:%.*]] = add i32 [[N]], [[TMP2]]
 ; IF-EVL-INLOOP-NEXT:    [[N_MOD_VF:%.*]] = urem i32 [[N_RND_UP]], [[TMP1]]
 ; IF-EVL-INLOOP-NEXT:    [[N_VEC:%.*]] = sub i32 [[N_RND_UP]], [[N_MOD_VF]]
+; IF-EVL-INLOOP-NEXT:    [[TRIP_COUNT_MINUS_1:%.*]] = sub i32 [[N]], 1
 ; IF-EVL-INLOOP-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vscale.i32()
 ; IF-EVL-INLOOP-NEXT:    [[TMP4:%.*]] = mul i32 [[TMP3]], 8
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[TRIP_COUNT_MINUS_1]], i64 0
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLAT2:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT1]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
 ; IF-EVL-INLOOP-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; IF-EVL-INLOOP:       vector.body:
 ; IF-EVL-INLOOP-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
@@ -197,12 +200,19 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
 ; IF-EVL-INLOOP-NEXT:    [[AVL:%.*]] = sub i32 [[N]], [[EVL_BASED_IV]]
 ; IF-EVL-INLOOP-NEXT:    [[TMP5:%.*]] = call i32 @llvm.experimental.get.vector.length.i32(i32 [[AVL]], i32 8, i1 true)
 ; IF-EVL-INLOOP-NEXT:    [[TMP6:%.*]] = add i32 [[EVL_BASED_IV]], 0
-; IF-EVL-INLOOP-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP6]]
-; IF-EVL-INLOOP-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i16, ptr [[TMP7]], i32 0
-; IF-EVL-INLOOP-NEXT:    [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP8]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP5]])
-; IF-EVL-INLOOP-NEXT:    [[TMP9:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
-; IF-EVL-INLOOP-NEXT:    [[TMP10:%.*]] = call i32 @llvm.vp.reduce.add.nxv8i32(i32 0, <vscale x 8 x i32> [[TMP9]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP5]])
-; IF-EVL-INLOOP-NEXT:    [[TMP11]] = add i32 [[TMP10]], [[VEC_PHI]]
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[EVL_BASED_IV]], i64 0
+; IF-EVL-INLOOP-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
+; IF-EVL-INLOOP-NEXT:    [[TMP7:%.*]] = call <vscale x 8 x i32> @llvm.stepvector.nxv8i32()
+; IF-EVL-INLOOP-NEXT:    [[TMP8:%.*]] = add <vscale x 8 x i32> zeroinitializer, [[TMP7]]
+; IF-EVL-INLOOP-NEXT:    [[VEC_IV:%.*]] = add <vscale x 8 x i32> [[BROADCAST_SPLAT]], [[TMP8]]
+; IF-EVL-INLOOP-NEXT:    [[TMP9:%.*]] = icmp ule <vscale x 8 x i32> [[VEC_IV]], [[BROADCAST_SPLAT2]]
+; IF-EVL-INLOOP-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP6]]
+; IF-EVL-INLOOP-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i16, ptr [[TMP10]], i32 0
+; IF-EVL-INLOOP-NEXT:    [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP15]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP5]])
+; IF-EVL-INLOOP-NEXT:    [[TMP16:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
+; IF-EVL-INLOOP-NEXT:    [[TMP17:%.*]] = select <vscale x 8 x i1> [[TMP9]], <vscale x 8 x i32> [[TMP16]], <vscale x 8 x i32> zeroinitializer
+; IF-EVL-INLOOP-NEXT:    [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP17]])
+; IF-EVL-INLOOP-NEXT:    [[TMP11]] = add i32 [[TMP14]], [[VEC_PHI]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_EVL_NEXT]] = add nuw i32 [[TMP5]], [[EVL_BASED_IV]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], [[TMP4]]
 ; IF-EVL-INLOOP-NEXT:    [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
index 8e132ed8399cd6..5b171adec78f6b 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll
@@ -424,62 +424,62 @@ define i32 @reduction_mix(ptr noalias nocapture %A, ptr noalias nocapture %B) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x i1> [[TMP0]], i64 0
 ; CHECK-NEXT:    br i1 [[TMP1]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
 ; CHECK:       pred.load.if:
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr [[TMP2]], align 4
-; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <4 x i32> poison, i32 [[TMP3]], i64 0
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = load i32, ptr [[TMP5]], align 4
 ; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x i32> poison, i32 [[TMP6]], i64 0
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[B1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[TMP12:%.*]] = load i32, ptr [[TMP8]], align 4
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <4 x i32> poison, i32 [[TMP12]], i64 0
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE]]
 ; CHECK:       pred.load.continue:
-; CHECK-NEXT:    [[TMP8:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP4]], [[PRED_LOAD_IF]] ]
 ; CHECK-NEXT:    [[TMP9:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP7]], [[PRED_LOAD_IF]] ]
+; CHECK-NEXT:    [[TMP14:%.*]] = phi <4 x i32> [ poison, [[VECTOR_BODY]] ], [ [[TMP13]], [[PRED_LOAD_IF]] ]
 ; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x i1> [[TMP0]], i64 1
 ; CHECK-NEXT:    br i1 [[TMP10]], label [[PRED_LOAD_IF3:%.*]], label [[PRED_LOAD_CONTINUE4:%.*]]
 ; CHECK:       pred.load.if3:
 ; CHECK-NEXT:    [[TMP11:%.*]] = or disjoint i64 [[INDEX]], 1
-; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP11]]
-; CHECK-NEXT:    [[TMP13:%.*]] = load i32, ptr [[TMP12]], align 4
-; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <4 x i32> [[TMP8]], i32 [[TMP13]], i64 1
 ; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP11]]
 ; CHECK-NEXT:    [[TMP16:%.*]] = load i32, ptr [[TMP15]], align 4
 ; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <4 x i32> [[TMP9]], i32 [[TMP16]], i64 1
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP22:%.*]] = load i32, ptr [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP23:%.*]] = insertelement <4 x i32> [[TMP14]], i32 [[TMP22]], i64 1
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE4]]
 ; CHECK:       pred.load.continue4:
-; CHECK-NEXT:    [[TMP18:%.*]] = phi <4 x i32> [ [[TMP8]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP14]], [[PRED_LOAD_IF3]] ]
 ; CHECK-NEXT:    [[TMP19:%.*]] = phi <4 x i32> [ [[TMP9]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP17]], [[PRED_LOAD_IF3]] ]
+; CHECK-NEXT:    [[TMP24:%.*]] = phi <4 x i32> [ [[TMP14]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP23]], [[PRED_LOAD_IF3]] ]
 ; CHECK-NEXT:    [[TMP20:%.*]] = extractelement <4 x i1> [[TMP0]], i64 2
 ; CHECK-NEXT:    br i1 [[TMP20]], label [[PRED_LOAD_IF5:%.*]], label [[PRED_LOAD_CONTINUE6:%.*]]
 ; CHECK:       pred.load.if5:
 ; CHECK-NEXT:    [[TMP21:%.*]] = or disjoint i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP22:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP21]]
-; CHECK-NEXT:    [[TMP23:%.*]] = load i32, ptr [[TMP22]], align 4
-; CHECK-NEXT:    [[TMP24:%.*]] = insertelement <4 x i32> [[TMP18]], i32 [[TMP23]], i64 2
 ; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP21]]
 ; CHECK-NEXT:    [[TMP26:%.*]] = load i32, ptr [[TMP25]], align 4
 ; CHECK-NEXT:    [[TMP27:%.*]] = insertelement <4 x i32> [[TMP19]], i32 [[TMP26]], i64 2
+; CHECK-NEXT:    [[TMP28:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP21]]
+; CHECK-NEXT:    [[TMP32:%.*]] = load i32, ptr [[TMP28]], align 4
+; CHECK-NEXT:    [[TMP33:%.*]] = insertelement <4 x i32> [[TMP24]], i32 [[TMP32]], i64 2
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE6]]
 ; CHECK:       pred.load.continue6:
-; CHECK-NEXT:    [[TMP28:%.*]] = phi <4 x i32> [ [[TMP18]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP24]], [[PRED_LOAD_IF5]] ]
 ; CHECK-NEXT:    [[TMP29:%.*]] = phi <4 x i32> [ [[TMP19]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP27]], [[PRED_LOAD_IF5]] ]
+; CHECK-NEXT:    [[TMP34:%.*]] = phi <4 x i32> [ [[TMP24]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP33]], [[PRED_LOAD_IF5]] ]
 ; CHECK-NEXT:    [[TMP30:%.*]] = extractelement <4 x i1> [[TMP0]], i64 3
 ; CHECK-NEXT:    br i1 [[TMP30]], label [[PRED_LOAD_IF7:%.*]], label [[PRED_LOAD_CONTINUE8]]
 ; CHECK:       pred.load.if7:
 ; CHECK-NEXT:    [[TMP31:%.*]] = or disjoint i64 [[INDEX]], 3
-; CHECK-NEXT:    [[TMP32:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP31]]
-; CHECK-NEXT:    [[TMP33:%.*]] = load i32, ptr [[TMP32]], align 4
-; CHECK-NEXT:    [[TMP34:%.*]] = insertelement <4 x i32> [[TMP28]], i32 [[TMP33]], i64 3
 ; CHECK-NEXT:    [[TMP35:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP31]]
 ; CHECK-NEXT:    [[TMP36:%.*]] = load i32, ptr [[TMP35]], align 4
 ; CHECK-NEXT:    [[TMP37:%.*]] = insertelement <4 x i32> [[TMP29]], i32 [[TMP36]], i64 3
+; CHECK-NEXT:    [[TMP38:%.*]] = getelementptr inbounds i32, ptr [[B1]], i64 [[TMP31]]
+; CHECK-NEXT:    [[TMP48:%.*]] = load i32, ptr [[TMP38]], align 4
+; CHECK-NEXT:    [[TMP49:%.*]] = insertelement <4 x i32> [[TMP34]], i32 [[TMP48]], i64 3
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE8]]
 ; CHECK:       pred.load.continue8:
-; CHECK-NEXT:    [[TMP38:%.*]] = phi <4 x i32> [ [[TMP28]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP34]], [[PRED_LOAD_IF7]] ]
 ; CHECK-NEXT:    [[TMP39:%.*]] = phi <4 x i32> [ [[TMP29]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP37]], [[PRED_LOAD_IF7]] ]
-; CHECK-NEXT:    [[TMP40:%.*]] = mul nsw <4 x i32> [[TMP39]], [[TMP38]]
+; CHECK-NEXT:    [[TMP50:%.*]] = phi <4 x i32> [ [[TMP34]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP49]], [[PRED_LOAD_IF7]] ]
 ; CHECK-NEXT:    [[TMP41:%.*]] = select <4 x i1> [[TMP0]], <4 x i32> [[VEC_IND1]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP42:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP41]])
 ; CHECK-NEXT:    [[TMP43:%.*]] = add i32 [[TMP42]], [[VEC_PHI]]
+; CHECK-NEXT:    [[TMP40:%.*]] = mul nsw <4 x i32> [[TMP50]], [[TMP39]]
 ; CHECK-NEXT:    [[TMP44:%.*]] = select <4 x i1> [[TMP0]], <4 x i32> [[TMP40]], <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP45:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP44]])
 ; CHECK-NEXT:    [[TMP46]] = add i32 [[TMP45]], [[TMP43]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index fe74a7c3a9b27c..b578e61d85dfa1 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -221,13 +221,13 @@ define i32 @reduction_mix(ptr noalias nocapture %A, ptr noalias nocapture %B) {
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP6:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_IND:%.*]] = phi <4 x i32> [ <i32 0, i32 1, i32 2, i32 3>, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i32>, ptr [[TMP1]], align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[B1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <4 x i32>, ptr [[TMP8]], align 4
 ; CHECK-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[VEC_IND]])
 ; CHECK-NEXT:    [[TMP4:%.*]] = add i32 [[TMP3]], [[VEC_PHI]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD2]], [[WIDE_LOAD1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP2]])
 ; CHECK-NEXT:    [[TMP6]] = add i32 [[TMP5]], [[TMP4]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
@@ -329,11 +329,11 @@ define i32 @start_at_non_zero(ptr nocapture %in, ptr nocapture %coeff, ptr nocap
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 120, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[IN:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[COEFF:%.*]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i32>, ptr [[TMP1]], align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds i32, ptr [[COEFF1:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <4 x i32>, ptr [[TMP6]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <4 x i32> [[WIDE_LOAD2]], [[WIDE_LOAD1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP2]])
 ; CHECK-NEXT:    [[TMP4]] = add i32 [[TMP3]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4

>From e5b50f7883cd07c5eccf4a7025d1926da0c6b4db Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Tue, 5 Nov 2024 17:49:26 -0800
Subject: [PATCH 05/17] Refactors

Using lamda function to early return when pattern matched.
Leave some assertions.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 121 ++++++++---------
 llvm/lib/Transforms/Vectorize/VPlan.h         |  55 ++++----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 123 +-----------------
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  19 +--
 .../LoopVectorize/ARM/mve-reductions.ll       |   3 +-
 .../LoopVectorize/RISCV/inloop-reduction.ll   |   8 +-
 6 files changed, 113 insertions(+), 216 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 2bf3b63961791c..f7bca1383cdad1 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7399,7 +7399,7 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       // VPExtendedReductionRecipe contains a folded extend instruction.
       if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
         SeenInstrs.insert(ExtendedRed->getExtInstr());
-      // VPMulAccRecupe constians a mul and otional extend instructions.
+      // VPMulAccRecipe constians a mul and otional extend instructions.
       else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
         SeenInstrs.insert(MulAcc->getMulInstr());
         if (MulAcc->isExtended()) {
@@ -9398,77 +9398,82 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      VPValue *A, *B;
-      VPSingleDefRecipe *RedRecipe;
-      // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
-      if (RdxDesc.getOpcode() == Instruction::Add &&
-          match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
-        VPRecipeBase *RecipeA = A->getDefiningRecipe();
-        VPRecipeBase *RecipeB = B->getDefiningRecipe();
-        if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
-            match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
-            cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
-                cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
-            !A->hasMoreThanOneUniqueUser() && !B->hasMoreThanOneUniqueUser()) {
-          RedRecipe = new VPMulAccRecipe(
-              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-              CM.useOrderedReductions(RdxDesc),
-              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
-              cast<VPWidenCastRecipe>(RecipeA),
-              cast<VPWidenCastRecipe>(RecipeB));
-        } else {
-          RedRecipe = new VPMulAccRecipe(
-              RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-              CM.useOrderedReductions(RdxDesc),
-              cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
-        }
-      } else if (RdxDesc.getOpcode() == Instruction::Add &&
-                 match(VecOp,
-                       m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue(A)),
-                                          m_ZExtOrSExt(m_VPValue(B)))))) {
-        VPWidenCastRecipe *Ext =
-            dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
-        VPWidenRecipe *Mul =
-            dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
-        if (Mul && match(Mul, m_Mul(m_ZExtOrSExt(m_VPValue()),
-                                    m_ZExtOrSExt(m_VPValue())))) {
+      auto TryToMatchMulAcc = [&]() -> VPSingleDefRecipe * {
+        VPValue *A, *B;
+        if (RdxDesc.getOpcode() != Instruction::Add)
+          return nullptr;
+        // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+        if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
+            !VecOp->hasMoreThanOneUniqueUser()) {
+          VPRecipeBase *RecipeA = A->getDefiningRecipe();
+          VPRecipeBase *RecipeB = B->getDefiningRecipe();
+          if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+              match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+              cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
+                  cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
+              !A->hasMoreThanOneUniqueUser() &&
+              !B->hasMoreThanOneUniqueUser()) {
+            return new VPMulAccRecipe(
+                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
+                cast<VPWidenCastRecipe>(RecipeA),
+                cast<VPWidenCastRecipe>(RecipeB));
+          } else {
+            // Matched reduce.add(mul(...))
+            return new VPMulAccRecipe(
+                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                CM.useOrderedReductions(RdxDesc),
+                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+          }
+          // Matched reduce.add(ext(mul(ext, ext)))
+          // Note that 3 extend instructions must have same opcode.
+        } else if (match(VecOp,
+                         m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
+                                            m_ZExtOrSExt(m_VPValue())))) &&
+                   !VecOp->hasMoreThanOneUniqueUser()) {
+          VPWidenCastRecipe *Ext =
+              dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
           VPWidenRecipe *Mul =
-              cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+              dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext0 =
               cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext1 =
               cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
           if (Ext->getOpcode() == Ext0->getOpcode() &&
-              Ext0->getOpcode() == Ext1->getOpcode()) {
-            RedRecipe = new VPMulAccRecipe(
+              Ext0->getOpcode() == Ext1->getOpcode() &&
+              !Mul->hasMoreThanOneUniqueUser() &&
+              !Ext0->hasMoreThanOneUniqueUser() &&
+              !Ext1->hasMoreThanOneUniqueUser()) {
+            return new VPMulAccRecipe(
                 RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                 CM.useOrderedReductions(RdxDesc),
                 cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
                 cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
-          } else
-            RedRecipe = new VPExtendedReductionRecipe(
-                RdxDesc, CurrentLinkI,
-                cast<CastInst>(
-                    cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
-                PreviousLink, cast<VPWidenCastRecipe>(VecOp)->getOperand(0),
-                CondOp, CM.useOrderedReductions(RdxDesc),
-                cast<VPWidenCastRecipe>(VecOp)->getResultType());
+          }
         }
-      }
-      // VPWidenCastRecipes can folded into VPReductionRecipe
-      else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
-               !VecOp->hasMoreThanOneUniqueUser()) {
-        RedRecipe = new VPExtendedReductionRecipe(
-            RdxDesc, CurrentLinkI,
-            cast<CastInst>(
-                cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
-            PreviousLink, A, CondOp, CM.useOrderedReductions(RdxDesc),
-            cast<VPWidenCastRecipe>(VecOp)->getResultType());
-      } else {
+        return nullptr;
+      };
+      auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
+        VPValue *A;
+        if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+            !VecOp->hasMoreThanOneUniqueUser()) {
+          return new VPExtendedReductionRecipe(
+              RdxDesc, CurrentLinkI, PreviousLink,
+              cast<VPWidenCastRecipe>(VecOp), CondOp,
+              CM.useOrderedReductions(RdxDesc));
+        }
+        return nullptr;
+      };
+      VPSingleDefRecipe *RedRecipe;
+      if (auto *MulAcc = TryToMatchMulAcc())
+        RedRecipe = MulAcc;
+      else if (auto *ExtendedRed = TryToMatchExtendedReduction())
+        RedRecipe = ExtendedRed;
+      else
         RedRecipe =
             new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
                                   CondOp, CM.useOrderedReductions(RdxDesc));
-      }
       // Append the recipe to the end of the VPBasicBlock because we need to
       // ensure that it comes after all of it's inputs, including CondOp.
       // Note that this transformation may leave over dead recipes (including
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c33e6081669f8b..e5b317dfd6f365 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2711,18 +2711,19 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
   bool IsConditional = false;
   /// Type after extend.
   Type *ResultTy;
+  /// Opcode for the extend instruction.
   Instruction::CastOps ExtOp;
-  CastInst *CastInstr;
+  CastInst *ExtInstr;
   bool IsZExt;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
                             const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, CastInst *CastI,
+                            Instruction::CastOps ExtOp, CastInst *ExtI,
                             ArrayRef<VPValue *> Operands, VPValue *CondOp,
                             bool IsOrdered, Type *ResultTy)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
-        ResultTy(ResultTy), ExtOp(ExtOp), CastInstr(CastI) {
+        ResultTy(ResultTy), ExtOp(ExtOp), ExtInstr(ExtI) {
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2732,20 +2733,13 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
-                            CastInst *CastI, VPValue *ChainOp, VPValue *VecOp,
-                            VPValue *CondOp, bool IsOrdered, Type *ResultTy)
-      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
-                                  CastI->getOpcode(), CastI,
-                                  ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
-                                  IsOrdered, ResultTy) {}
-
-  VPExtendedReductionRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *Ext)
+                            VPValue *ChainOp, VPWidenCastRecipe *Ext,
+                            VPValue *CondOp, bool IsOrdered)
       : VPExtendedReductionRecipe(
-            VPDef::VPExtendedReductionSC, Red->getRecurrenceDescriptor(),
-            Red->getUnderlyingInstr(), Ext->getOpcode(),
+            VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
             cast<CastInst>(Ext->getUnderlyingInstr()),
-            ArrayRef<VPValue *>({Red->getChainOp(), Ext->getOperand(0)}),
-            Red->getCondOp(), Red->isOrdered(), Ext->getResultType()) {}
+            ArrayRef<VPValue *>({ChainOp, Ext->getOperand(0)}), CondOp,
+            IsOrdered, Ext->getResultType()) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2762,7 +2756,6 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
     return R && classof(R);
   }
 
-  /// Generate the reduction in the loop
   void execute(VPTransformState &State) override {
     llvm_unreachable("VPExtendedReductionRecipe should be transform to "
                      "VPExtendedRecipe + VPReductionRecipe before execution.");
@@ -2794,9 +2787,12 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
   VPValue *getCondOp() const {
     return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
   }
+  /// The Type after extended.
   Type *getResultType() const { return ResultTy; };
+  /// The Opcode of extend instruction.
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
-  CastInst *getExtInstr() const { return CastInstr; };
+  /// The CastInst of the extend instruction.
+  CastInst *getExtInstr() const { return ExtInstr; };
 };
 
 /// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2812,16 +2808,17 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   bool IsConditional = false;
   /// Type after extend.
   Type *ResultType;
-  /// reduce.add(ext((mul(Ext(), Ext())))
+  // Note that all extend instruction must have the same opcode in MulAcc.
   Instruction::CastOps ExtOp;
 
+  /// reduce.add(ext(mul(ext0(), ext1())))
   Instruction *MulInstr;
   CastInst *ExtInstr = nullptr;
-  CastInst *Ext0Instr;
-  CastInst *Ext1Instr;
+  CastInst *Ext0Instr = nullptr;
+  CastInst *Ext1Instr = nullptr;
 
+  /// Is this MulAcc recipe contains extend recipes?
   bool IsExtended;
-  bool IsOuterExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2835,6 +2832,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
         ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2847,6 +2845,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
       : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
         MulInstr(MulInstr) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
@@ -2898,13 +2897,12 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
     return R && classof(R);
   }
 
-  /// Generate the reduction in the loop
   void execute(VPTransformState &State) override {
     llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
                      "VPWidenRecipe + VPReductionRecipe before execution");
   }
 
-  /// Return the cost of VPExtendedReductionRecipe.
+  /// Return the cost of VPMulAccRecipe.
   InstructionCost computeCost(ElementCount VF,
                               VPCostContext &Ctx) const override;
 
@@ -2931,13 +2929,24 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
   VPValue *getCondOp() const {
     return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
   }
+  /// Return the type after inner extended, which must equal to the type of mul
+  /// instruction. If the ResultType != recurrenceType, than it must have a
+  /// extend recipe after mul recipe.
   Type *getResultType() const { return ResultType; };
+  /// The opcode of the extend instructions.
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  /// The underlying instruction for VPWidenRecipe.
   Instruction *getMulInstr() const { return MulInstr; };
+  /// The underlying Instruction for outer VPWidenCastRecipe.
   CastInst *getExtInstr() const { return ExtInstr; };
+  /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt0Instr() const { return Ext0Instr; };
+  /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt1Instr() const { return Ext1Instr; };
+  /// Return if this MulAcc recipe contains extend instructions.
   bool isExtended() const { return IsExtended; };
+  /// Return if the operands of mul instruction come from same extend.
+  bool isSameExtend() const { return Ext0Instr == Ext1Instr; };
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ff1e31033fd9cf..33febe56a77294 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2214,122 +2214,6 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
         Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
-  /*
-  using namespace llvm::VPlanPatternMatch;
-  auto GetMulAccReductionCost =
-      [&](const VPReductionRecipe *Red) -> InstructionCost {
-    VPValue *A, *B;
-    InstructionCost InnerExt0Cost = 0;
-    InstructionCost InnerExt1Cost = 0;
-    InstructionCost ExtCost = 0;
-    InstructionCost MulCost = 0;
-
-    VectorType *SrcVecTy = VectorTy;
-    Type *InnerExt0Ty;
-    Type *InnerExt1Ty;
-    Type *MaxInnerExtTy;
-    bool IsUnsigned = true;
-    bool HasOuterExt = false;
-
-    auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
-        Red->getVecOp()->getDefiningRecipe());
-    VPRecipeBase *Mul;
-    // Try to match outer extend reduce.add(ext(...))
-    if (Ext && match(Ext, m_ZExtOrSExt(m_VPValue())) &&
-        cast<VPWidenCastRecipe>(Ext)->getNumUsers() == 1) {
-      IsUnsigned =
-          Ext->getOpcode() == Instruction::CastOps::ZExt ? true : false;
-      ExtCost = Ext->computeCost(VF, Ctx);
-      Mul = Ext->getOperand(0)->getDefiningRecipe();
-      HasOuterExt = true;
-    } else {
-      Mul = Red->getVecOp()->getDefiningRecipe();
-    }
-
-    // Match reduce.add(mul())
-    if (Mul && match(Mul, m_Mul(m_VPValue(A), m_VPValue(B))) &&
-        cast<VPWidenRecipe>(Mul)->getNumUsers() == 1) {
-      MulCost = cast<VPWidenRecipe>(Mul)->computeCost(VF, Ctx);
-      auto *InnerExt0 =
-          dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
-      auto *InnerExt1 =
-          dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
-      bool HasInnerExt = false;
-      // Try to match inner extends.
-      if (InnerExt0 && InnerExt1 &&
-          match(InnerExt0, m_ZExtOrSExt(m_VPValue())) &&
-          match(InnerExt1, m_ZExtOrSExt(m_VPValue())) &&
-          InnerExt0->getOpcode() == InnerExt1->getOpcode() &&
-          (InnerExt0->getNumUsers() > 0 &&
-           !InnerExt0->hasMoreThanOneUniqueUser()) &&
-          (InnerExt1->getNumUsers() > 0 &&
-           !InnerExt1->hasMoreThanOneUniqueUser())) {
-        InnerExt0Cost = InnerExt0->computeCost(VF, Ctx);
-        InnerExt1Cost = InnerExt1->computeCost(VF, Ctx);
-        Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
-        Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
-        Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth() >
-                                      InnerExt1Ty->getIntegerBitWidth()
-                                  ? InnerExt0Ty
-                                  : InnerExt1Ty;
-        SrcVecTy = cast<VectorType>(ToVectorTy(MaxInnerExtTy, VF));
-        IsUnsigned = true;
-        HasInnerExt = true;
-      }
-      InstructionCost MulAccRedCost = Ctx.TTI.getMulAccReductionCost(
-          IsUnsigned, ElementTy, SrcVecTy, CostKind);
-      // Check if folding ext/mul into MulAccReduction is profitable.
-      if (MulAccRedCost.isValid() &&
-          MulAccRedCost <
-              ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
-        if (HasInnerExt) {
-          Ctx.FoldedRecipes[VF].insert(InnerExt0);
-          Ctx.FoldedRecipes[VF].insert(InnerExt1);
-        }
-        Ctx.FoldedRecipes[VF].insert(Mul);
-        if (HasOuterExt)
-          Ctx.FoldedRecipes[VF].insert(Ext);
-        return MulAccRedCost;
-      }
-    }
-    return InstructionCost::getInvalid();
-  };
-
-  // Match reduce(ext(...))
-  auto GetExtendedReductionCost =
-      [&](const VPReductionRecipe *Red) -> InstructionCost {
-    VPValue *VecOp = Red->getVecOp();
-    VPValue *A;
-    if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && VecOp->getNumUsers() == 1) {
-      VPWidenCastRecipe *Ext =
-          cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
-      bool IsUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
-      InstructionCost ExtCost = Ext->computeCost(VF, Ctx);
-      auto *ExtVecTy =
-          cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(A), VF));
-      InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
-          Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags(),
-          CostKind);
-      // Check if folding ext into ExtendedReduction is profitable.
-      if (ExtendedRedCost.isValid() && ExtendedRedCost < ExtCost + BaseCost) {
-        Ctx.FoldedRecipes[VF].insert(Ext);
-        return ExtendedRedCost;
-      }
-    }
-    return InstructionCost::getInvalid();
-  };
-
-  // Match MulAccReduction patterns.
-  InstructionCost MulAccCost = GetMulAccReductionCost(this);
-  if (MulAccCost.isValid())
-    return MulAccCost;
-
-  // Match ExtendedReduction patterns.
-  InstructionCost ExtendedCost = GetExtendedReductionCost(this);
-  if (ExtendedCost.isValid())
-    return ExtendedCost;
-  */
-
   // Default cost.
   return BaseCost;
 }
@@ -2343,9 +2227,6 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
-         "Inferred type and recurrence type mismatch.");
-
   // BaseCost = Reduction cost + BinOp cost
   InstructionCost ReductionCost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
@@ -2387,8 +2268,8 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
-         "Inferred type and recurrence type mismatch.");
+  assert(Opcode == Instruction::Add &&
+         "Reduction opcode must be add in the VPMulAccRecipe.");
 
   // BaseCost = Reduction cost + BinOp cost
   InstructionCost ReductionCost =
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index c28953e7314f57..423694cdd84e05 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -545,11 +545,15 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
           Op0 = new VPWidenCastRecipe(
               MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
               MulAcc->getResultType(), *MulAcc->getExt0Instr());
-          Op1 = new VPWidenCastRecipe(
-              MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
-              MulAcc->getResultType(), *MulAcc->getExt1Instr());
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
-          Op1->getDefiningRecipe()->insertBefore(MulAcc);
+          if (!MulAcc->isSameExtend()) {
+            Op1 = new VPWidenCastRecipe(
+                MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+                MulAcc->getResultType(), *MulAcc->getExt1Instr());
+            Op1->getDefiningRecipe()->insertBefore(MulAcc);
+          } else {
+            Op1 = Op0;
+          }
         } else {
           Op0 = MulAcc->getVecOp0();
           Op1 = MulAcc->getVecOp1();
@@ -559,14 +563,13 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
         auto *Mul = new VPWidenRecipe(*MulInstr,
                                       make_range(MulOps.begin(), MulOps.end()));
-        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
-          // dbgs() <<"\n!!!"<< *OuterExtInstr << " " << MulAcc->getExtOpcode()
-          // << "\n";
+        // Outer extend.
+        if (auto *OuterExtInstr = MulAcc->getExtInstr())
           VecOp = new VPWidenCastRecipe(
               MulAcc->getExtOpcode(), Mul,
               MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
               *OuterExtInstr);
-        } else
+        else
           VecOp = Mul;
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index ffdd8e66983d85..9b7c345076480a 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1412,9 +1412,8 @@ define i32 @mla_i8_i32_multiuse(ptr nocapture readonly %x, ptr nocapture readonl
 ; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP0]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP7:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP1]], [[TMP7]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw <16 x i32> [[TMP7]], [[TMP7]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP2]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i32 [[TMP4]], [[VEC_PHI]]
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
index c6438857a7571e..c20b9051750622 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/inloop-reduction.ll
@@ -209,14 +209,14 @@ define i32 @add_i16_i32(ptr nocapture readonly %x, i32 %n) {
 ; IF-EVL-INLOOP-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[TMP6]]
 ; IF-EVL-INLOOP-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i16, ptr [[TMP10]], i32 0
 ; IF-EVL-INLOOP-NEXT:    [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.vp.load.nxv8i16.p0(ptr align 2 [[TMP15]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP5]])
-; IF-EVL-INLOOP-NEXT:    [[TMP16:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
-; IF-EVL-INLOOP-NEXT:    [[TMP17:%.*]] = select <vscale x 8 x i1> [[TMP9]], <vscale x 8 x i32> [[TMP16]], <vscale x 8 x i32> zeroinitializer
+; IF-EVL-INLOOP-NEXT:    [[TMP12:%.*]] = sext <vscale x 8 x i16> [[VP_OP_LOAD]] to <vscale x 8 x i32>
+; IF-EVL-INLOOP-NEXT:    [[TMP17:%.*]] = select <vscale x 8 x i1> [[TMP9]], <vscale x 8 x i32> [[TMP12]], <vscale x 8 x i32> zeroinitializer
 ; IF-EVL-INLOOP-NEXT:    [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP17]])
 ; IF-EVL-INLOOP-NEXT:    [[TMP11]] = add i32 [[TMP14]], [[VEC_PHI]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_EVL_NEXT]] = add nuw i32 [[TMP5]], [[EVL_BASED_IV]]
 ; IF-EVL-INLOOP-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], [[TMP4]]
-; IF-EVL-INLOOP-NEXT:    [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; IF-EVL-INLOOP-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; IF-EVL-INLOOP-NEXT:    [[TMP16:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
+; IF-EVL-INLOOP-NEXT:    br i1 [[TMP16]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; IF-EVL-INLOOP:       middle.block:
 ; IF-EVL-INLOOP-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ; IF-EVL-INLOOP:       scalar.ph:

>From cc004ffd4f94c071c78eda07f327a20f85696be1 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Wed, 6 Nov 2024 21:07:04 -0800
Subject: [PATCH 06/17] Fix typos and update printing test

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |   3 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |   7 +-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  18 +-
 .../LoopVectorize/vplan-printing.ll           | 236 ++++++++++++++++++
 4 files changed, 253 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f7bca1383cdad1..b70f7f9dba2611 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7679,7 +7679,7 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
                              CanonicalIVStartValue, State);
   VPlanTransforms::prepareToExecute(BestVPlan);
 
-  // TODO: Rebase to fhahn's implementation.
+  // TODO: Replace with upstream implementation.
   VPlanTransforms::prepareExecute(BestVPlan);
   BestVPlan.execute(&State);
 
@@ -9279,7 +9279,6 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
 // Adjust AnyOf reductions; replace the reduction phi for the selected value
 // with a boolean reduction phi node to check if the condition is true in any
 // iteration. The final value is selected by the final ComputeReductionResult.
-// TODO: Implement VPMulAccHere.
 void LoopVectorizationPlanner::adjustRecipesForReductions(
     VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) {
   using namespace VPlanPatternMatch;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e5b317dfd6f365..504021d0f621d3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2798,8 +2798,8 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
 /// A recipe to represent inloop MulAccreduction operations, performing a
 /// reduction on a vector operand into a scalar value, and adding the result to
 /// a chain. This recipe is high level abstract which will generate
-/// VPReductionRecipe VPWidenRecipe(mul)and VPWidenCastRecipe before execution.
-/// The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+/// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
+/// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPSingleDefRecipe {
   /// The recurrence decriptor for the reduction in question.
   const RecurrenceDescriptor &RdxDesc;
@@ -2819,6 +2819,8 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
 
   /// Is this MulAcc recipe contains extend recipes?
   bool IsExtended;
+  /// Is this reciep contains outer extend instuction?
+  bool IsOuterExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2838,6 +2840,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
       addOperand(CondOp);
     }
     IsExtended = true;
+    IsOuterExtended = ExtInstr != nullptr;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 33febe56a77294..a85a4bc81c2854 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -270,7 +270,7 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = &WidenMem->getIngredient();
 
   InstructionCost RecipeCost;
-  if ((UI && Ctx.skipCostComputation(UI, VF.isVector()))) {
+  if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
     RecipeCost = 0;
   } else {
     RecipeCost = computeCost(VF, Ctx);
@@ -2413,18 +2413,20 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   printAsOperand(O, SlotTracker);
   O << " = ";
   getChainOp()->printAsOperand(O, SlotTracker);
-  O << " +";
+  O << " + ";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
-  O << " mul ";
+  if (IsOuterExtended)
+    O << " (";
+  O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+  O << "mul ";
   if (IsExtended)
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
   if (IsExtended)
-    O << " extended to " << *getResultType() << ")";
-  if (IsExtended)
-    O << "(";
+    O << " extended to " << *getResultType() << "), (";
+  else
+    O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
   if (IsExtended)
     O << " extended to " << *getResultType() << ")";
@@ -2433,6 +2435,8 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
+  if (IsOuterExtended)
+    O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
   if (RdxDesc.IntermediateStore)
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 195f6a48640e54..b1d7388005ad4e 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -1143,6 +1143,242 @@ exit:
   ret i16 %for.1
 }
 
+define i64 @print_extended_reduction(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_extended_reduction'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:     WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT:     EXTENDED-REDUCE ir<%add> = ir<%r.09> + reduce.add (ir<%load0> extended to i64)
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%6> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT:   EMIT vp<%7> = extract-from-end vp<%6>, ir<1>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%7>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%6>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i32, ptr %x, i32 %i.010
+; CHECK-NEXT:   IR   %load0 = load i32, ptr %arrayidx, align 4
+; CHECK-NEXT:   IR   %conv0 = zext i32 %load0 to i64
+; CHECK-NEXT:   IR   %add = add nsw i64 %r.09, %conv0
+; CHECK-NEXT:   IR   %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:   IR   %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+  %cmp8 = icmp sgt i32 %n, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i32, ptr %x, i32 %i.010
+  %load0 = load i32, ptr %arrayidx, align 4
+  %conv0 = zext i32 %load0 to i64
+  %add = add nsw i64 %r.09, %conv0
+  %inc = add nuw nsw i32 %i.010, 1
+  %exitcond = icmp eq i32 %inc, %n
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %r.0.lcssa
+}
+
+define i64 @print_mulacc(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_mulacc'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:     WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT:     CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
+; CHECK-NEXT:     vp<%5> = vector-pointer ir<%arrayidx1>
+; CHECK-NEXT:     WIDEN ir<%load1> = load vp<%5>
+; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> + reduce.add (mul ir<%load0>, ir<%load1>)
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%7> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT:   EMIT vp<%8> = extract-from-end vp<%7>, ir<1>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%7>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i64, ptr %x, i32 %i.010
+; CHECK-NEXT:   IR   %load0 = load i64, ptr %arrayidx, align 4
+; CHECK-NEXT:   IR   %arrayidx1 = getelementptr inbounds i64, ptr %y, i32 %i.010
+; CHECK-NEXT:   IR   %load1 = load i64, ptr %arrayidx1, align 4
+; CHECK-NEXT:   IR   %mul = mul nsw i64 %load0, %load1
+; CHECK-NEXT:   IR   %add = add nsw i64 %r.09, %mul
+; CHECK-NEXT:   IR   %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:   IR   %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+  %cmp8 = icmp sgt i32 %n, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i64, ptr %x, i32 %i.010
+  %load0 = load i64, ptr %arrayidx, align 4
+  %arrayidx1 = getelementptr inbounds i64, ptr %y, i32 %i.010
+  %load1 = load i64, ptr %arrayidx1, align 4
+  %mul = mul nsw i64 %load0, %load1
+  %add = add nsw i64 %r.09, %mul
+  %inc = add nuw nsw i32 %i.010, 1
+  %exitcond = icmp eq i32 %inc, %n
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %r.0.lcssa
+}
+
+define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_mulacc_extended'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF * UF
+; CHECK-NEXT: Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%2> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%r.09> = phi ir<0>, ir<%add>
+; CHECK-NEXT:     vp<%3> = SCALAR-STEPS vp<%2>, ir<1>
+; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<%3>
+; CHECK-NEXT:     vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:     WIDEN ir<%load0> = load vp<%4>
+; CHECK-NEXT:     CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
+; CHECK-NEXT:     vp<%5> = vector-pointer ir<%arrayidx1>
+; CHECK-NEXT:     WIDEN ir<%load1> = load vp<%5>
+; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> +  (reduce.add (mul (ir<%load0> extended to i32), (ir<%load1> extended to i32)) extended to i64)
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%7> = compute-reduction-result ir<%r.09>, ir<%add>
+; CHECK-NEXT:   EMIT vp<%8> = extract-from-end vp<%7>, ir<1>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT vp<%bc.merge.rdx> = resume-phi vp<%7>, ir<0>
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i16, ptr %x, i32 %i.010
+; CHECK-NEXT:   IR   %load0 = load i16, ptr %arrayidx, align 4
+; CHECK-NEXT:   IR   %arrayidx1 = getelementptr inbounds i16, ptr %y, i32 %i.010
+; CHECK-NEXT:   IR   %load1 = load i16, ptr %arrayidx1, align 4
+; CHECK-NEXT:   IR   %conv0 = sext i16 %load0 to i32
+; CHECK-NEXT:   IR   %conv1 = sext i16 %load1 to i32
+; CHECK-NEXT:   IR   %mul = mul nsw i32 %conv0, %conv1
+; CHECK-NEXT:   IR   %conv = sext i32 %mul to i64
+; CHECK-NEXT:   IR   %add = add nsw i64 %r.09, %conv
+; CHECK-NEXT:   IR   %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:   IR   %exitcond = icmp eq i32 %inc, %n
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+;
+entry:
+  %cmp8 = icmp sgt i32 %n, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %r.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i16, ptr %x, i32 %i.010
+  %load0 = load i16, ptr %arrayidx, align 4
+  %arrayidx1 = getelementptr inbounds i16, ptr %y, i32 %i.010
+  %load1 = load i16, ptr %arrayidx1, align 4
+  %conv0 = sext i16 %load0 to i32
+  %conv1 = sext i16 %load1 to i32
+  %mul = mul nsw i32 %conv0, %conv1
+  %conv = sext i32 %mul to i64
+  %add = add nsw i64 %r.09, %conv
+  %inc = add nuw nsw i32 %i.010, 1
+  %exitcond = icmp eq i32 %inc, %n
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %r.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %r.0.lcssa
+}
+
 !llvm.dbg.cu = !{!0}
 !llvm.module.flags = !{!3, !4}
 

>From b5445ca7e2b33b3485dee9a4e265f876ef6a2470 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 10 Nov 2024 18:00:52 -0800
Subject: [PATCH 07/17] Fold reduce.add(zext(mul(sext(A), sext(B)))) into
 MulAccRecipe when A == B

For the future refactor of avoiding reference underlying instructions
and mismatched opcode and the entend instruction in the new added
pattern, removed passing UI when creating VPWidenCastRecipe.
This removed will lead to dupicate extend instruction created after loop
vectorizer when there are two reduction patterns exist in the same loop.
This redundant instruction might be removed after LV.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 31 ++++++++-----------
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 15 +++++----
 .../LoopVectorize/ARM/mve-reductions.ll       |  3 +-
 .../LoopVectorize/reduction-inloop.ll         |  6 ++--
 4 files changed, 24 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index b70f7f9dba2611..b59fe696c393ac 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9404,20 +9404,18 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
         if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
             !VecOp->hasMoreThanOneUniqueUser()) {
-          VPRecipeBase *RecipeA = A->getDefiningRecipe();
-          VPRecipeBase *RecipeB = B->getDefiningRecipe();
+          VPWidenCastRecipe *RecipeA =
+              dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+          VPWidenCastRecipe *RecipeB =
+              dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
           if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
               match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
-              cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
-                  cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
-              !A->hasMoreThanOneUniqueUser() &&
-              !B->hasMoreThanOneUniqueUser()) {
+              (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
             return new VPMulAccRecipe(
                 RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                 CM.useOrderedReductions(RdxDesc),
-                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
-                cast<VPWidenCastRecipe>(RecipeA),
-                cast<VPWidenCastRecipe>(RecipeB));
+                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()), RecipeA,
+                RecipeB);
           } else {
             // Matched reduce.add(mul(...))
             return new VPMulAccRecipe(
@@ -9425,8 +9423,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
                 CM.useOrderedReductions(RdxDesc),
                 cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
           }
-          // Matched reduce.add(ext(mul(ext, ext)))
-          // Note that 3 extend instructions must have same opcode.
+          // Matched reduce.add(ext(mul(ext(A), ext(B))))
+          // Note that 3 extend instructions must have same opcode or A == B
+          // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
         } else if (match(VecOp,
                          m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
                                             m_ZExtOrSExt(m_VPValue())))) &&
@@ -9439,11 +9438,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
               cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext1 =
               cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
-          if (Ext->getOpcode() == Ext0->getOpcode() &&
-              Ext0->getOpcode() == Ext1->getOpcode() &&
-              !Mul->hasMoreThanOneUniqueUser() &&
-              !Ext0->hasMoreThanOneUniqueUser() &&
-              !Ext1->hasMoreThanOneUniqueUser()) {
+          if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
+              Ext0->getOpcode() == Ext1->getOpcode()) {
             return new VPMulAccRecipe(
                 RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                 CM.useOrderedReductions(RdxDesc),
@@ -9455,8 +9451,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       };
       auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
         VPValue *A;
-        if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
-            !VecOp->hasMoreThanOneUniqueUser()) {
+        if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
           return new VPExtendedReductionRecipe(
               RdxDesc, CurrentLinkI, PreviousLink,
               cast<VPWidenCastRecipe>(VecOp), CondOp,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 423694cdd84e05..d3f0b5c9e5ba6b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -542,14 +542,14 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
-          Op0 = new VPWidenCastRecipe(
-              MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
-              MulAcc->getResultType(), *MulAcc->getExt0Instr());
+          Op0 =
+              new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+                                    MulAcc->getResultType());
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           if (!MulAcc->isSameExtend()) {
-            Op1 = new VPWidenCastRecipe(
-                MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
-                MulAcc->getResultType(), *MulAcc->getExt1Instr());
+            Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+                                        MulAcc->getVecOp1(),
+                                        MulAcc->getResultType());
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           } else {
             Op1 = Op0;
@@ -567,8 +567,7 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         if (auto *OuterExtInstr = MulAcc->getExtInstr())
           VecOp = new VPWidenCastRecipe(
               MulAcc->getExtOpcode(), Mul,
-              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
-              *OuterExtInstr);
+              MulAcc->getRecurrenceDescriptor().getRecurrenceType());
         else
           VecOp = Mul;
         auto *Red = new VPReductionRecipe(
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index 9b7c345076480a..b138666ba92f9b 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1535,7 +1535,8 @@ define i64 @mla_and_add_together_16_64(ptr nocapture noundef readonly %x, i32 no
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext nneg <8 x i32> [[TMP2]] to <8 x i64>
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i64 [[TMP4]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP1]])
+; CHECK-NEXT:    [[TMP10:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP10]])
 ; CHECK-NEXT:    [[TMP7]] = add i32 [[TMP6]], [[VEC_PHI1]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index b578e61d85dfa1..d6dbb74f26d4ab 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -1206,15 +1206,13 @@ define i32 @predicated_not_dominates_reduction_twoadd(ptr nocapture noundef read
 ; CHECK-NEXT:    [[TMP0:%.*]] = sext i32 [[INDEX]] to i64
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[H:%.*]], i64 [[TMP0]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
-; CHECK-NEXT:    [[DOTNOT:%.*]] = icmp eq <4 x i8> [[WIDE_LOAD]], zeroinitializer
 ; CHECK-NEXT:    [[TMP2:%.*]] = udiv <4 x i8> [[WIDE_LOAD]], splat (i8 31)
 ; CHECK-NEXT:    [[TMP3:%.*]] = shl nuw nsw <4 x i8> [[TMP2]], splat (i8 3)
 ; CHECK-NEXT:    [[TMP4:%.*]] = udiv <4 x i8> [[TMP3]], splat (i8 31)
 ; CHECK-NEXT:    [[TMP5:%.*]] = zext nneg <4 x i8> [[TMP4]] to <4 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i32> zeroinitializer, <4 x i32> [[TMP5]]
-; CHECK-NEXT:    [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
+; CHECK-NEXT:    [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP5]])
 ; CHECK-NEXT:    [[TMP8:%.*]] = add i32 [[TMP7]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP9:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i32> zeroinitializer, <4 x i32> [[TMP5]]
+; CHECK-NEXT:    [[TMP9:%.*]] = zext nneg <4 x i8> [[TMP4]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP9]])
 ; CHECK-NEXT:    [[TMP11]] = add i32 [[TMP10]], [[TMP8]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4

>From 1df91d490fef8cf29a724110bd59f17b1075ab84 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 10 Nov 2024 23:34:05 -0800
Subject: [PATCH 08/17] Refactor! Reuse functions from VPReductionRecipe.

---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 114 +++++-------------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |   6 +-
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  20 +--
 3 files changed, 47 insertions(+), 93 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 504021d0f621d3..2e8786cd569205 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -684,8 +684,6 @@ struct VPCostContext {
   LLVMContext &LLVMCtx;
   LoopVectorizationCostModel &CM;
   SmallPtrSet<Instruction *, 8> SkipCostComputation;
-  /// Contains recipes that are folded into other recipes.
-  SmallDenseMap<ElementCount, SmallPtrSet<VPRecipeBase *, 4>, 4> FoldedRecipes;
 
   VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,
                 Type *CanIVTy, LoopVectorizationCostModel &CM)
@@ -2616,6 +2614,8 @@ class VPReductionRecipe : public VPSingleDefRecipe {
                                  getVecOp(), getCondOp(), IsOrdered);
   }
 
+  // TODO: Support VPExtendedReductionRecipe and VPMulAccRecipe after EVL
+  // support.
   static inline bool classof(const VPRecipeBase *R) {
     return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
            R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
@@ -2703,33 +2703,20 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 /// a chain. This recipe is high level abstract which will generate
 /// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
 /// {ChainOp, VecOp, [Condition]}.
-class VPExtendedReductionRecipe : public VPSingleDefRecipe {
-  /// The recurrence decriptor for the reduction in question.
-  const RecurrenceDescriptor &RdxDesc;
-  bool IsOrdered;
-  /// Whether the reduction is conditional.
-  bool IsConditional = false;
+class VPExtendedReductionRecipe : public VPReductionRecipe {
   /// Type after extend.
   Type *ResultTy;
-  /// Opcode for the extend instruction.
-  Instruction::CastOps ExtOp;
   CastInst *ExtInstr;
-  bool IsZExt;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
                             const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, CastInst *ExtI,
-                            ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                            Instruction::CastOps ExtOp, CastInst *ExtInstr,
+                            VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
                             bool IsOrdered, Type *ResultTy)
-      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
-        ResultTy(ResultTy), ExtOp(ExtOp), ExtInstr(ExtI) {
-    if (CondOp) {
-      IsConditional = true;
-      addOperand(CondOp);
-    }
-    IsZExt = ExtOp == Instruction::CastOps::ZExt;
-  }
+      : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
+                          CondOp, IsOrdered),
+        ResultTy(ResultTy), ExtInstr(ExtInstr) {}
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
@@ -2737,9 +2724,8 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
                             VPValue *CondOp, bool IsOrdered)
       : VPExtendedReductionRecipe(
             VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
-            cast<CastInst>(Ext->getUnderlyingInstr()),
-            ArrayRef<VPValue *>({ChainOp, Ext->getOperand(0)}), CondOp,
-            IsOrdered, Ext->getResultType()) {}
+            cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
+            Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2771,26 +2757,11 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
              VPSlotTracker &SlotTracker) const override;
 #endif
 
-  /// Return the recurrence decriptor for the in-loop reduction.
-  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
-    return RdxDesc;
-  }
-  /// Return true if the in-loop reduction is ordered.
-  bool isOrdered() const { return IsOrdered; };
-  /// Return true if the in-loop reduction is conditional.
-  bool isConditional() const { return IsConditional; };
-  /// The VPValue of the scalar Chain being accumulated.
-  VPValue *getChainOp() const { return getOperand(0); }
-  /// The VPValue of the vector value to be extended and reduced.
-  VPValue *getVecOp() const { return getOperand(1); }
-  /// The VPValue of the condition for the block.
-  VPValue *getCondOp() const {
-    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
-  }
   /// The Type after extended.
   Type *getResultType() const { return ResultTy; };
+  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; };
   /// The Opcode of extend instruction.
-  Instruction::CastOps getExtOpcode() const { return ExtOp; };
+  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); };
   /// The CastInst of the extend instruction.
   CastInst *getExtInstr() const { return ExtInstr; };
 };
@@ -2800,12 +2771,7 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
 /// a chain. This recipe is high level abstract which will generate
 /// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
 /// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
-class VPMulAccRecipe : public VPSingleDefRecipe {
-  /// The recurrence decriptor for the reduction in question.
-  const RecurrenceDescriptor &RdxDesc;
-  bool IsOrdered;
-  /// Whether the reduction is conditional.
-  bool IsConditional = false;
+class VPMulAccRecipe : public VPReductionRecipe {
   /// Type after extend.
   Type *ResultType;
   // Note that all extend instruction must have the same opcode in MulAcc.
@@ -2827,32 +2793,29 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
                  Instruction *RedI, Instruction *ExtInstr,
                  Instruction *MulInstr, Instruction::CastOps ExtOp,
                  Instruction *Ext0Instr, Instruction *Ext1Instr,
-                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered,
-                 Type *ResultType)
-      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
+                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+      : VPReductionRecipe(SC, R, RedI,
+                          ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
+                          CondOp, IsOrdered),
         ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
         ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
-    if (CondOp) {
-      IsConditional = true;
-      addOperand(CondOp);
-    }
     IsExtended = true;
     IsOuterExtended = ExtInstr != nullptr;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *MulInstr,
-                 ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
-      : VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
+                 Instruction *RedI, Instruction *MulInstr, VPValue *ChainOp,
+                 VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp,
+                 bool IsOrdered)
+      : VPReductionRecipe(SC, R, RedI,
+                          ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
+                          CondOp, IsOrdered),
         MulInstr(MulInstr) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
-    if (CondOp) {
-      IsConditional = true;
-      addOperand(CondOp);
-    }
     IsExtended = false;
   }
 
@@ -2864,17 +2827,15 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
                        Mul->getUnderlyingInstr(), Ext0->getOpcode(),
                        Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ArrayRef<VPValue *>(
-                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
                        CondOp, IsOrdered, Ext0->getResultType()) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul)
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       ArrayRef<VPValue *>(
-                           {ChainOp, Mul->getOperand(0), Mul->getOperand(1)}),
-                       CondOp, IsOrdered) {}
+                       ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
+                       IsOrdered) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
@@ -2883,8 +2844,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
                        Mul->getUnderlyingInstr(), Ext0->getOpcode(),
                        Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ArrayRef<VPValue *>(
-                           {ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
                        CondOp, IsOrdered, Ext0->getResultType()) {}
 
   ~VPMulAccRecipe() override = default;
@@ -2915,37 +2875,27 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
              VPSlotTracker &SlotTracker) const override;
 #endif
 
-  /// Return the recurrence decriptor for the in-loop reduction.
-  const RecurrenceDescriptor &getRecurrenceDescriptor() const {
-    return RdxDesc;
-  }
-  /// Return true if the in-loop reduction is ordered.
-  bool isOrdered() const { return IsOrdered; };
-  /// Return true if the in-loop reduction is conditional.
-  bool isConditional() const { return IsConditional; };
-  /// The VPValue of the scalar Chain being accumulated.
-  VPValue *getChainOp() const { return getOperand(0); }
   /// The VPValue of the vector value to be extended and reduced.
   VPValue *getVecOp0() const { return getOperand(1); }
   VPValue *getVecOp1() const { return getOperand(2); }
-  /// The VPValue of the condition for the block.
-  VPValue *getCondOp() const {
-    return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
-  }
+
   /// Return the type after inner extended, which must equal to the type of mul
   /// instruction. If the ResultType != recurrenceType, than it must have a
   /// extend recipe after mul recipe.
   Type *getResultType() const { return ResultType; };
+
   /// The opcode of the extend instructions.
   Instruction::CastOps getExtOpcode() const { return ExtOp; };
   /// The underlying instruction for VPWidenRecipe.
   Instruction *getMulInstr() const { return MulInstr; };
+
   /// The underlying Instruction for outer VPWidenCastRecipe.
   CastInst *getExtInstr() const { return ExtInstr; };
   /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt0Instr() const { return Ext0Instr; };
   /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt1Instr() const { return Ext1Instr; };
+
   /// Return if this MulAcc recipe contains extend instructions.
   bool isExtended() const { return IsExtended; };
   /// Return if the operands of mul instruction come from same extend.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index a85a4bc81c2854..c7c50702f90b44 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2221,6 +2221,7 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
 InstructionCost
 VPExtendedReductionRecipe::computeCost(ElementCount VF,
                                        VPCostContext &Ctx) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
   RecurKind RdxKind = RdxDesc.getRecurrenceKind();
   Type *ElementTy = getResultType();
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
@@ -2251,7 +2252,7 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 
   // ExtendedReduction Cost
   InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
-      Opcode, IsZExt, ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
+      Opcode, isZExt(), ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
   // Check if folding ext into ExtendedReduction is profitable.
   if (ExtendedRedCost.isValid() &&
       ExtendedRedCost < ExtendedCost + ReductionCost) {
@@ -2262,6 +2263,7 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 
 InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
   Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
                                : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
@@ -2387,6 +2389,7 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
 
 void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                                       VPSlotTracker &SlotTracker) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
   O << Indent << "EXTENDED-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
@@ -2409,6 +2412,7 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 
 void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
                            VPSlotTracker &SlotTracker) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
   O << Indent << "MULACC-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index d3f0b5c9e5ba6b..5e9c30911e34e5 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -525,8 +525,7 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
   for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
            vp_depth_first_deep(Plan.getEntry()))) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
-      if (isa<VPExtendedReductionRecipe>(&R)) {
-        auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
+      if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
         auto *Ext = new VPWidenCastRecipe(
             ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
             *ExtRed->getExtInstr());
@@ -542,14 +541,14 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
-          Op0 =
-              new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
-                                    MulAcc->getResultType());
+          CastInst *Ext0 = MulAcc->getExt0Instr();
+          Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
+                                      MulAcc->getResultType(), *Ext0);
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           if (!MulAcc->isSameExtend()) {
-            Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
-                                        MulAcc->getVecOp1(),
-                                        MulAcc->getResultType());
+            CastInst *Ext1 = MulAcc->getExt1Instr();
+            Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
+                                        MulAcc->getResultType(), *Ext1);
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           } else {
             Op1 = Op0;
@@ -566,8 +565,9 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         // Outer extend.
         if (auto *OuterExtInstr = MulAcc->getExtInstr())
           VecOp = new VPWidenCastRecipe(
-              MulAcc->getExtOpcode(), Mul,
-              MulAcc->getRecurrenceDescriptor().getRecurrenceType());
+              OuterExtInstr->getOpcode(), Mul,
+              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
+              *OuterExtInstr);
         else
           VecOp = Mul;
         auto *Red = new VPReductionRecipe(

>From a0b2f30cd6596e687e68e7ffe5a3407e35a77f05 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 11 Nov 2024 23:17:46 -0800
Subject: [PATCH 09/17] Refactor! Add comments and refine new recipes.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 20 +++---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 72 +++++++++----------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 28 ++++----
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 35 ++++++---
 4 files changed, 85 insertions(+), 70 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index b59fe696c393ac..5cee4b7e2ce70e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9397,17 +9397,18 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      auto TryToMatchMulAcc = [&]() -> VPSingleDefRecipe * {
+      auto TryToMatchMulAcc = [&]() -> VPReductionRecipe * {
         VPValue *A, *B;
         if (RdxDesc.getOpcode() != Instruction::Add)
           return nullptr;
-        // reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
+        // Try to match reduce.add(mul(...))
         if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
             !VecOp->hasMoreThanOneUniqueUser()) {
           VPWidenCastRecipe *RecipeA =
               dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
           VPWidenCastRecipe *RecipeB =
               dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+          // Matched reduce.add(mul(ext, ext))
           if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
               match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
               (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
@@ -9417,23 +9418,23 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
                 cast<VPWidenRecipe>(VecOp->getDefiningRecipe()), RecipeA,
                 RecipeB);
           } else {
-            // Matched reduce.add(mul(...))
+            // Matched reduce.add(mul)
             return new VPMulAccRecipe(
                 RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                 CM.useOrderedReductions(RdxDesc),
                 cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
           }
           // Matched reduce.add(ext(mul(ext(A), ext(B))))
-          // Note that 3 extend instructions must have same opcode or A == B
+          // Note that all extend instructions must have same opcode or A == B
           // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
         } else if (match(VecOp,
                          m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
                                             m_ZExtOrSExt(m_VPValue())))) &&
                    !VecOp->hasMoreThanOneUniqueUser()) {
           VPWidenCastRecipe *Ext =
-              dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+              cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
           VPWidenRecipe *Mul =
-              dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+              cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext0 =
               cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
           VPWidenCastRecipe *Ext1 =
@@ -9449,8 +9450,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         }
         return nullptr;
       };
-      auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
+
+      auto TryToMatchExtendedReduction = [&]() -> VPReductionRecipe * {
         VPValue *A;
+        // Matched reduce(ext)).
         if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
           return new VPExtendedReductionRecipe(
               RdxDesc, CurrentLinkI, PreviousLink,
@@ -9459,7 +9462,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         }
         return nullptr;
       };
-      VPSingleDefRecipe *RedRecipe;
+
+      VPReductionRecipe *RedRecipe;
       if (auto *MulAcc = TryToMatchMulAcc())
         RedRecipe = MulAcc;
       else if (auto *ExtendedRed = TryToMatchExtendedReduction())
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 2e8786cd569205..ab9e6f0393602f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2758,12 +2758,12 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 #endif
 
   /// The Type after extended.
-  Type *getResultType() const { return ResultTy; };
-  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; };
+  Type *getResultType() const { return ResultTy; }
+  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
   /// The Opcode of extend instruction.
-  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); };
+  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
   /// The CastInst of the extend instruction.
-  CastInst *getExtInstr() const { return ExtInstr; };
+  CastInst *getExtInstr() const { return ExtInstr; }
 };
 
 /// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2774,8 +2774,6 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 class VPMulAccRecipe : public VPReductionRecipe {
   /// Type after extend.
   Type *ResultType;
-  // Note that all extend instruction must have the same opcode in MulAcc.
-  Instruction::CastOps ExtOp;
 
   /// reduce.add(ext(mul(ext0(), ext1())))
   Instruction *MulInstr;
@@ -2783,28 +2781,21 @@ class VPMulAccRecipe : public VPReductionRecipe {
   CastInst *Ext0Instr = nullptr;
   CastInst *Ext1Instr = nullptr;
 
-  /// Is this MulAcc recipe contains extend recipes?
-  bool IsExtended;
-  /// Is this reciep contains outer extend instuction?
-  bool IsOuterExtended = false;
-
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
                  Instruction *RedI, Instruction *ExtInstr,
-                 Instruction *MulInstr, Instruction::CastOps ExtOp,
-                 Instruction *Ext0Instr, Instruction *Ext1Instr,
-                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
-                 VPValue *CondOp, bool IsOrdered, Type *ResultType)
+                 Instruction *MulInstr, Instruction *Ext0Instr,
+                 Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
+                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
+                 Type *ResultType)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered),
-        ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
+        ResultType(ResultType), MulInstr(MulInstr),
         ExtInstr(cast_if_present<CastInst>(ExtInstr)),
         Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
-    IsExtended = true;
-    IsOuterExtended = ExtInstr != nullptr;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2816,7 +2807,6 @@ class VPMulAccRecipe : public VPReductionRecipe {
                           CondOp, IsOrdered),
         MulInstr(MulInstr) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
-    IsExtended = false;
   }
 
 public:
@@ -2825,10 +2815,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
-                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
-                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered, Ext0->getResultType()) {}
+                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered,
+                       Ext0->getResultType()) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
@@ -2842,10 +2832,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
                  VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
       : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
-                       Mul->getUnderlyingInstr(), Ext0->getOpcode(),
-                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered, Ext0->getResultType()) {}
+                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered,
+                       Ext0->getResultType()) {}
 
   ~VPMulAccRecipe() override = default;
 
@@ -2882,24 +2872,34 @@ class VPMulAccRecipe : public VPReductionRecipe {
   /// Return the type after inner extended, which must equal to the type of mul
   /// instruction. If the ResultType != recurrenceType, than it must have a
   /// extend recipe after mul recipe.
-  Type *getResultType() const { return ResultType; };
+  Type *getResultType() const { return ResultType; }
 
-  /// The opcode of the extend instructions.
-  Instruction::CastOps getExtOpcode() const { return ExtOp; };
   /// The underlying instruction for VPWidenRecipe.
-  Instruction *getMulInstr() const { return MulInstr; };
+  Instruction *getMulInstr() const { return MulInstr; }
 
   /// The underlying Instruction for outer VPWidenCastRecipe.
-  CastInst *getExtInstr() const { return ExtInstr; };
+  CastInst *getExtInstr() const { return ExtInstr; }
   /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt0Instr() const { return Ext0Instr; };
+  CastInst *getExt0Instr() const { return Ext0Instr; }
   /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt1Instr() const { return Ext1Instr; };
+  CastInst *getExt1Instr() const { return Ext1Instr; }
 
   /// Return if this MulAcc recipe contains extend instructions.
-  bool isExtended() const { return IsExtended; };
+  bool isExtended() const { return Ext0Instr && Ext1Instr; }
   /// Return if the operands of mul instruction come from same extend.
-  bool isSameExtend() const { return Ext0Instr == Ext1Instr; };
+  bool isSameExtend() const { return Ext0Instr == Ext1Instr; }
+  /// Return if the MulAcc recipes contains extend after mul.
+  bool isOuterExtended() const { return ExtInstr != nullptr; }
+  /// Return if the extend opcode is ZExt.
+  bool isZExt() const {
+    if (!isExtended())
+      return true;
+    // reduce.add(sext(mul(zext(A), zext(A)))) can be transform to
+    // reduce.add(zext(mul(sext(A), sext(A))))
+    if (ExtInstr && ExtInstr->getOpcode() != Ext0Instr->getOpcode())
+      return true;
+    return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
+  }
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index c7c50702f90b44..5aef2089ac37d8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2264,8 +2264,8 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
   const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
-  Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
-                               : Ctx.Types.inferScalarType(getVecOp0());
+  Type *ElementTy = isExtended() ? RdxDesc.getRecurrenceType()
+                                 : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
@@ -2281,20 +2281,19 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
 
   // Extended cost
   InstructionCost ExtendedCost = 0;
-  if (IsExtended) {
+  if (isExtended()) {
     auto *SrcTy = cast<VectorType>(
         ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
     auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
     TTI::CastContextHint CCH0 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    // Arm TTI will use the underlying instruction to determine the cost.
     ExtendedCost = Ctx.TTI.getCastInstrCost(
-        ExtOp, DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        Ext0Instr->getOpcode(), DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt0Instr()));
     TTI::CastContextHint CCH1 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
     ExtendedCost += Ctx.TTI.getCastInstrCost(
-        ExtOp, DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        Ext1Instr->getOpcode(), DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt1Instr()));
   }
 
@@ -2302,7 +2301,7 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   InstructionCost MulCost;
   SmallVector<const Value *, 4> Operands;
   Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
-  if (IsExtended)
+  if (isExtended())
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
@@ -2329,9 +2328,8 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   // MulAccReduction Cost
   VectorType *SrcVecTy =
       cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
-  InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
-      getExtOpcode() == Instruction::CastOps::ZExt, ElementTy, SrcVecTy,
-      CostKind);
+  InstructionCost MulAccCost =
+      Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
 
   // Check if folding ext into ExtendedReduction is profitable.
   if (MulAccCost.isValid() &&
@@ -2420,26 +2418,26 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " + ";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  if (IsOuterExtended)
+  if (isOuterExtended())
     O << " (";
   O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
   O << "mul ";
-  if (IsExtended)
+  if (isExtended())
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
-  if (IsExtended)
+  if (isExtended())
     O << " extended to " << *getResultType() << "), (";
   else
     O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
-  if (IsExtended)
+  if (isExtended())
     O << " extended to " << *getResultType() << ")";
   if (isConditional()) {
     O << ", ";
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
-  if (IsOuterExtended)
+  if (isOuterExtended())
     O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
   if (RdxDesc.IntermediateStore)
     O << " (with final reduction value stored in invariant address sank "
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 5e9c30911e34e5..53e080b5bb01d6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -526,9 +526,12 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
            vp_depth_first_deep(Plan.getEntry()))) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
       if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
+        // Genearte VPWidenCastRecipe.
         auto *Ext = new VPWidenCastRecipe(
             ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
             *ExtRed->getExtInstr());
+
+        // Generate VPreductionRecipe.
         auto *Red = new VPReductionRecipe(
             ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
             ExtRed->getChainOp(), Ext, ExtRed->getCondOp(),
@@ -539,45 +542,55 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         ExtRed->eraseFromParent();
       } else if (isa<VPMulAccRecipe>(&R)) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
+
+        // Generate inner VPWidenCastRecipes if necessary.
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
           CastInst *Ext0 = MulAcc->getExt0Instr();
           Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
                                       MulAcc->getResultType(), *Ext0);
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
-          if (!MulAcc->isSameExtend()) {
+          // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
+          // VPWidenCastRecipe.
+          if (MulAcc->isSameExtend()) {
+            Op1 = Op0;
+          } else {
             CastInst *Ext1 = MulAcc->getExt1Instr();
             Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
                                         MulAcc->getResultType(), *Ext1);
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
-          } else {
-            Op1 = Op0;
           }
+          // Not contains extend instruction in this MulAccRecipe.
         } else {
           Op0 = MulAcc->getVecOp0();
           Op1 = MulAcc->getVecOp1();
         }
+
+        // Generate VPWidenRecipe.
         VPSingleDefRecipe *VecOp;
-        Instruction *MulInstr = MulAcc->getMulInstr();
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
-        auto *Mul = new VPWidenRecipe(*MulInstr,
+        auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
                                       make_range(MulOps.begin(), MulOps.end()));
-        // Outer extend.
-        if (auto *OuterExtInstr = MulAcc->getExtInstr())
+        Mul->insertBefore(MulAcc);
+
+        // Generate outer VPWidenCastRecipe if necessary.
+        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
           VecOp = new VPWidenCastRecipe(
               OuterExtInstr->getOpcode(), Mul,
               MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
               *OuterExtInstr);
-        else
+          VecOp->insertBefore(MulAcc);
+        } else {
           VecOp = Mul;
+        }
+
+        // Generate VPReductionRecipe.
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
             MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
             MulAcc->isOrdered());
-        Mul->insertBefore(MulAcc);
-        if (VecOp != Mul)
-          VecOp->insertBefore(MulAcc);
         Red->insertBefore(MulAcc);
+
         MulAcc->replaceAllUsesWith(Red);
         MulAcc->eraseFromParent();
       }

>From 46928bdc43fe5fc41e4cdaf8e5698967cd076172 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 14 Nov 2024 04:23:02 -0800
Subject: [PATCH 10/17] Remove underying instruction dependency.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  24 ++--
 llvm/lib/Transforms/Vectorize/VPlan.h         | 116 +++++++-----------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  31 ++---
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  33 ++---
 .../LoopVectorize/ARM/mve-reductions.ll       |  89 +++++++++-----
 5 files changed, 137 insertions(+), 156 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 5cee4b7e2ce70e..9c22cbcb055e72 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7397,18 +7397,18 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       if (Instruction *UI = GetInstructionForCost(&R))
         SeenInstrs.insert(UI);
       // VPExtendedReductionRecipe contains a folded extend instruction.
-      if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
-        SeenInstrs.insert(ExtendedRed->getExtInstr());
-      // VPMulAccRecipe constians a mul and otional extend instructions.
-      else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
-        SeenInstrs.insert(MulAcc->getMulInstr());
-        if (MulAcc->isExtended()) {
-          SeenInstrs.insert(MulAcc->getExt0Instr());
-          SeenInstrs.insert(MulAcc->getExt1Instr());
-          if (auto *Ext = MulAcc->getExtInstr())
-            SeenInstrs.insert(Ext);
-        }
-      }
+      // if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
+      //   SeenInstrs.insert(ExtendedRed->getExtInstr());
+      // // VPMulAccRecipe constians a mul and otional extend instructions.
+      // else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+      //   SeenInstrs.insert(MulAcc->getMulInstr());
+      //   if (MulAcc->isExtended()) {
+      //     SeenInstrs.insert(MulAcc->getExt0Instr());
+      //     SeenInstrs.insert(MulAcc->getExt1Instr());
+      //     if (auto *Ext = MulAcc->getExtInstr())
+      //       SeenInstrs.insert(Ext);
+      //   }
+      // }
     }
   }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index ab9e6f0393602f..23a5c2756b6f92 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1448,11 +1448,20 @@ class VPWidenRecipe : public VPRecipeWithIRFlags {
                 iterator_range<IterT> Operands)
       : VPRecipeWithIRFlags(VPDefOpcode, Operands, I), Opcode(I.getOpcode()) {}
 
+  template <typename IterT>
+  VPWidenRecipe(unsigned VPDefOpcode, unsigned InstrOpcode,
+                iterator_range<IterT> Operands)
+      : VPRecipeWithIRFlags(VPDefOpcode, Operands), Opcode(InstrOpcode) {}
+
 public:
   template <typename IterT>
   VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands)
       : VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
 
+  template <typename IterT>
+  VPWidenRecipe(unsigned Opcode, iterator_range<IterT> Operands)
+      : VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands) {}
+
   ~VPWidenRecipe() override = default;
 
   VPWidenRecipe *clone() override {
@@ -2706,26 +2715,25 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 class VPExtendedReductionRecipe : public VPReductionRecipe {
   /// Type after extend.
   Type *ResultTy;
-  CastInst *ExtInstr;
+  Instruction::CastOps ExtOp;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
                             const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, CastInst *ExtInstr,
-                            VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
-                            bool IsOrdered, Type *ResultTy)
+                            Instruction::CastOps ExtOp, VPValue *ChainOp,
+                            VPValue *VecOp, VPValue *CondOp, bool IsOrdered,
+                            Type *ResultTy)
       : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
                           CondOp, IsOrdered),
-        ResultTy(ResultTy), ExtInstr(ExtInstr) {}
+        ResultTy(ResultTy), ExtOp(ExtOp) {}
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                             VPValue *ChainOp, VPWidenCastRecipe *Ext,
                             VPValue *CondOp, bool IsOrdered)
-      : VPExtendedReductionRecipe(
-            VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
-            cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
-            Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
+      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
+                                  Ext->getOpcode(), ChainOp, Ext->getOperand(0),
+                                  CondOp, IsOrdered, Ext->getResultType()) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2761,9 +2769,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
   Type *getResultType() const { return ResultTy; }
   bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
   /// The Opcode of extend instruction.
-  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
-  /// The CastInst of the extend instruction.
-  CastInst *getExtInstr() const { return ExtInstr; }
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
 };
 
 /// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2772,70 +2778,52 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
 /// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPReductionRecipe {
-  /// Type after extend.
-  Type *ResultType;
-
-  /// reduce.add(ext(mul(ext0(), ext1())))
-  Instruction *MulInstr;
-  CastInst *ExtInstr = nullptr;
-  CastInst *Ext0Instr = nullptr;
-  CastInst *Ext1Instr = nullptr;
+  Instruction::CastOps ExtOp;
+  bool IsExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *ExtInstr,
-                 Instruction *MulInstr, Instruction *Ext0Instr,
-                 Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
-                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
-                 Type *ResultType)
+                 Instruction *RedI, Instruction::CastOps ExtOp,
+                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
+                 VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered),
-        ResultType(ResultType), MulInstr(MulInstr),
-        ExtInstr(cast_if_present<CastInst>(ExtInstr)),
-        Ext0Instr(cast<CastInst>(Ext0Instr)),
-        Ext1Instr(cast<CastInst>(Ext1Instr)) {
-    assert(MulInstr->getOpcode() == Instruction::Mul);
+        ExtOp(ExtOp) {
+    IsExtended = true;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *MulInstr, VPValue *ChainOp,
-                 VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp,
-                 bool IsOrdered)
+                 Instruction *RedI, VPValue *ChainOp, VPValue *VecOp0,
+                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
-                          CondOp, IsOrdered),
-        MulInstr(MulInstr) {
-    assert(MulInstr->getOpcode() == Instruction::Mul);
-  }
+                          CondOp, IsOrdered) {}
 
 public:
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
-                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
-                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
-                       Ext1->getOperand(0), CondOp, IsOrdered,
-                       Ext0->getResultType()) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI,
+                       cast<CastInst>(Ext0->getUnderlyingInstr())->getOpcode(),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
+                       CondOp, IsOrdered) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
-                       IsOrdered) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, ChainOp, Mul->getOperand(0),
+                       Mul->getOperand(1), CondOp, IsOrdered) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
                  VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
-                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
-                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
-                       Ext1->getOperand(0), CondOp, IsOrdered,
-                       Ext0->getResultType()) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI,
+                       cast<CastInst>(Ext0->getUnderlyingInstr())->getOpcode(),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
+                       CondOp, IsOrdered) {}
 
   ~VPMulAccRecipe() override = default;
 
@@ -2869,37 +2857,17 @@ class VPMulAccRecipe : public VPReductionRecipe {
   VPValue *getVecOp0() const { return getOperand(1); }
   VPValue *getVecOp1() const { return getOperand(2); }
 
-  /// Return the type after inner extended, which must equal to the type of mul
-  /// instruction. If the ResultType != recurrenceType, than it must have a
-  /// extend recipe after mul recipe.
-  Type *getResultType() const { return ResultType; }
-
-  /// The underlying instruction for VPWidenRecipe.
-  Instruction *getMulInstr() const { return MulInstr; }
-
-  /// The underlying Instruction for outer VPWidenCastRecipe.
-  CastInst *getExtInstr() const { return ExtInstr; }
-  /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt0Instr() const { return Ext0Instr; }
-  /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt1Instr() const { return Ext1Instr; }
-
   /// Return if this MulAcc recipe contains extend instructions.
-  bool isExtended() const { return Ext0Instr && Ext1Instr; }
+  bool isExtended() const { return IsExtended; }
   /// Return if the operands of mul instruction come from same extend.
-  bool isSameExtend() const { return Ext0Instr == Ext1Instr; }
-  /// Return if the MulAcc recipes contains extend after mul.
-  bool isOuterExtended() const { return ExtInstr != nullptr; }
+  bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
   /// Return if the extend opcode is ZExt.
   bool isZExt() const {
     if (!isExtended())
       return true;
-    // reduce.add(sext(mul(zext(A), zext(A)))) can be transform to
-    // reduce.add(zext(mul(sext(A), sext(A))))
-    if (ExtInstr && ExtInstr->getOpcode() != Ext0Instr->getOpcode())
-      return true;
-    return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
+    return ExtOp == Instruction::CastOps::ZExt;
   }
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 5aef2089ac37d8..edcc229cbfd5a8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2267,6 +2267,8 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   Type *ElementTy = isExtended() ? RdxDesc.getRecurrenceType()
                                  : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  auto *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
@@ -2284,29 +2286,25 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   if (isExtended()) {
     auto *SrcTy = cast<VectorType>(
         ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
-    auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
     TTI::CastContextHint CCH0 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost = Ctx.TTI.getCastInstrCost(
-        Ext0Instr->getOpcode(), DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
-        dyn_cast_if_present<Instruction>(getExt0Instr()));
+    ExtendedCost = Ctx.TTI.getCastInstrCost(getExtOpcode(), VectorTy, SrcVecTy,
+                                            CCH0, TTI::TCK_RecipThroughput);
     TTI::CastContextHint CCH1 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost += Ctx.TTI.getCastInstrCost(
-        Ext1Instr->getOpcode(), DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
-        dyn_cast_if_present<Instruction>(getExt1Instr()));
+    ExtendedCost += Ctx.TTI.getCastInstrCost(getExtOpcode(), VectorTy, SrcVecTy,
+                                             CCH1, TTI::TCK_RecipThroughput);
   }
 
   // Mul cost
   InstructionCost MulCost;
   SmallVector<const Value *, 4> Operands;
-  Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
   if (isExtended())
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        Operands, MulInstr, &Ctx.TLI);
+        Operands, nullptr, &Ctx.TLI);
   else {
     VPValue *RHS = getVecOp1();
     // Certain instructions can be cheaper to vectorize if they have a constant
@@ -2319,15 +2317,15 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
     if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
         RHS->isDefinedOutsideLoopRegions())
       RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
+    Operands.append(
+        {getVecOp0()->getUnderlyingValue(), RHS->getUnderlyingValue()});
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        RHSInfo, Operands, MulInstr, &Ctx.TLI);
+        RHSInfo, Operands, nullptr, &Ctx.TLI);
   }
 
   // MulAccReduction Cost
-  VectorType *SrcVecTy =
-      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
   InstructionCost MulAccCost =
       Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
 
@@ -2411,6 +2409,7 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
                            VPSlotTracker &SlotTracker) const {
   const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+  Type *ElementTy = RdxDesc.getRecurrenceType();
   O << Indent << "MULACC-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
@@ -2418,27 +2417,23 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " + ";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  if (isOuterExtended())
-    O << " (";
   O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
   O << "mul ";
   if (isExtended())
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
   if (isExtended())
-    O << " extended to " << *getResultType() << "), (";
+    O << " extended to " << *ElementTy << "), (";
   else
     O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
   if (isExtended())
-    O << " extended to " << *getResultType() << ")";
+    O << " extended to " << *ElementTy << ")";
   if (isConditional()) {
     O << ", ";
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
-  if (isOuterExtended())
-    O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
   if (RdxDesc.IntermediateStore)
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 53e080b5bb01d6..50afdeb3aa78c8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -527,9 +527,9 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
       if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
         // Genearte VPWidenCastRecipe.
-        auto *Ext = new VPWidenCastRecipe(
-            ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
-            *ExtRed->getExtInstr());
+        auto *Ext =
+            new VPWidenCastRecipe(ExtRed->getExtOpcode(), ExtRed->getVecOp(),
+                                  ExtRed->getResultType());
 
         // Generate VPreductionRecipe.
         auto *Red = new VPReductionRecipe(
@@ -542,22 +542,21 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         ExtRed->eraseFromParent();
       } else if (isa<VPMulAccRecipe>(&R)) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
+        Type *RedType = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
 
         // Generate inner VPWidenCastRecipes if necessary.
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
-          CastInst *Ext0 = MulAcc->getExt0Instr();
-          Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
-                                      MulAcc->getResultType(), *Ext0);
+          Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+                                      MulAcc->getVecOp0(), RedType);
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
           // VPWidenCastRecipe.
           if (MulAcc->isSameExtend()) {
             Op1 = Op0;
           } else {
-            CastInst *Ext1 = MulAcc->getExt1Instr();
-            Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
-                                        MulAcc->getResultType(), *Ext1);
+            Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+                                        MulAcc->getVecOp1(), RedType);
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           }
           // Not contains extend instruction in this MulAccRecipe.
@@ -567,27 +566,15 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         }
 
         // Generate VPWidenRecipe.
-        VPSingleDefRecipe *VecOp;
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
-        auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
+        auto *Mul = new VPWidenRecipe(Instruction::Mul,
                                       make_range(MulOps.begin(), MulOps.end()));
         Mul->insertBefore(MulAcc);
 
-        // Generate outer VPWidenCastRecipe if necessary.
-        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
-          VecOp = new VPWidenCastRecipe(
-              OuterExtInstr->getOpcode(), Mul,
-              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
-              *OuterExtInstr);
-          VecOp->insertBefore(MulAcc);
-        } else {
-          VecOp = Mul;
-        }
-
         // Generate VPReductionRecipe.
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
-            MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
+            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
             MulAcc->isOrdered());
         Red->insertBefore(MulAcc);
 
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index b138666ba92f9b..f99f33365fad0f 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1670,24 +1670,55 @@ define i64 @test_std_q31(ptr %x, i32 %n) #0 {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP11:%.*]] = icmp sgt i32 [[N:%.*]], 0
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[CMP11]])
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp samesign ult i32 [[N]], 4
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N]], 2147483644
 ; CHECK-NEXT:    br label [[FOR_BODY1:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT:    [[VEC_PHI1:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP10]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr <4 x i32> [[WIDE_LOAD]], splat (i32 8)
+; CHECK-NEXT:    [[TMP2:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP2]])
+; CHECK-NEXT:    [[TMP4]] = add i64 [[TMP3]], [[VEC_PHI]]
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <4 x i64> [[TMP5]], [[TMP5]]
+; CHECK-NEXT:    [[TMP7:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP6]])
+; CHECK-NEXT:    [[TMP8]] = add i64 [[TMP7]], [[VEC_PHI1]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[FOR_BODY1]], !llvm.loop [[LOOP37:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK:       scalar.ph:
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i64 [ [[TMP4]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX2:%.*]] = phi i64 [ [[TMP8]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.cond.cleanup:
-; CHECK-NEXT:    [[DIV:%.*]] = sdiv i64 [[ADD3:%.*]], [[ADD:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = phi i64 [ [[ADD1:%.*]], [[FOR_BODY]] ], [ [[TMP4]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[ADD3:%.*]] = phi i64 [ [[ADD5:%.*]], [[FOR_BODY]] ], [ [[TMP8]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[DIV:%.*]] = sdiv i64 [[ADD3]], [[ADD]]
 ; CHECK-NEXT:    ret i64 [[DIV]]
 ; CHECK:       for.body:
-; CHECK-NEXT:    [[S_014:%.*]] = phi i64 [ [[ADD]], [[FOR_BODY1]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT:    [[T_012:%.*]] = phi i64 [ [[ADD3]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[I_013]]
+; CHECK-NEXT:    [[S_014:%.*]] = phi i64 [ [[ADD1]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
+; CHECK-NEXT:    [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
+; CHECK-NEXT:    [[T_012:%.*]] = phi i64 [ [[ADD5]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX2]], [[SCALAR_PH]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X]], i32 [[I_013]]
 ; CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
 ; CHECK-NEXT:    [[SHR:%.*]] = ashr i32 [[TMP0]], 8
 ; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[SHR]] to i64
-; CHECK-NEXT:    [[ADD]] = add nsw i64 [[S_014]], [[CONV]]
+; CHECK-NEXT:    [[ADD1]] = add nsw i64 [[S_014]], [[CONV]]
 ; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV]], [[CONV]]
-; CHECK-NEXT:    [[ADD3]] = add nuw nsw i64 [[MUL]], [[T_012]]
+; CHECK-NEXT:    [[ADD5]] = add nuw nsw i64 [[MUL]], [[T_012]]
 ; CHECK-NEXT:    [[ADD4]] = add nuw nsw i32 [[I_013]], 1
 ; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i32 [[ADD4]], [[N]]
-; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY1]]
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP38:![0-9]+]]
 ;
 entry:
   %cmp11 = icmp sgt i32 %n, 0
@@ -1723,10 +1754,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[TMP0:%.*]] = add nsw i32 [[N]], -1
 ; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -4
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -8
 ; CHECK-NEXT:    [[IND_END:%.*]] = shl i32 [[N_VEC]], 1
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
@@ -1734,28 +1765,28 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32>
+; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]]
-; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64>
-; CHECK-NEXT:    [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32>
-; CHECK-NEXT:    [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]]
-; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64>
-; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]])
+; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i32>
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <8 x i32> [[TMP5]], [[TMP6]]
+; CHECK-NEXT:    [[TMP8:%.*]] = sext <8 x i32> [[TMP7]] to <8 x i64>
+; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP8]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]])
+; CHECK-NEXT:    [[TMP11:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i32>
+; CHECK-NEXT:    [[TMP12:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP13:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP12]]
+; CHECK-NEXT:    [[TMP14:%.*]] = sext <8 x i32> [[TMP13]] to <8 x i64>
+; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP14]])
 ; CHECK-NEXT:    [[TMP16]] = add i64 [[TMP15]], [[TMP10]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]]
+; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP39:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[TMP2]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP]], label [[SCALAR_PH]]
@@ -1790,7 +1821,7 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[ADD12]] = add nsw i64 [[ADD]], [[CONV11]]
 ; CHECK-NEXT:    [[ADD13]] = add nuw nsw i32 [[I_025]], 2
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[ADD13]], [[N]]
-; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP38:![0-9]+]]
+; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP40:![0-9]+]]
 ;
 entry:
   %cmp23 = icmp sgt i32 %n, 0

>From 35abf19fdb78efe2ded6cbe283a582f78aa068e6 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 14 Nov 2024 05:21:18 -0800
Subject: [PATCH 11/17] Revert "Remove underying instruction dependency."

This reverts commit e4e510802f44c9936f5a892eb16ba2b19651ee15.

We need underlying instruction for accurate TTI query and
the metadata from to attach on the generating vector instructions.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  24 ++--
 llvm/lib/Transforms/Vectorize/VPlan.h         | 116 +++++++++++-------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  31 +++--
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  33 +++--
 .../LoopVectorize/ARM/mve-reductions.ll       |  89 +++++---------
 5 files changed, 156 insertions(+), 137 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 9c22cbcb055e72..5cee4b7e2ce70e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7397,18 +7397,18 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       if (Instruction *UI = GetInstructionForCost(&R))
         SeenInstrs.insert(UI);
       // VPExtendedReductionRecipe contains a folded extend instruction.
-      // if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
-      //   SeenInstrs.insert(ExtendedRed->getExtInstr());
-      // // VPMulAccRecipe constians a mul and otional extend instructions.
-      // else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
-      //   SeenInstrs.insert(MulAcc->getMulInstr());
-      //   if (MulAcc->isExtended()) {
-      //     SeenInstrs.insert(MulAcc->getExt0Instr());
-      //     SeenInstrs.insert(MulAcc->getExt1Instr());
-      //     if (auto *Ext = MulAcc->getExtInstr())
-      //       SeenInstrs.insert(Ext);
-      //   }
-      // }
+      if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
+        SeenInstrs.insert(ExtendedRed->getExtInstr());
+      // VPMulAccRecipe constians a mul and otional extend instructions.
+      else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+        SeenInstrs.insert(MulAcc->getMulInstr());
+        if (MulAcc->isExtended()) {
+          SeenInstrs.insert(MulAcc->getExt0Instr());
+          SeenInstrs.insert(MulAcc->getExt1Instr());
+          if (auto *Ext = MulAcc->getExtInstr())
+            SeenInstrs.insert(Ext);
+        }
+      }
     }
   }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 23a5c2756b6f92..ab9e6f0393602f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1448,20 +1448,11 @@ class VPWidenRecipe : public VPRecipeWithIRFlags {
                 iterator_range<IterT> Operands)
       : VPRecipeWithIRFlags(VPDefOpcode, Operands, I), Opcode(I.getOpcode()) {}
 
-  template <typename IterT>
-  VPWidenRecipe(unsigned VPDefOpcode, unsigned InstrOpcode,
-                iterator_range<IterT> Operands)
-      : VPRecipeWithIRFlags(VPDefOpcode, Operands), Opcode(InstrOpcode) {}
-
 public:
   template <typename IterT>
   VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands)
       : VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
 
-  template <typename IterT>
-  VPWidenRecipe(unsigned Opcode, iterator_range<IterT> Operands)
-      : VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands) {}
-
   ~VPWidenRecipe() override = default;
 
   VPWidenRecipe *clone() override {
@@ -2715,25 +2706,26 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 class VPExtendedReductionRecipe : public VPReductionRecipe {
   /// Type after extend.
   Type *ResultTy;
-  Instruction::CastOps ExtOp;
+  CastInst *ExtInstr;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
                             const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, VPValue *ChainOp,
-                            VPValue *VecOp, VPValue *CondOp, bool IsOrdered,
-                            Type *ResultTy)
+                            Instruction::CastOps ExtOp, CastInst *ExtInstr,
+                            VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
+                            bool IsOrdered, Type *ResultTy)
       : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
                           CondOp, IsOrdered),
-        ResultTy(ResultTy), ExtOp(ExtOp) {}
+        ResultTy(ResultTy), ExtInstr(ExtInstr) {}
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                             VPValue *ChainOp, VPWidenCastRecipe *Ext,
                             VPValue *CondOp, bool IsOrdered)
-      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
-                                  Ext->getOpcode(), ChainOp, Ext->getOperand(0),
-                                  CondOp, IsOrdered, Ext->getResultType()) {}
+      : VPExtendedReductionRecipe(
+            VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
+            cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
+            Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2769,7 +2761,9 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
   Type *getResultType() const { return ResultTy; }
   bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
   /// The Opcode of extend instruction.
-  Instruction::CastOps getExtOpcode() const { return ExtOp; }
+  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
+  /// The CastInst of the extend instruction.
+  CastInst *getExtInstr() const { return ExtInstr; }
 };
 
 /// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2778,52 +2772,70 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
 /// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPReductionRecipe {
-  Instruction::CastOps ExtOp;
-  bool IsExtended = false;
+  /// Type after extend.
+  Type *ResultType;
+
+  /// reduce.add(ext(mul(ext0(), ext1())))
+  Instruction *MulInstr;
+  CastInst *ExtInstr = nullptr;
+  CastInst *Ext0Instr = nullptr;
+  CastInst *Ext1Instr = nullptr;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction::CastOps ExtOp,
-                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
-                 VPValue *CondOp, bool IsOrdered)
+                 Instruction *RedI, Instruction *ExtInstr,
+                 Instruction *MulInstr, Instruction *Ext0Instr,
+                 Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
+                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
+                 Type *ResultType)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered),
-        ExtOp(ExtOp) {
-    IsExtended = true;
+        ResultType(ResultType), MulInstr(MulInstr),
+        ExtInstr(cast_if_present<CastInst>(ExtInstr)),
+        Ext0Instr(cast<CastInst>(Ext0Instr)),
+        Ext1Instr(cast<CastInst>(Ext1Instr)) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, VPValue *ChainOp, VPValue *VecOp0,
-                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered)
+                 Instruction *RedI, Instruction *MulInstr, VPValue *ChainOp,
+                 VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp,
+                 bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
-                          CondOp, IsOrdered) {}
+                          CondOp, IsOrdered),
+        MulInstr(MulInstr) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
+  }
 
 public:
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI,
-                       cast<CastInst>(Ext0->getUnderlyingInstr())->getOpcode(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
+                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered,
+                       Ext0->getResultType()) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, ChainOp, Mul->getOperand(0),
-                       Mul->getOperand(1), CondOp, IsOrdered) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
+                       IsOrdered) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
                  VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI,
-                       cast<CastInst>(Ext0->getUnderlyingInstr())->getOpcode(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
+                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered,
+                       Ext0->getResultType()) {}
 
   ~VPMulAccRecipe() override = default;
 
@@ -2857,17 +2869,37 @@ class VPMulAccRecipe : public VPReductionRecipe {
   VPValue *getVecOp0() const { return getOperand(1); }
   VPValue *getVecOp1() const { return getOperand(2); }
 
+  /// Return the type after inner extended, which must equal to the type of mul
+  /// instruction. If the ResultType != recurrenceType, than it must have a
+  /// extend recipe after mul recipe.
+  Type *getResultType() const { return ResultType; }
+
+  /// The underlying instruction for VPWidenRecipe.
+  Instruction *getMulInstr() const { return MulInstr; }
+
+  /// The underlying Instruction for outer VPWidenCastRecipe.
+  CastInst *getExtInstr() const { return ExtInstr; }
+  /// The underlying Instruction for inner VPWidenCastRecipe.
+  CastInst *getExt0Instr() const { return Ext0Instr; }
+  /// The underlying Instruction for inner VPWidenCastRecipe.
+  CastInst *getExt1Instr() const { return Ext1Instr; }
+
   /// Return if this MulAcc recipe contains extend instructions.
-  bool isExtended() const { return IsExtended; }
+  bool isExtended() const { return Ext0Instr && Ext1Instr; }
   /// Return if the operands of mul instruction come from same extend.
-  bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
+  bool isSameExtend() const { return Ext0Instr == Ext1Instr; }
+  /// Return if the MulAcc recipes contains extend after mul.
+  bool isOuterExtended() const { return ExtInstr != nullptr; }
   /// Return if the extend opcode is ZExt.
   bool isZExt() const {
     if (!isExtended())
       return true;
-    return ExtOp == Instruction::CastOps::ZExt;
+    // reduce.add(sext(mul(zext(A), zext(A)))) can be transform to
+    // reduce.add(zext(mul(sext(A), sext(A))))
+    if (ExtInstr && ExtInstr->getOpcode() != Ext0Instr->getOpcode())
+      return true;
+    return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
   }
-  Instruction::CastOps getExtOpcode() const { return ExtOp; }
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index edcc229cbfd5a8..5aef2089ac37d8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2267,8 +2267,6 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   Type *ElementTy = isExtended() ? RdxDesc.getRecurrenceType()
                                  : Ctx.Types.inferScalarType(getVecOp0());
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
-  auto *SrcVecTy =
-      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
@@ -2286,25 +2284,29 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   if (isExtended()) {
     auto *SrcTy = cast<VectorType>(
         ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+    auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
     TTI::CastContextHint CCH0 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost = Ctx.TTI.getCastInstrCost(getExtOpcode(), VectorTy, SrcVecTy,
-                                            CCH0, TTI::TCK_RecipThroughput);
+    ExtendedCost = Ctx.TTI.getCastInstrCost(
+        Ext0Instr->getOpcode(), DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt0Instr()));
     TTI::CastContextHint CCH1 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost += Ctx.TTI.getCastInstrCost(getExtOpcode(), VectorTy, SrcVecTy,
-                                             CCH1, TTI::TCK_RecipThroughput);
+    ExtendedCost += Ctx.TTI.getCastInstrCost(
+        Ext1Instr->getOpcode(), DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt1Instr()));
   }
 
   // Mul cost
   InstructionCost MulCost;
   SmallVector<const Value *, 4> Operands;
+  Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
   if (isExtended())
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        Operands, nullptr, &Ctx.TLI);
+        Operands, MulInstr, &Ctx.TLI);
   else {
     VPValue *RHS = getVecOp1();
     // Certain instructions can be cheaper to vectorize if they have a constant
@@ -2317,15 +2319,15 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
     if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
         RHS->isDefinedOutsideLoopRegions())
       RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
-    Operands.append(
-        {getVecOp0()->getUnderlyingValue(), RHS->getUnderlyingValue()});
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        RHSInfo, Operands, nullptr, &Ctx.TLI);
+        RHSInfo, Operands, MulInstr, &Ctx.TLI);
   }
 
   // MulAccReduction Cost
+  VectorType *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
   InstructionCost MulAccCost =
       Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
 
@@ -2409,7 +2411,6 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
                            VPSlotTracker &SlotTracker) const {
   const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
-  Type *ElementTy = RdxDesc.getRecurrenceType();
   O << Indent << "MULACC-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
@@ -2417,23 +2418,27 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " + ";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
+  if (isOuterExtended())
+    O << " (";
   O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
   O << "mul ";
   if (isExtended())
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
   if (isExtended())
-    O << " extended to " << *ElementTy << "), (";
+    O << " extended to " << *getResultType() << "), (";
   else
     O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
   if (isExtended())
-    O << " extended to " << *ElementTy << ")";
+    O << " extended to " << *getResultType() << ")";
   if (isConditional()) {
     O << ", ";
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
+  if (isOuterExtended())
+    O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
   if (RdxDesc.IntermediateStore)
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 50afdeb3aa78c8..53e080b5bb01d6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -527,9 +527,9 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
       if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
         // Genearte VPWidenCastRecipe.
-        auto *Ext =
-            new VPWidenCastRecipe(ExtRed->getExtOpcode(), ExtRed->getVecOp(),
-                                  ExtRed->getResultType());
+        auto *Ext = new VPWidenCastRecipe(
+            ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+            *ExtRed->getExtInstr());
 
         // Generate VPreductionRecipe.
         auto *Red = new VPReductionRecipe(
@@ -542,21 +542,22 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         ExtRed->eraseFromParent();
       } else if (isa<VPMulAccRecipe>(&R)) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
-        Type *RedType = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
 
         // Generate inner VPWidenCastRecipes if necessary.
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
-          Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
-                                      MulAcc->getVecOp0(), RedType);
+          CastInst *Ext0 = MulAcc->getExt0Instr();
+          Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
+                                      MulAcc->getResultType(), *Ext0);
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
           // VPWidenCastRecipe.
           if (MulAcc->isSameExtend()) {
             Op1 = Op0;
           } else {
-            Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
-                                        MulAcc->getVecOp1(), RedType);
+            CastInst *Ext1 = MulAcc->getExt1Instr();
+            Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
+                                        MulAcc->getResultType(), *Ext1);
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           }
           // Not contains extend instruction in this MulAccRecipe.
@@ -566,15 +567,27 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         }
 
         // Generate VPWidenRecipe.
+        VPSingleDefRecipe *VecOp;
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
-        auto *Mul = new VPWidenRecipe(Instruction::Mul,
+        auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
                                       make_range(MulOps.begin(), MulOps.end()));
         Mul->insertBefore(MulAcc);
 
+        // Generate outer VPWidenCastRecipe if necessary.
+        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
+          VecOp = new VPWidenCastRecipe(
+              OuterExtInstr->getOpcode(), Mul,
+              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
+              *OuterExtInstr);
+          VecOp->insertBefore(MulAcc);
+        } else {
+          VecOp = Mul;
+        }
+
         // Generate VPReductionRecipe.
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
-            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+            MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
             MulAcc->isOrdered());
         Red->insertBefore(MulAcc);
 
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index f99f33365fad0f..b138666ba92f9b 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1670,55 +1670,24 @@ define i64 @test_std_q31(ptr %x, i32 %n) #0 {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP11:%.*]] = icmp sgt i32 [[N:%.*]], 0
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[CMP11]])
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp samesign ult i32 [[N]], 4
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
-; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N]], 2147483644
 ; CHECK-NEXT:    br label [[FOR_BODY1:%.*]]
-; CHECK:       vector.body:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[FOR_BODY1]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[FOR_BODY1]] ]
-; CHECK-NEXT:    [[VEC_PHI1:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[FOR_BODY1]] ]
-; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP10]], align 4
-; CHECK-NEXT:    [[TMP1:%.*]] = ashr <4 x i32> [[WIDE_LOAD]], splat (i32 8)
-; CHECK-NEXT:    [[TMP2:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
-; CHECK-NEXT:    [[TMP3:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP2]])
-; CHECK-NEXT:    [[TMP4]] = add i64 [[TMP3]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
-; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <4 x i64> [[TMP5]], [[TMP5]]
-; CHECK-NEXT:    [[TMP7:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP6]])
-; CHECK-NEXT:    [[TMP8]] = add i64 [[TMP7]], [[VEC_PHI1]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
-; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[FOR_BODY1]], !llvm.loop [[LOOP37:![0-9]+]]
-; CHECK:       middle.block:
-; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
-; CHECK:       scalar.ph:
-; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i64 [ [[TMP4]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT:    [[BC_MERGE_RDX2:%.*]] = phi i64 [ [[TMP8]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.cond.cleanup:
-; CHECK-NEXT:    [[ADD:%.*]] = phi i64 [ [[ADD1:%.*]], [[FOR_BODY]] ], [ [[TMP4]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    [[ADD3:%.*]] = phi i64 [ [[ADD5:%.*]], [[FOR_BODY]] ], [ [[TMP8]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    [[DIV:%.*]] = sdiv i64 [[ADD3]], [[ADD]]
+; CHECK-NEXT:    [[DIV:%.*]] = sdiv i64 [[ADD3:%.*]], [[ADD:%.*]]
 ; CHECK-NEXT:    ret i64 [[DIV]]
 ; CHECK:       for.body:
-; CHECK-NEXT:    [[S_014:%.*]] = phi i64 [ [[ADD1]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
-; CHECK-NEXT:    [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
-; CHECK-NEXT:    [[T_012:%.*]] = phi i64 [ [[ADD5]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX2]], [[SCALAR_PH]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X]], i32 [[I_013]]
+; CHECK-NEXT:    [[S_014:%.*]] = phi i64 [ [[ADD]], [[FOR_BODY1]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    [[T_012:%.*]] = phi i64 [ [[ADD3]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[I_013]]
 ; CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
 ; CHECK-NEXT:    [[SHR:%.*]] = ashr i32 [[TMP0]], 8
 ; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[SHR]] to i64
-; CHECK-NEXT:    [[ADD1]] = add nsw i64 [[S_014]], [[CONV]]
+; CHECK-NEXT:    [[ADD]] = add nsw i64 [[S_014]], [[CONV]]
 ; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV]], [[CONV]]
-; CHECK-NEXT:    [[ADD5]] = add nuw nsw i64 [[MUL]], [[T_012]]
+; CHECK-NEXT:    [[ADD3]] = add nuw nsw i64 [[MUL]], [[T_012]]
 ; CHECK-NEXT:    [[ADD4]] = add nuw nsw i32 [[I_013]], 1
 ; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i32 [[ADD4]], [[N]]
-; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP38:![0-9]+]]
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY1]]
 ;
 entry:
   %cmp11 = icmp sgt i32 %n, 0
@@ -1754,10 +1723,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[TMP0:%.*]] = add nsw i32 [[N]], -1
 ; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -8
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -4
 ; CHECK-NEXT:    [[IND_END:%.*]] = shl i32 [[N_VEC]], 1
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
@@ -1765,28 +1734,28 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
-; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
-; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <8 x i32> [[TMP5]], [[TMP6]]
-; CHECK-NEXT:    [[TMP8:%.*]] = sext <8 x i32> [[TMP7]] to <8 x i64>
-; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP8]])
+; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64>
+; CHECK-NEXT:    [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]]
+; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64>
+; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP11:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i32>
-; CHECK-NEXT:    [[TMP12:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP13:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP12]]
-; CHECK-NEXT:    [[TMP14:%.*]] = sext <8 x i32> [[TMP13]] to <8 x i64>
-; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP14]])
+; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]])
 ; CHECK-NEXT:    [[TMP16]] = add i64 [[TMP15]], [[TMP10]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP39:![0-9]+]]
+; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[TMP2]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP]], label [[SCALAR_PH]]
@@ -1821,7 +1790,7 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[ADD12]] = add nsw i64 [[ADD]], [[CONV11]]
 ; CHECK-NEXT:    [[ADD13]] = add nuw nsw i32 [[I_025]], 2
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[ADD13]], [[N]]
-; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP40:![0-9]+]]
+; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP38:![0-9]+]]
 ;
 entry:
   %cmp23 = icmp sgt i32 %n, 0

>From 453997e4b1a451430491bd1a9cdb851c0927ed23 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 14 Nov 2024 16:48:33 -0800
Subject: [PATCH 12/17] Remove extended instruction after mul in MulAccRecipe.

Note that reduce.add(ext(mul(ext(A), ext(A)))) is mathmetical equal
to reduce.add(mul(zext(A), zext(A))).
---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  13 --
 llvm/lib/Transforms/Vectorize/VPlan.h         |  58 ++++-----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  15 +--
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  25 ++--
 .../LoopVectorize/ARM/mve-reductions.ll       | 111 +++++++++++-------
 .../LoopVectorize/vplan-printing.ll           |   2 +-
 6 files changed, 104 insertions(+), 120 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 5cee4b7e2ce70e..14821362f71550 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7396,19 +7396,6 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
       }
       if (Instruction *UI = GetInstructionForCost(&R))
         SeenInstrs.insert(UI);
-      // VPExtendedReductionRecipe contains a folded extend instruction.
-      if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
-        SeenInstrs.insert(ExtendedRed->getExtInstr());
-      // VPMulAccRecipe constians a mul and otional extend instructions.
-      else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
-        SeenInstrs.insert(MulAcc->getMulInstr());
-        if (MulAcc->isExtended()) {
-          SeenInstrs.insert(MulAcc->getExt0Instr());
-          SeenInstrs.insert(MulAcc->getExt1Instr());
-          if (auto *Ext = MulAcc->getExtInstr())
-            SeenInstrs.insert(Ext);
-        }
-      }
     }
   }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index ab9e6f0393602f..b04b694196e298 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2614,8 +2614,6 @@ class VPReductionRecipe : public VPSingleDefRecipe {
                                  getVecOp(), getCondOp(), IsOrdered);
   }
 
-  // TODO: Support VPExtendedReductionRecipe and VPMulAccRecipe after EVL
-  // support.
   static inline bool classof(const VPRecipeBase *R) {
     return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
            R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
@@ -2772,30 +2770,26 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
 /// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPReductionRecipe {
-  /// Type after extend.
-  Type *ResultType;
 
   /// reduce.add(ext(mul(ext0(), ext1())))
   Instruction *MulInstr;
-  CastInst *ExtInstr = nullptr;
   CastInst *Ext0Instr = nullptr;
   CastInst *Ext1Instr = nullptr;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *ExtInstr,
-                 Instruction *MulInstr, Instruction *Ext0Instr,
-                 Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
-                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
-                 Type *ResultType)
+                 Instruction *RedI, Instruction *MulInstr,
+                 Instruction *Ext0Instr, Instruction *Ext1Instr,
+                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
+                 VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered),
-        ResultType(ResultType), MulInstr(MulInstr),
-        ExtInstr(cast_if_present<CastInst>(ExtInstr)),
-        Ext0Instr(cast<CastInst>(Ext0Instr)),
+        MulInstr(MulInstr), Ext0Instr(cast<CastInst>(Ext0Instr)),
         Ext1Instr(cast<CastInst>(Ext1Instr)) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
+    assert(R.getOpcode() == Instruction::Add);
+    assert(Ext0Instr->getOpcode() == Ext1Instr->getOpcode());
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2807,6 +2801,7 @@ class VPMulAccRecipe : public VPReductionRecipe {
                           CondOp, IsOrdered),
         MulInstr(MulInstr) {
     assert(MulInstr->getOpcode() == Instruction::Mul);
+    assert(R.getOpcode() == Instruction::Add);
   }
 
 public:
@@ -2814,11 +2809,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
-                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
-                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
-                       Ext1->getOperand(0), CondOp, IsOrdered,
-                       Ext0->getResultType()) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
+                       CondOp, IsOrdered) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
@@ -2831,11 +2825,10 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
                  VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
-                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
-                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
-                       Ext1->getOperand(0), CondOp, IsOrdered,
-                       Ext0->getResultType()) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
+                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
+                       CondOp, IsOrdered) {}
 
   ~VPMulAccRecipe() override = default;
 
@@ -2869,35 +2862,26 @@ class VPMulAccRecipe : public VPReductionRecipe {
   VPValue *getVecOp0() const { return getOperand(1); }
   VPValue *getVecOp1() const { return getOperand(2); }
 
-  /// Return the type after inner extended, which must equal to the type of mul
-  /// instruction. If the ResultType != recurrenceType, than it must have a
-  /// extend recipe after mul recipe.
-  Type *getResultType() const { return ResultType; }
-
   /// The underlying instruction for VPWidenRecipe.
   Instruction *getMulInstr() const { return MulInstr; }
 
-  /// The underlying Instruction for outer VPWidenCastRecipe.
-  CastInst *getExtInstr() const { return ExtInstr; }
   /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt0Instr() const { return Ext0Instr; }
+
   /// The underlying Instruction for inner VPWidenCastRecipe.
   CastInst *getExt1Instr() const { return Ext1Instr; }
 
   /// Return if this MulAcc recipe contains extend instructions.
   bool isExtended() const { return Ext0Instr && Ext1Instr; }
+
   /// Return if the operands of mul instruction come from same extend.
-  bool isSameExtend() const { return Ext0Instr == Ext1Instr; }
-  /// Return if the MulAcc recipes contains extend after mul.
-  bool isOuterExtended() const { return ExtInstr != nullptr; }
+  bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
+  Instruction::CastOps getExtOpcode() const { return Ext0Instr->getOpcode(); }
+
   /// Return if the extend opcode is ZExt.
   bool isZExt() const {
     if (!isExtended())
       return true;
-    // reduce.add(sext(mul(zext(A), zext(A)))) can be transform to
-    // reduce.add(zext(mul(sext(A), sext(A))))
-    if (ExtInstr && ExtInstr->getOpcode() != Ext0Instr->getOpcode())
-      return true;
     return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
   }
 };
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 5aef2089ac37d8..f6e6f601309d55 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2284,16 +2284,15 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   if (isExtended()) {
     auto *SrcTy = cast<VectorType>(
         ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
-    auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
     TTI::CastContextHint CCH0 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
     ExtendedCost = Ctx.TTI.getCastInstrCost(
-        Ext0Instr->getOpcode(), DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        Ext0Instr->getOpcode(), VectorTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt0Instr()));
     TTI::CastContextHint CCH1 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
     ExtendedCost += Ctx.TTI.getCastInstrCost(
-        Ext1Instr->getOpcode(), DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        Ext1Instr->getOpcode(), VectorTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt1Instr()));
   }
 
@@ -2411,6 +2410,8 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
                            VPSlotTracker &SlotTracker) const {
   const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+  Type *RedTy = RdxDesc.getRecurrenceType();
+
   O << Indent << "MULACC-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
@@ -2418,27 +2419,23 @@ void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " + ";
   if (isa<FPMathOperator>(getUnderlyingInstr()))
     O << getUnderlyingInstr()->getFastMathFlags();
-  if (isOuterExtended())
-    O << " (";
   O << "reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
   O << "mul ";
   if (isExtended())
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
   if (isExtended())
-    O << " extended to " << *getResultType() << "), (";
+    O << " extended to " << *RedTy << "), (";
   else
     O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
   if (isExtended())
-    O << " extended to " << *getResultType() << ")";
+    O << " extended to " << *RedTy << ")";
   if (isConditional()) {
     O << ", ";
     getCondOp()->printAsOperand(O, SlotTracker);
   }
   O << ")";
-  if (isOuterExtended())
-    O << " extended to " << *RdxDesc.getRecurrenceType() << ")";
   if (RdxDesc.IntermediateStore)
     O << " (with final reduction value stored in invariant address sank "
          "outside of loop)";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 53e080b5bb01d6..a41d1a411e899d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -542,52 +542,43 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         ExtRed->eraseFromParent();
       } else if (isa<VPMulAccRecipe>(&R)) {
         auto *MulAcc = cast<VPMulAccRecipe>(&R);
+        Type *RedTy = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
 
         // Generate inner VPWidenCastRecipes if necessary.
+        // Note that we will drop the extend after of mul which transform
+        // reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
           CastInst *Ext0 = MulAcc->getExt0Instr();
           Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
-                                      MulAcc->getResultType(), *Ext0);
+                                      RedTy, *Ext0);
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
           // VPWidenCastRecipe.
-          if (MulAcc->isSameExtend()) {
+          if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
             Op1 = Op0;
           } else {
             CastInst *Ext1 = MulAcc->getExt1Instr();
             Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
-                                        MulAcc->getResultType(), *Ext1);
+                                        RedTy, *Ext1);
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           }
-          // Not contains extend instruction in this MulAccRecipe.
+          // No extends in this MulAccRecipe.
         } else {
           Op0 = MulAcc->getVecOp0();
           Op1 = MulAcc->getVecOp1();
         }
 
         // Generate VPWidenRecipe.
-        VPSingleDefRecipe *VecOp;
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
         auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
                                       make_range(MulOps.begin(), MulOps.end()));
         Mul->insertBefore(MulAcc);
 
-        // Generate outer VPWidenCastRecipe if necessary.
-        if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
-          VecOp = new VPWidenCastRecipe(
-              OuterExtInstr->getOpcode(), Mul,
-              MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
-              *OuterExtInstr);
-          VecOp->insertBefore(MulAcc);
-        } else {
-          VecOp = Mul;
-        }
-
         // Generate VPReductionRecipe.
         auto *Red = new VPReductionRecipe(
             MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
-            MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
+            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
             MulAcc->isOrdered());
         Red->insertBefore(MulAcc);
 
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index b138666ba92f9b..f722fd2584ca96 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -648,10 +648,9 @@ define i64 @mla_i16_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i3
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i16>, ptr [[TMP2]], align 2
-; CHECK-NEXT:    [[TMP11:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw <8 x i32> [[TMP11]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i32> [[TMP4]] to <8 x i64>
+; CHECK-NEXT:    [[TMP4:%.*]] = sext <8 x i16> [[WIDE_LOAD1]] to <8 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i64>
+; CHECK-NEXT:    [[TMP5:%.*]] = mul nsw <8 x i64> [[TMP4]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
 ; CHECK-NEXT:    [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
@@ -728,10 +727,9 @@ define i64 @mla_i8_i64(ptr nocapture readonly %x, ptr nocapture readonly %y, i32
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i8>, ptr [[TMP0]], align 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[Y:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i8>, ptr [[TMP2]], align 1
-; CHECK-NEXT:    [[TMP11:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw <8 x i32> [[TMP11]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = zext nneg <8 x i32> [[TMP4]] to <8 x i64>
+; CHECK-NEXT:    [[TMP4:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i64>
+; CHECK-NEXT:    [[TMP5:%.*]] = mul nuw nsw <8 x i64> [[TMP4]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
 ; CHECK-NEXT:    [[TMP7]] = add i64 [[TMP6]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
@@ -1461,9 +1459,8 @@ define i64 @mla_xx_sext_zext(ptr nocapture noundef readonly %x, i32 %n) #0 {
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <8 x i32> [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = zext nneg <8 x i32> [[TMP2]] to <8 x i64>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = mul nsw <8 x i64> [[TMP1]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i64 [[TMP4]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
@@ -1530,9 +1527,8 @@ define i64 @mla_and_add_together_16_64(ptr nocapture noundef readonly %x, i32 no
 ; CHECK-NEXT:    [[VEC_PHI1:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i16>, ptr [[TMP0]], align 2
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = mul nsw <8 x i32> [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = zext nneg <8 x i32> [[TMP2]] to <8 x i64>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = mul nsw <8 x i64> [[TMP1]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP3]])
 ; CHECK-NEXT:    [[TMP5]] = add i64 [[TMP4]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[TMP10:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
@@ -1670,24 +1666,55 @@ define i64 @test_std_q31(ptr %x, i32 %n) #0 {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP11:%.*]] = icmp sgt i32 [[N:%.*]], 0
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[CMP11]])
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp samesign ult i32 [[N]], 4
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N]], 2147483644
 ; CHECK-NEXT:    br label [[FOR_BODY1:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP4:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT:    [[VEC_PHI1:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[FOR_BODY1]] ]
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP10]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr <4 x i32> [[WIDE_LOAD]], splat (i32 8)
+; CHECK-NEXT:    [[TMP2:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP2]])
+; CHECK-NEXT:    [[TMP4]] = add i64 [[TMP3]], [[VEC_PHI]]
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <4 x i64> [[TMP5]], [[TMP5]]
+; CHECK-NEXT:    [[TMP7:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP6]])
+; CHECK-NEXT:    [[TMP8]] = add i64 [[TMP7]], [[VEC_PHI1]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[FOR_BODY1]], !llvm.loop [[LOOP37:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK:       scalar.ph:
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i64 [ [[TMP4]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX2:%.*]] = phi i64 [ [[TMP8]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.cond.cleanup:
-; CHECK-NEXT:    [[DIV:%.*]] = sdiv i64 [[ADD3:%.*]], [[ADD:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = phi i64 [ [[ADD1:%.*]], [[FOR_BODY]] ], [ [[TMP4]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[ADD3:%.*]] = phi i64 [ [[ADD5:%.*]], [[FOR_BODY]] ], [ [[TMP8]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[DIV:%.*]] = sdiv i64 [[ADD3]], [[ADD]]
 ; CHECK-NEXT:    ret i64 [[DIV]]
 ; CHECK:       for.body:
-; CHECK-NEXT:    [[S_014:%.*]] = phi i64 [ [[ADD]], [[FOR_BODY1]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT:    [[T_012:%.*]] = phi i64 [ [[ADD3]], [[FOR_BODY1]] ], [ 0, [[ENTRY]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X:%.*]], i32 [[I_013]]
+; CHECK-NEXT:    [[S_014:%.*]] = phi i64 [ [[ADD1]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ]
+; CHECK-NEXT:    [[I_013:%.*]] = phi i32 [ [[ADD4:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
+; CHECK-NEXT:    [[T_012:%.*]] = phi i64 [ [[ADD5]], [[FOR_BODY]] ], [ [[BC_MERGE_RDX2]], [[SCALAR_PH]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[X]], i32 [[I_013]]
 ; CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
 ; CHECK-NEXT:    [[SHR:%.*]] = ashr i32 [[TMP0]], 8
 ; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[SHR]] to i64
-; CHECK-NEXT:    [[ADD]] = add nsw i64 [[S_014]], [[CONV]]
+; CHECK-NEXT:    [[ADD1]] = add nsw i64 [[S_014]], [[CONV]]
 ; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV]], [[CONV]]
-; CHECK-NEXT:    [[ADD3]] = add nuw nsw i64 [[MUL]], [[T_012]]
+; CHECK-NEXT:    [[ADD5]] = add nuw nsw i64 [[MUL]], [[T_012]]
 ; CHECK-NEXT:    [[ADD4]] = add nuw nsw i32 [[I_013]], 1
 ; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i32 [[ADD4]], [[N]]
-; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY1]]
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP38:![0-9]+]]
 ;
 entry:
   %cmp11 = icmp sgt i32 %n, 0
@@ -1723,10 +1750,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[TMP0:%.*]] = add nsw i32 [[N]], -1
 ; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -4
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -8
 ; CHECK-NEXT:    [[IND_END:%.*]] = shl i32 [[N_VEC]], 1
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
@@ -1734,28 +1761,26 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32>
+; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]]
-; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64>
-; CHECK-NEXT:    [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32>
-; CHECK-NEXT:    [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]]
-; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64>
-; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]])
+; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i64>
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i64>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <8 x i64> [[TMP5]], [[TMP6]]
+; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP7]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]])
+; CHECK-NEXT:    [[TMP13:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i64>
+; CHECK-NEXT:    [[TMP11:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i64>
+; CHECK-NEXT:    [[TMP12:%.*]] = mul nsw <8 x i64> [[TMP13]], [[TMP11]]
+; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP12]])
 ; CHECK-NEXT:    [[TMP16]] = add i64 [[TMP15]], [[TMP10]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]]
+; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP39:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[TMP2]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP]], label [[SCALAR_PH]]
@@ -1790,7 +1815,7 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[ADD12]] = add nsw i64 [[ADD]], [[CONV11]]
 ; CHECK-NEXT:    [[ADD13]] = add nuw nsw i32 [[I_025]], 2
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[ADD13]], [[N]]
-; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP38:![0-9]+]]
+; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP]], !llvm.loop [[LOOP40:![0-9]+]]
 ;
 entry:
   %cmp23 = icmp sgt i32 %n, 0
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index b1d7388005ad4e..95ce514d8a92ff 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -1315,7 +1315,7 @@ define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture reado
 ; CHECK-NEXT:     CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
 ; CHECK-NEXT:     vp<%5> = vector-pointer ir<%arrayidx1>
 ; CHECK-NEXT:     WIDEN ir<%load1> = load vp<%5>
-; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> +  (reduce.add (mul (ir<%load0> extended to i32), (ir<%load1> extended to i32)) extended to i64)
+; CHECK-NEXT:     MULACC-REDUCE ir<%add> = ir<%r.09> + reduce.add (mul (ir<%load0> extended to i64), (ir<%load1> extended to i64))
 ; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
 ; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
 ; CHECK-NEXT:   No successors

>From fa4f476c811d216cc982931be1ca601dff9cdd2a Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Thu, 14 Nov 2024 18:07:19 -0800
Subject: [PATCH 13/17] Refactor.

- Move tryToMatch{MulAcc|ExtendedReduction} out from
adjustRecipesForReduction().
- Remove `ResultTy` which is same as the recurrence type.
- Use VP_CLASSOF_IMPL.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 153 ++++++++++--------
 llvm/lib/Transforms/Vectorize/VPlan.h         |  47 ++----
 .../Transforms/Vectorize/VPlanTransforms.cpp  |   2 +-
 3 files changed, 100 insertions(+), 102 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 14821362f71550..72a9ae76eccb68 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9253,6 +9253,87 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
   return Plan;
 }
 
+/// Try to match the extended-reduction and create VPExtendedReductionRecipe.
+///
+/// This function try to match following pattern which will generate
+/// extended-reduction instruction.
+///    reduce(ext(...)).
+static VPExtendedReductionRecipe *
+tryToMatchExtendedReduction(const RecurrenceDescriptor &RdxDesc,
+                            Instruction *CurrentLinkI, VPValue *PreviousLink,
+                            VPValue *VecOp, VPValue *CondOp,
+                            LoopVectorizationCostModel &CM) {
+  using namespace VPlanPatternMatch;
+  VPValue *A;
+  // Matched reduce(ext)).
+  if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
+    return new VPExtendedReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
+                                         cast<VPWidenCastRecipe>(VecOp), CondOp,
+                                         CM.useOrderedReductions(RdxDesc));
+  }
+  return nullptr;
+}
+
+/// Try to match the mul-acc-reduction and create VPMulAccRecipe.
+///
+/// This function try to match following patterns which will generate mul-acc
+/// instructions.
+///    reduce.add(mul(...)),
+///    reduce.add(mul(ext(A), ext(B))),
+///    reduce.add(ext(mul(ext(A), ext(B)))).
+static VPMulAccRecipe *tryToMatchMulAcc(const RecurrenceDescriptor &RdxDesc,
+                                        Instruction *CurrentLinkI,
+                                        VPValue *PreviousLink, VPValue *VecOp,
+                                        VPValue *CondOp,
+                                        LoopVectorizationCostModel &CM) {
+  using namespace VPlanPatternMatch;
+  VPValue *A, *B;
+  if (RdxDesc.getOpcode() != Instruction::Add)
+    return nullptr;
+  // Try to match reduce.add(mul(...))
+  if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+    VPWidenCastRecipe *RecipeA =
+        dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+    VPWidenCastRecipe *RecipeB =
+        dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+    // Matched reduce.add(mul(ext, ext))
+    if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+        match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+        (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
+      return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                                CM.useOrderedReductions(RdxDesc),
+                                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
+                                RecipeA, RecipeB);
+    } else {
+      // Matched reduce.add(mul)
+      return new VPMulAccRecipe(
+          RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+          CM.useOrderedReductions(RdxDesc),
+          cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
+    }
+    // Matched reduce.add(ext(mul(ext(A), ext(B))))
+    // All extend instructions must have same opcode or A == B
+    // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
+  } else if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
+                                             m_ZExtOrSExt(m_VPValue()))))) {
+    VPWidenCastRecipe *Ext =
+        cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+    VPWidenRecipe *Mul =
+        cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+    VPWidenCastRecipe *Ext0 =
+        cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
+    VPWidenCastRecipe *Ext1 =
+        cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
+    if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
+        Ext0->getOpcode() == Ext1->getOpcode()) {
+      return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+                                CM.useOrderedReductions(RdxDesc), Mul, Ext0,
+                                Ext1);
+    }
+  }
+  return nullptr;
+}
+
 // Adjust the recipes for reductions. For in-loop reductions the chain of
 // instructions leading from the loop exit instr to the phi need to be converted
 // to reductions, with one operand being vector and the other being the scalar
@@ -9384,76 +9465,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      auto TryToMatchMulAcc = [&]() -> VPReductionRecipe * {
-        VPValue *A, *B;
-        if (RdxDesc.getOpcode() != Instruction::Add)
-          return nullptr;
-        // Try to match reduce.add(mul(...))
-        if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
-            !VecOp->hasMoreThanOneUniqueUser()) {
-          VPWidenCastRecipe *RecipeA =
-              dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
-          VPWidenCastRecipe *RecipeB =
-              dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
-          // Matched reduce.add(mul(ext, ext))
-          if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
-              match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
-              (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
-            return new VPMulAccRecipe(
-                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-                CM.useOrderedReductions(RdxDesc),
-                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()), RecipeA,
-                RecipeB);
-          } else {
-            // Matched reduce.add(mul)
-            return new VPMulAccRecipe(
-                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-                CM.useOrderedReductions(RdxDesc),
-                cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
-          }
-          // Matched reduce.add(ext(mul(ext(A), ext(B))))
-          // Note that all extend instructions must have same opcode or A == B
-          // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
-        } else if (match(VecOp,
-                         m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
-                                            m_ZExtOrSExt(m_VPValue())))) &&
-                   !VecOp->hasMoreThanOneUniqueUser()) {
-          VPWidenCastRecipe *Ext =
-              cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
-          VPWidenRecipe *Mul =
-              cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
-          VPWidenCastRecipe *Ext0 =
-              cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
-          VPWidenCastRecipe *Ext1 =
-              cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
-          if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
-              Ext0->getOpcode() == Ext1->getOpcode()) {
-            return new VPMulAccRecipe(
-                RdxDesc, CurrentLinkI, PreviousLink, CondOp,
-                CM.useOrderedReductions(RdxDesc),
-                cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
-                cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
-          }
-        }
-        return nullptr;
-      };
-
-      auto TryToMatchExtendedReduction = [&]() -> VPReductionRecipe * {
-        VPValue *A;
-        // Matched reduce(ext)).
-        if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
-          return new VPExtendedReductionRecipe(
-              RdxDesc, CurrentLinkI, PreviousLink,
-              cast<VPWidenCastRecipe>(VecOp), CondOp,
-              CM.useOrderedReductions(RdxDesc));
-        }
-        return nullptr;
-      };
-
       VPReductionRecipe *RedRecipe;
-      if (auto *MulAcc = TryToMatchMulAcc())
+      if (auto *MulAcc = tryToMatchMulAcc(RdxDesc, CurrentLinkI, PreviousLink,
+                                          VecOp, CondOp, CM))
         RedRecipe = MulAcc;
-      else if (auto *ExtendedRed = TryToMatchExtendedReduction())
+      else if (auto *ExtendedRed = tryToMatchExtendedReduction(
+                   RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp, CM))
         RedRecipe = ExtendedRed;
       else
         RedRecipe =
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index b04b694196e298..bc04c5c47f7f75 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2702,8 +2702,6 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 /// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
 /// {ChainOp, VecOp, [Condition]}.
 class VPExtendedReductionRecipe : public VPReductionRecipe {
-  /// Type after extend.
-  Type *ResultTy;
   CastInst *ExtInstr;
 
 protected:
@@ -2711,10 +2709,10 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
                             const RecurrenceDescriptor &R, Instruction *RedI,
                             Instruction::CastOps ExtOp, CastInst *ExtInstr,
                             VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
-                            bool IsOrdered, Type *ResultTy)
+                            bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
                           CondOp, IsOrdered),
-        ResultTy(ResultTy), ExtInstr(ExtInstr) {}
+        ExtInstr(ExtInstr) {}
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
@@ -2723,7 +2721,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
       : VPExtendedReductionRecipe(
             VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
             cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
-            Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
+            Ext->getOperand(0), CondOp, IsOrdered) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2731,14 +2729,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
     llvm_unreachable("Not implement yet");
   }
 
-  static inline bool classof(const VPRecipeBase *R) {
-    return R->getVPDefID() == VPDef::VPExtendedReductionSC;
-  }
-
-  static inline bool classof(const VPUser *U) {
-    auto *R = dyn_cast<VPRecipeBase>(U);
-    return R && classof(R);
-  }
+  VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC);
 
   void execute(VPTransformState &State) override {
     llvm_unreachable("VPExtendedReductionRecipe should be transform to "
@@ -2755,11 +2746,16 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
              VPSlotTracker &SlotTracker) const override;
 #endif
 
-  /// The Type after extended.
-  Type *getResultType() const { return ResultTy; }
+  /// The scalar type after extended.
+  Type *getResultType() const {
+    return getRecurrenceDescriptor().getRecurrenceType();
+  }
+
   bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
+
   /// The Opcode of extend instruction.
   Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
+
   /// The CastInst of the extend instruction.
   CastInst *getExtInstr() const { return ExtInstr; }
 };
@@ -2771,7 +2767,7 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPReductionRecipe {
 
-  /// reduce.add(ext(mul(ext0(), ext1())))
+  // reduce.add(mul(ext0, ext1)).
   Instruction *MulInstr;
   CastInst *Ext0Instr = nullptr;
   CastInst *Ext1Instr = nullptr;
@@ -2821,27 +2817,11 @@ class VPMulAccRecipe : public VPReductionRecipe {
                        ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
                        IsOrdered) {}
 
-  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
-                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
-                 VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
-                 VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered) {}
-
   ~VPMulAccRecipe() override = default;
 
   VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
 
-  static inline bool classof(const VPRecipeBase *R) {
-    return R->getVPDefID() == VPRecipeBase::VPMulAccSC;
-  }
-
-  static inline bool classof(const VPUser *U) {
-    auto *R = dyn_cast<VPRecipeBase>(U);
-    return R && classof(R);
-  }
+  VP_CLASSOF_IMPL(VPDef::VPMulAccSC);
 
   void execute(VPTransformState &State) override {
     llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
@@ -2876,6 +2856,7 @@ class VPMulAccRecipe : public VPReductionRecipe {
 
   /// Return if the operands of mul instruction come from same extend.
   bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
+
   Instruction::CastOps getExtOpcode() const { return Ext0Instr->getOpcode(); }
 
   /// Return if the extend opcode is ZExt.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index a41d1a411e899d..bae8aabc6f9692 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -545,7 +545,7 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         Type *RedTy = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
 
         // Generate inner VPWidenCastRecipes if necessary.
-        // Note that we will drop the extend after of mul which transform
+        // Note that we will drop the extend after mul which transform
         // reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {

>From 86ad2d8565c5c2abfcfb88fb6155ff08dfbeb7ba Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Fri, 15 Nov 2024 01:03:18 -0800
Subject: [PATCH 14/17] Clamp the range when the ExtendedReduction or MulAcc
 cost is invalid.

---
 .../Vectorize/LoopVectorizationPlanner.h      |  2 +-
 .../Transforms/Vectorize/LoopVectorize.cpp    | 97 +++++++++++++++----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 42 ++++----
 3 files changed, 106 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index fbcf181a45a664..b8920c936f3b6a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -509,7 +509,7 @@ class LoopVectorizationPlanner {
   // between the phi and users outside the vector region when folding the tail.
   void adjustRecipesForReductions(VPlanPtr &Plan,
                                   VPRecipeBuilder &RecipeBuilder,
-                                  ElementCount MinVF);
+                                  VFRange &Range);
 
 #ifndef NDEBUG
   /// \return The most profitable vectorization factor for the available VPlans
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 72a9ae76eccb68..fbb39c659c9175 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9145,7 +9145,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
   // ---------------------------------------------------------------------------
 
   // Adjust the recipes for any inloop reductions.
-  adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start);
+  adjustRecipesForReductions(Plan, RecipeBuilder, Range);
 
   // Interleave memory: for each Interleave Group we marked earlier as relevant
   // for this VPlan, replace the Recipes widening its memory instructions with a
@@ -9258,15 +9258,39 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
 /// This function try to match following pattern which will generate
 /// extended-reduction instruction.
 ///    reduce(ext(...)).
-static VPExtendedReductionRecipe *
-tryToMatchExtendedReduction(const RecurrenceDescriptor &RdxDesc,
-                            Instruction *CurrentLinkI, VPValue *PreviousLink,
-                            VPValue *VecOp, VPValue *CondOp,
-                            LoopVectorizationCostModel &CM) {
+static VPExtendedReductionRecipe *tryToMatchAndCreateExtendedReduction(
+    const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+    VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp,
+    LoopVectorizationCostModel &CM, VPCostContext &Ctx, VFRange &Range) {
   using namespace VPlanPatternMatch;
+
   VPValue *A;
+  Type *RedTy = RdxDesc.getRecurrenceType();
+
+  // Test if the cost of extended-reduction is valid and clamp the range.
+  // Note that reduction-extended is not always valid for all VF and types.
+  auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
+                                             Type *SrcTy) -> bool {
+    return LoopVectorizationPlanner::getDecisionAndClampRange(
+        [&](ElementCount VF) {
+          VectorType *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+          return Ctx.TTI
+              .getExtendedReductionCost(Opcode, isZExt, RedTy, SrcVecTy,
+                                        RdxDesc.getFastMathFlags(),
+                                        TTI::TCK_RecipThroughput)
+              .isValid();
+        },
+        Range);
+  };
+
   // Matched reduce(ext)).
   if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
+    if (!IsExtendedRedValidAndClampRange(
+            RdxDesc.getOpcode(),
+            cast<VPWidenCastRecipe>(VecOp)->getOpcode() ==
+                Instruction::CastOps::ZExt,
+            Ctx.Types.inferScalarType(A)))
+      return nullptr;
     return new VPExtendedReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
                                          cast<VPWidenCastRecipe>(VecOp), CondOp,
                                          CM.useOrderedReductions(RdxDesc));
@@ -9281,31 +9305,59 @@ tryToMatchExtendedReduction(const RecurrenceDescriptor &RdxDesc,
 ///    reduce.add(mul(...)),
 ///    reduce.add(mul(ext(A), ext(B))),
 ///    reduce.add(ext(mul(ext(A), ext(B)))).
-static VPMulAccRecipe *tryToMatchMulAcc(const RecurrenceDescriptor &RdxDesc,
-                                        Instruction *CurrentLinkI,
-                                        VPValue *PreviousLink, VPValue *VecOp,
-                                        VPValue *CondOp,
-                                        LoopVectorizationCostModel &CM) {
+static VPMulAccRecipe *tryToMatchAndCreateMulAcc(
+    const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+    VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp,
+    LoopVectorizationCostModel &CM, VPCostContext &Ctx, VFRange &Range) {
   using namespace VPlanPatternMatch;
+
   VPValue *A, *B;
+  Type *RedTy = RdxDesc.getRecurrenceType();
+
+  // Test if the cost of MulAcc is valid and clamp the range.
+  // Note that mul-acc is not always valid for all VF and types.
+  auto IsMulAccValidAndClampRange = [&](bool isZExt, Type *SrcTy) -> bool {
+    return LoopVectorizationPlanner::getDecisionAndClampRange(
+        [&](ElementCount VF) {
+          VectorType *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+          return Ctx.TTI
+              .getMulAccReductionCost(isZExt, RedTy, SrcVecTy,
+                                      TTI::TCK_RecipThroughput)
+              .isValid();
+        },
+        Range);
+  };
+
   if (RdxDesc.getOpcode() != Instruction::Add)
     return nullptr;
+
   // Try to match reduce.add(mul(...))
   if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
     VPWidenCastRecipe *RecipeA =
         dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
     VPWidenCastRecipe *RecipeB =
         dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+
     // Matched reduce.add(mul(ext, ext))
     if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
         match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
         (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
+
+      // Only create MulAccRecipe if the cost is valid.
+      if (!IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
+                                          Instruction::CastOps::ZExt,
+                                      Ctx.Types.inferScalarType(RecipeA)))
+        return nullptr;
+
       return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                                 CM.useOrderedReductions(RdxDesc),
                                 cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
                                 RecipeA, RecipeB);
     } else {
       // Matched reduce.add(mul)
+      if (!IsMulAccValidAndClampRange(true, RedTy))
+        return nullptr;
+
       return new VPMulAccRecipe(
           RdxDesc, CurrentLinkI, PreviousLink, CondOp,
           CM.useOrderedReductions(RdxDesc),
@@ -9326,6 +9378,12 @@ static VPMulAccRecipe *tryToMatchMulAcc(const RecurrenceDescriptor &RdxDesc,
         cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
     if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
         Ext0->getOpcode() == Ext1->getOpcode()) {
+      // Only create MulAcc recipe if the cost if valid.
+      if (!IsMulAccValidAndClampRange(Ext0->getOpcode() ==
+                                          Instruction::CastOps::ZExt,
+                                      Ctx.Types.inferScalarType(Ext0)))
+        return nullptr;
+
       return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
                                 CM.useOrderedReductions(RdxDesc), Mul, Ext0,
                                 Ext1);
@@ -9348,8 +9406,9 @@ static VPMulAccRecipe *tryToMatchMulAcc(const RecurrenceDescriptor &RdxDesc,
 // with a boolean reduction phi node to check if the condition is true in any
 // iteration. The final value is selected by the final ComputeReductionResult.
 void LoopVectorizationPlanner::adjustRecipesForReductions(
-    VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) {
+    VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, VFRange &Range) {
   using namespace VPlanPatternMatch;
+  ElementCount MinVF = Range.Start;
   VPRegionBlock *VectorLoopRegion = Plan->getVectorLoopRegion();
   VPBasicBlock *Header = VectorLoopRegion->getEntryBasicBlock();
   VPBasicBlock *MiddleVPBB = Plan->getMiddleBlock();
@@ -9466,12 +9525,16 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
       VPReductionRecipe *RedRecipe;
-      if (auto *MulAcc = tryToMatchMulAcc(RdxDesc, CurrentLinkI, PreviousLink,
-                                          VecOp, CondOp, CM))
+      VPCostContext CostCtx(CM.TTI, *CM.TLI, Legal->getWidestInductionType(),
+                            CM);
+      if (auto *MulAcc =
+              tryToMatchAndCreateMulAcc(RdxDesc, CurrentLinkI, PreviousLink,
+                                        VecOp, CondOp, CM, CostCtx, Range))
         RedRecipe = MulAcc;
-      else if (auto *ExtendedRed = tryToMatchExtendedReduction(
-                   RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp, CM))
-        RedRecipe = ExtendedRed;
+      else if (auto *ExtRed = tryToMatchAndCreateExtendedReduction(
+                   RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp, CM,
+                   CostCtx, Range))
+        RedRecipe = ExtRed;
       else
         RedRecipe =
             new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index f6e6f601309d55..ad3771fcaa142d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2225,9 +2225,19 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   RecurKind RdxKind = RdxDesc.getRecurrenceKind();
   Type *ElementTy = getResultType();
   auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  auto *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
+  // ExtendedReduction Cost
+  InstructionCost ExtendedRedCost =
+      Ctx.TTI.getExtendedReductionCost(Opcode, isZExt(), ElementTy, SrcVecTy,
+                                       RdxDesc.getFastMathFlags(), CostKind);
+
+  assert(ExtendedRedCost.isValid() && "VPExtendedReductionRecipe should not be "
+                                      "created if the cost is invalid.");
+
   // BaseCost = Reduction cost + BinOp cost
   InstructionCost ReductionCost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
@@ -2241,23 +2251,18 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   }
 
   // Extended cost
-  auto *SrcTy =
-      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
-  auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
   TTI::CastContextHint CCH = computeCCH(getVecOp()->getDefiningRecipe(), VF);
   // Arm TTI will use the underlying instruction to determine the cost.
   InstructionCost ExtendedCost = Ctx.TTI.getCastInstrCost(
-      Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
+      Opcode, VectorTy, SrcVecTy, CCH, TTI::TCK_RecipThroughput,
       dyn_cast_if_present<Instruction>(getUnderlyingValue()));
 
-  // ExtendedReduction Cost
-  InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
-      Opcode, isZExt(), ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
   // Check if folding ext into ExtendedReduction is profitable.
   if (ExtendedRedCost.isValid() &&
       ExtendedRedCost < ExtendedCost + ReductionCost) {
     return ExtendedRedCost;
   }
+
   return ExtendedCost + ReductionCost;
 }
 
@@ -2272,6 +2277,14 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
 
   assert(Opcode == Instruction::Add &&
          "Reduction opcode must be add in the VPMulAccRecipe.");
+  // MulAccReduction Cost
+  VectorType *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+  InstructionCost MulAccCost =
+      Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
+
+  assert(MulAccCost.isValid() && "VPMulAccRecipe should not be "
+                                 "created if the cost is invalid.");
 
   // BaseCost = Reduction cost + BinOp cost
   InstructionCost ReductionCost =
@@ -2282,17 +2295,17 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   // Extended cost
   InstructionCost ExtendedCost = 0;
   if (isExtended()) {
-    auto *SrcTy = cast<VectorType>(
-        ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
     TTI::CastContextHint CCH0 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
     ExtendedCost = Ctx.TTI.getCastInstrCost(
-        Ext0Instr->getOpcode(), VectorTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        Ext0Instr->getOpcode(), VectorTy, SrcVecTy, CCH0,
+        TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt0Instr()));
     TTI::CastContextHint CCH1 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
     ExtendedCost += Ctx.TTI.getCastInstrCost(
-        Ext1Instr->getOpcode(), VectorTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        Ext1Instr->getOpcode(), VectorTy, SrcVecTy, CCH1,
+        TTI::TCK_RecipThroughput,
         dyn_cast_if_present<Instruction>(getExt1Instr()));
   }
 
@@ -2324,17 +2337,12 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
         RHSInfo, Operands, MulInstr, &Ctx.TLI);
   }
 
-  // MulAccReduction Cost
-  VectorType *SrcVecTy =
-      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
-  InstructionCost MulAccCost =
-      Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
-
   // Check if folding ext into ExtendedReduction is profitable.
   if (MulAccCost.isValid() &&
       MulAccCost < ExtendedCost + ReductionCost + MulCost) {
     return MulAccCost;
   }
+
   return ExtendedCost + ReductionCost + MulCost;
 }
 

>From 594f9e49425fd8c8a223c7102801aa5f5d4ef341 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 17 Nov 2024 21:26:30 -0800
Subject: [PATCH 15/17] Try to not depend on underlying ext/mul instructions
 and preserve flags/DL.

This commit expand the ctor of VPWidenRecipe and the VPWidenCastRecipe
to generate vector instruction with debug information and flags without
underlying instruction.

The extend instruction (ZExt) will need non-negative flags. and the int mul
instructions will need no-unsigned-wrap/no-signed-wrap flags.
---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 157 ++++++++++++------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  19 +--
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  42 +++--
 3 files changed, 144 insertions(+), 74 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index bc04c5c47f7f75..60ab1d156e8806 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -951,13 +951,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     GEPFlagsTy(bool IsInBounds) : IsInBounds(IsInBounds) {}
   };
 
+  struct NonNegFlagsTy {
+    char NonNeg : 1;
+    NonNegFlagsTy(bool IsNonNeg) : NonNeg(IsNonNeg) {}
+  };
+
 private:
   struct ExactFlagsTy {
     char IsExact : 1;
   };
-  struct NonNegFlagsTy {
-    char NonNeg : 1;
-  };
   struct FastMathFlagsTy {
     char AllowReassoc : 1;
     char NoNaNs : 1;
@@ -1051,6 +1053,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
       : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
         DisjointFlags(DisjointFlags) {}
 
+  template <typename IterT>
+  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+                      NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
+        NonNegFlags(NonNegFlags) {}
+
 protected:
   template <typename IterT>
   VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
@@ -1160,6 +1168,8 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
 
   FastMathFlags getFastMathFlags() const;
 
+  bool isNonNeg() const { return NonNegFlags.NonNeg; }
+
   bool hasNoUnsignedWrap() const {
     assert(OpType == OperationType::OverflowingBinOp &&
            "recipe doesn't have a NUW flag");
@@ -1448,11 +1458,22 @@ class VPWidenRecipe : public VPRecipeWithIRFlags {
                 iterator_range<IterT> Operands)
       : VPRecipeWithIRFlags(VPDefOpcode, Operands, I), Opcode(I.getOpcode()) {}
 
+  template <typename IterT>
+  VPWidenRecipe(unsigned VPDefOpcode, unsigned Opcode,
+                iterator_range<IterT> Operands, bool NUW, bool NSW, DebugLoc DL)
+      : VPRecipeWithIRFlags(VPDefOpcode, Operands, WrapFlagsTy(NUW, NSW), DL),
+        Opcode(Opcode) {}
+
 public:
   template <typename IterT>
   VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands)
       : VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
 
+  template <typename IterT>
+  VPWidenRecipe(unsigned Opcode, iterator_range<IterT> Operands, bool NUW,
+                bool NSW, DebugLoc DL)
+      : VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands, NUW, NSW, DL) {}
+
   ~VPWidenRecipe() override = default;
 
   VPWidenRecipe *clone() override {
@@ -1553,10 +1574,17 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags {
            "opcode of underlying cast doesn't match");
   }
 
-  VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy)
-      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), Opcode(Opcode),
+  VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
+                    DebugLoc DL = {})
+      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), Opcode(Opcode),
         ResultTy(ResultTy) {}
 
+  VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
+                    bool IsNonNeg, DebugLoc DL)
+      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
+                            DL),
+        Opcode(Opcode), ResultTy(ResultTy) {}
+
   ~VPWidenCastRecipe() override = default;
 
   VPWidenCastRecipe *clone() override {
@@ -2702,26 +2730,29 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
 /// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
 /// {ChainOp, VecOp, [Condition]}.
 class VPExtendedReductionRecipe : public VPReductionRecipe {
-  CastInst *ExtInstr;
+  Instruction::CastOps ExtOp;
+  DebugLoc ExtDL;
+  /// Non-negative flag for the extended instruction.
+  bool IsNonNeg;
 
 protected:
   VPExtendedReductionRecipe(const unsigned char SC,
                             const RecurrenceDescriptor &R, Instruction *RedI,
-                            Instruction::CastOps ExtOp, CastInst *ExtInstr,
-                            VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
-                            bool IsOrdered)
+                            Instruction::CastOps ExtOp, DebugLoc ExtDL,
+                            bool IsNonNeg, VPValue *ChainOp, VPValue *VecOp,
+                            VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
                           CondOp, IsOrdered),
-        ExtInstr(ExtInstr) {}
+        ExtOp(ExtOp), ExtDL(ExtDL), IsNonNeg(IsNonNeg) {}
 
 public:
   VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                             VPValue *ChainOp, VPWidenCastRecipe *Ext,
                             VPValue *CondOp, bool IsOrdered)
-      : VPExtendedReductionRecipe(
-            VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
-            cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
-            Ext->getOperand(0), CondOp, IsOrdered) {}
+      : VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
+                                  Ext->getOpcode(), Ext->getDebugLoc(),
+                                  Ext->isNonNeg(), ChainOp, Ext->getOperand(0),
+                                  CondOp, IsOrdered) {}
 
   ~VPExtendedReductionRecipe() override = default;
 
@@ -2754,10 +2785,12 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
   bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
 
   /// The Opcode of extend instruction.
-  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
+
+  bool getNonNegFlags() const { return IsNonNeg; }
 
-  /// The CastInst of the extend instruction.
-  CastInst *getExtInstr() const { return ExtInstr; }
+  /// Return the debug location of the extend instruction.
+  DebugLoc getExtDebugLoc() const { return ExtDL; }
 };
 
 /// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2767,36 +2800,47 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccRecipe : public VPReductionRecipe {
 
-  // reduce.add(mul(ext0, ext1)).
-  Instruction *MulInstr;
-  CastInst *Ext0Instr = nullptr;
-  CastInst *Ext1Instr = nullptr;
+  // WrapFlags (NoUnsignedWraps, NoSignedWraps)
+  bool NUW;
+  bool NSW;
+
+  // Debug location of the mul instruction.
+  DebugLoc MulDL;
+  Instruction::CastOps ExtOp;
+
+  // Non-neg flag for the extend instruction.
+  bool IsNonNeg;
+
+  // Debug location of extend instruction.
+  DebugLoc Ext0DL;
+  DebugLoc Ext1DL;
+
+  // Is this mul-acc recipe contains extend recipes?
+  bool IsExtended = false;
 
 protected:
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *MulInstr,
-                 Instruction *Ext0Instr, Instruction *Ext1Instr,
-                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
-                 VPValue *CondOp, bool IsOrdered)
+                 Instruction *RedI, bool NUW, bool NSW, DebugLoc MulDL,
+                 Instruction::CastOps ExtOp, bool IsNonNeg, DebugLoc Ext0DL,
+                 DebugLoc Ext1DL, VPValue *ChainOp, VPValue *VecOp0,
+                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered),
-        MulInstr(MulInstr), Ext0Instr(cast<CastInst>(Ext0Instr)),
-        Ext1Instr(cast<CastInst>(Ext1Instr)) {
-    assert(MulInstr->getOpcode() == Instruction::Mul);
+        NUW(NUW), NSW(NSW), MulDL(MulDL), ExtOp(ExtOp), IsNonNeg(IsNonNeg),
+        Ext0DL(Ext0DL), Ext1DL(Ext1DL) {
     assert(R.getOpcode() == Instruction::Add);
-    assert(Ext0Instr->getOpcode() == Ext1Instr->getOpcode());
+    IsExtended = true;
   }
 
   VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
-                 Instruction *RedI, Instruction *MulInstr, VPValue *ChainOp,
-                 VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp,
-                 bool IsOrdered)
+                 Instruction *RedI, bool NUW, bool NSW, DebugLoc MulDL,
+                 VPValue *ChainOp, VPValue *VecOp0, VPValue *VecOp1,
+                 VPValue *CondOp, bool IsOrdered)
       : VPReductionRecipe(SC, R, RedI,
                           ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
                           CondOp, IsOrdered),
-        MulInstr(MulInstr) {
-    assert(MulInstr->getOpcode() == Instruction::Mul);
+        NUW(NUW), NSW(NSW), MulDL(MulDL) {
     assert(R.getOpcode() == Instruction::Add);
   }
 
@@ -2805,16 +2849,18 @@ class VPMulAccRecipe : public VPReductionRecipe {
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
                  VPWidenCastRecipe *Ext1)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
-                       ChainOp, Ext0->getOperand(0), Ext1->getOperand(0),
-                       CondOp, IsOrdered) {}
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->hasNoUnsignedWrap(),
+                       Mul->hasNoSignedWrap(), Mul->getDebugLoc(),
+                       Ext0->getOpcode(), Ext0->isNonNeg(), Ext0->getDebugLoc(),
+                       Ext1->getDebugLoc(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered) {}
 
   VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
                  VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
                  VPWidenRecipe *Mul)
-      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
-                       ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->hasNoUnsignedWrap(),
+                       Mul->hasNoSignedWrap(), Mul->getDebugLoc(), ChainOp,
+                       Mul->getOperand(0), Mul->getOperand(1), CondOp,
                        IsOrdered) {}
 
   ~VPMulAccRecipe() override = default;
@@ -2842,29 +2888,36 @@ class VPMulAccRecipe : public VPReductionRecipe {
   VPValue *getVecOp0() const { return getOperand(1); }
   VPValue *getVecOp1() const { return getOperand(2); }
 
-  /// The underlying instruction for VPWidenRecipe.
-  Instruction *getMulInstr() const { return MulInstr; }
-
-  /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt0Instr() const { return Ext0Instr; }
-
-  /// The underlying Instruction for inner VPWidenCastRecipe.
-  CastInst *getExt1Instr() const { return Ext1Instr; }
-
   /// Return if this MulAcc recipe contains extend instructions.
-  bool isExtended() const { return Ext0Instr && Ext1Instr; }
+  bool isExtended() const { return IsExtended; }
 
   /// Return if the operands of mul instruction come from same extend.
   bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
 
-  Instruction::CastOps getExtOpcode() const { return Ext0Instr->getOpcode(); }
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
 
   /// Return if the extend opcode is ZExt.
   bool isZExt() const {
     if (!isExtended())
       return true;
-    return Ext0Instr->getOpcode() == Instruction::CastOps::ZExt;
+    return ExtOp == Instruction::CastOps::ZExt;
   }
+
+  /// Return the non-neg flag.
+  bool getNonNegFlags() const { return IsNonNeg; }
+
+  /// Return no-unsigned-wrap.
+  bool hasNoUnsignedWrap() const { return NUW; }
+
+  /// Return no-signed-warp.
+  bool hasNoSignedWrap() const { return NSW; }
+
+  /// Return debug location of mul instruction.
+  DebugLoc getMulDebugLoc() const { return MulDL; }
+
+  /// Return debug location of extend instructions.
+  DebugLoc getExt0DebugLoc() const { return Ext0DL; }
+  DebugLoc getExt1DebugLoc() const { return Ext1DL; }
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ad3771fcaa142d..39beced3a1e3ba 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2297,28 +2297,23 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
   if (isExtended()) {
     TTI::CastContextHint CCH0 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost = Ctx.TTI.getCastInstrCost(
-        Ext0Instr->getOpcode(), VectorTy, SrcVecTy, CCH0,
-        TTI::TCK_RecipThroughput,
-        dyn_cast_if_present<Instruction>(getExt0Instr()));
+    ExtendedCost = Ctx.TTI.getCastInstrCost(ExtOp, VectorTy, SrcVecTy, CCH0,
+                                            TTI::TCK_RecipThroughput);
     TTI::CastContextHint CCH1 =
         computeCCH(getVecOp0()->getDefiningRecipe(), VF);
-    ExtendedCost += Ctx.TTI.getCastInstrCost(
-        Ext1Instr->getOpcode(), VectorTy, SrcVecTy, CCH1,
-        TTI::TCK_RecipThroughput,
-        dyn_cast_if_present<Instruction>(getExt1Instr()));
+    ExtendedCost += Ctx.TTI.getCastInstrCost(ExtOp, VectorTy, SrcVecTy, CCH1,
+                                             TTI::TCK_RecipThroughput);
   }
 
   // Mul cost
   InstructionCost MulCost;
   SmallVector<const Value *, 4> Operands;
-  Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
   if (isExtended())
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        Operands, MulInstr, &Ctx.TLI);
+        Operands, nullptr, &Ctx.TLI);
   else {
     VPValue *RHS = getVecOp1();
     // Certain instructions can be cheaper to vectorize if they have a constant
@@ -2331,10 +2326,12 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
     if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
         RHS->isDefinedOutsideLoopRegions())
       RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
+    Operands.append(
+        {getVecOp0()->getUnderlyingValue(), RHS->getUnderlyingValue()});
     MulCost = Ctx.TTI.getArithmeticInstrCost(
         Instruction::Mul, VectorTy, CostKind,
         {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        RHSInfo, Operands, MulInstr, &Ctx.TLI);
+        RHSInfo, Operands, nullptr, &Ctx.TLI);
   }
 
   // Check if folding ext into ExtendedReduction is profitable.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index bae8aabc6f9692..733c144c7aa359 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -527,9 +527,17 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
       if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
         // Genearte VPWidenCastRecipe.
-        auto *Ext = new VPWidenCastRecipe(
-            ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
-            *ExtRed->getExtInstr());
+        VPWidenCastRecipe *Ext;
+        // Only ZExt contiains non-neg flags.
+        if (ExtRed->isZExt())
+          Ext = new VPWidenCastRecipe(
+              ExtRed->getExtOpcode(), ExtRed->getVecOp(),
+              ExtRed->getResultType(), ExtRed->getNonNegFlags(),
+              ExtRed->getExtDebugLoc());
+        else
+          Ext = new VPWidenCastRecipe(
+              ExtRed->getExtOpcode(), ExtRed->getVecOp(),
+              ExtRed->getResultType(), ExtRed->getExtDebugLoc());
 
         // Generate VPreductionRecipe.
         auto *Red = new VPReductionRecipe(
@@ -549,18 +557,28 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
         // reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
         VPValue *Op0, *Op1;
         if (MulAcc->isExtended()) {
-          CastInst *Ext0 = MulAcc->getExt0Instr();
-          Op0 = new VPWidenCastRecipe(Ext0->getOpcode(), MulAcc->getVecOp0(),
-                                      RedTy, *Ext0);
+          if (MulAcc->isZExt())
+            Op0 = new VPWidenCastRecipe(
+                MulAcc->getExtOpcode(), MulAcc->getVecOp0(), RedTy,
+                MulAcc->getNonNegFlags(), MulAcc->getExt0DebugLoc());
+          else
+            Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+                                        MulAcc->getVecOp0(), RedTy,
+                                        MulAcc->getExt0DebugLoc());
           Op0->getDefiningRecipe()->insertBefore(MulAcc);
           // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
           // VPWidenCastRecipe.
           if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
             Op1 = Op0;
           } else {
-            CastInst *Ext1 = MulAcc->getExt1Instr();
-            Op1 = new VPWidenCastRecipe(Ext1->getOpcode(), MulAcc->getVecOp1(),
-                                        RedTy, *Ext1);
+            if (MulAcc->isZExt())
+              Op1 = new VPWidenCastRecipe(
+                  MulAcc->getExtOpcode(), MulAcc->getVecOp1(), RedTy,
+                  MulAcc->getNonNegFlags(), MulAcc->getExt1DebugLoc());
+            else
+              Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
+                                          MulAcc->getVecOp1(), RedTy,
+                                          MulAcc->getExt1DebugLoc());
             Op1->getDefiningRecipe()->insertBefore(MulAcc);
           }
           // No extends in this MulAccRecipe.
@@ -571,8 +589,10 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
 
         // Generate VPWidenRecipe.
         SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
-        auto *Mul = new VPWidenRecipe(*MulAcc->getMulInstr(),
-                                      make_range(MulOps.begin(), MulOps.end()));
+        auto *Mul = new VPWidenRecipe(
+            Instruction::Mul, make_range(MulOps.begin(), MulOps.end()),
+            MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap(),
+            MulAcc->getMulDebugLoc());
         Mul->insertBefore(MulAcc);
 
         // Generate VPReductionRecipe.

>From 52369d05b9fb5ff6bed4446f28988060dee08455 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Sun, 24 Nov 2024 23:28:17 -0800
Subject: [PATCH 16/17] Update testcase and fix reduction cost.

TTI should model the cost of move result of reduction to the scalar
register and add/sub/cmp... to the scalar value. We don't need to add
the cost of binOp in the vplan-based cost model.
---
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 28 +++++++------------
 .../AArch64/sve-strict-fadd-cost.ll           | 12 ++++----
 .../LoopVectorize/vplan-printing.ll           | 12 ++++----
 3 files changed, 22 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 39beced3a1e3ba..2e35c4e738a156 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2202,20 +2202,16 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
          "Inferred type and recurrence type mismatch.");
 
-  // BaseCost = Reduction cost + BinOp cost
-  InstructionCost BaseCost =
-      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  // Note that TTI should model the cost of moving result to the scalar register
+  // and the binOp cost in the getReductionCost().
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    BaseCost += Ctx.TTI.getMinMaxReductionCost(
-        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
-  } else {
-    BaseCost += Ctx.TTI.getArithmeticReductionCost(
-        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+    return Ctx.TTI.getMinMaxReductionCost(Id, VectorTy,
+                                          RdxDesc.getFastMathFlags(), CostKind);
   }
 
-  // Default cost.
-  return BaseCost;
+  return Ctx.TTI.getArithmeticReductionCost(
+      Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
 }
 
 InstructionCost
@@ -2238,15 +2234,13 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
   assert(ExtendedRedCost.isValid() && "VPExtendedReductionRecipe should not be "
                                       "created if the cost is invalid.");
 
-  // BaseCost = Reduction cost + BinOp cost
-  InstructionCost ReductionCost =
-      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  InstructionCost ReductionCost;
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    ReductionCost += Ctx.TTI.getMinMaxReductionCost(
+    ReductionCost = Ctx.TTI.getMinMaxReductionCost(
         Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   } else {
-    ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+    ReductionCost = Ctx.TTI.getArithmeticReductionCost(
         Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
@@ -2287,9 +2281,7 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
                                  "created if the cost is invalid.");
 
   // BaseCost = Reduction cost + BinOp cost
-  InstructionCost ReductionCost =
-      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
-  ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+  InstructionCost ReductionCost = Ctx.TTI.getArithmeticReductionCost(
       Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
 
   // Extended cost
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-strict-fadd-cost.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-strict-fadd-cost.ll
index 0efdf077dca669..4208773e947349 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-strict-fadd-cost.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-strict-fadd-cost.ll
@@ -12,14 +12,14 @@ target triple="aarch64-unknown-linux-gnu"
 
 ; CHECK-VSCALE2-LABEL: LV: Checking a loop in 'fadd_strict32'
 ; CHECK-VSCALE2: Cost of 4 for VF vscale x 2:
-; CHECK-VSCALE2:  in-loop reduction   %add = fadd float %0, %sum.07
+; CHECK-VSCALE2:  REDUCE ir<%add> = ir<%sum.07> + reduce.fadd (ir<%0>)
 ; CHECK-VSCALE2: Cost of 8 for VF vscale x 4:
-; CHECK-VSCALE2:  in-loop reduction   %add = fadd float %0, %sum.07
+; CHECK-VSCALE2:  REDUCE ir<%add> = ir<%sum.07> + reduce.fadd (ir<%0>)
 ; CHECK-VSCALE1-LABEL: LV: Checking a loop in 'fadd_strict32'
 ; CHECK-VSCALE1: Cost of 2 for VF vscale x 2:
-; CHECK-VSCALE1:  in-loop reduction   %add = fadd float %0, %sum.07
+; CHECK-VSCALE1:  REDUCE ir<%add> = ir<%sum.07> + reduce.fadd (ir<%0>)
 ; CHECK-VSCALE1: Cost of 4 for VF vscale x 4:
-; CHECK-VSCALE1:  in-loop reduction   %add = fadd float %0, %sum.07
+; CHECK-VSCALE1:  REDUCE ir<%add> = ir<%sum.07> + reduce.fadd (ir<%0>)
 
 define float @fadd_strict32(ptr noalias nocapture readonly %a, i64 %n) #0 {
 entry:
@@ -42,10 +42,10 @@ for.end:
 
 ; CHECK-VSCALE2-LABEL: LV: Checking a loop in 'fadd_strict64'
 ; CHECK-VSCALE2: Cost of 4 for VF vscale x 2:
-; CHECK-VSCALE2:  in-loop reduction   %add = fadd double %0, %sum.07
+; CHECK-VSCALE2:  REDUCE ir<%add> = ir<%sum.07> + reduce.fadd (ir<%0>)
 ; CHECK-VSCALE1-LABEL: LV: Checking a loop in 'fadd_strict64'
 ; CHECK-VSCALE1: Cost of 2 for VF vscale x 2:
-; CHECK-VSCALE1:  in-loop reduction   %add = fadd double %0, %sum.07
+; CHECK-VSCALE1:  REDUCE ir<%add> = ir<%sum.07> + reduce.fadd (ir<%0>)
 
 define double @fadd_strict64(ptr noalias nocapture readonly %a, i64 %n) #0 {
 entry:
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 95ce514d8a92ff..a80723bd35b52a 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -1176,7 +1176,7 @@ define i64 @print_extended_reduction(ptr nocapture readonly %x, ptr nocapture re
 ; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
-; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%7>)
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%7> from middle.block)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: scalar.ph:
@@ -1185,7 +1185,7 @@ define i64 @print_extended_reduction(ptr nocapture readonly %x, ptr nocapture re
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.body>:
 ; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
-; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
 ; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i32, ptr %x, i32 %i.010
 ; CHECK-NEXT:   IR   %load0 = load i32, ptr %arrayidx, align 4
 ; CHECK-NEXT:   IR   %conv0 = zext i32 %load0 to i64
@@ -1251,7 +1251,7 @@ define i64 @print_mulacc(ptr nocapture readonly %x, ptr nocapture readonly %y, i
 ; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
-; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8> from middle.block)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: scalar.ph:
@@ -1260,7 +1260,7 @@ define i64 @print_mulacc(ptr nocapture readonly %x, ptr nocapture readonly %y, i
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.body>:
 ; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
-; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
 ; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i64, ptr %x, i32 %i.010
 ; CHECK-NEXT:   IR   %load0 = load i64, ptr %arrayidx, align 4
 ; CHECK-NEXT:   IR   %arrayidx1 = getelementptr inbounds i64, ptr %y, i32 %i.010
@@ -1330,7 +1330,7 @@ define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture reado
 ; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
-; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8>)
+; CHECK-NEXT:   IR   %add.lcssa = phi i64 [ %add, %for.body ] (extra operand: vp<%8> from middle.block)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: scalar.ph:
@@ -1339,7 +1339,7 @@ define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture reado
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.body>:
 ; CHECK-NEXT:   IR   %i.010 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
-; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx>)
+; CHECK-NEXT:   IR   %r.09 = phi i64 [ %add, %for.body ], [ 0, %for.body.preheader ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
 ; CHECK-NEXT:   IR   %arrayidx = getelementptr inbounds i16, ptr %x, i32 %i.010
 ; CHECK-NEXT:   IR   %load0 = load i16, ptr %arrayidx, align 4
 ; CHECK-NEXT:   IR   %arrayidx1 = getelementptr inbounds i16, ptr %y, i32 %i.010

>From abc08f39dedbcd31dd0689380c828b84087e36e8 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Wed, 4 Dec 2024 16:52:14 -0800
Subject: [PATCH 17/17] !fixup. Rebase to upstream `prepareToExecute()`
 implementation.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |   2 -
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 180 +++++++++---------
 2 files changed, 89 insertions(+), 93 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index fbb39c659c9175..02b3f1001fd276 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7666,8 +7666,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
                              CanonicalIVStartValue, State);
   VPlanTransforms::prepareToExecute(BestVPlan);
 
-  // TODO: Replace with upstream implementation.
-  VPlanTransforms::prepareExecute(BestVPlan);
   BestVPlan.execute(&State);
 
   // 2.5 Collect reduction resume values.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 733c144c7aa359..d4e84ecc46818d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -519,96 +519,6 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
   }
 }
 
-void VPlanTransforms::prepareExecute(VPlan &Plan) {
-  ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
-      Plan.getVectorLoopRegion());
-  for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
-           vp_depth_first_deep(Plan.getEntry()))) {
-    for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
-      if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
-        // Genearte VPWidenCastRecipe.
-        VPWidenCastRecipe *Ext;
-        // Only ZExt contiains non-neg flags.
-        if (ExtRed->isZExt())
-          Ext = new VPWidenCastRecipe(
-              ExtRed->getExtOpcode(), ExtRed->getVecOp(),
-              ExtRed->getResultType(), ExtRed->getNonNegFlags(),
-              ExtRed->getExtDebugLoc());
-        else
-          Ext = new VPWidenCastRecipe(
-              ExtRed->getExtOpcode(), ExtRed->getVecOp(),
-              ExtRed->getResultType(), ExtRed->getExtDebugLoc());
-
-        // Generate VPreductionRecipe.
-        auto *Red = new VPReductionRecipe(
-            ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
-            ExtRed->getChainOp(), Ext, ExtRed->getCondOp(),
-            ExtRed->isOrdered());
-        Ext->insertBefore(ExtRed);
-        Red->insertBefore(ExtRed);
-        ExtRed->replaceAllUsesWith(Red);
-        ExtRed->eraseFromParent();
-      } else if (isa<VPMulAccRecipe>(&R)) {
-        auto *MulAcc = cast<VPMulAccRecipe>(&R);
-        Type *RedTy = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
-
-        // Generate inner VPWidenCastRecipes if necessary.
-        // Note that we will drop the extend after mul which transform
-        // reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
-        VPValue *Op0, *Op1;
-        if (MulAcc->isExtended()) {
-          if (MulAcc->isZExt())
-            Op0 = new VPWidenCastRecipe(
-                MulAcc->getExtOpcode(), MulAcc->getVecOp0(), RedTy,
-                MulAcc->getNonNegFlags(), MulAcc->getExt0DebugLoc());
-          else
-            Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
-                                        MulAcc->getVecOp0(), RedTy,
-                                        MulAcc->getExt0DebugLoc());
-          Op0->getDefiningRecipe()->insertBefore(MulAcc);
-          // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
-          // VPWidenCastRecipe.
-          if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
-            Op1 = Op0;
-          } else {
-            if (MulAcc->isZExt())
-              Op1 = new VPWidenCastRecipe(
-                  MulAcc->getExtOpcode(), MulAcc->getVecOp1(), RedTy,
-                  MulAcc->getNonNegFlags(), MulAcc->getExt1DebugLoc());
-            else
-              Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(),
-                                          MulAcc->getVecOp1(), RedTy,
-                                          MulAcc->getExt1DebugLoc());
-            Op1->getDefiningRecipe()->insertBefore(MulAcc);
-          }
-          // No extends in this MulAccRecipe.
-        } else {
-          Op0 = MulAcc->getVecOp0();
-          Op1 = MulAcc->getVecOp1();
-        }
-
-        // Generate VPWidenRecipe.
-        SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
-        auto *Mul = new VPWidenRecipe(
-            Instruction::Mul, make_range(MulOps.begin(), MulOps.end()),
-            MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap(),
-            MulAcc->getMulDebugLoc());
-        Mul->insertBefore(MulAcc);
-
-        // Generate VPReductionRecipe.
-        auto *Red = new VPReductionRecipe(
-            MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
-            MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
-            MulAcc->isOrdered());
-        Red->insertBefore(MulAcc);
-
-        MulAcc->replaceAllUsesWith(Red);
-        MulAcc->eraseFromParent();
-      }
-    }
-  }
-}
-
 static VPScalarIVStepsRecipe *
 createScalarIVSteps(VPlan &Plan, InductionDescriptor::InductionKind Kind,
                     Instruction::BinaryOps InductionOpcode,
@@ -1910,12 +1820,100 @@ void VPlanTransforms::createInterleaveGroups(
   }
 }
 
+// Expand VPExtendedReductionRecipe to VPWidenCastRecipe + VPReductionRecipe.
+static void expandVPExtendedReduction(VPExtendedReductionRecipe *ExtRed) {
+  // Genearte VPWidenCastRecipe.
+  VPWidenCastRecipe *Ext;
+  // Only ZExt contiains non-neg flags.
+  if (ExtRed->isZExt())
+    Ext = new VPWidenCastRecipe(
+        ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
+        ExtRed->getNonNegFlags(), ExtRed->getExtDebugLoc());
+  else
+    Ext = new VPWidenCastRecipe(ExtRed->getExtOpcode(), ExtRed->getVecOp(),
+                                ExtRed->getResultType(),
+                                ExtRed->getExtDebugLoc());
+
+  // Generate VPreductionRecipe.
+  auto *Red = new VPReductionRecipe(
+      ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
+      ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
+  Ext->insertBefore(ExtRed);
+  Red->insertBefore(ExtRed);
+  ExtRed->replaceAllUsesWith(Red);
+  ExtRed->eraseFromParent();
+}
+
+// Expand VPMulAccRecipe to VPWidenRecipe (mul) + VPReductionRecipe (reduce.add)
+// + VPWidenCastRecipe (optional).
+static void expandVPMulAcc(VPMulAccRecipe *MulAcc) {
+  Type *RedTy = MulAcc->getRecurrenceDescriptor().getRecurrenceType();
+
+  // Generate inner VPWidenCastRecipes if necessary.
+  // Note that we will drop the extend after mul which transform
+  // reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
+  VPValue *Op0, *Op1;
+  if (MulAcc->isExtended()) {
+    if (MulAcc->isZExt())
+      Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+                                  RedTy, MulAcc->getNonNegFlags(),
+                                  MulAcc->getExt0DebugLoc());
+    else
+      Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+                                  RedTy, MulAcc->getExt0DebugLoc());
+    Op0->getDefiningRecipe()->insertBefore(MulAcc);
+    // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
+    // VPWidenCastRecipe.
+    if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
+      Op1 = Op0;
+    } else {
+      if (MulAcc->isZExt())
+        Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+                                    RedTy, MulAcc->getNonNegFlags(),
+                                    MulAcc->getExt1DebugLoc());
+      else
+        Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
+                                    RedTy, MulAcc->getExt1DebugLoc());
+      Op1->getDefiningRecipe()->insertBefore(MulAcc);
+    }
+    // No extends in this MulAccRecipe.
+  } else {
+    Op0 = MulAcc->getVecOp0();
+    Op1 = MulAcc->getVecOp1();
+  }
+
+  // Generate VPWidenRecipe.
+  SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
+  auto *Mul = new VPWidenRecipe(
+      Instruction::Mul, make_range(MulOps.begin(), MulOps.end()),
+      MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap(),
+      MulAcc->getMulDebugLoc());
+  Mul->insertBefore(MulAcc);
+
+  // Generate VPReductionRecipe.
+  auto *Red = new VPReductionRecipe(
+      MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
+      MulAcc->getChainOp(), Mul, MulAcc->getCondOp(), MulAcc->isOrdered());
+  Red->insertBefore(MulAcc);
+
+  MulAcc->replaceAllUsesWith(Red);
+  MulAcc->eraseFromParent();
+}
+
 void VPlanTransforms::prepareToExecute(VPlan &Plan) {
   ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
       Plan.getVectorLoopRegion());
   for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
            vp_depth_first_deep(Plan.getEntry()))) {
-    for (VPRecipeBase &R : make_early_inc_range(VPBB->phis())) {
+    for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
+      if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
+        expandVPExtendedReduction(ExtRed);
+        continue;
+      }
+      if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
+        expandVPMulAcc(MulAcc);
+        continue;
+      }
       if (!isa<VPCanonicalIVPHIRecipe, VPEVLBasedIVPHIRecipe>(&R))
         continue;
       auto *PhiR = cast<VPHeaderPHIRecipe>(&R);



More information about the llvm-commits mailing list