[llvm] [LoopVectorizer] Add support for partial reductions (PR #92418)

Nicholas Guy via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 4 08:10:22 PDT 2024


https://github.com/NickGuy-Arm updated https://github.com/llvm/llvm-project/pull/92418

>From 0bf0a41250186d78b8f111b2ddc0f1ef9de8d1f8 Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Fri, 17 May 2024 11:15:11 +0100
Subject: [PATCH 01/10] [NFC] Test pre-commit

---
 .../CodeGen/AArch64/partial-reduce-sdot.ll    | 99 +++++++++++++++++++
 1 file changed, 99 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll

diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
new file mode 100644
index 0000000000000..fc6e3239a1b43
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes=loop-vectorize -force-vector-interleave=1 -S < %s | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define void @dotp(ptr %a, ptr %b) #0 {
+; CHECK-LABEL: define void @dotp(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 16
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 0, [[TMP1]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 16
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 0, [[TMP3]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 0, [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 16
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 16 x i8>, ptr [[TMP17]], align 1
+; CHECK-NEXT:    [[TMP19:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD2]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr i8, ptr [[TMP21]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
+; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP29:%.*]] = mul <vscale x 16 x i32> [[TMP27]], [[TMP19]]
+; CHECK-NEXT:    [[TMP14]] = add <vscale x 16 x i32> [[TMP29]], [[VEC_PHI]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP32]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP33:%.*]] = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP14]])
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+; CHECK:       scalar.ph:
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.cond.cleanup.loopexit:
+; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[TMP20:%.*]] = lshr i32 [[ADD_LCSSA]], 0
+; CHECK-NEXT:    ret void
+; CHECK:       for.body:
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ], [ [[ADD]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP18:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
+; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP18]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP22:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP22]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[CONV3]], [[CONV]]
+; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
+; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
+; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], 0
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+
+; uselistorder directives
+  uselistorder i32 %add, { 1, 0 }
+}
+
+attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
+;.
+; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
+; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
+; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
+; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
+;.

>From 6103cb3218c3fb4f634e40f34fddafeabf949cfd Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Fri, 17 May 2024 11:17:26 +0100
Subject: [PATCH 02/10] [LoopVectorizer] Add support for partial reductions

---
 llvm/include/llvm/IR/DerivedTypes.h           |  10 ++
 llvm/include/llvm/IR/Intrinsics.h             |   5 +-
 llvm/include/llvm/IR/Intrinsics.td            |  10 ++
 llvm/lib/IR/Function.cpp                      |  16 +++
 .../Transforms/Vectorize/LoopVectorize.cpp    | 122 ++++++++++++++++++
 .../Transforms/Vectorize/VPRecipeBuilder.h    |   2 +
 llvm/lib/Transforms/Vectorize/VPlan.h         |  43 +++++-
 .../Transforms/Vectorize/VPlanAnalysis.cpp    |   6 +-
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h |   2 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  80 +++++++++++-
 llvm/lib/Transforms/Vectorize/VPlanValue.h    |   1 +
 .../CodeGen/AArch64/partial-reduce-sdot.ll    |   7 +-
 12 files changed, 291 insertions(+), 13 deletions(-)

diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index 443fb7de3b821..866a01c9afebd 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -512,6 +512,16 @@ class VectorType : public Type {
                            EltCnt.divideCoefficientBy(2));
   }
 
+  /// This static method returns a VectorType with quarter as many elements as the
+  /// input type and the same element type.
+  static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
+    auto EltCnt = VTy->getElementCount();
+    assert(EltCnt.isKnownEven() &&
+           "Cannot halve vector with odd number of elements.");
+    return VectorType::get(VTy->getElementType(),
+                           EltCnt.divideCoefficientBy(4));
+  }
+
   /// This static method returns a VectorType with twice as many elements as the
   /// input type and the same element type.
   static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index f79df522dc805..d0c4bee59e889 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -135,6 +135,7 @@ namespace Intrinsic {
       ExtendArgument,
       TruncArgument,
       HalfVecArgument,
+      QuarterVecArgument,
       SameVecWidthArgument,
       VecOfAnyPtrsToElt,
       VecElementArgument,
@@ -164,7 +165,7 @@ namespace Intrinsic {
 
     unsigned getArgumentNumber() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument || Kind == VecElementArgument ||
              Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
              Kind == VecOfBitcastsToInt);
@@ -172,7 +173,7 @@ namespace Intrinsic {
     }
     ArgKind getArgumentKind() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument ||
              Kind == VecElementArgument || Kind == Subdivide2Argument ||
              Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 95dbd2854322d..98c5221ade2c8 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -321,6 +321,7 @@ def IIT_I4 : IIT_Int<4, 58>;
 def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
 def IIT_V6 : IIT_Vec<6, 60>;
 def IIT_V10 : IIT_Vec<10, 61>;
+def IIT_QUARTER_VEC_ARG : IIT_Base<62>;
 }
 
 defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
@@ -457,6 +458,9 @@ class LLVMVectorElementType<int num> : LLVMMatchType<num, IIT_VEC_ELEMENT>;
 class LLVMHalfElementsVectorType<int num>
   : LLVMMatchType<num, IIT_HALF_VEC_ARG>;
 
+class LLVMQuarterElementsVectorType<int num>
+  : LLVMMatchType<num, IIT_QUARTER_VEC_ARG>;
+
 // Match the type of another intrinsic parameter that is expected to be a
 // vector type (i.e. <N x iM>) but with each element subdivided to
 // form a vector with more elements that are smaller than the original.
@@ -2646,6 +2650,12 @@ def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatc
                                                                        [llvm_anyvector_ty, llvm_anyvector_ty],
                                                                        [IntrNoMem]>;
 
+//===-------------- Intrinsics to perform partial reduction ---------------===//
+
+def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMQuarterElementsVectorType<0>],
+                                                                       [llvm_anyvector_ty],
+                                                                       [IntrNoMem]>;
+
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
 
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 5fb348a8bbcd4..a1ce95c2c92bb 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1259,6 +1259,12 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
                                              ArgInfo));
     return;
   }
+  case IIT_QUARTER_VEC_ARG: {
+    unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
+    OutputTable.push_back(IITDescriptor::get(IITDescriptor::QuarterVecArgument,
+                                             ArgInfo));
+    return;
+  }
   case IIT_SAME_VEC_WIDTH_ARG: {
     unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
     OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument,
@@ -1423,6 +1429,9 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
   case IITDescriptor::HalfVecArgument:
     return VectorType::getHalfElementsVectorType(cast<VectorType>(
                                                   Tys[D.getArgumentNumber()]));
+  case IITDescriptor::QuarterVecArgument:  {
+    return VectorType::getQuarterElementsVectorType(cast<VectorType>(Tys[D.getArgumentNumber()]));
+  }
   case IITDescriptor::SameVecWidthArgument: {
     Type *EltTy = DecodeFixedType(Infos, Tys, Context);
     Type *Ty = Tys[D.getArgumentNumber()];
@@ -1650,6 +1659,13 @@ static bool matchIntrinsicType(
       return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
              VectorType::getHalfElementsVectorType(
                      cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    case IITDescriptor::QuarterVecArgument: {
+    if (D.getArgumentNumber() >= ArgTys.size())
+        return IsDeferredCheck || DeferCheck(Ty);
+      return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
+             VectorType::getQuarterElementsVectorType(
+                     cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    }
     case IITDescriptor::SameVecWidthArgument: {
       if (D.getArgumentNumber() >= ArgTys.size()) {
         // Defer check and subsequent check for the vector element type.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index d7b0240fd8a81..a184cb3ea0fb4 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2191,6 +2191,92 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
          Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
 }
 
+static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  Chain.push_back(Mul);
+  Chain.push_back(Ext0);
+  Chain.push_back(Ext1);
+  Chain.push_back(Instr->getOperand(1));
+}
+
+
+/// @param Instr The root instruction to scan
+static bool isInstrPartialReduction(Instruction *Instr) {
+  Value *ExpectedPhi;
+  Value *A, *B;
+  Value *InductionA, *InductionB;
+
+  using namespace llvm::PatternMatch;
+  auto Pattern = m_Add(
+    m_OneUse(m_Mul(
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(A),
+              m_Value(InductionA)))))),
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(B),
+              m_Value(InductionB))))))
+        )), m_Value(ExpectedPhi));
+
+  bool Matches = match(Instr, Pattern);
+
+  if(!Matches)
+    return false;
+
+  // Check that the two induction variable uses are to the same induction variable
+  if(InductionA != InductionB) {
+    LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  // Check that the extends extend to i32
+  if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the loads are loading i8
+  LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
+  LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
+  if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
+    LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
+    return false;
+  }
+
+  // Check that the add feeds into ExpectedPhi
+  PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
+  if(!PhiNode) {
+    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the first phi value is a zero initializer
+  ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
+  if(!ZeroInit || !ZeroInit->isZero()) {
+    LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the second phi value is the instruction we're looking at
+  Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
+  if(!MaybeAdd || MaybeAdd != Instr) {
+    LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  return true;
+}
+
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
 // vectorization. The loop needs to be annotated with #pragma omp simd
 // simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -4981,6 +5067,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
         return false;
   }
 
+  // Prevent epilogue vectorization if a partial reduction is involved
+  // TODO Is there a cleaner way to check this?
+  if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
+    return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
+  }))
+    return false;
+
   // Epilogue vectorization code has not been auditted to ensure it handles
   // non-latch exits properly.  It may be fine, but it needs auditted and
   // tested.
@@ -7106,6 +7199,17 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
     const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts();
     VecValuesToIgnore.insert(Casts.begin(), Casts.end());
   }
+
+  // Ignore any values that we know will be flattened
+  for(auto Reduction : this->Legal->getReductionVars()) {
+    auto &Recurrence = Reduction.second;
+    if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
+      SmallVector<Value*, 4> PartialReductionValues;
+      getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
+      ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+      VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+    }
+  }
 }
 
 void LoopVectorizationCostModel::collectInLoopReductions() {
@@ -8465,9 +8569,24 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
                                  *CI);
   }
 
+  if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
+    return PartialReduce;
+
   return tryToWiden(Instr, Operands, VPBB);
 }
 
+VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
+    VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
+
+  if(isInstrPartialReduction(Instr)) {
+    auto EC = ElementCount::getScalable(16);
+    if(std::find(Range.begin(), Range.end(), EC) == Range.end())
+      return nullptr;
+    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
+  }
+  return nullptr;
+}
+
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
                                                         ElementCount MaxVF) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -8677,6 +8796,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         VPBB->appendRecipe(Recipe);
     }
 
+    for(auto &Recipe : *VPBB)
+      Recipe.postInsertionOp();
+
     VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
     VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor());
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index b4c7ab02f928f..c439f221709e1 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -116,6 +116,8 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
+  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, Instruction* Instr, ArrayRef<VPValue*> Operands);
+
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
     assert(!Ingredient2Recipe.contains(I) &&
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index db10c7a240c7e..feef9da7e84f8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -784,6 +784,8 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
   /// \returns an iterator pointing to the element after the erased one
   iplist<VPRecipeBase>::iterator eraseFromParent();
 
+  virtual void postInsertionOp() {}
+
   /// Method to support type inquiry through isa, cast, and dyn_cast.
   static inline bool classof(const VPDef *D) {
     // All VPDefs are also VPRecipeBases.
@@ -1915,14 +1917,19 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   /// The phi is part of an ordered reduction. Requires IsInLoop to be true.
   bool IsOrdered;
 
+  /// The amount that the VF should be divided by during ::execute
+  unsigned VFScaleFactor = 1;
+
 public:
+
   /// Create a new VPReductionPHIRecipe for the reduction \p Phi described by \p
   /// RdxDesc.
   VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc,
                        VPValue &Start, bool IsInLoop = false,
-                       bool IsOrdered = false)
+                       bool IsOrdered = false, unsigned VFScaleFactor = 1)
       : VPHeaderPHIRecipe(VPDef::VPReductionPHISC, Phi, &Start),
-        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered) {
+        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered),
+        VFScaleFactor(VFScaleFactor) {
     assert((!IsOrdered || IsInLoop) && "IsOrdered requires IsInLoop");
   }
 
@@ -1931,7 +1938,7 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   VPReductionPHIRecipe *clone() override {
     auto *R =
         new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc,
-                                 *getOperand(0), IsInLoop, IsOrdered);
+                                 *getOperand(0), IsInLoop, IsOrdered, VFScaleFactor);
     R->addOperand(getBackedgeValue());
     return R;
   }
@@ -1942,6 +1949,10 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
     return R->getVPDefID() == VPDef::VPReductionPHISC;
   }
 
+  void SetVFScaleFactor(unsigned ScaleFactor) {
+    VFScaleFactor = ScaleFactor;
+  }
+
   /// Generate the phi/select nodes.
   void execute(VPTransformState &State) override;
 
@@ -1962,6 +1973,32 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   bool isInLoop() const { return IsInLoop; }
 };
 
+class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
+  unsigned Opcode;
+public:
+  template <typename IterT>
+  VPPartialReductionRecipe(Instruction &I,
+                           iterator_range<IterT> Operands) : VPRecipeWithIRFlags(
+    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode())
+  {}
+  ~VPPartialReductionRecipe() override = default;
+  VPPartialReductionRecipe *clone() override {
+    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands());
+    R->transferFlags(*this);
+    return R;
+  }
+  VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override;
+  void postInsertionOp() override;
+  unsigned getOpcode() { return Opcode; }
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+};
+
 /// A recipe for vectorizing a phi-node as a sequence of mask-based select
 /// instructions.
 class VPBlendRecipe : public VPSingleDefRecipe {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index eca5d1d4c5e1d..f48baff51f290 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -223,6 +223,10 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
   llvm_unreachable("Unhandled opcode");
 }
 
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
+  return R->getUnderlyingInstr()->getType();
+}
+
 Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
   if (Type *CachedTy = CachedTypes.lookup(V))
     return CachedTy;
@@ -253,7 +257,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
             return inferScalarType(R->getOperand(0));
           })
           .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
-                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe>(
+                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe, VPPartialReductionRecipe>(
               [this](const auto *R) { return inferScalarTypeForRecipe(R); })
           .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
             // TODO: Use info from interleave group.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 7d310b1b31b6f..3bd8d24542199 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -23,6 +23,7 @@ class VPWidenIntOrFpInductionRecipe;
 class VPWidenMemoryRecipe;
 struct VPWidenSelectRecipe;
 class VPReplicateRecipe;
+class VPPartialReductionRecipe;
 class Type;
 
 /// An analysis for type-inference for VPValues.
@@ -49,6 +50,7 @@ class VPTypeAnalysis {
   Type *inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R);
   Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R);
   Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPPartialReductionRecipe *R);
 
 public:
   VPTypeAnalysis(Type *CanonicalIVTy, LLVMContext &Ctx)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 788e6c96d32aa..b24762ec5ca68 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -255,6 +255,76 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB,
   insertBefore(BB, I);
 }
 
+void VPPartialReductionRecipe::execute(VPTransformState &State) {
+  State.setDebugLocFrom(getDebugLoc());
+  auto &Builder = State.Builder;
+
+  switch(Opcode) {
+  case Instruction::Add: {
+
+    for (unsigned Part = 0; Part < State.UF; ++Part) {
+      Value* Mul = nullptr;
+      Value* Phi = nullptr;
+      SmallVector<Value*, 2> Ops;
+      for (VPValue *VPOp : operands()) {
+        auto *Op = State.get(VPOp, Part);
+        Ops.push_back(Op);
+        if(isa<PHINode>(Op))
+          Phi = Op;
+        else
+          Mul = Op;
+      }
+
+      assert(Phi && Mul && "Phi and Mul must be set");
+      assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
+
+      ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
+      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), 4);
+
+      Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
+      switch(Opcode) {
+      case Instruction::Add:
+        PartialIntrinsic =
+            Intrinsic::experimental_vector_partial_reduce_add;
+        break;
+      default:
+        llvm_unreachable("Opcode not handled");
+      }
+
+      assert(PartialIntrinsic != Intrinsic::not_intrinsic);
+
+      Value *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, Mul, nullptr, Twine("partial.reduce"));
+      V = Builder.CreateNAryOp(Opcode, {V, Phi});
+      if (auto *VecOp = dyn_cast<Instruction>(V))
+        setFlags(VecOp);
+
+      // Use this vector value for all users of the original instruction.
+      State.set(this, V, Part);
+      State.addMetadata(V, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
+    }
+    break;
+  }
+  default:
+    LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : " << Instruction::getOpcodeName(Opcode));
+    llvm_unreachable("Unhandled instruction!");
+  }
+}
+
+void VPPartialReductionRecipe::postInsertionOp() {
+  cast<VPReductionPHIRecipe>(this->getOperand(1))->SetVFScaleFactor(4);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+  VPSlotTracker &SlotTracker) const {
+  O << Indent << "PARTIAL-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = " << Instruction::getOpcodeName(Opcode);
+  printFlags(O);
+  printOperands(O, SlotTracker);
+}
+#endif
+
 FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const {
   assert(OpType == OperationType::FPMathOp &&
          "recipe doesn't have fast math flags");
@@ -2033,6 +2103,8 @@ void VPFirstOrderRecurrencePHIRecipe::print(raw_ostream &O, const Twine &Indent,
 void VPReductionPHIRecipe::execute(VPTransformState &State) {
   auto &Builder = State.Builder;
 
+  auto VF = State.VF.divideCoefficientBy(VFScaleFactor);
+
   // Reductions do not have to start at zero. They can start with
   // any loop invariant values.
   VPValue *StartVPV = getStartValue();
@@ -2042,9 +2114,9 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
   // Phi nodes have cycles, so we need to vectorize them in two stages. This is
   // stage #1: We create a new vector PHI node with no incoming edges. We'll use
   // this value when we vectorize all of the instructions that use the PHI.
-  bool ScalarPHI = State.VF.isScalar() || IsInLoop;
+  bool ScalarPHI = VF.isScalar() || IsInLoop;
   Type *VecTy = ScalarPHI ? StartV->getType()
-                          : VectorType::get(StartV->getType(), State.VF);
+                          : VectorType::get(StartV->getType(), VF);
 
   BasicBlock *HeaderBB = State.CFG.PrevBB;
   assert(State.CurrentVectorLoop->getHeader() == HeaderBB &&
@@ -2069,14 +2141,14 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
       IRBuilderBase::InsertPointGuard IPBuilder(Builder);
       Builder.SetInsertPoint(VectorPH->getTerminator());
       StartV = Iden =
-          Builder.CreateVectorSplat(State.VF, StartV, "minmax.ident");
+          Builder.CreateVectorSplat(VF, StartV, "minmax.ident");
     }
   } else {
     Iden = RdxDesc.getRecurrenceIdentity(RK, VecTy->getScalarType(),
                                          RdxDesc.getFastMathFlags());
 
     if (!ScalarPHI) {
-      Iden = Builder.CreateVectorSplat(State.VF, Iden);
+      Iden = Builder.CreateVectorSplat(VF, Iden);
       IRBuilderBase::InsertPointGuard IPBuilder(Builder);
       Builder.SetInsertPoint(VectorPH->getTerminator());
       Constant *Zero = Builder.getInt32(0);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 8d945f6f2b8ea..ff34e1ded9f3f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -343,6 +343,7 @@ class VPDef {
     VPInstructionSC,
     VPInterleaveSC,
     VPReductionSC,
+    VPPartialReductionSC,
     VPReplicateSC,
     VPScalarCastSC,
     VPScalarIVStepsSC,
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
index fc6e3239a1b43..1eafd505b199e 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
@@ -22,7 +22,7 @@ define void @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP11]]
 ; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
@@ -33,12 +33,13 @@ define void @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
 ; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
 ; CHECK-NEXT:    [[TMP29:%.*]] = mul <vscale x 16 x i32> [[TMP27]], [[TMP19]]
-; CHECK-NEXT:    [[TMP14]] = add <vscale x 16 x i32> [[TMP29]], [[VEC_PHI]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP29]])
+; CHECK-NEXT:    [[TMP14]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
 ; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP32]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP33:%.*]] = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP14]])
+; CHECK-NEXT:    [[TMP33:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP14]])
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:

>From f90ddb3ff376b10857a551a513fb03d74757c116 Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Thu, 30 May 2024 15:04:55 +0100
Subject: [PATCH 03/10] [LoopVectorizer] Removed 4x restriction from partial
 reduction intrinsic

---
 llvm/include/llvm/IR/DerivedTypes.h              | 10 ----------
 llvm/include/llvm/IR/Intrinsics.h                |  5 ++---
 llvm/include/llvm/IR/Intrinsics.td               |  6 +-----
 llvm/lib/IR/Function.cpp                         | 16 ----------------
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp  |  2 +-
 llvm/lib/Transforms/Vectorize/VPlan.h            |  7 ++++---
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp   |  5 +++--
 llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll |  2 +-
 8 files changed, 12 insertions(+), 41 deletions(-)

diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index 866a01c9afebd..443fb7de3b821 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -512,16 +512,6 @@ class VectorType : public Type {
                            EltCnt.divideCoefficientBy(2));
   }
 
-  /// This static method returns a VectorType with quarter as many elements as the
-  /// input type and the same element type.
-  static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
-    auto EltCnt = VTy->getElementCount();
-    assert(EltCnt.isKnownEven() &&
-           "Cannot halve vector with odd number of elements.");
-    return VectorType::get(VTy->getElementType(),
-                           EltCnt.divideCoefficientBy(4));
-  }
-
   /// This static method returns a VectorType with twice as many elements as the
   /// input type and the same element type.
   static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index d0c4bee59e889..f79df522dc805 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -135,7 +135,6 @@ namespace Intrinsic {
       ExtendArgument,
       TruncArgument,
       HalfVecArgument,
-      QuarterVecArgument,
       SameVecWidthArgument,
       VecOfAnyPtrsToElt,
       VecElementArgument,
@@ -165,7 +164,7 @@ namespace Intrinsic {
 
     unsigned getArgumentNumber() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument ||
              Kind == SameVecWidthArgument || Kind == VecElementArgument ||
              Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
              Kind == VecOfBitcastsToInt);
@@ -173,7 +172,7 @@ namespace Intrinsic {
     }
     ArgKind getArgumentKind() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument ||
              Kind == SameVecWidthArgument ||
              Kind == VecElementArgument || Kind == Subdivide2Argument ||
              Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 98c5221ade2c8..0ece1b81a9b0e 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -321,7 +321,6 @@ def IIT_I4 : IIT_Int<4, 58>;
 def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
 def IIT_V6 : IIT_Vec<6, 60>;
 def IIT_V10 : IIT_Vec<10, 61>;
-def IIT_QUARTER_VEC_ARG : IIT_Base<62>;
 }
 
 defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
@@ -458,9 +457,6 @@ class LLVMVectorElementType<int num> : LLVMMatchType<num, IIT_VEC_ELEMENT>;
 class LLVMHalfElementsVectorType<int num>
   : LLVMMatchType<num, IIT_HALF_VEC_ARG>;
 
-class LLVMQuarterElementsVectorType<int num>
-  : LLVMMatchType<num, IIT_QUARTER_VEC_ARG>;
-
 // Match the type of another intrinsic parameter that is expected to be a
 // vector type (i.e. <N x iM>) but with each element subdivided to
 // form a vector with more elements that are smaller than the original.
@@ -2652,7 +2648,7 @@ def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatc
 
 //===-------------- Intrinsics to perform partial reduction ---------------===//
 
-def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMQuarterElementsVectorType<0>],
+def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
                                                                        [llvm_anyvector_ty],
                                                                        [IntrNoMem]>;
 
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index a1ce95c2c92bb..5fb348a8bbcd4 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1259,12 +1259,6 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
                                              ArgInfo));
     return;
   }
-  case IIT_QUARTER_VEC_ARG: {
-    unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
-    OutputTable.push_back(IITDescriptor::get(IITDescriptor::QuarterVecArgument,
-                                             ArgInfo));
-    return;
-  }
   case IIT_SAME_VEC_WIDTH_ARG: {
     unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
     OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument,
@@ -1429,9 +1423,6 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
   case IITDescriptor::HalfVecArgument:
     return VectorType::getHalfElementsVectorType(cast<VectorType>(
                                                   Tys[D.getArgumentNumber()]));
-  case IITDescriptor::QuarterVecArgument:  {
-    return VectorType::getQuarterElementsVectorType(cast<VectorType>(Tys[D.getArgumentNumber()]));
-  }
   case IITDescriptor::SameVecWidthArgument: {
     Type *EltTy = DecodeFixedType(Infos, Tys, Context);
     Type *Ty = Tys[D.getArgumentNumber()];
@@ -1659,13 +1650,6 @@ static bool matchIntrinsicType(
       return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
              VectorType::getHalfElementsVectorType(
                      cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
-    case IITDescriptor::QuarterVecArgument: {
-    if (D.getArgumentNumber() >= ArgTys.size())
-        return IsDeferredCheck || DeferCheck(Ty);
-      return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
-             VectorType::getQuarterElementsVectorType(
-                     cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
-    }
     case IITDescriptor::SameVecWidthArgument: {
       if (D.getArgumentNumber() >= ArgTys.size()) {
         // Defer check and subsequent check for the vector element type.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index a184cb3ea0fb4..0376a95699e67 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8582,7 +8582,7 @@ VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
     auto EC = ElementCount::getScalable(16);
     if(std::find(Range.begin(), Range.end(), EC) == Range.end())
       return nullptr;
-    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
+    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()), 4);
   }
   return nullptr;
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index feef9da7e84f8..19a337841be3e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1975,15 +1975,16 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
 
 class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
   unsigned Opcode;
+  unsigned Scale;
 public:
   template <typename IterT>
   VPPartialReductionRecipe(Instruction &I,
-                           iterator_range<IterT> Operands) : VPRecipeWithIRFlags(
-    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode())
+                           iterator_range<IterT> Operands, unsigned Scale) : VPRecipeWithIRFlags(
+    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode()), Scale(Scale)
   {}
   ~VPPartialReductionRecipe() override = default;
   VPPartialReductionRecipe *clone() override {
-    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands());
+    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands(), Scale);
     R->transferFlags(*this);
     return R;
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index b24762ec5ca68..a56408edad4a2 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -279,7 +279,8 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
       assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
 
       ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
-      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), 4);
+      auto EC = FullTy->getElementCount();
+      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale).getKnownMinValue());
 
       Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
       switch(Opcode) {
@@ -311,7 +312,7 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
 }
 
 void VPPartialReductionRecipe::postInsertionOp() {
-  cast<VPReductionPHIRecipe>(this->getOperand(1))->SetVFScaleFactor(4);
+  cast<VPReductionPHIRecipe>(this->getOperand(1))->SetVFScaleFactor(Scale);
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
index 1eafd505b199e..7883cfc05a13b 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
@@ -33,7 +33,7 @@ define void @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
 ; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
 ; CHECK-NEXT:    [[TMP29:%.*]] = mul <vscale x 16 x i32> [[TMP27]], [[TMP19]]
-; CHECK-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP29]])
+; CHECK-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 16 x i32> [[TMP29]])
 ; CHECK-NEXT:    [[TMP14]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
 ; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]

>From 987abef923510a78016b9f4aa4dd92a22a3ec72c Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:12:04 +0100
Subject: [PATCH 04/10] Commit of test files

---
 .../CodeGen/AArch64/partial-reduce-sdot-ir.ll | 99 +++++++++++++++++++
 ...rtial-reduce-sdot.ll => partial-reduce.ll} |  0
 2 files changed, 99 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
 rename llvm/test/CodeGen/AArch64/{partial-reduce-sdot.ll => partial-reduce.ll} (100%)

diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
new file mode 100644
index 0000000000000..3519ba58b3df3
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes="default<O3>" -force-vector-interleave=1 -S < %s | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define void @dotp(ptr %out, ptr %a, ptr %b, i64 %wide.trip.count) #0 {
+; CHECK-LABEL: define void @dotp(
+; CHECK-SAME: ptr nocapture writeonly [[OUT:%.*]], ptr nocapture readonly [[A:%.*]], ptr nocapture readonly [[B:%.*]], i64 [[WIDE_TRIP_COUNT:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[WIDE_TRIP_COUNT]], 1
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i64 [[TMP1]], 4
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], [[TMP2]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[FOR_BODY_PREHEADER:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP4:%.*]] = shl i64 [[TMP3]], 4
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], [[TMP4]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP5:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP6:%.*]] = shl i64 [[TMP5]], 4
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 16 x i8>, ptr [[TMP7]], align 1
+; CHECK-NEXT:    [[TMP8:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 16 x i8>, ptr [[TMP9]], align 1
+; CHECK-NEXT:    [[TMP10:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD1]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = mul nuw nsw <vscale x 16 x i32> [[TMP10]], [[TMP8]]
+; CHECK-NEXT:    [[TMP12]] = add <vscale x 16 x i32> [[TMP11]], [[VEC_PHI]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP6]]
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP14:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP12]])
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[N_MOD_VF]], 0
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[FOR_BODY_PREHEADER]]
+; CHECK:       for.body.preheader:
+; CHECK-NEXT:    [[INDVARS_IV_PH:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[N_VEC]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[ACC_010_PH:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP14]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.cond.cleanup.loopexit:
+; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP14]], [[MIDDLE_BLOCK]] ], [ [[ADD:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[TMP15:%.*]] = trunc i32 [[ADD_LCSSA]] to i8
+; CHECK-NEXT:    store i8 [[TMP15]], ptr [[OUT]], align 1
+; CHECK-NEXT:    ret void
+; CHECK:       for.body:
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[INDVARS_IV_PH]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[ADD]], [[FOR_BODY]] ], [ [[ACC_010_PH]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP16:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
+; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP16]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP17:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP17]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw nsw i32 [[CONV3]], [[CONV]]
+; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
+; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
+; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV]], [[WIDE_TRIP_COUNT]]
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = trunc i32 %add to i8
+  store i8 %0, ptr %out, align 1
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv, %wide.trip.count
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+
+; uselistorder directives
+  uselistorder i32 %add, { 1, 0 }
+}
+
+attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
+;.
+; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
+; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
+; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
+; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
+;.
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll b/llvm/test/CodeGen/AArch64/partial-reduce.ll
similarity index 100%
rename from llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
rename to llvm/test/CodeGen/AArch64/partial-reduce.ll

>From f1516f5df3efa8ec196a15a90e3ce9adab539c91 Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:12:58 +0100
Subject: [PATCH 05/10] Add generic decomposition of partial reduction
 intrinsic

---
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 2f36c2e86b1c3..222b82d434d0e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -957,6 +957,11 @@ END_TWO_BYTE_PACK()
   inline const APInt &getAsAPIntVal() const;
 
   const SDValue &getOperand(unsigned Num) const {
+    if(Num >= NumOperands) {
+      dbgs() << Num << ">=" << NumOperands << "\n";
+      printr(dbgs());
+      dbgs() << "\n";
+    }
     assert(Num < NumOperands && "Invalid child # of SDNode!");
     return OperandList[Num];
   }

>From 07f4398eee272b2a9b79ed40423eb58c910519bb Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:14:15 +0100
Subject: [PATCH 06/10] Add basic cost modeling for partial reductions

---
 llvm/include/llvm/Analysis/TargetTransformInfo.h   | 13 +++++++++++++
 .../llvm/Analysis/TargetTransformInfoImpl.h        |  7 +++++++
 llvm/include/llvm/CodeGen/BasicTTIImpl.h           |  7 +++++++
 llvm/lib/Analysis/TargetTransformInfo.cpp          |  6 ++++++
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp    | 14 +++++++++++++-
 5 files changed, 46 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index dcdd9f82cde8e..ba007ffbfd742 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1528,6 +1528,10 @@ class TargetTransformInfo {
       FastMathFlags FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
+  InstructionCost getPartialReductionCost(
+    unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
+    FastMathFlags FMF, TargetCostKind CostKind = TCK_RecipThroughput) const;
+
   /// \returns The cost of Intrinsic instructions. Analyses the real arguments.
   /// Three cases are handled: 1. scalar instruction 2. vector instruction
   /// 3. scalar instruction which is to be vectorized.
@@ -2093,6 +2097,9 @@ class TargetTransformInfo::Concept {
       unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
       FastMathFlags FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
+  virtual InstructionCost getPartialReductionCost(
+      unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
+      FastMathFlags FMF, TargetCostKind CostKind = TCK_RecipThroughput) = 0;
   virtual InstructionCost getMulAccReductionCost(
       bool IsUnsigned, Type *ResTy, VectorType *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
@@ -2776,6 +2783,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.getExtendedReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
                                          CostKind);
   }
+  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
+      VectorType *ResTy, VectorType *Ty, FastMathFlags FMF,
+      TargetCostKind CostKind = TCK_RecipThroughput) override {
+      return Impl.getPartialReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
+          CostKind);
+  }
   InstructionCost
   getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
                          TTI::TargetCostKind CostKind) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 01624de190d51..d96c9e1ef4aec 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -806,6 +806,13 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
+  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
+                                           VectorType *ResTy, VectorType *Ty,
+                                           FastMathFlags FMF,
+                                           TTI::TargetCostKind CostKind) const {
+    return InstructionCost::getMax();
+  }
+
   InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
                                          VectorType *Ty,
                                          TTI::TargetCostKind CostKind) const {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 4f1dc9f991c06..13fe0628cbc87 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -2609,6 +2609,13 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return RedCost + ExtCost;
   }
 
+  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
+                                          VectorType *ResTy, VectorType *Ty,
+                                          FastMathFlags FMD,
+                                          TTI::TargetCostKind CostKind) {
+    return InstructionCost::getMax();
+  }
+
   InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
                                          VectorType *Ty,
                                          TTI::TargetCostKind CostKind) {
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index c175d1737e54b..122a41d4791a2 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1167,6 +1167,12 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
                                            CostKind);
 }
 
+InstructionCost TargetTransformInfo::getPartialReductionCost(
+  unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
+  FastMathFlags FMF, TargetCostKind CostKind) const {
+  return TTIImpl->getPartialReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF, CostKind);
+}
+
 InstructionCost TargetTransformInfo::getMulAccReductionCost(
     bool IsUnsigned, Type *ResTy, VectorType *Ty,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0376a95699e67..ade40d2a07e69 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8582,7 +8582,19 @@ VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
     auto EC = ElementCount::getScalable(16);
     if(std::find(Range.begin(), Range.end(), EC) == Range.end())
       return nullptr;
-    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()), 4);
+
+    // Scale factor of 4 for sdot/udot.
+    unsigned Scale = 4;
+    VectorType* ResTy = ScalableVectorType::get(Instr->getType(), Scale);
+    VectorType* ValTy = ScalableVectorType::get(Instr->getType(), EC.getKnownMinValue());
+    using namespace llvm::PatternMatch;
+    bool IsUnsigned = match(Instr, m_Add(m_Mul(m_ZExt(m_Value()), m_ZExt(m_Value())), m_Value()));
+    auto RecipeCost = this->CM.TTI.getPartialReductionCost(Instr->getOpcode(), IsUnsigned, ResTy, ValTy, FastMathFlags::getFast());
+    // TODO replace with more informed cost check
+    if(RecipeCost == InstructionCost::getMax())
+      return nullptr;
+
+    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()), Scale);
   }
   return nullptr;
 }

>From 7e1339b748a26474fa26d97ee619312edd724d6f Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:14:55 +0100
Subject: [PATCH 07/10] Add missing support for sign-extends

---
 .../lib/Transforms/Vectorize/LoopVectorize.cpp | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index ade40d2a07e69..d94b882ee9827 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2203,6 +2203,9 @@ static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*
 }
 
 
+/// Checks if the given instruction the root of a partial reduction chain
+///
+/// Note: This currently only supports udot/sdot chains
 /// @param Instr The root instruction to scan
 static bool isInstrPartialReduction(Instruction *Instr) {
   Value *ExpectedPhi;
@@ -2212,12 +2215,12 @@ static bool isInstrPartialReduction(Instruction *Instr) {
   using namespace llvm::PatternMatch;
   auto Pattern = m_Add(
     m_OneUse(m_Mul(
-      m_OneUse(m_ZExt(
+      m_OneUse(m_ZExtOrSExt(
         m_OneUse(m_Load(
           m_GEP(
               m_Value(A),
               m_Value(InductionA)))))),
-      m_OneUse(m_ZExt(
+      m_OneUse(m_ZExtOrSExt(
         m_OneUse(m_Load(
           m_GEP(
               m_Value(B),
@@ -2236,8 +2239,13 @@ static bool isInstrPartialReduction(Instruction *Instr) {
   }
 
   Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
-  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
-  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+  Instruction *Ext0 = cast<CastInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<CastInst>(Mul->getOperand(1));
+
+  if(Ext0->getOpcode() != Ext1->getOpcode()) {
+    LLVM_DEBUG(dbgs() << "Extends aren't of the same type, cannot create a partial reduction.\n");
+    return false;
+  }
 
   // Check that the extends extend to i32
   if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
@@ -8579,6 +8587,8 @@ VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
     VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
 
   if(isInstrPartialReduction(Instr)) {
+    // Restricting this case to 16x means that, using a scale of 4, we avoid
+    // trying to generate illegal types such as <vscale x 2 x i32>
     auto EC = ElementCount::getScalable(16);
     if(std::find(Range.begin(), Range.end(), EC) == Range.end())
       return nullptr;

>From 794869488eba7a30858bb8f38f7e594cd40a487f Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:27:39 +0100
Subject: [PATCH 08/10] Remove debug statements

---
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 5 -----
 1 file changed, 5 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 222b82d434d0e..2f36c2e86b1c3 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -957,11 +957,6 @@ END_TWO_BYTE_PACK()
   inline const APInt &getAsAPIntVal() const;
 
   const SDValue &getOperand(unsigned Num) const {
-    if(Num >= NumOperands) {
-      dbgs() << Num << ">=" << NumOperands << "\n";
-      printr(dbgs());
-      dbgs() << "\n";
-    }
     assert(Num < NumOperands && "Invalid child # of SDNode!");
     return OperandList[Num];
   }

>From 2ce1a068c4c67c8ead8040ec02f44a06609f217d Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Fri, 14 Jun 2024 12:54:52 +0100
Subject: [PATCH 09/10] Redesign how the LoopVectorizer identifies partial
 reductions

---
 .../llvm/Analysis/TargetTransformInfo.h       |  27 +--
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  14 +-
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      |   7 -
 llvm/include/llvm/IR/Intrinsics.td            |   6 -
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  13 +-
 .../Vectorize/LoopVectorizationPlanner.h      |  23 +++
 .../Transforms/Vectorize/LoopVectorize.cpp    | 177 +++++++++---------
 .../Transforms/Vectorize/VPRecipeBuilder.h    |   2 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |   2 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  10 +-
 ...educe.ll => partial-reduce-dot-product.ll} |  54 +++---
 .../CodeGen/AArch64/partial-reduce-sdot-ir.ll |  99 ----------
 12 files changed, 167 insertions(+), 267 deletions(-)
 rename llvm/test/CodeGen/AArch64/{partial-reduce.ll => partial-reduce-dot-product.ll} (66%)
 delete mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index ba007ffbfd742..7f69c4eea715d 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1248,6 +1248,8 @@ class TargetTransformInfo {
   /// \return if target want to issue a prefetch in address space \p AS.
   bool shouldPrefetchAddressSpace(unsigned AS) const;
 
+  bool isPartialReductionSupported(const Instruction* ReductionInstr, Type* InputType, unsigned ScaleFactor, bool IsInputASignExtended, bool IsInputBSignExtended, const Instruction* BinOp = nullptr) const;
+    
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism
   /// and the number of execution units in the CPU.
@@ -1528,10 +1530,6 @@ class TargetTransformInfo {
       FastMathFlags FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
-  InstructionCost getPartialReductionCost(
-    unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
-    FastMathFlags FMF, TargetCostKind CostKind = TCK_RecipThroughput) const;
-
   /// \returns The cost of Intrinsic instructions. Analyses the real arguments.
   /// Three cases are handled: 1. scalar instruction 2. vector instruction
   /// 3. scalar instruction which is to be vectorized.
@@ -2016,6 +2014,11 @@ class TargetTransformInfo::Concept {
   /// \return if target want to issue a prefetch in address space \p AS.
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
+  virtual bool isPartialReductionSupported(const Instruction* ReductionInstr,
+      Type* InputType, unsigned ScaleFactor,
+      bool IsInputASignExtended, bool IsInputBSignExtended,
+      const Instruction* BinOp = nullptr) const = 0;
+    
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
@@ -2097,9 +2100,6 @@ class TargetTransformInfo::Concept {
       unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
       FastMathFlags FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
-  virtual InstructionCost getPartialReductionCost(
-      unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
-      FastMathFlags FMF, TargetCostKind CostKind = TCK_RecipThroughput) = 0;
   virtual InstructionCost getMulAccReductionCost(
       bool IsUnsigned, Type *ResTy, VectorType *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
@@ -2648,6 +2648,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.shouldPrefetchAddressSpace(AS);
   }
 
+  bool isPartialReductionSupported(const Instruction* ReductionInstr, Type* InputType, unsigned ScaleFactor,
+                                              bool IsInputASignExtended, bool IsInputBSignExtended,
+                                              const Instruction* BinOp = nullptr) const override
+  {
+      return Impl.isPartialReductionSupported(ReductionInstr, InputType, ScaleFactor, IsInputASignExtended, IsInputBSignExtended, BinOp);
+  }
+
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
@@ -2783,12 +2790,6 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.getExtendedReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
                                          CostKind);
   }
-  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
-      VectorType *ResTy, VectorType *Ty, FastMathFlags FMF,
-      TargetCostKind CostKind = TCK_RecipThroughput) override {
-      return Impl.getPartialReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
-          CostKind);
-  }
   InstructionCost
   getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
                          TTI::TargetCostKind CostKind) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index d96c9e1ef4aec..11773f59170a9 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -541,6 +541,13 @@ class TargetTransformInfoImplBase {
   bool enableWritePrefetching() const { return false; }
   bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
 
+  bool isPartialReductionSupported(const Instruction* ReductionInstr,
+      Type* InputType, unsigned ScaleFactor,
+      bool IsInputASignExtended, bool IsInputBSignExtended,
+      const Instruction* BinOp = nullptr) const {
+    return false;
+  }
+
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
 
   InstructionCost getArithmeticInstrCost(
@@ -806,13 +813,6 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
-  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
-                                           VectorType *ResTy, VectorType *Ty,
-                                           FastMathFlags FMF,
-                                           TTI::TargetCostKind CostKind) const {
-    return InstructionCost::getMax();
-  }
-
   InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
                                          VectorType *Ty,
                                          TTI::TargetCostKind CostKind) const {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 13fe0628cbc87..4f1dc9f991c06 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -2609,13 +2609,6 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return RedCost + ExtCost;
   }
 
-  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
-                                          VectorType *ResTy, VectorType *Ty,
-                                          FastMathFlags FMD,
-                                          TTI::TargetCostKind CostKind) {
-    return InstructionCost::getMax();
-  }
-
   InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
                                          VectorType *Ty,
                                          TTI::TargetCostKind CostKind) {
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 0ece1b81a9b0e..95dbd2854322d 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2646,12 +2646,6 @@ def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatc
                                                                        [llvm_anyvector_ty, llvm_anyvector_ty],
                                                                        [IntrNoMem]>;
 
-//===-------------- Intrinsics to perform partial reduction ---------------===//
-
-def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
-                                                                       [llvm_anyvector_ty],
-                                                                       [IntrNoMem]>;
-
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 122a41d4791a2..f74f1f087e597 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -825,6 +825,13 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
   return TTIImpl->shouldPrefetchAddressSpace(AS);
 }
 
+bool TargetTransformInfo::isPartialReductionSupported(
+    const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+    bool IsInputASignExtended, bool IsInputBSignExtended,
+    const Instruction *BinOp) const {
+  return TTIImpl->isPartialReductionSupported(ReductionInstr, InputType, ScaleFactor, IsInputASignExtended, IsInputBSignExtended, BinOp);
+}
+
 unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
@@ -1167,12 +1174,6 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
                                            CostKind);
 }
 
-InstructionCost TargetTransformInfo::getPartialReductionCost(
-  unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
-  FastMathFlags FMF, TargetCostKind CostKind) const {
-  return TTIImpl->getPartialReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF, CostKind);
-}
-
 InstructionCost TargetTransformInfo::getMulAccReductionCost(
     bool IsUnsigned, Type *ResTy, VectorType *Ty,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 93f6d1d82e244..e02a00f2912eb 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -293,6 +293,19 @@ struct FixedScalableVFPair {
   bool hasVector() const { return FixedVF.isVector() || ScalableVF.isVector(); }
 };
 
+struct PartialReductionChain {
+  Instruction *Reduction;
+  Instruction *BinOp;
+  Instruction *ExtendA;
+  Instruction *ExtendB;
+  
+  Value *InputA;
+  Value *InputB;
+  Value *Accumulator;
+
+  unsigned ScaleFactor;
+};
+
 /// Planner drives the vectorization process after having passed
 /// Legality checks.
 class LoopVectorizationPlanner {
@@ -331,6 +344,8 @@ class LoopVectorizationPlanner {
   /// Profitable vector factors.
   SmallVector<VectorizationFactor, 8> ProfitableVFs;
 
+  SmallVector<PartialReductionChain> PartialReductionChains;
+
   /// A builder used to construct the current plan.
   VPBuilder Builder;
 
@@ -398,6 +413,10 @@ class LoopVectorizationPlanner {
   VectorizationFactor
   selectEpilogueVectorizationFactor(const ElementCount MaxVF, unsigned IC);
 
+  SmallVector<PartialReductionChain> getPartialReductionChains() const {
+    return PartialReductionChains;
+  } 
+
 protected:
   /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive,
   /// according to the information gathered by Legal when it checked if it is
@@ -444,6 +463,10 @@ class LoopVectorizationPlanner {
   /// Determines if we have the infrastructure to vectorize the loop and its
   /// epilogue, assuming the main loop is vectorized by \p VF.
   bool isCandidateForEpilogueVectorization(const ElementCount VF) const;
+
+  bool getInstructionsPartialReduction(Instruction* I, PartialReductionChain &Chain) const;
+
+  
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index d94b882ee9827..ba7a2a1188953 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1137,7 +1137,7 @@ class LoopVectorizationCostModel {
   calculateRegisterUsage(ArrayRef<ElementCount> VFs);
 
   /// Collect values we want to ignore in the cost model.
-  void collectValuesToIgnore();
+  void collectValuesToIgnore(LoopVectorizationPlanner *LVP);
 
   /// Collect all element types in the loop for which widening is needed.
   void collectElementTypesForWidening();
@@ -2191,73 +2191,59 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
          Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
 }
 
-static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
-  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
-  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
-  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+static PartialReductionChain getPartialReductionInstrChain(Instruction *Instr) {
+  Instruction *BinOp = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<Instruction>(BinOp->getOperand(0));
+  Instruction *Ext1 = cast<Instruction>(BinOp->getOperand(1));
 
-  Chain.push_back(Mul);
-  Chain.push_back(Ext0);
-  Chain.push_back(Ext1);
-  Chain.push_back(Instr->getOperand(1));
+  PartialReductionChain Chain;
+  Chain.Reduction = Instr;
+  Chain.BinOp = BinOp;
+  Chain.ExtendA = Ext0;
+  Chain.ExtendB = Ext1;
+  Chain.InputA = Ext0->getOperand(0);
+  Chain.InputB = Ext1->getOperand(0);
+  Chain.Accumulator = Instr->getOperand(1);
+
+  unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
+  unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
+  Chain.ScaleFactor = ResultSizeBits / InputSizeBits; 
+  return Chain;
 }
 
 
 /// Checks if the given instruction the root of a partial reduction chain
 ///
-/// Note: This currently only supports udot/sdot chains
 /// @param Instr The root instruction to scan
 static bool isInstrPartialReduction(Instruction *Instr) {
   Value *ExpectedPhi;
   Value *A, *B;
-  Value *InductionA, *InductionB;
 
   using namespace llvm::PatternMatch;
-  auto Pattern = m_Add(
-    m_OneUse(m_Mul(
-      m_OneUse(m_ZExtOrSExt(
-        m_OneUse(m_Load(
-          m_GEP(
-              m_Value(A),
-              m_Value(InductionA)))))),
-      m_OneUse(m_ZExtOrSExt(
-        m_OneUse(m_Load(
-          m_GEP(
-              m_Value(B),
-              m_Value(InductionB))))))
-        )), m_Value(ExpectedPhi));
+  auto Pattern = m_BinOp(
+      m_OneUse(m_BinOp(
+        m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
+        m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
+      m_Value(ExpectedPhi));
 
   bool Matches = match(Instr, Pattern);
 
   if(!Matches)
     return false;
 
-  // Check that the two induction variable uses are to the same induction variable
-  if(InductionA != InductionB) {
-    LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
-    return false;
-  }
-
-  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
-  Instruction *Ext0 = cast<CastInst>(Mul->getOperand(0));
-  Instruction *Ext1 = cast<CastInst>(Mul->getOperand(1));
-
-  if(Ext0->getOpcode() != Ext1->getOpcode()) {
-    LLVM_DEBUG(dbgs() << "Extends aren't of the same type, cannot create a partial reduction.\n");
+  // Check that the extends extend from the same type
+  if(A->getType() != B->getType()) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot create a partial reduction.\n");
     return false;
   }
 
-  // Check that the extends extend to i32
-  if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
-    LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
-    return false;
-  }
+  // A and B are one-use, so the first user of each should be the respective extend 
+  Instruction *Ext0 = cast<CastInst>(*A->user_begin());
+  Instruction *Ext1 = cast<CastInst>(*B->user_begin());
 
-  // Check that the loads are loading i8
-  LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
-  LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
-  if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
-    LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
+  // Check that the extends extend to the same type
+  if(Ext0->getType() != Ext1->getType()) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend to the same type, cannot create a partial reduction.\n");
     return false;
   }
 
@@ -2268,23 +2254,32 @@ static bool isInstrPartialReduction(Instruction *Instr) {
     return false;
   }
 
-  // Check that the first phi value is a zero initializer
-  ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
-  if(!ZeroInit || !ZeroInit->isZero()) {
-    LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
-    return false;
-  }
-
   // Check that the second phi value is the instruction we're looking at
   Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
   if(!MaybeAdd || MaybeAdd != Instr) {
-    LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
+    LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot create a partial reduction.\n");
     return false;
   }
 
   return true;
 }
 
+static bool isPartialReductionChainValid(PartialReductionChain &Chain, const TargetTransformInfo &TTI) {
+  if(Chain.Reduction->getOpcode() != Instruction::Add)
+    return false;
+
+  unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
+  unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
+
+  if(ResultSizeBits < InputSizeBits || (ResultSizeBits % InputSizeBits) != 0)
+    return false;
+  
+  bool IsASignExtended = isa<SExtInst>(Chain.ExtendA);
+  bool IsBSignExtended = isa<SExtInst>(Chain.ExtendB);
+
+  return TTI.isPartialReductionSupported(Chain.Reduction, Chain.InputA->getType(), Chain.ScaleFactor, IsASignExtended, IsBSignExtended, Chain.BinOp);
+}
+
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
 // vectorization. The loop needs to be annotated with #pragma omp simd
 // simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -5091,6 +5086,16 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
   return true;
 }
 
+bool LoopVectorizationPlanner::getInstructionsPartialReduction(Instruction *I, PartialReductionChain &Chain) const {
+  for(auto &C : PartialReductionChains) {
+    if(C.Reduction == I) {
+      Chain = C;
+      return true;
+    }
+  }
+  return false;
+}
+
 bool LoopVectorizationCostModel::isEpilogueVectorizationProfitable(
     const ElementCount VF) const {
   // FIXME: We need a much better cost-model to take different parameters such
@@ -7152,7 +7157,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
   } // end of switch.
 }
 
-void LoopVectorizationCostModel::collectValuesToIgnore() {
+void LoopVectorizationCostModel::collectValuesToIgnore(LoopVectorizationPlanner* LVP) {
   // Ignore ephemeral values.
   CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore);
 
@@ -7209,14 +7214,10 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
   }
 
   // Ignore any values that we know will be flattened
-  for(auto Reduction : this->Legal->getReductionVars()) {
-    auto &Recurrence = Reduction.second;
-    if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
-      SmallVector<Value*, 4> PartialReductionValues;
-      getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
-      ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
-      VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
-    }
+  for(auto Chain : LVP->getPartialReductionChains()) {
+    SmallVector<Value*> PartialReductionValues{Chain.Reduction, Chain.BinOp, Chain.ExtendA, Chain.ExtendB, Chain.Accumulator};
+    ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+    VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
   }
 }
 
@@ -7343,7 +7344,17 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
 std::optional<VectorizationFactor>
 LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
-  CM.collectValuesToIgnore();
+
+  for(auto ReductionVar : Legal->getReductionVars()) {
+    auto *ReductionExitInstr = ReductionVar.second.getLoopExitInstr();
+    if(isInstrPartialReduction(ReductionExitInstr)) {
+      auto Chain = getPartialReductionInstrChain(ReductionExitInstr);
+      if(isPartialReductionChainValid(Chain, TTI)) 
+        PartialReductionChains.push_back(Chain);
+    }
+  }
+  
+  CM.collectValuesToIgnore(this);
   CM.collectElementTypesForWidening();
 
   FixedScalableVFPair MaxFactors = CM.computeMaxVF(UserVF, UserIC);
@@ -8577,36 +8588,12 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
                                  *CI);
   }
 
-  if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
-    return PartialReduce;
-
   return tryToWiden(Instr, Operands, VPBB);
 }
 
 VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
-    VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
-
-  if(isInstrPartialReduction(Instr)) {
-    // Restricting this case to 16x means that, using a scale of 4, we avoid
-    // trying to generate illegal types such as <vscale x 2 x i32>
-    auto EC = ElementCount::getScalable(16);
-    if(std::find(Range.begin(), Range.end(), EC) == Range.end())
-      return nullptr;
-
-    // Scale factor of 4 for sdot/udot.
-    unsigned Scale = 4;
-    VectorType* ResTy = ScalableVectorType::get(Instr->getType(), Scale);
-    VectorType* ValTy = ScalableVectorType::get(Instr->getType(), EC.getKnownMinValue());
-    using namespace llvm::PatternMatch;
-    bool IsUnsigned = match(Instr, m_Add(m_Mul(m_ZExt(m_Value()), m_ZExt(m_Value())), m_Value()));
-    auto RecipeCost = this->CM.TTI.getPartialReductionCost(Instr->getOpcode(), IsUnsigned, ResTy, ValTy, FastMathFlags::getFast());
-    // TODO replace with more informed cost check
-    if(RecipeCost == InstructionCost::getMax())
-      return nullptr;
-
-    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()), Scale);
-  }
-  return nullptr;
+    VFRange &Range, PartialReductionChain &Chain, ArrayRef<VPValue *> Operands) {
+  return new VPPartialReductionRecipe(*Chain.Reduction, make_range(Operands.begin(), Operands.end()), Chain.ScaleFactor);
 }
 
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
@@ -8795,8 +8782,14 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
           Legal->isInvariantAddressOfReduction(SI->getPointerOperand()))
         continue;
 
-      VPRecipeBase *Recipe =
-          RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB);
+      VPRecipeBase *Recipe = nullptr;
+
+      PartialReductionChain Chain;
+      if(getInstructionsPartialReduction(Instr, Chain)) 
+        Recipe = RecipeBuilder.tryToCreatePartialReduction(Range, Chain, Operands);
+      
+      if (!Recipe)
+        Recipe = RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB);
       if (!Recipe)
         Recipe = RecipeBuilder.handleReplication(Instr, Range);
 
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index c439f221709e1..91b3a39a89959 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -116,7 +116,7 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
-  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, Instruction* Instr, ArrayRef<VPValue*> Operands);
+  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, PartialReductionChain &Chain, ArrayRef<VPValue*> Operands);
 
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 19a337841be3e..a3711abdef1c1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -964,7 +964,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     FastMathFlagsTy(const FastMathFlags &FMF);
   };
 
+public:
   OperationType OpType;
+private:
 
   union {
     CmpInst::Predicate CmpPredicate;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index a56408edad4a2..648823f532927 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -276,11 +276,10 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
       }
 
       assert(Phi && Mul && "Phi and Mul must be set");
-      assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
 
-      ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
+      VectorType *FullTy = cast<VectorType>(Ops[0]->getType());
       auto EC = FullTy->getElementCount();
-      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale).getKnownMinValue());
+      Type *RetTy = VectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale));
 
       Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
       switch(Opcode) {
@@ -294,10 +293,7 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
 
       assert(PartialIntrinsic != Intrinsic::not_intrinsic);
 
-      Value *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, Mul, nullptr, Twine("partial.reduce"));
-      V = Builder.CreateNAryOp(Opcode, {V, Phi});
-      if (auto *VecOp = dyn_cast<Instruction>(V))
-        setFlags(VecOp);
+      CallInst *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, {Phi, Mul}, nullptr, Twine("partial.reduce"));
 
       // Use this vector value for all users of the original instruction.
       State.set(this, V, Part);
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
similarity index 66%
rename from llvm/test/CodeGen/AArch64/partial-reduce.ll
rename to llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index 7883cfc05a13b..2907bf903c031 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -9,56 +9,55 @@ define void @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 16
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 0, [[TMP1]]
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
 ; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 16
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
 ; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 0, [[TMP3]]
 ; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 0, [[N_MOD_VF]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 16
+; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 0
-; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP11]]
-; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 16 x i8>, ptr [[TMP17]], align 1
-; CHECK-NEXT:    [[TMP19:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD2]] to <vscale x 16 x i32>
-; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP11]]
-; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr i8, ptr [[TMP21]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
-; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
-; CHECK-NEXT:    [[TMP29:%.*]] = mul <vscale x 16 x i32> [[TMP27]], [[TMP19]]
-; CHECK-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 16 x i32> [[TMP29]])
-; CHECK-NEXT:    [[TMP14]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE]], [[VEC_PHI]]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP7]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP8]], align 1
+; CHECK-NEXT:    [[TMP9:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP10]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT:    [[TMP12:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD1]] to <vscale x 4 x i32>
+; CHECK-NEXT:    [[TMP13:%.*]] = mul <vscale x 4 x i32> [[TMP12]], [[TMP9]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE]] = add <vscale x 4 x i32> [[TMP13]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
-; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP32]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP33:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP14]])
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
+; CHECK-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]])
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
 ; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP15]], [[MIDDLE_BLOCK]] ]
 ; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.cond.cleanup.loopexit:
-; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    [[TMP20:%.*]] = lshr i32 [[ADD_LCSSA]], 0
+; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[TMP15]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[TMP16:%.*]] = lshr i32 [[ADD_LCSSA]], 0
 ; CHECK-NEXT:    ret void
 ; CHECK:       for.body:
 ; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
 ; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ], [ [[ADD]], [[FOR_BODY]] ]
 ; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP18:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
-; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP18]] to i32
+; CHECK-NEXT:    [[TMP17:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
+; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP17]] to i32
 ; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP22:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
-; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP22]] to i32
+; CHECK-NEXT:    [[TMP18:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP18]] to i32
 ; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[CONV3]], [[CONV]]
 ; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
 ; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
@@ -86,9 +85,6 @@ for.body:                                         ; preds = %for.body, %entry
   %indvars.iv.next = add i64 %indvars.iv, 1
   %exitcond.not = icmp eq i64 %indvars.iv.next, 0
   br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
-
-; uselistorder directives
-  uselistorder i32 %add, { 1, 0 }
 }
 
 attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
deleted file mode 100644
index 3519ba58b3df3..0000000000000
--- a/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
+++ /dev/null
@@ -1,99 +0,0 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt -passes="default<O3>" -force-vector-interleave=1 -S < %s | FileCheck %s
-
-target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
-target triple = "aarch64-none-unknown-elf"
-
-define void @dotp(ptr %out, ptr %a, ptr %b, i64 %wide.trip.count) #0 {
-; CHECK-LABEL: define void @dotp(
-; CHECK-SAME: ptr nocapture writeonly [[OUT:%.*]], ptr nocapture readonly [[A:%.*]], ptr nocapture readonly [[B:%.*]], i64 [[WIDE_TRIP_COUNT:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[WIDE_TRIP_COUNT]], 1
-; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP2:%.*]] = shl i64 [[TMP1]], 4
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], [[TMP2]]
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[FOR_BODY_PREHEADER:%.*]], label [[VECTOR_PH:%.*]]
-; CHECK:       vector.ph:
-; CHECK-NEXT:    [[TMP3:%.*]] = tail call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP4:%.*]] = shl i64 [[TMP3]], 4
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], [[TMP4]]
-; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
-; CHECK-NEXT:    [[TMP5:%.*]] = tail call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP6:%.*]] = shl i64 [[TMP5]], 4
-; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
-; CHECK:       vector.body:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 16 x i8>, ptr [[TMP7]], align 1
-; CHECK-NEXT:    [[TMP8:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD]] to <vscale x 16 x i32>
-; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 16 x i8>, ptr [[TMP9]], align 1
-; CHECK-NEXT:    [[TMP10:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD1]] to <vscale x 16 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul nuw nsw <vscale x 16 x i32> [[TMP10]], [[TMP8]]
-; CHECK-NEXT:    [[TMP12]] = add <vscale x 16 x i32> [[TMP11]], [[VEC_PHI]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP6]]
-; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
-; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP14:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP12]])
-; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[N_MOD_VF]], 0
-; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[FOR_BODY_PREHEADER]]
-; CHECK:       for.body.preheader:
-; CHECK-NEXT:    [[INDVARS_IV_PH:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[N_VEC]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    [[ACC_010_PH:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP14]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
-; CHECK:       for.cond.cleanup.loopexit:
-; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP14]], [[MIDDLE_BLOCK]] ], [ [[ADD:%.*]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[TMP15:%.*]] = trunc i32 [[ADD_LCSSA]] to i8
-; CHECK-NEXT:    store i8 [[TMP15]], ptr [[OUT]], align 1
-; CHECK-NEXT:    ret void
-; CHECK:       for.body:
-; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[INDVARS_IV_PH]], [[FOR_BODY_PREHEADER]] ]
-; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[ADD]], [[FOR_BODY]] ], [ [[ACC_010_PH]], [[FOR_BODY_PREHEADER]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP16:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
-; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP16]] to i32
-; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP17:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
-; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP17]] to i32
-; CHECK-NEXT:    [[MUL:%.*]] = mul nuw nsw i32 [[CONV3]], [[CONV]]
-; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
-; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
-; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV]], [[WIDE_TRIP_COUNT]]
-; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
-;
-entry:
-  br label %for.body
-
-for.cond.cleanup.loopexit:                        ; preds = %for.body
-  %0 = trunc i32 %add to i8
-  store i8 %0, ptr %out, align 1
-  ret void
-
-for.body:                                         ; preds = %for.body, %entry
-  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
-  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
-  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
-  %1 = load i8, ptr %arrayidx, align 1
-  %conv = zext i8 %1 to i32
-  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
-  %2 = load i8, ptr %arrayidx2, align 1
-  %conv3 = zext i8 %2 to i32
-  %mul = mul i32 %conv3, %conv
-  %add = add i32 %mul, %acc.010
-  %indvars.iv.next = add i64 %indvars.iv, 1
-  %exitcond.not = icmp eq i64 %indvars.iv, %wide.trip.count
-  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
-
-; uselistorder directives
-  uselistorder i32 %add, { 1, 0 }
-}
-
-attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
-;.
-; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
-; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
-; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
-; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
-;.

>From c145094a508f525a535d29c9e387b4b5843612ee Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Thu, 4 Jul 2024 16:07:43 +0100
Subject: [PATCH 10/10] Format

---
 .../llvm/Analysis/TargetTransformInfo.h       |  28 +++--
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   9 +-
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   4 +-
 .../Vectorize/LoopVectorizationPlanner.h      |   9 +-
 .../Transforms/Vectorize/LoopVectorize.cpp    | 110 +++++++++++-------
 .../Transforms/Vectorize/VPRecipeBuilder.h    |   4 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |  25 ++--
 .../Transforms/Vectorize/VPlanAnalysis.cpp    |   6 +-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  33 +++---
 9 files changed, 129 insertions(+), 99 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 7f69c4eea715d..3990aef7f5e56 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1248,8 +1248,12 @@ class TargetTransformInfo {
   /// \return if target want to issue a prefetch in address space \p AS.
   bool shouldPrefetchAddressSpace(unsigned AS) const;
 
-  bool isPartialReductionSupported(const Instruction* ReductionInstr, Type* InputType, unsigned ScaleFactor, bool IsInputASignExtended, bool IsInputBSignExtended, const Instruction* BinOp = nullptr) const;
-    
+  bool isPartialReductionSupported(const Instruction *ReductionInstr,
+                                   Type *InputType, unsigned ScaleFactor,
+                                   bool IsInputASignExtended,
+                                   bool IsInputBSignExtended,
+                                   const Instruction *BinOp = nullptr) const;
+
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism
   /// and the number of execution units in the CPU.
@@ -2014,11 +2018,11 @@ class TargetTransformInfo::Concept {
   /// \return if target want to issue a prefetch in address space \p AS.
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
-  virtual bool isPartialReductionSupported(const Instruction* ReductionInstr,
-      Type* InputType, unsigned ScaleFactor,
+  virtual bool isPartialReductionSupported(
+      const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
       bool IsInputASignExtended, bool IsInputBSignExtended,
-      const Instruction* BinOp = nullptr) const = 0;
-    
+      const Instruction *BinOp = nullptr) const = 0;
+
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
@@ -2648,11 +2652,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.shouldPrefetchAddressSpace(AS);
   }
 
-  bool isPartialReductionSupported(const Instruction* ReductionInstr, Type* InputType, unsigned ScaleFactor,
-                                              bool IsInputASignExtended, bool IsInputBSignExtended,
-                                              const Instruction* BinOp = nullptr) const override
-  {
-      return Impl.isPartialReductionSupported(ReductionInstr, InputType, ScaleFactor, IsInputASignExtended, IsInputBSignExtended, BinOp);
+  bool isPartialReductionSupported(
+      const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+      bool IsInputASignExtended, bool IsInputBSignExtended,
+      const Instruction *BinOp = nullptr) const override {
+    return Impl.isPartialReductionSupported(ReductionInstr, InputType,
+                                            ScaleFactor, IsInputASignExtended,
+                                            IsInputBSignExtended, BinOp);
   }
 
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 11773f59170a9..5aa8e17fea94f 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -541,10 +541,11 @@ class TargetTransformInfoImplBase {
   bool enableWritePrefetching() const { return false; }
   bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
 
-  bool isPartialReductionSupported(const Instruction* ReductionInstr,
-      Type* InputType, unsigned ScaleFactor,
-      bool IsInputASignExtended, bool IsInputBSignExtended,
-      const Instruction* BinOp = nullptr) const {
+  bool isPartialReductionSupported(const Instruction *ReductionInstr,
+                                   Type *InputType, unsigned ScaleFactor,
+                                   bool IsInputASignExtended,
+                                   bool IsInputBSignExtended,
+                                   const Instruction *BinOp = nullptr) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index f74f1f087e597..75784bb19e4c4 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -829,7 +829,9 @@ bool TargetTransformInfo::isPartialReductionSupported(
     const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
     bool IsInputASignExtended, bool IsInputBSignExtended,
     const Instruction *BinOp) const {
-  return TTIImpl->isPartialReductionSupported(ReductionInstr, InputType, ScaleFactor, IsInputASignExtended, IsInputBSignExtended, BinOp);
+  return TTIImpl->isPartialReductionSupported(ReductionInstr, InputType,
+                                              ScaleFactor, IsInputASignExtended,
+                                              IsInputBSignExtended, BinOp);
 }
 
 unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index e02a00f2912eb..01f6318e13a3a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -298,7 +298,7 @@ struct PartialReductionChain {
   Instruction *BinOp;
   Instruction *ExtendA;
   Instruction *ExtendB;
-  
+
   Value *InputA;
   Value *InputB;
   Value *Accumulator;
@@ -415,7 +415,7 @@ class LoopVectorizationPlanner {
 
   SmallVector<PartialReductionChain> getPartialReductionChains() const {
     return PartialReductionChains;
-  } 
+  }
 
 protected:
   /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive,
@@ -464,9 +464,8 @@ class LoopVectorizationPlanner {
   /// epilogue, assuming the main loop is vectorized by \p VF.
   bool isCandidateForEpilogueVectorization(const ElementCount VF) const;
 
-  bool getInstructionsPartialReduction(Instruction* I, PartialReductionChain &Chain) const;
-
-  
+  bool getInstructionsPartialReduction(Instruction *I,
+                                       PartialReductionChain &Chain) const;
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index ba7a2a1188953..b198d20775dfa 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2207,11 +2207,10 @@ static PartialReductionChain getPartialReductionInstrChain(Instruction *Instr) {
 
   unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
   unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
-  Chain.ScaleFactor = ResultSizeBits / InputSizeBits; 
+  Chain.ScaleFactor = ResultSizeBits / InputSizeBits;
   return Chain;
 }
 
-
 /// Checks if the given instruction the root of a partial reduction chain
 ///
 /// @param Instr The root instruction to scan
@@ -2220,64 +2219,71 @@ static bool isInstrPartialReduction(Instruction *Instr) {
   Value *A, *B;
 
   using namespace llvm::PatternMatch;
-  auto Pattern = m_BinOp(
-      m_OneUse(m_BinOp(
-        m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
-        m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
-      m_Value(ExpectedPhi));
+  auto Pattern =
+      m_BinOp(m_OneUse(m_BinOp(m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
+                               m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
+              m_Value(ExpectedPhi));
 
   bool Matches = match(Instr, Pattern);
 
-  if(!Matches)
+  if (!Matches)
     return false;
 
   // Check that the extends extend from the same type
-  if(A->getType() != B->getType()) {
-    LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot create a partial reduction.\n");
+  if (A->getType() != B->getType()) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot "
+                         "create a partial reduction.\n");
     return false;
   }
 
-  // A and B are one-use, so the first user of each should be the respective extend 
+  // A and B are one-use, so the first user of each should be the respective
+  // extend
   Instruction *Ext0 = cast<CastInst>(*A->user_begin());
   Instruction *Ext1 = cast<CastInst>(*B->user_begin());
 
   // Check that the extends extend to the same type
-  if(Ext0->getType() != Ext1->getType()) {
-    LLVM_DEBUG(dbgs() << "Extends don't extend to the same type, cannot create a partial reduction.\n");
+  if (Ext0->getType() != Ext1->getType()) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend to the same type, cannot create "
+                         "a partial reduction.\n");
     return false;
   }
 
   // Check that the add feeds into ExpectedPhi
   PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
-  if(!PhiNode) {
-    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
+  if (!PhiNode) {
+    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a "
+                         "partial reduction.\n");
     return false;
   }
 
   // Check that the second phi value is the instruction we're looking at
   Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
-  if(!MaybeAdd || MaybeAdd != Instr) {
-    LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot create a partial reduction.\n");
+  if (!MaybeAdd || MaybeAdd != Instr) {
+    LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot "
+                         "create a partial reduction.\n");
     return false;
   }
 
   return true;
 }
 
-static bool isPartialReductionChainValid(PartialReductionChain &Chain, const TargetTransformInfo &TTI) {
-  if(Chain.Reduction->getOpcode() != Instruction::Add)
+static bool isPartialReductionChainValid(PartialReductionChain &Chain,
+                                         const TargetTransformInfo &TTI) {
+  if (Chain.Reduction->getOpcode() != Instruction::Add)
     return false;
 
   unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
   unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
 
-  if(ResultSizeBits < InputSizeBits || (ResultSizeBits % InputSizeBits) != 0)
+  if (ResultSizeBits < InputSizeBits || (ResultSizeBits % InputSizeBits) != 0)
     return false;
-  
+
   bool IsASignExtended = isa<SExtInst>(Chain.ExtendA);
   bool IsBSignExtended = isa<SExtInst>(Chain.ExtendB);
 
-  return TTI.isPartialReductionSupported(Chain.Reduction, Chain.InputA->getType(), Chain.ScaleFactor, IsASignExtended, IsBSignExtended, Chain.BinOp);
+  return TTI.isPartialReductionSupported(
+      Chain.Reduction, Chain.InputA->getType(), Chain.ScaleFactor,
+      IsASignExtended, IsBSignExtended, Chain.BinOp);
 }
 
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
@@ -5072,9 +5078,11 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
 
   // Prevent epilogue vectorization if a partial reduction is involved
   // TODO Is there a cleaner way to check this?
-  if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
-    return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
-  }))
+  if (any_of(Legal->getReductionVars(),
+             [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
+               return isInstrPartialReduction(
+                   Reduction.second.getLoopExitInstr());
+             }))
     return false;
 
   // Epilogue vectorization code has not been auditted to ensure it handles
@@ -5086,9 +5094,10 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
   return true;
 }
 
-bool LoopVectorizationPlanner::getInstructionsPartialReduction(Instruction *I, PartialReductionChain &Chain) const {
-  for(auto &C : PartialReductionChains) {
-    if(C.Reduction == I) {
+bool LoopVectorizationPlanner::getInstructionsPartialReduction(
+    Instruction *I, PartialReductionChain &Chain) const {
+  for (auto &C : PartialReductionChains) {
+    if (C.Reduction == I) {
       Chain = C;
       return true;
     }
@@ -7157,7 +7166,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
   } // end of switch.
 }
 
-void LoopVectorizationCostModel::collectValuesToIgnore(LoopVectorizationPlanner* LVP) {
+void LoopVectorizationCostModel::collectValuesToIgnore(
+    LoopVectorizationPlanner *LVP) {
   // Ignore ephemeral values.
   CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore);
 
@@ -7214,10 +7224,14 @@ void LoopVectorizationCostModel::collectValuesToIgnore(LoopVectorizationPlanner*
   }
 
   // Ignore any values that we know will be flattened
-  for(auto Chain : LVP->getPartialReductionChains()) {
-    SmallVector<Value*> PartialReductionValues{Chain.Reduction, Chain.BinOp, Chain.ExtendA, Chain.ExtendB, Chain.Accumulator};
-    ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
-    VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+  for (auto Chain : LVP->getPartialReductionChains()) {
+    SmallVector<Value *> PartialReductionValues{Chain.Reduction, Chain.BinOp,
+                                                Chain.ExtendA, Chain.ExtendB,
+                                                Chain.Accumulator};
+    ValuesToIgnore.insert(PartialReductionValues.begin(),
+                          PartialReductionValues.end());
+    VecValuesToIgnore.insert(PartialReductionValues.begin(),
+                             PartialReductionValues.end());
   }
 }
 
@@ -7345,15 +7359,15 @@ std::optional<VectorizationFactor>
 LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
 
-  for(auto ReductionVar : Legal->getReductionVars()) {
+  for (auto ReductionVar : Legal->getReductionVars()) {
     auto *ReductionExitInstr = ReductionVar.second.getLoopExitInstr();
-    if(isInstrPartialReduction(ReductionExitInstr)) {
+    if (isInstrPartialReduction(ReductionExitInstr)) {
       auto Chain = getPartialReductionInstrChain(ReductionExitInstr);
-      if(isPartialReductionChainValid(Chain, TTI)) 
+      if (isPartialReductionChainValid(Chain, TTI))
         PartialReductionChains.push_back(Chain);
     }
   }
-  
+
   CM.collectValuesToIgnore(this);
   CM.collectElementTypesForWidening();
 
@@ -8591,9 +8605,13 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
   return tryToWiden(Instr, Operands, VPBB);
 }
 
-VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
-    VFRange &Range, PartialReductionChain &Chain, ArrayRef<VPValue *> Operands) {
-  return new VPPartialReductionRecipe(*Chain.Reduction, make_range(Operands.begin(), Operands.end()), Chain.ScaleFactor);
+VPRecipeBase *
+VPRecipeBuilder::tryToCreatePartialReduction(VFRange &Range,
+                                             PartialReductionChain &Chain,
+                                             ArrayRef<VPValue *> Operands) {
+  return new VPPartialReductionRecipe(
+      *Chain.Reduction, make_range(Operands.begin(), Operands.end()),
+      Chain.ScaleFactor);
 }
 
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
@@ -8785,11 +8803,13 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
       VPRecipeBase *Recipe = nullptr;
 
       PartialReductionChain Chain;
-      if(getInstructionsPartialReduction(Instr, Chain)) 
-        Recipe = RecipeBuilder.tryToCreatePartialReduction(Range, Chain, Operands);
-      
+      if (getInstructionsPartialReduction(Instr, Chain))
+        Recipe =
+            RecipeBuilder.tryToCreatePartialReduction(Range, Chain, Operands);
+
       if (!Recipe)
-        Recipe = RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB);
+        Recipe =
+            RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB);
       if (!Recipe)
         Recipe = RecipeBuilder.handleReplication(Instr, Range);
 
@@ -8811,7 +8831,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         VPBB->appendRecipe(Recipe);
     }
 
-    for(auto &Recipe : *VPBB)
+    for (auto &Recipe : *VPBB)
       Recipe.postInsertionOp();
 
     VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 91b3a39a89959..f79da9b53441b 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -116,7 +116,9 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
-  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, PartialReductionChain &Chain, ArrayRef<VPValue*> Operands);
+  VPRecipeBase *tryToCreatePartialReduction(VFRange &Range,
+                                            PartialReductionChain &Chain,
+                                            ArrayRef<VPValue *> Operands);
 
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index a3711abdef1c1..97469a6d30d5a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -964,9 +964,7 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     FastMathFlagsTy(const FastMathFlags &FMF);
   };
 
-public:
   OperationType OpType;
-private:
 
   union {
     CmpInst::Predicate CmpPredicate;
@@ -1923,7 +1921,6 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   unsigned VFScaleFactor = 1;
 
 public:
-
   /// Create a new VPReductionPHIRecipe for the reduction \p Phi described by \p
   /// RdxDesc.
   VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc,
@@ -1938,9 +1935,9 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   ~VPReductionPHIRecipe() override = default;
 
   VPReductionPHIRecipe *clone() override {
-    auto *R =
-        new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc,
-                                 *getOperand(0), IsInLoop, IsOrdered, VFScaleFactor);
+    auto *R = new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()),
+                                       RdxDesc, *getOperand(0), IsInLoop,
+                                       IsOrdered, VFScaleFactor);
     R->addOperand(getBackedgeValue());
     return R;
   }
@@ -1951,9 +1948,7 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
     return R->getVPDefID() == VPDef::VPReductionPHISC;
   }
 
-  void SetVFScaleFactor(unsigned ScaleFactor) {
-    VFScaleFactor = ScaleFactor;
-  }
+  void SetVFScaleFactor(unsigned ScaleFactor) { VFScaleFactor = ScaleFactor; }
 
   /// Generate the phi/select nodes.
   void execute(VPTransformState &State) override;
@@ -1978,15 +1973,17 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
 class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
   unsigned Opcode;
   unsigned Scale;
+
 public:
   template <typename IterT>
-  VPPartialReductionRecipe(Instruction &I,
-                           iterator_range<IterT> Operands, unsigned Scale) : VPRecipeWithIRFlags(
-    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode()), Scale(Scale)
-  {}
+  VPPartialReductionRecipe(Instruction &I, iterator_range<IterT> Operands,
+                           unsigned Scale)
+      : VPRecipeWithIRFlags(VPDef::VPPartialReductionSC, Operands, I),
+        Opcode(I.getOpcode()), Scale(Scale) {}
   ~VPPartialReductionRecipe() override = default;
   VPPartialReductionRecipe *clone() override {
-    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands(), Scale);
+    auto *R =
+        new VPPartialReductionRecipe(*getUnderlyingInstr(), operands(), Scale);
     R->transferFlags(*this);
     return R;
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index f48baff51f290..50ad5dd57268d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -223,7 +223,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
   llvm_unreachable("Unhandled opcode");
 }
 
-Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
+Type *
+VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
   return R->getUnderlyingInstr()->getType();
 }
 
@@ -257,7 +258,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
             return inferScalarType(R->getOperand(0));
           })
           .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
-                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe, VPPartialReductionRecipe>(
+                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe,
+                VPPartialReductionRecipe>(
               [this](const auto *R) { return inferScalarTypeForRecipe(R); })
           .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
             // TODO: Use info from interleave group.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 648823f532927..497b5b2a4a4ad 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -259,17 +259,17 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
   State.setDebugLocFrom(getDebugLoc());
   auto &Builder = State.Builder;
 
-  switch(Opcode) {
+  switch (Opcode) {
   case Instruction::Add: {
 
     for (unsigned Part = 0; Part < State.UF; ++Part) {
-      Value* Mul = nullptr;
-      Value* Phi = nullptr;
-      SmallVector<Value*, 2> Ops;
+      Value *Mul = nullptr;
+      Value *Phi = nullptr;
+      SmallVector<Value *, 2> Ops;
       for (VPValue *VPOp : operands()) {
         auto *Op = State.get(VPOp, Part);
         Ops.push_back(Op);
-        if(isa<PHINode>(Op))
+        if (isa<PHINode>(Op))
           Phi = Op;
         else
           Mul = Op;
@@ -279,13 +279,13 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
 
       VectorType *FullTy = cast<VectorType>(Ops[0]->getType());
       auto EC = FullTy->getElementCount();
-      Type *RetTy = VectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale));
+      Type *RetTy = VectorType::get(FullTy->getScalarType(),
+                                    EC.divideCoefficientBy(Scale));
 
       Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
-      switch(Opcode) {
+      switch (Opcode) {
       case Instruction::Add:
-        PartialIntrinsic =
-            Intrinsic::experimental_vector_partial_reduce_add;
+        PartialIntrinsic = Intrinsic::experimental_vector_partial_reduce_add;
         break;
       default:
         llvm_unreachable("Opcode not handled");
@@ -293,7 +293,8 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
 
       assert(PartialIntrinsic != Intrinsic::not_intrinsic);
 
-      CallInst *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, {Phi, Mul}, nullptr, Twine("partial.reduce"));
+      CallInst *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, {Phi, Mul},
+                                            nullptr, Twine("partial.reduce"));
 
       // Use this vector value for all users of the original instruction.
       State.set(this, V, Part);
@@ -302,7 +303,8 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
     break;
   }
   default:
-    LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : " << Instruction::getOpcodeName(Opcode));
+    LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : "
+                      << Instruction::getOpcodeName(Opcode));
     llvm_unreachable("Unhandled instruction!");
   }
 }
@@ -313,7 +315,7 @@ void VPPartialReductionRecipe::postInsertionOp() {
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
-  VPSlotTracker &SlotTracker) const {
+                                     VPSlotTracker &SlotTracker) const {
   O << Indent << "PARTIAL-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = " << Instruction::getOpcodeName(Opcode);
@@ -2112,8 +2114,8 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
   // stage #1: We create a new vector PHI node with no incoming edges. We'll use
   // this value when we vectorize all of the instructions that use the PHI.
   bool ScalarPHI = VF.isScalar() || IsInLoop;
-  Type *VecTy = ScalarPHI ? StartV->getType()
-                          : VectorType::get(StartV->getType(), VF);
+  Type *VecTy =
+      ScalarPHI ? StartV->getType() : VectorType::get(StartV->getType(), VF);
 
   BasicBlock *HeaderBB = State.CFG.PrevBB;
   assert(State.CurrentVectorLoop->getHeader() == HeaderBB &&
@@ -2137,8 +2139,7 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
     } else {
       IRBuilderBase::InsertPointGuard IPBuilder(Builder);
       Builder.SetInsertPoint(VectorPH->getTerminator());
-      StartV = Iden =
-          Builder.CreateVectorSplat(VF, StartV, "minmax.ident");
+      StartV = Iden = Builder.CreateVectorSplat(VF, StartV, "minmax.ident");
     }
   } else {
     Iden = RdxDesc.getRecurrenceIdentity(RK, VecTy->getScalarType(),



More information about the llvm-commits mailing list