[llvm] [LoopVectorizer] Add support for partial reductions (PR #92418)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 18 06:03:39 PDT 2024


https://github.com/SamTebbs33 updated https://github.com/llvm/llvm-project/pull/92418

>From 0bf0a41250186d78b8f111b2ddc0f1ef9de8d1f8 Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Fri, 17 May 2024 11:15:11 +0100
Subject: [PATCH 01/16] [NFC] Test pre-commit

---
 .../CodeGen/AArch64/partial-reduce-sdot.ll    | 99 +++++++++++++++++++
 1 file changed, 99 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll

diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
new file mode 100644
index 0000000000000..fc6e3239a1b43
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes=loop-vectorize -force-vector-interleave=1 -S < %s | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define void @dotp(ptr %a, ptr %b) #0 {
+; CHECK-LABEL: define void @dotp(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 16
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 0, [[TMP1]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 16
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 0, [[TMP3]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 0, [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 16
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 16 x i8>, ptr [[TMP17]], align 1
+; CHECK-NEXT:    [[TMP19:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD2]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr i8, ptr [[TMP21]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
+; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP29:%.*]] = mul <vscale x 16 x i32> [[TMP27]], [[TMP19]]
+; CHECK-NEXT:    [[TMP14]] = add <vscale x 16 x i32> [[TMP29]], [[VEC_PHI]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP32]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP33:%.*]] = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP14]])
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+; CHECK:       scalar.ph:
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.cond.cleanup.loopexit:
+; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[TMP20:%.*]] = lshr i32 [[ADD_LCSSA]], 0
+; CHECK-NEXT:    ret void
+; CHECK:       for.body:
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ], [ [[ADD]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP18:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
+; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP18]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP22:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP22]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[CONV3]], [[CONV]]
+; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
+; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
+; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], 0
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+
+; uselistorder directives
+  uselistorder i32 %add, { 1, 0 }
+}
+
+attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
+;.
+; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
+; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
+; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
+; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
+;.

>From 6103cb3218c3fb4f634e40f34fddafeabf949cfd Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Fri, 17 May 2024 11:17:26 +0100
Subject: [PATCH 02/16] [LoopVectorizer] Add support for partial reductions

---
 llvm/include/llvm/IR/DerivedTypes.h           |  10 ++
 llvm/include/llvm/IR/Intrinsics.h             |   5 +-
 llvm/include/llvm/IR/Intrinsics.td            |  10 ++
 llvm/lib/IR/Function.cpp                      |  16 +++
 .../Transforms/Vectorize/LoopVectorize.cpp    | 122 ++++++++++++++++++
 .../Transforms/Vectorize/VPRecipeBuilder.h    |   2 +
 llvm/lib/Transforms/Vectorize/VPlan.h         |  43 +++++-
 .../Transforms/Vectorize/VPlanAnalysis.cpp    |   6 +-
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h |   2 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  80 +++++++++++-
 llvm/lib/Transforms/Vectorize/VPlanValue.h    |   1 +
 .../CodeGen/AArch64/partial-reduce-sdot.ll    |   7 +-
 12 files changed, 291 insertions(+), 13 deletions(-)

diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index 443fb7de3b821..866a01c9afebd 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -512,6 +512,16 @@ class VectorType : public Type {
                            EltCnt.divideCoefficientBy(2));
   }
 
+  /// This static method returns a VectorType with quarter as many elements as the
+  /// input type and the same element type.
+  static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
+    auto EltCnt = VTy->getElementCount();
+    assert(EltCnt.isKnownEven() &&
+           "Cannot halve vector with odd number of elements.");
+    return VectorType::get(VTy->getElementType(),
+                           EltCnt.divideCoefficientBy(4));
+  }
+
   /// This static method returns a VectorType with twice as many elements as the
   /// input type and the same element type.
   static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index f79df522dc805..d0c4bee59e889 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -135,6 +135,7 @@ namespace Intrinsic {
       ExtendArgument,
       TruncArgument,
       HalfVecArgument,
+      QuarterVecArgument,
       SameVecWidthArgument,
       VecOfAnyPtrsToElt,
       VecElementArgument,
@@ -164,7 +165,7 @@ namespace Intrinsic {
 
     unsigned getArgumentNumber() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument || Kind == VecElementArgument ||
              Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
              Kind == VecOfBitcastsToInt);
@@ -172,7 +173,7 @@ namespace Intrinsic {
     }
     ArgKind getArgumentKind() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument ||
              Kind == VecElementArgument || Kind == Subdivide2Argument ||
              Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 95dbd2854322d..98c5221ade2c8 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -321,6 +321,7 @@ def IIT_I4 : IIT_Int<4, 58>;
 def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
 def IIT_V6 : IIT_Vec<6, 60>;
 def IIT_V10 : IIT_Vec<10, 61>;
+def IIT_QUARTER_VEC_ARG : IIT_Base<62>;
 }
 
 defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
@@ -457,6 +458,9 @@ class LLVMVectorElementType<int num> : LLVMMatchType<num, IIT_VEC_ELEMENT>;
 class LLVMHalfElementsVectorType<int num>
   : LLVMMatchType<num, IIT_HALF_VEC_ARG>;
 
+class LLVMQuarterElementsVectorType<int num>
+  : LLVMMatchType<num, IIT_QUARTER_VEC_ARG>;
+
 // Match the type of another intrinsic parameter that is expected to be a
 // vector type (i.e. <N x iM>) but with each element subdivided to
 // form a vector with more elements that are smaller than the original.
@@ -2646,6 +2650,12 @@ def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatc
                                                                        [llvm_anyvector_ty, llvm_anyvector_ty],
                                                                        [IntrNoMem]>;
 
+//===-------------- Intrinsics to perform partial reduction ---------------===//
+
+def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMQuarterElementsVectorType<0>],
+                                                                       [llvm_anyvector_ty],
+                                                                       [IntrNoMem]>;
+
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
 
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 5fb348a8bbcd4..a1ce95c2c92bb 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1259,6 +1259,12 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
                                              ArgInfo));
     return;
   }
+  case IIT_QUARTER_VEC_ARG: {
+    unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
+    OutputTable.push_back(IITDescriptor::get(IITDescriptor::QuarterVecArgument,
+                                             ArgInfo));
+    return;
+  }
   case IIT_SAME_VEC_WIDTH_ARG: {
     unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
     OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument,
@@ -1423,6 +1429,9 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
   case IITDescriptor::HalfVecArgument:
     return VectorType::getHalfElementsVectorType(cast<VectorType>(
                                                   Tys[D.getArgumentNumber()]));
+  case IITDescriptor::QuarterVecArgument:  {
+    return VectorType::getQuarterElementsVectorType(cast<VectorType>(Tys[D.getArgumentNumber()]));
+  }
   case IITDescriptor::SameVecWidthArgument: {
     Type *EltTy = DecodeFixedType(Infos, Tys, Context);
     Type *Ty = Tys[D.getArgumentNumber()];
@@ -1650,6 +1659,13 @@ static bool matchIntrinsicType(
       return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
              VectorType::getHalfElementsVectorType(
                      cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    case IITDescriptor::QuarterVecArgument: {
+    if (D.getArgumentNumber() >= ArgTys.size())
+        return IsDeferredCheck || DeferCheck(Ty);
+      return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
+             VectorType::getQuarterElementsVectorType(
+                     cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    }
     case IITDescriptor::SameVecWidthArgument: {
       if (D.getArgumentNumber() >= ArgTys.size()) {
         // Defer check and subsequent check for the vector element type.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index d7b0240fd8a81..a184cb3ea0fb4 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2191,6 +2191,92 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
          Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
 }
 
+static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  Chain.push_back(Mul);
+  Chain.push_back(Ext0);
+  Chain.push_back(Ext1);
+  Chain.push_back(Instr->getOperand(1));
+}
+
+
+/// @param Instr The root instruction to scan
+static bool isInstrPartialReduction(Instruction *Instr) {
+  Value *ExpectedPhi;
+  Value *A, *B;
+  Value *InductionA, *InductionB;
+
+  using namespace llvm::PatternMatch;
+  auto Pattern = m_Add(
+    m_OneUse(m_Mul(
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(A),
+              m_Value(InductionA)))))),
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(B),
+              m_Value(InductionB))))))
+        )), m_Value(ExpectedPhi));
+
+  bool Matches = match(Instr, Pattern);
+
+  if(!Matches)
+    return false;
+
+  // Check that the two induction variable uses are to the same induction variable
+  if(InductionA != InductionB) {
+    LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  // Check that the extends extend to i32
+  if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the loads are loading i8
+  LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
+  LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
+  if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
+    LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
+    return false;
+  }
+
+  // Check that the add feeds into ExpectedPhi
+  PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
+  if(!PhiNode) {
+    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the first phi value is a zero initializer
+  ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
+  if(!ZeroInit || !ZeroInit->isZero()) {
+    LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the second phi value is the instruction we're looking at
+  Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
+  if(!MaybeAdd || MaybeAdd != Instr) {
+    LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  return true;
+}
+
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
 // vectorization. The loop needs to be annotated with #pragma omp simd
 // simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -4981,6 +5067,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
         return false;
   }
 
+  // Prevent epilogue vectorization if a partial reduction is involved
+  // TODO Is there a cleaner way to check this?
+  if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
+    return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
+  }))
+    return false;
+
   // Epilogue vectorization code has not been auditted to ensure it handles
   // non-latch exits properly.  It may be fine, but it needs auditted and
   // tested.
@@ -7106,6 +7199,17 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
     const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts();
     VecValuesToIgnore.insert(Casts.begin(), Casts.end());
   }
+
+  // Ignore any values that we know will be flattened
+  for(auto Reduction : this->Legal->getReductionVars()) {
+    auto &Recurrence = Reduction.second;
+    if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
+      SmallVector<Value*, 4> PartialReductionValues;
+      getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
+      ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+      VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+    }
+  }
 }
 
 void LoopVectorizationCostModel::collectInLoopReductions() {
@@ -8465,9 +8569,24 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
                                  *CI);
   }
 
+  if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
+    return PartialReduce;
+
   return tryToWiden(Instr, Operands, VPBB);
 }
 
+VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
+    VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
+
+  if(isInstrPartialReduction(Instr)) {
+    auto EC = ElementCount::getScalable(16);
+    if(std::find(Range.begin(), Range.end(), EC) == Range.end())
+      return nullptr;
+    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
+  }
+  return nullptr;
+}
+
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
                                                         ElementCount MaxVF) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -8677,6 +8796,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         VPBB->appendRecipe(Recipe);
     }
 
+    for(auto &Recipe : *VPBB)
+      Recipe.postInsertionOp();
+
     VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
     VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor());
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index b4c7ab02f928f..c439f221709e1 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -116,6 +116,8 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
+  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, Instruction* Instr, ArrayRef<VPValue*> Operands);
+
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
     assert(!Ingredient2Recipe.contains(I) &&
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index db10c7a240c7e..feef9da7e84f8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -784,6 +784,8 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
   /// \returns an iterator pointing to the element after the erased one
   iplist<VPRecipeBase>::iterator eraseFromParent();
 
+  virtual void postInsertionOp() {}
+
   /// Method to support type inquiry through isa, cast, and dyn_cast.
   static inline bool classof(const VPDef *D) {
     // All VPDefs are also VPRecipeBases.
@@ -1915,14 +1917,19 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   /// The phi is part of an ordered reduction. Requires IsInLoop to be true.
   bool IsOrdered;
 
+  /// The amount that the VF should be divided by during ::execute
+  unsigned VFScaleFactor = 1;
+
 public:
+
   /// Create a new VPReductionPHIRecipe for the reduction \p Phi described by \p
   /// RdxDesc.
   VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc,
                        VPValue &Start, bool IsInLoop = false,
-                       bool IsOrdered = false)
+                       bool IsOrdered = false, unsigned VFScaleFactor = 1)
       : VPHeaderPHIRecipe(VPDef::VPReductionPHISC, Phi, &Start),
-        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered) {
+        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered),
+        VFScaleFactor(VFScaleFactor) {
     assert((!IsOrdered || IsInLoop) && "IsOrdered requires IsInLoop");
   }
 
@@ -1931,7 +1938,7 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   VPReductionPHIRecipe *clone() override {
     auto *R =
         new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc,
-                                 *getOperand(0), IsInLoop, IsOrdered);
+                                 *getOperand(0), IsInLoop, IsOrdered, VFScaleFactor);
     R->addOperand(getBackedgeValue());
     return R;
   }
@@ -1942,6 +1949,10 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
     return R->getVPDefID() == VPDef::VPReductionPHISC;
   }
 
+  void SetVFScaleFactor(unsigned ScaleFactor) {
+    VFScaleFactor = ScaleFactor;
+  }
+
   /// Generate the phi/select nodes.
   void execute(VPTransformState &State) override;
 
@@ -1962,6 +1973,32 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   bool isInLoop() const { return IsInLoop; }
 };
 
+class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
+  unsigned Opcode;
+public:
+  template <typename IterT>
+  VPPartialReductionRecipe(Instruction &I,
+                           iterator_range<IterT> Operands) : VPRecipeWithIRFlags(
+    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode())
+  {}
+  ~VPPartialReductionRecipe() override = default;
+  VPPartialReductionRecipe *clone() override {
+    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands());
+    R->transferFlags(*this);
+    return R;
+  }
+  VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override;
+  void postInsertionOp() override;
+  unsigned getOpcode() { return Opcode; }
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+};
+
 /// A recipe for vectorizing a phi-node as a sequence of mask-based select
 /// instructions.
 class VPBlendRecipe : public VPSingleDefRecipe {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index eca5d1d4c5e1d..f48baff51f290 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -223,6 +223,10 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
   llvm_unreachable("Unhandled opcode");
 }
 
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
+  return R->getUnderlyingInstr()->getType();
+}
+
 Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
   if (Type *CachedTy = CachedTypes.lookup(V))
     return CachedTy;
@@ -253,7 +257,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
             return inferScalarType(R->getOperand(0));
           })
           .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
-                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe>(
+                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe, VPPartialReductionRecipe>(
               [this](const auto *R) { return inferScalarTypeForRecipe(R); })
           .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
             // TODO: Use info from interleave group.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 7d310b1b31b6f..3bd8d24542199 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -23,6 +23,7 @@ class VPWidenIntOrFpInductionRecipe;
 class VPWidenMemoryRecipe;
 struct VPWidenSelectRecipe;
 class VPReplicateRecipe;
+class VPPartialReductionRecipe;
 class Type;
 
 /// An analysis for type-inference for VPValues.
@@ -49,6 +50,7 @@ class VPTypeAnalysis {
   Type *inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R);
   Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R);
   Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPPartialReductionRecipe *R);
 
 public:
   VPTypeAnalysis(Type *CanonicalIVTy, LLVMContext &Ctx)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 788e6c96d32aa..b24762ec5ca68 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -255,6 +255,76 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB,
   insertBefore(BB, I);
 }
 
+void VPPartialReductionRecipe::execute(VPTransformState &State) {
+  State.setDebugLocFrom(getDebugLoc());
+  auto &Builder = State.Builder;
+
+  switch(Opcode) {
+  case Instruction::Add: {
+
+    for (unsigned Part = 0; Part < State.UF; ++Part) {
+      Value* Mul = nullptr;
+      Value* Phi = nullptr;
+      SmallVector<Value*, 2> Ops;
+      for (VPValue *VPOp : operands()) {
+        auto *Op = State.get(VPOp, Part);
+        Ops.push_back(Op);
+        if(isa<PHINode>(Op))
+          Phi = Op;
+        else
+          Mul = Op;
+      }
+
+      assert(Phi && Mul && "Phi and Mul must be set");
+      assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
+
+      ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
+      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), 4);
+
+      Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
+      switch(Opcode) {
+      case Instruction::Add:
+        PartialIntrinsic =
+            Intrinsic::experimental_vector_partial_reduce_add;
+        break;
+      default:
+        llvm_unreachable("Opcode not handled");
+      }
+
+      assert(PartialIntrinsic != Intrinsic::not_intrinsic);
+
+      Value *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, Mul, nullptr, Twine("partial.reduce"));
+      V = Builder.CreateNAryOp(Opcode, {V, Phi});
+      if (auto *VecOp = dyn_cast<Instruction>(V))
+        setFlags(VecOp);
+
+      // Use this vector value for all users of the original instruction.
+      State.set(this, V, Part);
+      State.addMetadata(V, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
+    }
+    break;
+  }
+  default:
+    LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : " << Instruction::getOpcodeName(Opcode));
+    llvm_unreachable("Unhandled instruction!");
+  }
+}
+
+void VPPartialReductionRecipe::postInsertionOp() {
+  cast<VPReductionPHIRecipe>(this->getOperand(1))->SetVFScaleFactor(4);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+  VPSlotTracker &SlotTracker) const {
+  O << Indent << "PARTIAL-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = " << Instruction::getOpcodeName(Opcode);
+  printFlags(O);
+  printOperands(O, SlotTracker);
+}
+#endif
+
 FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const {
   assert(OpType == OperationType::FPMathOp &&
          "recipe doesn't have fast math flags");
@@ -2033,6 +2103,8 @@ void VPFirstOrderRecurrencePHIRecipe::print(raw_ostream &O, const Twine &Indent,
 void VPReductionPHIRecipe::execute(VPTransformState &State) {
   auto &Builder = State.Builder;
 
+  auto VF = State.VF.divideCoefficientBy(VFScaleFactor);
+
   // Reductions do not have to start at zero. They can start with
   // any loop invariant values.
   VPValue *StartVPV = getStartValue();
@@ -2042,9 +2114,9 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
   // Phi nodes have cycles, so we need to vectorize them in two stages. This is
   // stage #1: We create a new vector PHI node with no incoming edges. We'll use
   // this value when we vectorize all of the instructions that use the PHI.
-  bool ScalarPHI = State.VF.isScalar() || IsInLoop;
+  bool ScalarPHI = VF.isScalar() || IsInLoop;
   Type *VecTy = ScalarPHI ? StartV->getType()
-                          : VectorType::get(StartV->getType(), State.VF);
+                          : VectorType::get(StartV->getType(), VF);
 
   BasicBlock *HeaderBB = State.CFG.PrevBB;
   assert(State.CurrentVectorLoop->getHeader() == HeaderBB &&
@@ -2069,14 +2141,14 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
       IRBuilderBase::InsertPointGuard IPBuilder(Builder);
       Builder.SetInsertPoint(VectorPH->getTerminator());
       StartV = Iden =
-          Builder.CreateVectorSplat(State.VF, StartV, "minmax.ident");
+          Builder.CreateVectorSplat(VF, StartV, "minmax.ident");
     }
   } else {
     Iden = RdxDesc.getRecurrenceIdentity(RK, VecTy->getScalarType(),
                                          RdxDesc.getFastMathFlags());
 
     if (!ScalarPHI) {
-      Iden = Builder.CreateVectorSplat(State.VF, Iden);
+      Iden = Builder.CreateVectorSplat(VF, Iden);
       IRBuilderBase::InsertPointGuard IPBuilder(Builder);
       Builder.SetInsertPoint(VectorPH->getTerminator());
       Constant *Zero = Builder.getInt32(0);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 8d945f6f2b8ea..ff34e1ded9f3f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -343,6 +343,7 @@ class VPDef {
     VPInstructionSC,
     VPInterleaveSC,
     VPReductionSC,
+    VPPartialReductionSC,
     VPReplicateSC,
     VPScalarCastSC,
     VPScalarIVStepsSC,
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
index fc6e3239a1b43..1eafd505b199e 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
@@ -22,7 +22,7 @@ define void @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP11]]
 ; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
@@ -33,12 +33,13 @@ define void @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
 ; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
 ; CHECK-NEXT:    [[TMP29:%.*]] = mul <vscale x 16 x i32> [[TMP27]], [[TMP19]]
-; CHECK-NEXT:    [[TMP14]] = add <vscale x 16 x i32> [[TMP29]], [[VEC_PHI]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP29]])
+; CHECK-NEXT:    [[TMP14]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
 ; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP32]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP33:%.*]] = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP14]])
+; CHECK-NEXT:    [[TMP33:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP14]])
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:

>From f90ddb3ff376b10857a551a513fb03d74757c116 Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Thu, 30 May 2024 15:04:55 +0100
Subject: [PATCH 03/16] [LoopVectorizer] Removed 4x restriction from partial
 reduction intrinsic

---
 llvm/include/llvm/IR/DerivedTypes.h              | 10 ----------
 llvm/include/llvm/IR/Intrinsics.h                |  5 ++---
 llvm/include/llvm/IR/Intrinsics.td               |  6 +-----
 llvm/lib/IR/Function.cpp                         | 16 ----------------
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp  |  2 +-
 llvm/lib/Transforms/Vectorize/VPlan.h            |  7 ++++---
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp   |  5 +++--
 llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll |  2 +-
 8 files changed, 12 insertions(+), 41 deletions(-)

diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index 866a01c9afebd..443fb7de3b821 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -512,16 +512,6 @@ class VectorType : public Type {
                            EltCnt.divideCoefficientBy(2));
   }
 
-  /// This static method returns a VectorType with quarter as many elements as the
-  /// input type and the same element type.
-  static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
-    auto EltCnt = VTy->getElementCount();
-    assert(EltCnt.isKnownEven() &&
-           "Cannot halve vector with odd number of elements.");
-    return VectorType::get(VTy->getElementType(),
-                           EltCnt.divideCoefficientBy(4));
-  }
-
   /// This static method returns a VectorType with twice as many elements as the
   /// input type and the same element type.
   static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index d0c4bee59e889..f79df522dc805 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -135,7 +135,6 @@ namespace Intrinsic {
       ExtendArgument,
       TruncArgument,
       HalfVecArgument,
-      QuarterVecArgument,
       SameVecWidthArgument,
       VecOfAnyPtrsToElt,
       VecElementArgument,
@@ -165,7 +164,7 @@ namespace Intrinsic {
 
     unsigned getArgumentNumber() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument ||
              Kind == SameVecWidthArgument || Kind == VecElementArgument ||
              Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
              Kind == VecOfBitcastsToInt);
@@ -173,7 +172,7 @@ namespace Intrinsic {
     }
     ArgKind getArgumentKind() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument ||
              Kind == SameVecWidthArgument ||
              Kind == VecElementArgument || Kind == Subdivide2Argument ||
              Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 98c5221ade2c8..0ece1b81a9b0e 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -321,7 +321,6 @@ def IIT_I4 : IIT_Int<4, 58>;
 def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
 def IIT_V6 : IIT_Vec<6, 60>;
 def IIT_V10 : IIT_Vec<10, 61>;
-def IIT_QUARTER_VEC_ARG : IIT_Base<62>;
 }
 
 defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
@@ -458,9 +457,6 @@ class LLVMVectorElementType<int num> : LLVMMatchType<num, IIT_VEC_ELEMENT>;
 class LLVMHalfElementsVectorType<int num>
   : LLVMMatchType<num, IIT_HALF_VEC_ARG>;
 
-class LLVMQuarterElementsVectorType<int num>
-  : LLVMMatchType<num, IIT_QUARTER_VEC_ARG>;
-
 // Match the type of another intrinsic parameter that is expected to be a
 // vector type (i.e. <N x iM>) but with each element subdivided to
 // form a vector with more elements that are smaller than the original.
@@ -2652,7 +2648,7 @@ def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatc
 
 //===-------------- Intrinsics to perform partial reduction ---------------===//
 
-def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMQuarterElementsVectorType<0>],
+def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
                                                                        [llvm_anyvector_ty],
                                                                        [IntrNoMem]>;
 
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index a1ce95c2c92bb..5fb348a8bbcd4 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1259,12 +1259,6 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
                                              ArgInfo));
     return;
   }
-  case IIT_QUARTER_VEC_ARG: {
-    unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
-    OutputTable.push_back(IITDescriptor::get(IITDescriptor::QuarterVecArgument,
-                                             ArgInfo));
-    return;
-  }
   case IIT_SAME_VEC_WIDTH_ARG: {
     unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
     OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument,
@@ -1429,9 +1423,6 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
   case IITDescriptor::HalfVecArgument:
     return VectorType::getHalfElementsVectorType(cast<VectorType>(
                                                   Tys[D.getArgumentNumber()]));
-  case IITDescriptor::QuarterVecArgument:  {
-    return VectorType::getQuarterElementsVectorType(cast<VectorType>(Tys[D.getArgumentNumber()]));
-  }
   case IITDescriptor::SameVecWidthArgument: {
     Type *EltTy = DecodeFixedType(Infos, Tys, Context);
     Type *Ty = Tys[D.getArgumentNumber()];
@@ -1659,13 +1650,6 @@ static bool matchIntrinsicType(
       return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
              VectorType::getHalfElementsVectorType(
                      cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
-    case IITDescriptor::QuarterVecArgument: {
-    if (D.getArgumentNumber() >= ArgTys.size())
-        return IsDeferredCheck || DeferCheck(Ty);
-      return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
-             VectorType::getQuarterElementsVectorType(
-                     cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
-    }
     case IITDescriptor::SameVecWidthArgument: {
       if (D.getArgumentNumber() >= ArgTys.size()) {
         // Defer check and subsequent check for the vector element type.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index a184cb3ea0fb4..0376a95699e67 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8582,7 +8582,7 @@ VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
     auto EC = ElementCount::getScalable(16);
     if(std::find(Range.begin(), Range.end(), EC) == Range.end())
       return nullptr;
-    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
+    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()), 4);
   }
   return nullptr;
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index feef9da7e84f8..19a337841be3e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1975,15 +1975,16 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
 
 class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
   unsigned Opcode;
+  unsigned Scale;
 public:
   template <typename IterT>
   VPPartialReductionRecipe(Instruction &I,
-                           iterator_range<IterT> Operands) : VPRecipeWithIRFlags(
-    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode())
+                           iterator_range<IterT> Operands, unsigned Scale) : VPRecipeWithIRFlags(
+    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode()), Scale(Scale)
   {}
   ~VPPartialReductionRecipe() override = default;
   VPPartialReductionRecipe *clone() override {
-    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands());
+    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands(), Scale);
     R->transferFlags(*this);
     return R;
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index b24762ec5ca68..a56408edad4a2 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -279,7 +279,8 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
       assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
 
       ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
-      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), 4);
+      auto EC = FullTy->getElementCount();
+      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale).getKnownMinValue());
 
       Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
       switch(Opcode) {
@@ -311,7 +312,7 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
 }
 
 void VPPartialReductionRecipe::postInsertionOp() {
-  cast<VPReductionPHIRecipe>(this->getOperand(1))->SetVFScaleFactor(4);
+  cast<VPReductionPHIRecipe>(this->getOperand(1))->SetVFScaleFactor(Scale);
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
index 1eafd505b199e..7883cfc05a13b 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
@@ -33,7 +33,7 @@ define void @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
 ; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
 ; CHECK-NEXT:    [[TMP29:%.*]] = mul <vscale x 16 x i32> [[TMP27]], [[TMP19]]
-; CHECK-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP29]])
+; CHECK-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 16 x i32> [[TMP29]])
 ; CHECK-NEXT:    [[TMP14]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
 ; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]

>From 987abef923510a78016b9f4aa4dd92a22a3ec72c Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:12:04 +0100
Subject: [PATCH 04/16] Commit of test files

---
 .../CodeGen/AArch64/partial-reduce-sdot-ir.ll | 99 +++++++++++++++++++
 ...rtial-reduce-sdot.ll => partial-reduce.ll} |  0
 2 files changed, 99 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
 rename llvm/test/CodeGen/AArch64/{partial-reduce-sdot.ll => partial-reduce.ll} (100%)

diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
new file mode 100644
index 0000000000000..3519ba58b3df3
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes="default<O3>" -force-vector-interleave=1 -S < %s | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define void @dotp(ptr %out, ptr %a, ptr %b, i64 %wide.trip.count) #0 {
+; CHECK-LABEL: define void @dotp(
+; CHECK-SAME: ptr nocapture writeonly [[OUT:%.*]], ptr nocapture readonly [[A:%.*]], ptr nocapture readonly [[B:%.*]], i64 [[WIDE_TRIP_COUNT:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[WIDE_TRIP_COUNT]], 1
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i64 [[TMP1]], 4
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], [[TMP2]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[FOR_BODY_PREHEADER:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP4:%.*]] = shl i64 [[TMP3]], 4
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], [[TMP4]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP5:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP6:%.*]] = shl i64 [[TMP5]], 4
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 16 x i8>, ptr [[TMP7]], align 1
+; CHECK-NEXT:    [[TMP8:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 16 x i8>, ptr [[TMP9]], align 1
+; CHECK-NEXT:    [[TMP10:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD1]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = mul nuw nsw <vscale x 16 x i32> [[TMP10]], [[TMP8]]
+; CHECK-NEXT:    [[TMP12]] = add <vscale x 16 x i32> [[TMP11]], [[VEC_PHI]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP6]]
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP14:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP12]])
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[N_MOD_VF]], 0
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[FOR_BODY_PREHEADER]]
+; CHECK:       for.body.preheader:
+; CHECK-NEXT:    [[INDVARS_IV_PH:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[N_VEC]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[ACC_010_PH:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP14]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.cond.cleanup.loopexit:
+; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP14]], [[MIDDLE_BLOCK]] ], [ [[ADD:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[TMP15:%.*]] = trunc i32 [[ADD_LCSSA]] to i8
+; CHECK-NEXT:    store i8 [[TMP15]], ptr [[OUT]], align 1
+; CHECK-NEXT:    ret void
+; CHECK:       for.body:
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[INDVARS_IV_PH]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[ADD]], [[FOR_BODY]] ], [ [[ACC_010_PH]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP16:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
+; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP16]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP17:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP17]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw nsw i32 [[CONV3]], [[CONV]]
+; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
+; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
+; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV]], [[WIDE_TRIP_COUNT]]
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = trunc i32 %add to i8
+  store i8 %0, ptr %out, align 1
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv, %wide.trip.count
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+
+; uselistorder directives
+  uselistorder i32 %add, { 1, 0 }
+}
+
+attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
+;.
+; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
+; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
+; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
+; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
+;.
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll b/llvm/test/CodeGen/AArch64/partial-reduce.ll
similarity index 100%
rename from llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
rename to llvm/test/CodeGen/AArch64/partial-reduce.ll

>From f1516f5df3efa8ec196a15a90e3ce9adab539c91 Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:12:58 +0100
Subject: [PATCH 05/16] Add generic decomposition of partial reduction
 intrinsic

---
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 2f36c2e86b1c3..222b82d434d0e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -957,6 +957,11 @@ END_TWO_BYTE_PACK()
   inline const APInt &getAsAPIntVal() const;
 
   const SDValue &getOperand(unsigned Num) const {
+    if(Num >= NumOperands) {
+      dbgs() << Num << ">=" << NumOperands << "\n";
+      printr(dbgs());
+      dbgs() << "\n";
+    }
     assert(Num < NumOperands && "Invalid child # of SDNode!");
     return OperandList[Num];
   }

>From 07f4398eee272b2a9b79ed40423eb58c910519bb Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:14:15 +0100
Subject: [PATCH 06/16] Add basic cost modeling for partial reductions

---
 llvm/include/llvm/Analysis/TargetTransformInfo.h   | 13 +++++++++++++
 .../llvm/Analysis/TargetTransformInfoImpl.h        |  7 +++++++
 llvm/include/llvm/CodeGen/BasicTTIImpl.h           |  7 +++++++
 llvm/lib/Analysis/TargetTransformInfo.cpp          |  6 ++++++
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp    | 14 +++++++++++++-
 5 files changed, 46 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index dcdd9f82cde8e..ba007ffbfd742 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1528,6 +1528,10 @@ class TargetTransformInfo {
       FastMathFlags FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
+  InstructionCost getPartialReductionCost(
+    unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
+    FastMathFlags FMF, TargetCostKind CostKind = TCK_RecipThroughput) const;
+
   /// \returns The cost of Intrinsic instructions. Analyses the real arguments.
   /// Three cases are handled: 1. scalar instruction 2. vector instruction
   /// 3. scalar instruction which is to be vectorized.
@@ -2093,6 +2097,9 @@ class TargetTransformInfo::Concept {
       unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
       FastMathFlags FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
+  virtual InstructionCost getPartialReductionCost(
+      unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
+      FastMathFlags FMF, TargetCostKind CostKind = TCK_RecipThroughput) = 0;
   virtual InstructionCost getMulAccReductionCost(
       bool IsUnsigned, Type *ResTy, VectorType *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
@@ -2776,6 +2783,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.getExtendedReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
                                          CostKind);
   }
+  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
+      VectorType *ResTy, VectorType *Ty, FastMathFlags FMF,
+      TargetCostKind CostKind = TCK_RecipThroughput) override {
+      return Impl.getPartialReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
+          CostKind);
+  }
   InstructionCost
   getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
                          TTI::TargetCostKind CostKind) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 01624de190d51..d96c9e1ef4aec 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -806,6 +806,13 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
+  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
+                                           VectorType *ResTy, VectorType *Ty,
+                                           FastMathFlags FMF,
+                                           TTI::TargetCostKind CostKind) const {
+    return InstructionCost::getMax();
+  }
+
   InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
                                          VectorType *Ty,
                                          TTI::TargetCostKind CostKind) const {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 4f1dc9f991c06..13fe0628cbc87 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -2609,6 +2609,13 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return RedCost + ExtCost;
   }
 
+  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
+                                          VectorType *ResTy, VectorType *Ty,
+                                          FastMathFlags FMD,
+                                          TTI::TargetCostKind CostKind) {
+    return InstructionCost::getMax();
+  }
+
   InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
                                          VectorType *Ty,
                                          TTI::TargetCostKind CostKind) {
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index c175d1737e54b..122a41d4791a2 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1167,6 +1167,12 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
                                            CostKind);
 }
 
+InstructionCost TargetTransformInfo::getPartialReductionCost(
+  unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
+  FastMathFlags FMF, TargetCostKind CostKind) const {
+  return TTIImpl->getPartialReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF, CostKind);
+}
+
 InstructionCost TargetTransformInfo::getMulAccReductionCost(
     bool IsUnsigned, Type *ResTy, VectorType *Ty,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0376a95699e67..ade40d2a07e69 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8582,7 +8582,19 @@ VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
     auto EC = ElementCount::getScalable(16);
     if(std::find(Range.begin(), Range.end(), EC) == Range.end())
       return nullptr;
-    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()), 4);
+
+    // Scale factor of 4 for sdot/udot.
+    unsigned Scale = 4;
+    VectorType* ResTy = ScalableVectorType::get(Instr->getType(), Scale);
+    VectorType* ValTy = ScalableVectorType::get(Instr->getType(), EC.getKnownMinValue());
+    using namespace llvm::PatternMatch;
+    bool IsUnsigned = match(Instr, m_Add(m_Mul(m_ZExt(m_Value()), m_ZExt(m_Value())), m_Value()));
+    auto RecipeCost = this->CM.TTI.getPartialReductionCost(Instr->getOpcode(), IsUnsigned, ResTy, ValTy, FastMathFlags::getFast());
+    // TODO replace with more informed cost check
+    if(RecipeCost == InstructionCost::getMax())
+      return nullptr;
+
+    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()), Scale);
   }
   return nullptr;
 }

>From 7e1339b748a26474fa26d97ee619312edd724d6f Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:14:55 +0100
Subject: [PATCH 07/16] Add missing support for sign-extends

---
 .../lib/Transforms/Vectorize/LoopVectorize.cpp | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index ade40d2a07e69..d94b882ee9827 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2203,6 +2203,9 @@ static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*
 }
 
 
+/// Checks if the given instruction the root of a partial reduction chain
+///
+/// Note: This currently only supports udot/sdot chains
 /// @param Instr The root instruction to scan
 static bool isInstrPartialReduction(Instruction *Instr) {
   Value *ExpectedPhi;
@@ -2212,12 +2215,12 @@ static bool isInstrPartialReduction(Instruction *Instr) {
   using namespace llvm::PatternMatch;
   auto Pattern = m_Add(
     m_OneUse(m_Mul(
-      m_OneUse(m_ZExt(
+      m_OneUse(m_ZExtOrSExt(
         m_OneUse(m_Load(
           m_GEP(
               m_Value(A),
               m_Value(InductionA)))))),
-      m_OneUse(m_ZExt(
+      m_OneUse(m_ZExtOrSExt(
         m_OneUse(m_Load(
           m_GEP(
               m_Value(B),
@@ -2236,8 +2239,13 @@ static bool isInstrPartialReduction(Instruction *Instr) {
   }
 
   Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
-  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
-  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+  Instruction *Ext0 = cast<CastInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<CastInst>(Mul->getOperand(1));
+
+  if(Ext0->getOpcode() != Ext1->getOpcode()) {
+    LLVM_DEBUG(dbgs() << "Extends aren't of the same type, cannot create a partial reduction.\n");
+    return false;
+  }
 
   // Check that the extends extend to i32
   if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
@@ -8579,6 +8587,8 @@ VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
     VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
 
   if(isInstrPartialReduction(Instr)) {
+    // Restricting this case to 16x means that, using a scale of 4, we avoid
+    // trying to generate illegal types such as <vscale x 2 x i32>
     auto EC = ElementCount::getScalable(16);
     if(std::find(Range.begin(), Range.end(), EC) == Range.end())
       return nullptr;

>From 794869488eba7a30858bb8f38f7e594cd40a487f Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:27:39 +0100
Subject: [PATCH 08/16] Remove debug statements

---
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 5 -----
 1 file changed, 5 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 222b82d434d0e..2f36c2e86b1c3 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -957,11 +957,6 @@ END_TWO_BYTE_PACK()
   inline const APInt &getAsAPIntVal() const;
 
   const SDValue &getOperand(unsigned Num) const {
-    if(Num >= NumOperands) {
-      dbgs() << Num << ">=" << NumOperands << "\n";
-      printr(dbgs());
-      dbgs() << "\n";
-    }
     assert(Num < NumOperands && "Invalid child # of SDNode!");
     return OperandList[Num];
   }

>From 2ce1a068c4c67c8ead8040ec02f44a06609f217d Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Fri, 14 Jun 2024 12:54:52 +0100
Subject: [PATCH 09/16] Redesign how the LoopVectorizer identifies partial
 reductions

---
 .../llvm/Analysis/TargetTransformInfo.h       |  27 +--
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  14 +-
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      |   7 -
 llvm/include/llvm/IR/Intrinsics.td            |   6 -
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  13 +-
 .../Vectorize/LoopVectorizationPlanner.h      |  23 +++
 .../Transforms/Vectorize/LoopVectorize.cpp    | 177 +++++++++---------
 .../Transforms/Vectorize/VPRecipeBuilder.h    |   2 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |   2 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  10 +-
 ...educe.ll => partial-reduce-dot-product.ll} |  54 +++---
 .../CodeGen/AArch64/partial-reduce-sdot-ir.ll |  99 ----------
 12 files changed, 167 insertions(+), 267 deletions(-)
 rename llvm/test/CodeGen/AArch64/{partial-reduce.ll => partial-reduce-dot-product.ll} (66%)
 delete mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index ba007ffbfd742..7f69c4eea715d 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1248,6 +1248,8 @@ class TargetTransformInfo {
   /// \return if target want to issue a prefetch in address space \p AS.
   bool shouldPrefetchAddressSpace(unsigned AS) const;
 
+  bool isPartialReductionSupported(const Instruction* ReductionInstr, Type* InputType, unsigned ScaleFactor, bool IsInputASignExtended, bool IsInputBSignExtended, const Instruction* BinOp = nullptr) const;
+    
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism
   /// and the number of execution units in the CPU.
@@ -1528,10 +1530,6 @@ class TargetTransformInfo {
       FastMathFlags FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
-  InstructionCost getPartialReductionCost(
-    unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
-    FastMathFlags FMF, TargetCostKind CostKind = TCK_RecipThroughput) const;
-
   /// \returns The cost of Intrinsic instructions. Analyses the real arguments.
   /// Three cases are handled: 1. scalar instruction 2. vector instruction
   /// 3. scalar instruction which is to be vectorized.
@@ -2016,6 +2014,11 @@ class TargetTransformInfo::Concept {
   /// \return if target want to issue a prefetch in address space \p AS.
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
+  virtual bool isPartialReductionSupported(const Instruction* ReductionInstr,
+      Type* InputType, unsigned ScaleFactor,
+      bool IsInputASignExtended, bool IsInputBSignExtended,
+      const Instruction* BinOp = nullptr) const = 0;
+    
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
@@ -2097,9 +2100,6 @@ class TargetTransformInfo::Concept {
       unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
       FastMathFlags FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
-  virtual InstructionCost getPartialReductionCost(
-      unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
-      FastMathFlags FMF, TargetCostKind CostKind = TCK_RecipThroughput) = 0;
   virtual InstructionCost getMulAccReductionCost(
       bool IsUnsigned, Type *ResTy, VectorType *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
@@ -2648,6 +2648,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.shouldPrefetchAddressSpace(AS);
   }
 
+  bool isPartialReductionSupported(const Instruction* ReductionInstr, Type* InputType, unsigned ScaleFactor,
+                                              bool IsInputASignExtended, bool IsInputBSignExtended,
+                                              const Instruction* BinOp = nullptr) const override
+  {
+      return Impl.isPartialReductionSupported(ReductionInstr, InputType, ScaleFactor, IsInputASignExtended, IsInputBSignExtended, BinOp);
+  }
+
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
@@ -2783,12 +2790,6 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.getExtendedReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
                                          CostKind);
   }
-  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
-      VectorType *ResTy, VectorType *Ty, FastMathFlags FMF,
-      TargetCostKind CostKind = TCK_RecipThroughput) override {
-      return Impl.getPartialReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
-          CostKind);
-  }
   InstructionCost
   getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
                          TTI::TargetCostKind CostKind) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index d96c9e1ef4aec..11773f59170a9 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -541,6 +541,13 @@ class TargetTransformInfoImplBase {
   bool enableWritePrefetching() const { return false; }
   bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
 
+  bool isPartialReductionSupported(const Instruction* ReductionInstr,
+      Type* InputType, unsigned ScaleFactor,
+      bool IsInputASignExtended, bool IsInputBSignExtended,
+      const Instruction* BinOp = nullptr) const {
+    return false;
+  }
+
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
 
   InstructionCost getArithmeticInstrCost(
@@ -806,13 +813,6 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
-  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
-                                           VectorType *ResTy, VectorType *Ty,
-                                           FastMathFlags FMF,
-                                           TTI::TargetCostKind CostKind) const {
-    return InstructionCost::getMax();
-  }
-
   InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
                                          VectorType *Ty,
                                          TTI::TargetCostKind CostKind) const {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 13fe0628cbc87..4f1dc9f991c06 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -2609,13 +2609,6 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return RedCost + ExtCost;
   }
 
-  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
-                                          VectorType *ResTy, VectorType *Ty,
-                                          FastMathFlags FMD,
-                                          TTI::TargetCostKind CostKind) {
-    return InstructionCost::getMax();
-  }
-
   InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
                                          VectorType *Ty,
                                          TTI::TargetCostKind CostKind) {
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 0ece1b81a9b0e..95dbd2854322d 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2646,12 +2646,6 @@ def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatc
                                                                        [llvm_anyvector_ty, llvm_anyvector_ty],
                                                                        [IntrNoMem]>;
 
-//===-------------- Intrinsics to perform partial reduction ---------------===//
-
-def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
-                                                                       [llvm_anyvector_ty],
-                                                                       [IntrNoMem]>;
-
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 122a41d4791a2..f74f1f087e597 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -825,6 +825,13 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
   return TTIImpl->shouldPrefetchAddressSpace(AS);
 }
 
+bool TargetTransformInfo::isPartialReductionSupported(
+    const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+    bool IsInputASignExtended, bool IsInputBSignExtended,
+    const Instruction *BinOp) const {
+  return TTIImpl->isPartialReductionSupported(ReductionInstr, InputType, ScaleFactor, IsInputASignExtended, IsInputBSignExtended, BinOp);
+}
+
 unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
@@ -1167,12 +1174,6 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
                                            CostKind);
 }
 
-InstructionCost TargetTransformInfo::getPartialReductionCost(
-  unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
-  FastMathFlags FMF, TargetCostKind CostKind) const {
-  return TTIImpl->getPartialReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF, CostKind);
-}
-
 InstructionCost TargetTransformInfo::getMulAccReductionCost(
     bool IsUnsigned, Type *ResTy, VectorType *Ty,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 93f6d1d82e244..e02a00f2912eb 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -293,6 +293,19 @@ struct FixedScalableVFPair {
   bool hasVector() const { return FixedVF.isVector() || ScalableVF.isVector(); }
 };
 
+struct PartialReductionChain {
+  Instruction *Reduction;
+  Instruction *BinOp;
+  Instruction *ExtendA;
+  Instruction *ExtendB;
+  
+  Value *InputA;
+  Value *InputB;
+  Value *Accumulator;
+
+  unsigned ScaleFactor;
+};
+
 /// Planner drives the vectorization process after having passed
 /// Legality checks.
 class LoopVectorizationPlanner {
@@ -331,6 +344,8 @@ class LoopVectorizationPlanner {
   /// Profitable vector factors.
   SmallVector<VectorizationFactor, 8> ProfitableVFs;
 
+  SmallVector<PartialReductionChain> PartialReductionChains;
+
   /// A builder used to construct the current plan.
   VPBuilder Builder;
 
@@ -398,6 +413,10 @@ class LoopVectorizationPlanner {
   VectorizationFactor
   selectEpilogueVectorizationFactor(const ElementCount MaxVF, unsigned IC);
 
+  SmallVector<PartialReductionChain> getPartialReductionChains() const {
+    return PartialReductionChains;
+  } 
+
 protected:
   /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive,
   /// according to the information gathered by Legal when it checked if it is
@@ -444,6 +463,10 @@ class LoopVectorizationPlanner {
   /// Determines if we have the infrastructure to vectorize the loop and its
   /// epilogue, assuming the main loop is vectorized by \p VF.
   bool isCandidateForEpilogueVectorization(const ElementCount VF) const;
+
+  bool getInstructionsPartialReduction(Instruction* I, PartialReductionChain &Chain) const;
+
+  
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index d94b882ee9827..ba7a2a1188953 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1137,7 +1137,7 @@ class LoopVectorizationCostModel {
   calculateRegisterUsage(ArrayRef<ElementCount> VFs);
 
   /// Collect values we want to ignore in the cost model.
-  void collectValuesToIgnore();
+  void collectValuesToIgnore(LoopVectorizationPlanner *LVP);
 
   /// Collect all element types in the loop for which widening is needed.
   void collectElementTypesForWidening();
@@ -2191,73 +2191,59 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
          Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
 }
 
-static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
-  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
-  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
-  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+static PartialReductionChain getPartialReductionInstrChain(Instruction *Instr) {
+  Instruction *BinOp = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<Instruction>(BinOp->getOperand(0));
+  Instruction *Ext1 = cast<Instruction>(BinOp->getOperand(1));
 
-  Chain.push_back(Mul);
-  Chain.push_back(Ext0);
-  Chain.push_back(Ext1);
-  Chain.push_back(Instr->getOperand(1));
+  PartialReductionChain Chain;
+  Chain.Reduction = Instr;
+  Chain.BinOp = BinOp;
+  Chain.ExtendA = Ext0;
+  Chain.ExtendB = Ext1;
+  Chain.InputA = Ext0->getOperand(0);
+  Chain.InputB = Ext1->getOperand(0);
+  Chain.Accumulator = Instr->getOperand(1);
+
+  unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
+  unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
+  Chain.ScaleFactor = ResultSizeBits / InputSizeBits; 
+  return Chain;
 }
 
 
 /// Checks if the given instruction the root of a partial reduction chain
 ///
-/// Note: This currently only supports udot/sdot chains
 /// @param Instr The root instruction to scan
 static bool isInstrPartialReduction(Instruction *Instr) {
   Value *ExpectedPhi;
   Value *A, *B;
-  Value *InductionA, *InductionB;
 
   using namespace llvm::PatternMatch;
-  auto Pattern = m_Add(
-    m_OneUse(m_Mul(
-      m_OneUse(m_ZExtOrSExt(
-        m_OneUse(m_Load(
-          m_GEP(
-              m_Value(A),
-              m_Value(InductionA)))))),
-      m_OneUse(m_ZExtOrSExt(
-        m_OneUse(m_Load(
-          m_GEP(
-              m_Value(B),
-              m_Value(InductionB))))))
-        )), m_Value(ExpectedPhi));
+  auto Pattern = m_BinOp(
+      m_OneUse(m_BinOp(
+        m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
+        m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
+      m_Value(ExpectedPhi));
 
   bool Matches = match(Instr, Pattern);
 
   if(!Matches)
     return false;
 
-  // Check that the two induction variable uses are to the same induction variable
-  if(InductionA != InductionB) {
-    LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
-    return false;
-  }
-
-  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
-  Instruction *Ext0 = cast<CastInst>(Mul->getOperand(0));
-  Instruction *Ext1 = cast<CastInst>(Mul->getOperand(1));
-
-  if(Ext0->getOpcode() != Ext1->getOpcode()) {
-    LLVM_DEBUG(dbgs() << "Extends aren't of the same type, cannot create a partial reduction.\n");
+  // Check that the extends extend from the same type
+  if(A->getType() != B->getType()) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot create a partial reduction.\n");
     return false;
   }
 
-  // Check that the extends extend to i32
-  if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
-    LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
-    return false;
-  }
+  // A and B are one-use, so the first user of each should be the respective extend 
+  Instruction *Ext0 = cast<CastInst>(*A->user_begin());
+  Instruction *Ext1 = cast<CastInst>(*B->user_begin());
 
-  // Check that the loads are loading i8
-  LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
-  LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
-  if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
-    LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
+  // Check that the extends extend to the same type
+  if(Ext0->getType() != Ext1->getType()) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend to the same type, cannot create a partial reduction.\n");
     return false;
   }
 
@@ -2268,23 +2254,32 @@ static bool isInstrPartialReduction(Instruction *Instr) {
     return false;
   }
 
-  // Check that the first phi value is a zero initializer
-  ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
-  if(!ZeroInit || !ZeroInit->isZero()) {
-    LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
-    return false;
-  }
-
   // Check that the second phi value is the instruction we're looking at
   Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
   if(!MaybeAdd || MaybeAdd != Instr) {
-    LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
+    LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot create a partial reduction.\n");
     return false;
   }
 
   return true;
 }
 
+static bool isPartialReductionChainValid(PartialReductionChain &Chain, const TargetTransformInfo &TTI) {
+  if(Chain.Reduction->getOpcode() != Instruction::Add)
+    return false;
+
+  unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
+  unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
+
+  if(ResultSizeBits < InputSizeBits || (ResultSizeBits % InputSizeBits) != 0)
+    return false;
+  
+  bool IsASignExtended = isa<SExtInst>(Chain.ExtendA);
+  bool IsBSignExtended = isa<SExtInst>(Chain.ExtendB);
+
+  return TTI.isPartialReductionSupported(Chain.Reduction, Chain.InputA->getType(), Chain.ScaleFactor, IsASignExtended, IsBSignExtended, Chain.BinOp);
+}
+
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
 // vectorization. The loop needs to be annotated with #pragma omp simd
 // simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -5091,6 +5086,16 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
   return true;
 }
 
+bool LoopVectorizationPlanner::getInstructionsPartialReduction(Instruction *I, PartialReductionChain &Chain) const {
+  for(auto &C : PartialReductionChains) {
+    if(C.Reduction == I) {
+      Chain = C;
+      return true;
+    }
+  }
+  return false;
+}
+
 bool LoopVectorizationCostModel::isEpilogueVectorizationProfitable(
     const ElementCount VF) const {
   // FIXME: We need a much better cost-model to take different parameters such
@@ -7152,7 +7157,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
   } // end of switch.
 }
 
-void LoopVectorizationCostModel::collectValuesToIgnore() {
+void LoopVectorizationCostModel::collectValuesToIgnore(LoopVectorizationPlanner* LVP) {
   // Ignore ephemeral values.
   CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore);
 
@@ -7209,14 +7214,10 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
   }
 
   // Ignore any values that we know will be flattened
-  for(auto Reduction : this->Legal->getReductionVars()) {
-    auto &Recurrence = Reduction.second;
-    if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
-      SmallVector<Value*, 4> PartialReductionValues;
-      getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
-      ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
-      VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
-    }
+  for(auto Chain : LVP->getPartialReductionChains()) {
+    SmallVector<Value*> PartialReductionValues{Chain.Reduction, Chain.BinOp, Chain.ExtendA, Chain.ExtendB, Chain.Accumulator};
+    ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+    VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
   }
 }
 
@@ -7343,7 +7344,17 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
 std::optional<VectorizationFactor>
 LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
-  CM.collectValuesToIgnore();
+
+  for(auto ReductionVar : Legal->getReductionVars()) {
+    auto *ReductionExitInstr = ReductionVar.second.getLoopExitInstr();
+    if(isInstrPartialReduction(ReductionExitInstr)) {
+      auto Chain = getPartialReductionInstrChain(ReductionExitInstr);
+      if(isPartialReductionChainValid(Chain, TTI)) 
+        PartialReductionChains.push_back(Chain);
+    }
+  }
+  
+  CM.collectValuesToIgnore(this);
   CM.collectElementTypesForWidening();
 
   FixedScalableVFPair MaxFactors = CM.computeMaxVF(UserVF, UserIC);
@@ -8577,36 +8588,12 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
                                  *CI);
   }
 
-  if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
-    return PartialReduce;
-
   return tryToWiden(Instr, Operands, VPBB);
 }
 
 VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
-    VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
-
-  if(isInstrPartialReduction(Instr)) {
-    // Restricting this case to 16x means that, using a scale of 4, we avoid
-    // trying to generate illegal types such as <vscale x 2 x i32>
-    auto EC = ElementCount::getScalable(16);
-    if(std::find(Range.begin(), Range.end(), EC) == Range.end())
-      return nullptr;
-
-    // Scale factor of 4 for sdot/udot.
-    unsigned Scale = 4;
-    VectorType* ResTy = ScalableVectorType::get(Instr->getType(), Scale);
-    VectorType* ValTy = ScalableVectorType::get(Instr->getType(), EC.getKnownMinValue());
-    using namespace llvm::PatternMatch;
-    bool IsUnsigned = match(Instr, m_Add(m_Mul(m_ZExt(m_Value()), m_ZExt(m_Value())), m_Value()));
-    auto RecipeCost = this->CM.TTI.getPartialReductionCost(Instr->getOpcode(), IsUnsigned, ResTy, ValTy, FastMathFlags::getFast());
-    // TODO replace with more informed cost check
-    if(RecipeCost == InstructionCost::getMax())
-      return nullptr;
-
-    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()), Scale);
-  }
-  return nullptr;
+    VFRange &Range, PartialReductionChain &Chain, ArrayRef<VPValue *> Operands) {
+  return new VPPartialReductionRecipe(*Chain.Reduction, make_range(Operands.begin(), Operands.end()), Chain.ScaleFactor);
 }
 
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
@@ -8795,8 +8782,14 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
           Legal->isInvariantAddressOfReduction(SI->getPointerOperand()))
         continue;
 
-      VPRecipeBase *Recipe =
-          RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB);
+      VPRecipeBase *Recipe = nullptr;
+
+      PartialReductionChain Chain;
+      if(getInstructionsPartialReduction(Instr, Chain)) 
+        Recipe = RecipeBuilder.tryToCreatePartialReduction(Range, Chain, Operands);
+      
+      if (!Recipe)
+        Recipe = RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB);
       if (!Recipe)
         Recipe = RecipeBuilder.handleReplication(Instr, Range);
 
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index c439f221709e1..91b3a39a89959 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -116,7 +116,7 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
-  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, Instruction* Instr, ArrayRef<VPValue*> Operands);
+  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, PartialReductionChain &Chain, ArrayRef<VPValue*> Operands);
 
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 19a337841be3e..a3711abdef1c1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -964,7 +964,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     FastMathFlagsTy(const FastMathFlags &FMF);
   };
 
+public:
   OperationType OpType;
+private:
 
   union {
     CmpInst::Predicate CmpPredicate;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index a56408edad4a2..648823f532927 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -276,11 +276,10 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
       }
 
       assert(Phi && Mul && "Phi and Mul must be set");
-      assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
 
-      ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
+      VectorType *FullTy = cast<VectorType>(Ops[0]->getType());
       auto EC = FullTy->getElementCount();
-      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale).getKnownMinValue());
+      Type *RetTy = VectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale));
 
       Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
       switch(Opcode) {
@@ -294,10 +293,7 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
 
       assert(PartialIntrinsic != Intrinsic::not_intrinsic);
 
-      Value *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, Mul, nullptr, Twine("partial.reduce"));
-      V = Builder.CreateNAryOp(Opcode, {V, Phi});
-      if (auto *VecOp = dyn_cast<Instruction>(V))
-        setFlags(VecOp);
+      CallInst *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, {Phi, Mul}, nullptr, Twine("partial.reduce"));
 
       // Use this vector value for all users of the original instruction.
       State.set(this, V, Part);
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
similarity index 66%
rename from llvm/test/CodeGen/AArch64/partial-reduce.ll
rename to llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index 7883cfc05a13b..2907bf903c031 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -9,56 +9,55 @@ define void @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 16
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 0, [[TMP1]]
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
 ; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 16
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
 ; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 0, [[TMP3]]
 ; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 0, [[N_MOD_VF]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 16
+; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 0
-; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP11]]
-; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 16 x i8>, ptr [[TMP17]], align 1
-; CHECK-NEXT:    [[TMP19:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD2]] to <vscale x 16 x i32>
-; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP11]]
-; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr i8, ptr [[TMP21]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
-; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
-; CHECK-NEXT:    [[TMP29:%.*]] = mul <vscale x 16 x i32> [[TMP27]], [[TMP19]]
-; CHECK-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 16 x i32> [[TMP29]])
-; CHECK-NEXT:    [[TMP14]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE]], [[VEC_PHI]]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP7]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP8]], align 1
+; CHECK-NEXT:    [[TMP9:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP10]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT:    [[TMP12:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD1]] to <vscale x 4 x i32>
+; CHECK-NEXT:    [[TMP13:%.*]] = mul <vscale x 4 x i32> [[TMP12]], [[TMP9]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE]] = add <vscale x 4 x i32> [[TMP13]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
-; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP32]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP33:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP14]])
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
+; CHECK-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]])
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
 ; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP15]], [[MIDDLE_BLOCK]] ]
 ; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.cond.cleanup.loopexit:
-; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    [[TMP20:%.*]] = lshr i32 [[ADD_LCSSA]], 0
+; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[TMP15]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[TMP16:%.*]] = lshr i32 [[ADD_LCSSA]], 0
 ; CHECK-NEXT:    ret void
 ; CHECK:       for.body:
 ; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
 ; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ], [ [[ADD]], [[FOR_BODY]] ]
 ; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP18:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
-; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP18]] to i32
+; CHECK-NEXT:    [[TMP17:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
+; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP17]] to i32
 ; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP22:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
-; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP22]] to i32
+; CHECK-NEXT:    [[TMP18:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP18]] to i32
 ; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[CONV3]], [[CONV]]
 ; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
 ; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
@@ -86,9 +85,6 @@ for.body:                                         ; preds = %for.body, %entry
   %indvars.iv.next = add i64 %indvars.iv, 1
   %exitcond.not = icmp eq i64 %indvars.iv.next, 0
   br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
-
-; uselistorder directives
-  uselistorder i32 %add, { 1, 0 }
 }
 
 attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
deleted file mode 100644
index 3519ba58b3df3..0000000000000
--- a/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
+++ /dev/null
@@ -1,99 +0,0 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt -passes="default<O3>" -force-vector-interleave=1 -S < %s | FileCheck %s
-
-target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
-target triple = "aarch64-none-unknown-elf"
-
-define void @dotp(ptr %out, ptr %a, ptr %b, i64 %wide.trip.count) #0 {
-; CHECK-LABEL: define void @dotp(
-; CHECK-SAME: ptr nocapture writeonly [[OUT:%.*]], ptr nocapture readonly [[A:%.*]], ptr nocapture readonly [[B:%.*]], i64 [[WIDE_TRIP_COUNT:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[WIDE_TRIP_COUNT]], 1
-; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP2:%.*]] = shl i64 [[TMP1]], 4
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], [[TMP2]]
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[FOR_BODY_PREHEADER:%.*]], label [[VECTOR_PH:%.*]]
-; CHECK:       vector.ph:
-; CHECK-NEXT:    [[TMP3:%.*]] = tail call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP4:%.*]] = shl i64 [[TMP3]], 4
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], [[TMP4]]
-; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
-; CHECK-NEXT:    [[TMP5:%.*]] = tail call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP6:%.*]] = shl i64 [[TMP5]], 4
-; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
-; CHECK:       vector.body:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 16 x i8>, ptr [[TMP7]], align 1
-; CHECK-NEXT:    [[TMP8:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD]] to <vscale x 16 x i32>
-; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 16 x i8>, ptr [[TMP9]], align 1
-; CHECK-NEXT:    [[TMP10:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD1]] to <vscale x 16 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul nuw nsw <vscale x 16 x i32> [[TMP10]], [[TMP8]]
-; CHECK-NEXT:    [[TMP12]] = add <vscale x 16 x i32> [[TMP11]], [[VEC_PHI]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP6]]
-; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
-; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP14:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP12]])
-; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[N_MOD_VF]], 0
-; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[FOR_BODY_PREHEADER]]
-; CHECK:       for.body.preheader:
-; CHECK-NEXT:    [[INDVARS_IV_PH:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[N_VEC]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    [[ACC_010_PH:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP14]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
-; CHECK:       for.cond.cleanup.loopexit:
-; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP14]], [[MIDDLE_BLOCK]] ], [ [[ADD:%.*]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[TMP15:%.*]] = trunc i32 [[ADD_LCSSA]] to i8
-; CHECK-NEXT:    store i8 [[TMP15]], ptr [[OUT]], align 1
-; CHECK-NEXT:    ret void
-; CHECK:       for.body:
-; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[INDVARS_IV_PH]], [[FOR_BODY_PREHEADER]] ]
-; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[ADD]], [[FOR_BODY]] ], [ [[ACC_010_PH]], [[FOR_BODY_PREHEADER]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP16:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
-; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP16]] to i32
-; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP17:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
-; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP17]] to i32
-; CHECK-NEXT:    [[MUL:%.*]] = mul nuw nsw i32 [[CONV3]], [[CONV]]
-; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
-; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
-; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV]], [[WIDE_TRIP_COUNT]]
-; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
-;
-entry:
-  br label %for.body
-
-for.cond.cleanup.loopexit:                        ; preds = %for.body
-  %0 = trunc i32 %add to i8
-  store i8 %0, ptr %out, align 1
-  ret void
-
-for.body:                                         ; preds = %for.body, %entry
-  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
-  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
-  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
-  %1 = load i8, ptr %arrayidx, align 1
-  %conv = zext i8 %1 to i32
-  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
-  %2 = load i8, ptr %arrayidx2, align 1
-  %conv3 = zext i8 %2 to i32
-  %mul = mul i32 %conv3, %conv
-  %add = add i32 %mul, %acc.010
-  %indvars.iv.next = add i64 %indvars.iv, 1
-  %exitcond.not = icmp eq i64 %indvars.iv, %wide.trip.count
-  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
-
-; uselistorder directives
-  uselistorder i32 %add, { 1, 0 }
-}
-
-attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
-;.
-; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
-; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
-; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
-; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
-;.

>From c145094a508f525a535d29c9e387b4b5843612ee Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Thu, 4 Jul 2024 16:07:43 +0100
Subject: [PATCH 10/16] Format

---
 .../llvm/Analysis/TargetTransformInfo.h       |  28 +++--
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   9 +-
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   4 +-
 .../Vectorize/LoopVectorizationPlanner.h      |   9 +-
 .../Transforms/Vectorize/LoopVectorize.cpp    | 110 +++++++++++-------
 .../Transforms/Vectorize/VPRecipeBuilder.h    |   4 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |  25 ++--
 .../Transforms/Vectorize/VPlanAnalysis.cpp    |   6 +-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  33 +++---
 9 files changed, 129 insertions(+), 99 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 7f69c4eea715d..3990aef7f5e56 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1248,8 +1248,12 @@ class TargetTransformInfo {
   /// \return if target want to issue a prefetch in address space \p AS.
   bool shouldPrefetchAddressSpace(unsigned AS) const;
 
-  bool isPartialReductionSupported(const Instruction* ReductionInstr, Type* InputType, unsigned ScaleFactor, bool IsInputASignExtended, bool IsInputBSignExtended, const Instruction* BinOp = nullptr) const;
-    
+  bool isPartialReductionSupported(const Instruction *ReductionInstr,
+                                   Type *InputType, unsigned ScaleFactor,
+                                   bool IsInputASignExtended,
+                                   bool IsInputBSignExtended,
+                                   const Instruction *BinOp = nullptr) const;
+
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism
   /// and the number of execution units in the CPU.
@@ -2014,11 +2018,11 @@ class TargetTransformInfo::Concept {
   /// \return if target want to issue a prefetch in address space \p AS.
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
-  virtual bool isPartialReductionSupported(const Instruction* ReductionInstr,
-      Type* InputType, unsigned ScaleFactor,
+  virtual bool isPartialReductionSupported(
+      const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
       bool IsInputASignExtended, bool IsInputBSignExtended,
-      const Instruction* BinOp = nullptr) const = 0;
-    
+      const Instruction *BinOp = nullptr) const = 0;
+
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
@@ -2648,11 +2652,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.shouldPrefetchAddressSpace(AS);
   }
 
-  bool isPartialReductionSupported(const Instruction* ReductionInstr, Type* InputType, unsigned ScaleFactor,
-                                              bool IsInputASignExtended, bool IsInputBSignExtended,
-                                              const Instruction* BinOp = nullptr) const override
-  {
-      return Impl.isPartialReductionSupported(ReductionInstr, InputType, ScaleFactor, IsInputASignExtended, IsInputBSignExtended, BinOp);
+  bool isPartialReductionSupported(
+      const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+      bool IsInputASignExtended, bool IsInputBSignExtended,
+      const Instruction *BinOp = nullptr) const override {
+    return Impl.isPartialReductionSupported(ReductionInstr, InputType,
+                                            ScaleFactor, IsInputASignExtended,
+                                            IsInputBSignExtended, BinOp);
   }
 
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 11773f59170a9..5aa8e17fea94f 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -541,10 +541,11 @@ class TargetTransformInfoImplBase {
   bool enableWritePrefetching() const { return false; }
   bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
 
-  bool isPartialReductionSupported(const Instruction* ReductionInstr,
-      Type* InputType, unsigned ScaleFactor,
-      bool IsInputASignExtended, bool IsInputBSignExtended,
-      const Instruction* BinOp = nullptr) const {
+  bool isPartialReductionSupported(const Instruction *ReductionInstr,
+                                   Type *InputType, unsigned ScaleFactor,
+                                   bool IsInputASignExtended,
+                                   bool IsInputBSignExtended,
+                                   const Instruction *BinOp = nullptr) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index f74f1f087e597..75784bb19e4c4 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -829,7 +829,9 @@ bool TargetTransformInfo::isPartialReductionSupported(
     const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
     bool IsInputASignExtended, bool IsInputBSignExtended,
     const Instruction *BinOp) const {
-  return TTIImpl->isPartialReductionSupported(ReductionInstr, InputType, ScaleFactor, IsInputASignExtended, IsInputBSignExtended, BinOp);
+  return TTIImpl->isPartialReductionSupported(ReductionInstr, InputType,
+                                              ScaleFactor, IsInputASignExtended,
+                                              IsInputBSignExtended, BinOp);
 }
 
 unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index e02a00f2912eb..01f6318e13a3a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -298,7 +298,7 @@ struct PartialReductionChain {
   Instruction *BinOp;
   Instruction *ExtendA;
   Instruction *ExtendB;
-  
+
   Value *InputA;
   Value *InputB;
   Value *Accumulator;
@@ -415,7 +415,7 @@ class LoopVectorizationPlanner {
 
   SmallVector<PartialReductionChain> getPartialReductionChains() const {
     return PartialReductionChains;
-  } 
+  }
 
 protected:
   /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive,
@@ -464,9 +464,8 @@ class LoopVectorizationPlanner {
   /// epilogue, assuming the main loop is vectorized by \p VF.
   bool isCandidateForEpilogueVectorization(const ElementCount VF) const;
 
-  bool getInstructionsPartialReduction(Instruction* I, PartialReductionChain &Chain) const;
-
-  
+  bool getInstructionsPartialReduction(Instruction *I,
+                                       PartialReductionChain &Chain) const;
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index ba7a2a1188953..b198d20775dfa 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2207,11 +2207,10 @@ static PartialReductionChain getPartialReductionInstrChain(Instruction *Instr) {
 
   unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
   unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
-  Chain.ScaleFactor = ResultSizeBits / InputSizeBits; 
+  Chain.ScaleFactor = ResultSizeBits / InputSizeBits;
   return Chain;
 }
 
-
 /// Checks if the given instruction the root of a partial reduction chain
 ///
 /// @param Instr The root instruction to scan
@@ -2220,64 +2219,71 @@ static bool isInstrPartialReduction(Instruction *Instr) {
   Value *A, *B;
 
   using namespace llvm::PatternMatch;
-  auto Pattern = m_BinOp(
-      m_OneUse(m_BinOp(
-        m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
-        m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
-      m_Value(ExpectedPhi));
+  auto Pattern =
+      m_BinOp(m_OneUse(m_BinOp(m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
+                               m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
+              m_Value(ExpectedPhi));
 
   bool Matches = match(Instr, Pattern);
 
-  if(!Matches)
+  if (!Matches)
     return false;
 
   // Check that the extends extend from the same type
-  if(A->getType() != B->getType()) {
-    LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot create a partial reduction.\n");
+  if (A->getType() != B->getType()) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot "
+                         "create a partial reduction.\n");
     return false;
   }
 
-  // A and B are one-use, so the first user of each should be the respective extend 
+  // A and B are one-use, so the first user of each should be the respective
+  // extend
   Instruction *Ext0 = cast<CastInst>(*A->user_begin());
   Instruction *Ext1 = cast<CastInst>(*B->user_begin());
 
   // Check that the extends extend to the same type
-  if(Ext0->getType() != Ext1->getType()) {
-    LLVM_DEBUG(dbgs() << "Extends don't extend to the same type, cannot create a partial reduction.\n");
+  if (Ext0->getType() != Ext1->getType()) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend to the same type, cannot create "
+                         "a partial reduction.\n");
     return false;
   }
 
   // Check that the add feeds into ExpectedPhi
   PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
-  if(!PhiNode) {
-    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
+  if (!PhiNode) {
+    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a "
+                         "partial reduction.\n");
     return false;
   }
 
   // Check that the second phi value is the instruction we're looking at
   Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
-  if(!MaybeAdd || MaybeAdd != Instr) {
-    LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot create a partial reduction.\n");
+  if (!MaybeAdd || MaybeAdd != Instr) {
+    LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot "
+                         "create a partial reduction.\n");
     return false;
   }
 
   return true;
 }
 
-static bool isPartialReductionChainValid(PartialReductionChain &Chain, const TargetTransformInfo &TTI) {
-  if(Chain.Reduction->getOpcode() != Instruction::Add)
+static bool isPartialReductionChainValid(PartialReductionChain &Chain,
+                                         const TargetTransformInfo &TTI) {
+  if (Chain.Reduction->getOpcode() != Instruction::Add)
     return false;
 
   unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
   unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
 
-  if(ResultSizeBits < InputSizeBits || (ResultSizeBits % InputSizeBits) != 0)
+  if (ResultSizeBits < InputSizeBits || (ResultSizeBits % InputSizeBits) != 0)
     return false;
-  
+
   bool IsASignExtended = isa<SExtInst>(Chain.ExtendA);
   bool IsBSignExtended = isa<SExtInst>(Chain.ExtendB);
 
-  return TTI.isPartialReductionSupported(Chain.Reduction, Chain.InputA->getType(), Chain.ScaleFactor, IsASignExtended, IsBSignExtended, Chain.BinOp);
+  return TTI.isPartialReductionSupported(
+      Chain.Reduction, Chain.InputA->getType(), Chain.ScaleFactor,
+      IsASignExtended, IsBSignExtended, Chain.BinOp);
 }
 
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
@@ -5072,9 +5078,11 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
 
   // Prevent epilogue vectorization if a partial reduction is involved
   // TODO Is there a cleaner way to check this?
-  if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
-    return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
-  }))
+  if (any_of(Legal->getReductionVars(),
+             [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
+               return isInstrPartialReduction(
+                   Reduction.second.getLoopExitInstr());
+             }))
     return false;
 
   // Epilogue vectorization code has not been auditted to ensure it handles
@@ -5086,9 +5094,10 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
   return true;
 }
 
-bool LoopVectorizationPlanner::getInstructionsPartialReduction(Instruction *I, PartialReductionChain &Chain) const {
-  for(auto &C : PartialReductionChains) {
-    if(C.Reduction == I) {
+bool LoopVectorizationPlanner::getInstructionsPartialReduction(
+    Instruction *I, PartialReductionChain &Chain) const {
+  for (auto &C : PartialReductionChains) {
+    if (C.Reduction == I) {
       Chain = C;
       return true;
     }
@@ -7157,7 +7166,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
   } // end of switch.
 }
 
-void LoopVectorizationCostModel::collectValuesToIgnore(LoopVectorizationPlanner* LVP) {
+void LoopVectorizationCostModel::collectValuesToIgnore(
+    LoopVectorizationPlanner *LVP) {
   // Ignore ephemeral values.
   CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore);
 
@@ -7214,10 +7224,14 @@ void LoopVectorizationCostModel::collectValuesToIgnore(LoopVectorizationPlanner*
   }
 
   // Ignore any values that we know will be flattened
-  for(auto Chain : LVP->getPartialReductionChains()) {
-    SmallVector<Value*> PartialReductionValues{Chain.Reduction, Chain.BinOp, Chain.ExtendA, Chain.ExtendB, Chain.Accumulator};
-    ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
-    VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+  for (auto Chain : LVP->getPartialReductionChains()) {
+    SmallVector<Value *> PartialReductionValues{Chain.Reduction, Chain.BinOp,
+                                                Chain.ExtendA, Chain.ExtendB,
+                                                Chain.Accumulator};
+    ValuesToIgnore.insert(PartialReductionValues.begin(),
+                          PartialReductionValues.end());
+    VecValuesToIgnore.insert(PartialReductionValues.begin(),
+                             PartialReductionValues.end());
   }
 }
 
@@ -7345,15 +7359,15 @@ std::optional<VectorizationFactor>
 LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
 
-  for(auto ReductionVar : Legal->getReductionVars()) {
+  for (auto ReductionVar : Legal->getReductionVars()) {
     auto *ReductionExitInstr = ReductionVar.second.getLoopExitInstr();
-    if(isInstrPartialReduction(ReductionExitInstr)) {
+    if (isInstrPartialReduction(ReductionExitInstr)) {
       auto Chain = getPartialReductionInstrChain(ReductionExitInstr);
-      if(isPartialReductionChainValid(Chain, TTI)) 
+      if (isPartialReductionChainValid(Chain, TTI))
         PartialReductionChains.push_back(Chain);
     }
   }
-  
+
   CM.collectValuesToIgnore(this);
   CM.collectElementTypesForWidening();
 
@@ -8591,9 +8605,13 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
   return tryToWiden(Instr, Operands, VPBB);
 }
 
-VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
-    VFRange &Range, PartialReductionChain &Chain, ArrayRef<VPValue *> Operands) {
-  return new VPPartialReductionRecipe(*Chain.Reduction, make_range(Operands.begin(), Operands.end()), Chain.ScaleFactor);
+VPRecipeBase *
+VPRecipeBuilder::tryToCreatePartialReduction(VFRange &Range,
+                                             PartialReductionChain &Chain,
+                                             ArrayRef<VPValue *> Operands) {
+  return new VPPartialReductionRecipe(
+      *Chain.Reduction, make_range(Operands.begin(), Operands.end()),
+      Chain.ScaleFactor);
 }
 
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
@@ -8785,11 +8803,13 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
       VPRecipeBase *Recipe = nullptr;
 
       PartialReductionChain Chain;
-      if(getInstructionsPartialReduction(Instr, Chain)) 
-        Recipe = RecipeBuilder.tryToCreatePartialReduction(Range, Chain, Operands);
-      
+      if (getInstructionsPartialReduction(Instr, Chain))
+        Recipe =
+            RecipeBuilder.tryToCreatePartialReduction(Range, Chain, Operands);
+
       if (!Recipe)
-        Recipe = RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB);
+        Recipe =
+            RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB);
       if (!Recipe)
         Recipe = RecipeBuilder.handleReplication(Instr, Range);
 
@@ -8811,7 +8831,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         VPBB->appendRecipe(Recipe);
     }
 
-    for(auto &Recipe : *VPBB)
+    for (auto &Recipe : *VPBB)
       Recipe.postInsertionOp();
 
     VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 91b3a39a89959..f79da9b53441b 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -116,7 +116,9 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
-  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, PartialReductionChain &Chain, ArrayRef<VPValue*> Operands);
+  VPRecipeBase *tryToCreatePartialReduction(VFRange &Range,
+                                            PartialReductionChain &Chain,
+                                            ArrayRef<VPValue *> Operands);
 
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index a3711abdef1c1..97469a6d30d5a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -964,9 +964,7 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     FastMathFlagsTy(const FastMathFlags &FMF);
   };
 
-public:
   OperationType OpType;
-private:
 
   union {
     CmpInst::Predicate CmpPredicate;
@@ -1923,7 +1921,6 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   unsigned VFScaleFactor = 1;
 
 public:
-
   /// Create a new VPReductionPHIRecipe for the reduction \p Phi described by \p
   /// RdxDesc.
   VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc,
@@ -1938,9 +1935,9 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   ~VPReductionPHIRecipe() override = default;
 
   VPReductionPHIRecipe *clone() override {
-    auto *R =
-        new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc,
-                                 *getOperand(0), IsInLoop, IsOrdered, VFScaleFactor);
+    auto *R = new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()),
+                                       RdxDesc, *getOperand(0), IsInLoop,
+                                       IsOrdered, VFScaleFactor);
     R->addOperand(getBackedgeValue());
     return R;
   }
@@ -1951,9 +1948,7 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
     return R->getVPDefID() == VPDef::VPReductionPHISC;
   }
 
-  void SetVFScaleFactor(unsigned ScaleFactor) {
-    VFScaleFactor = ScaleFactor;
-  }
+  void SetVFScaleFactor(unsigned ScaleFactor) { VFScaleFactor = ScaleFactor; }
 
   /// Generate the phi/select nodes.
   void execute(VPTransformState &State) override;
@@ -1978,15 +1973,17 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
 class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
   unsigned Opcode;
   unsigned Scale;
+
 public:
   template <typename IterT>
-  VPPartialReductionRecipe(Instruction &I,
-                           iterator_range<IterT> Operands, unsigned Scale) : VPRecipeWithIRFlags(
-    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode()), Scale(Scale)
-  {}
+  VPPartialReductionRecipe(Instruction &I, iterator_range<IterT> Operands,
+                           unsigned Scale)
+      : VPRecipeWithIRFlags(VPDef::VPPartialReductionSC, Operands, I),
+        Opcode(I.getOpcode()), Scale(Scale) {}
   ~VPPartialReductionRecipe() override = default;
   VPPartialReductionRecipe *clone() override {
-    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands(), Scale);
+    auto *R =
+        new VPPartialReductionRecipe(*getUnderlyingInstr(), operands(), Scale);
     R->transferFlags(*this);
     return R;
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index f48baff51f290..50ad5dd57268d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -223,7 +223,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
   llvm_unreachable("Unhandled opcode");
 }
 
-Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
+Type *
+VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
   return R->getUnderlyingInstr()->getType();
 }
 
@@ -257,7 +258,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
             return inferScalarType(R->getOperand(0));
           })
           .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
-                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe, VPPartialReductionRecipe>(
+                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe,
+                VPPartialReductionRecipe>(
               [this](const auto *R) { return inferScalarTypeForRecipe(R); })
           .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
             // TODO: Use info from interleave group.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 648823f532927..497b5b2a4a4ad 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -259,17 +259,17 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
   State.setDebugLocFrom(getDebugLoc());
   auto &Builder = State.Builder;
 
-  switch(Opcode) {
+  switch (Opcode) {
   case Instruction::Add: {
 
     for (unsigned Part = 0; Part < State.UF; ++Part) {
-      Value* Mul = nullptr;
-      Value* Phi = nullptr;
-      SmallVector<Value*, 2> Ops;
+      Value *Mul = nullptr;
+      Value *Phi = nullptr;
+      SmallVector<Value *, 2> Ops;
       for (VPValue *VPOp : operands()) {
         auto *Op = State.get(VPOp, Part);
         Ops.push_back(Op);
-        if(isa<PHINode>(Op))
+        if (isa<PHINode>(Op))
           Phi = Op;
         else
           Mul = Op;
@@ -279,13 +279,13 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
 
       VectorType *FullTy = cast<VectorType>(Ops[0]->getType());
       auto EC = FullTy->getElementCount();
-      Type *RetTy = VectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale));
+      Type *RetTy = VectorType::get(FullTy->getScalarType(),
+                                    EC.divideCoefficientBy(Scale));
 
       Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
-      switch(Opcode) {
+      switch (Opcode) {
       case Instruction::Add:
-        PartialIntrinsic =
-            Intrinsic::experimental_vector_partial_reduce_add;
+        PartialIntrinsic = Intrinsic::experimental_vector_partial_reduce_add;
         break;
       default:
         llvm_unreachable("Opcode not handled");
@@ -293,7 +293,8 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
 
       assert(PartialIntrinsic != Intrinsic::not_intrinsic);
 
-      CallInst *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, {Phi, Mul}, nullptr, Twine("partial.reduce"));
+      CallInst *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, {Phi, Mul},
+                                            nullptr, Twine("partial.reduce"));
 
       // Use this vector value for all users of the original instruction.
       State.set(this, V, Part);
@@ -302,7 +303,8 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
     break;
   }
   default:
-    LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : " << Instruction::getOpcodeName(Opcode));
+    LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : "
+                      << Instruction::getOpcodeName(Opcode));
     llvm_unreachable("Unhandled instruction!");
   }
 }
@@ -313,7 +315,7 @@ void VPPartialReductionRecipe::postInsertionOp() {
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
-  VPSlotTracker &SlotTracker) const {
+                                     VPSlotTracker &SlotTracker) const {
   O << Indent << "PARTIAL-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = " << Instruction::getOpcodeName(Opcode);
@@ -2112,8 +2114,8 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
   // stage #1: We create a new vector PHI node with no incoming edges. We'll use
   // this value when we vectorize all of the instructions that use the PHI.
   bool ScalarPHI = VF.isScalar() || IsInLoop;
-  Type *VecTy = ScalarPHI ? StartV->getType()
-                          : VectorType::get(StartV->getType(), VF);
+  Type *VecTy =
+      ScalarPHI ? StartV->getType() : VectorType::get(StartV->getType(), VF);
 
   BasicBlock *HeaderBB = State.CFG.PrevBB;
   assert(State.CurrentVectorLoop->getHeader() == HeaderBB &&
@@ -2137,8 +2139,7 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
     } else {
       IRBuilderBase::InsertPointGuard IPBuilder(Builder);
       Builder.SetInsertPoint(VectorPH->getTerminator());
-      StartV = Iden =
-          Builder.CreateVectorSplat(VF, StartV, "minmax.ident");
+      StartV = Iden = Builder.CreateVectorSplat(VF, StartV, "minmax.ident");
     }
   } else {
     Iden = RdxDesc.getRecurrenceIdentity(RK, VecTy->getScalarType(),

>From ae02a231151af15b190275e2ffaab52eb767ed29 Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Fri, 5 Jul 2024 13:36:04 +0100
Subject: [PATCH 11/16] Add TLI hook for delegating intrinsic lowering to the
 target

---
 llvm/include/llvm/CodeGen/TargetLowering.h            | 6 ++++++
 llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 6 ++++++
 2 files changed, 12 insertions(+)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9a0df8b29d752..af6b06131b3d6 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -467,6 +467,12 @@ class TargetLoweringBase {
     return true;
   }
 
+  /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
+  /// should be expanded using generic code in SelectionDAGBuilder.
+  virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
+    return true;
+  }
+
   /// Return true if the @llvm.experimental.cttz.elts intrinsic should be
   /// expanded using generic code in SelectionDAGBuilder.
   virtual bool shouldExpandCttzElements(EVT VT) const { return true; }
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index cc55d53597b65..cf0f3803ee5c8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -7978,6 +7978,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     return;
   }
   case Intrinsic::experimental_vector_partial_reduce_add: {
+
+    if (!TLI.shouldExpandPartialReductionIntrinsic(&I))) {
+        visitTargetIntrinsic(I, Intrinsic);
+        return;
+      }
+
     SDValue OpNode = getValue(I.getOperand(1));
     EVT ReducedTy = EVT::getEVT(I.getType());
     EVT FullTy = OpNode.getValueType();

>From 227a83ea0faaaebc9dd824ccc0e7f153c02b6253 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 17 Jul 2024 11:15:45 +0100
Subject: [PATCH 12/16] Add to VPSingleDefRecipe::classof

---
 llvm/lib/Transforms/Vectorize/VPlan.h | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 97469a6d30d5a..40cebaf2e6af0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -883,6 +883,7 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPWidenPointerInductionSC:
     case VPRecipeBase::VPReductionPHISC:
     case VPRecipeBase::VPScalarCastSC:
+    case VPRecipeBase::VPPartialReductionSC:
       return true;
     case VPRecipeBase::VPInterleaveSC:
     case VPRecipeBase::VPBranchOnMaskSC:

>From 98a2c91b8d51a7102a6f4289b91de8d2d92c5bc8 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 17 Jul 2024 16:40:04 +0100
Subject: [PATCH 13/16] Fix build error

---
 llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index cf0f3803ee5c8..d4784be4f0d25 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -7979,7 +7979,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
   }
   case Intrinsic::experimental_vector_partial_reduce_add: {
 
-    if (!TLI.shouldExpandPartialReductionIntrinsic(&I))) {
+    if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) {
         visitTargetIntrinsic(I, Intrinsic);
         return;
       }

>From 320fcfa307bc6ca861966f0d72de6bf132612cd9 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 17 Jul 2024 16:40:50 +0100
Subject: [PATCH 14/16] Move test and add negative tests

---
 .../AArch64/partial-reduce-dot-product.ll     |  96 -------
 .../partial-reduce-dot-product.ll             | 272 ++++++++++++++++++
 2 files changed, 272 insertions(+), 96 deletions(-)
 delete mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
 create mode 100644 llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll

diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
deleted file mode 100644
index 2907bf903c031..0000000000000
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ /dev/null
@@ -1,96 +0,0 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt -passes=loop-vectorize -force-vector-interleave=1 -S < %s | FileCheck %s
-
-target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
-target triple = "aarch64-none-unknown-elf"
-
-define void @dotp(ptr %a, ptr %b) #0 {
-; CHECK-LABEL: define void @dotp(
-; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 0, [[TMP1]]
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
-; CHECK:       vector.ph:
-; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 0, [[TMP3]]
-; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 0, [[N_MOD_VF]]
-; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
-; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
-; CHECK:       vector.body:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 0
-; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP6]]
-; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP7]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP8]], align 1
-; CHECK-NEXT:    [[TMP9:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
-; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP6]]
-; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP10]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x i8>, ptr [[TMP11]], align 1
-; CHECK-NEXT:    [[TMP12:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD1]] to <vscale x 4 x i32>
-; CHECK-NEXT:    [[TMP13:%.*]] = mul <vscale x 4 x i32> [[TMP12]], [[TMP9]]
-; CHECK-NEXT:    [[PARTIAL_REDUCE]] = add <vscale x 4 x i32> [[TMP13]], [[VEC_PHI]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
-; CHECK-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
-; CHECK:       middle.block:
-; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
-; CHECK-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]])
-; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
-; CHECK:       scalar.ph:
-; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP15]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
-; CHECK:       for.cond.cleanup.loopexit:
-; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[TMP15]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    [[TMP16:%.*]] = lshr i32 [[ADD_LCSSA]], 0
-; CHECK-NEXT:    ret void
-; CHECK:       for.body:
-; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ], [ [[ADD]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP17:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
-; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP17]] to i32
-; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP18:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
-; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP18]] to i32
-; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[CONV3]], [[CONV]]
-; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
-; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
-; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], 0
-; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
-;
-entry:
-  br label %for.body
-
-for.cond.cleanup.loopexit:                        ; preds = %for.body
-  %0 = lshr i32 %add, 0
-  ret void
-
-for.body:                                         ; preds = %for.body, %entry
-  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
-  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
-  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
-  %1 = load i8, ptr %arrayidx, align 1
-  %conv = zext i8 %1 to i32
-  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
-  %2 = load i8, ptr %arrayidx2, align 1
-  %conv3 = zext i8 %2 to i32
-  %mul = mul i32 %conv3, %conv
-  %add = add i32 %mul, %acc.010
-  %indvars.iv.next = add i64 %indvars.iv, 1
-  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
-  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
-}
-
-attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
-;.
-; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
-; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
-; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
-; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
-;.
diff --git a/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll b/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll
new file mode 100644
index 0000000000000..949dfb5f8844b
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll
@@ -0,0 +1,272 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes=loop-vectorize -force-vector-interleave=1 -S < %s | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define void @dotp(ptr %a, ptr %b) {
+; CHECK-LABEL: define void @dotp(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP7]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP10]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT:    [[TMP15:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP3]]
+; CHECK-NEXT:    [[TMP14]] = add <16 x i32> [[TMP16]], [[VEC_PHI]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
+; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP17:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP14]])
+; CHECK-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+}
+
+define void @dotp_different_types(ptr %a, ptr %b) {
+; CHECK-LABEL: define void @dotp_different_types(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP69:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP1:%.*]] = add i64 [[INDEX]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = add i64 [[INDEX]], 2
+; CHECK-NEXT:    [[TMP3:%.*]] = add i64 [[INDEX]], 3
+; CHECK-NEXT:    [[TMP4:%.*]] = add i64 [[INDEX]], 4
+; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[INDEX]], 5
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 6
+; CHECK-NEXT:    [[TMP7:%.*]] = add i64 [[INDEX]], 7
+; CHECK-NEXT:    [[TMP8:%.*]] = add i64 [[INDEX]], 8
+; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 9
+; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[INDEX]], 10
+; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 11
+; CHECK-NEXT:    [[TMP12:%.*]] = add i64 [[INDEX]], 12
+; CHECK-NEXT:    [[TMP13:%.*]] = add i64 [[INDEX]], 13
+; CHECK-NEXT:    [[TMP14:%.*]] = add i64 [[INDEX]], 14
+; CHECK-NEXT:    [[TMP15:%.*]] = add i64 [[INDEX]], 15
+; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP0]]
+; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP16]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP17]], align 1
+; CHECK-NEXT:    [[TMP18:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP19:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
+; CHECK-NEXT:    [[TMP20:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP1]]
+; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP2]]
+; CHECK-NEXT:    [[TMP22:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP3]]
+; CHECK-NEXT:    [[TMP23:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP4]]
+; CHECK-NEXT:    [[TMP24:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP5]]
+; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP26:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP7]]
+; CHECK-NEXT:    [[TMP27:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP8]]
+; CHECK-NEXT:    [[TMP28:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP29:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP10]]
+; CHECK-NEXT:    [[TMP30:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP31:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP12]]
+; CHECK-NEXT:    [[TMP32:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP13]]
+; CHECK-NEXT:    [[TMP33:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP14]]
+; CHECK-NEXT:    [[TMP34:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP15]]
+; CHECK-NEXT:    [[TMP35:%.*]] = load i16, ptr [[TMP19]], align 2
+; CHECK-NEXT:    [[TMP36:%.*]] = load i16, ptr [[TMP20]], align 2
+; CHECK-NEXT:    [[TMP37:%.*]] = load i16, ptr [[TMP21]], align 2
+; CHECK-NEXT:    [[TMP38:%.*]] = load i16, ptr [[TMP22]], align 2
+; CHECK-NEXT:    [[TMP39:%.*]] = load i16, ptr [[TMP23]], align 2
+; CHECK-NEXT:    [[TMP40:%.*]] = load i16, ptr [[TMP24]], align 2
+; CHECK-NEXT:    [[TMP41:%.*]] = load i16, ptr [[TMP25]], align 2
+; CHECK-NEXT:    [[TMP42:%.*]] = load i16, ptr [[TMP26]], align 2
+; CHECK-NEXT:    [[TMP43:%.*]] = load i16, ptr [[TMP27]], align 2
+; CHECK-NEXT:    [[TMP44:%.*]] = load i16, ptr [[TMP28]], align 2
+; CHECK-NEXT:    [[TMP45:%.*]] = load i16, ptr [[TMP29]], align 2
+; CHECK-NEXT:    [[TMP46:%.*]] = load i16, ptr [[TMP30]], align 2
+; CHECK-NEXT:    [[TMP47:%.*]] = load i16, ptr [[TMP31]], align 2
+; CHECK-NEXT:    [[TMP48:%.*]] = load i16, ptr [[TMP32]], align 2
+; CHECK-NEXT:    [[TMP49:%.*]] = load i16, ptr [[TMP33]], align 2
+; CHECK-NEXT:    [[TMP50:%.*]] = load i16, ptr [[TMP34]], align 2
+; CHECK-NEXT:    [[TMP51:%.*]] = insertelement <16 x i16> poison, i16 [[TMP35]], i32 0
+; CHECK-NEXT:    [[TMP52:%.*]] = insertelement <16 x i16> [[TMP51]], i16 [[TMP36]], i32 1
+; CHECK-NEXT:    [[TMP53:%.*]] = insertelement <16 x i16> [[TMP52]], i16 [[TMP37]], i32 2
+; CHECK-NEXT:    [[TMP54:%.*]] = insertelement <16 x i16> [[TMP53]], i16 [[TMP38]], i32 3
+; CHECK-NEXT:    [[TMP55:%.*]] = insertelement <16 x i16> [[TMP54]], i16 [[TMP39]], i32 4
+; CHECK-NEXT:    [[TMP56:%.*]] = insertelement <16 x i16> [[TMP55]], i16 [[TMP40]], i32 5
+; CHECK-NEXT:    [[TMP57:%.*]] = insertelement <16 x i16> [[TMP56]], i16 [[TMP41]], i32 6
+; CHECK-NEXT:    [[TMP58:%.*]] = insertelement <16 x i16> [[TMP57]], i16 [[TMP42]], i32 7
+; CHECK-NEXT:    [[TMP59:%.*]] = insertelement <16 x i16> [[TMP58]], i16 [[TMP43]], i32 8
+; CHECK-NEXT:    [[TMP60:%.*]] = insertelement <16 x i16> [[TMP59]], i16 [[TMP44]], i32 9
+; CHECK-NEXT:    [[TMP61:%.*]] = insertelement <16 x i16> [[TMP60]], i16 [[TMP45]], i32 10
+; CHECK-NEXT:    [[TMP62:%.*]] = insertelement <16 x i16> [[TMP61]], i16 [[TMP46]], i32 11
+; CHECK-NEXT:    [[TMP63:%.*]] = insertelement <16 x i16> [[TMP62]], i16 [[TMP47]], i32 12
+; CHECK-NEXT:    [[TMP64:%.*]] = insertelement <16 x i16> [[TMP63]], i16 [[TMP48]], i32 13
+; CHECK-NEXT:    [[TMP65:%.*]] = insertelement <16 x i16> [[TMP64]], i16 [[TMP49]], i32 14
+; CHECK-NEXT:    [[TMP66:%.*]] = insertelement <16 x i16> [[TMP65]], i16 [[TMP50]], i32 15
+; CHECK-NEXT:    [[TMP67:%.*]] = zext <16 x i16> [[TMP66]] to <16 x i32>
+; CHECK-NEXT:    [[TMP68:%.*]] = mul <16 x i32> [[TMP67]], [[TMP18]]
+; CHECK-NEXT:    [[TMP69]] = add <16 x i32> [[TMP68]], [[VEC_PHI]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEXT:    [[TMP70:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
+; CHECK-NEXT:    br i1 [[TMP70]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP71:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP69]])
+; CHECK-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i16, ptr %arrayidx2, align 2
+  %conv3 = zext i16 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+}
+
+define void @dotp_not_loop_carried(ptr %a, ptr %b) {
+; CHECK-LABEL: define void @dotp_not_loop_carried(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VECTOR_RECUR:%.*]] = phi <16 x i32> [ <i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 0>, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP10]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP14:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP14]], align 1
+; CHECK-NEXT:    [[TMP6:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP7]] = mul <16 x i32> [[TMP6]], [[TMP3]]
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <16 x i32> [[VECTOR_RECUR]], <16 x i32> [[TMP7]], <16 x i32> <i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30>
+; CHECK-NEXT:    [[TMP15:%.*]] = add <16 x i32> [[TMP7]], [[TMP8]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEXT:    [[TMP16:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
+; CHECK-NEXT:    br i1 [[TMP16]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <16 x i32> [[TMP15]], i32 15
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <16 x i32> [[TMP7]], i32 15
+; CHECK-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %mul, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+}
+
+define void @dotp_not_phi(ptr %a, ptr %b) {
+; CHECK-LABEL: define void @dotp_not_phi(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VECTOR_RECUR:%.*]] = phi <16 x i32> [ <i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 0>, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP10]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP14:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP14]], align 1
+; CHECK-NEXT:    [[TMP6:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul <16 x i32> [[TMP6]], [[TMP3]]
+; CHECK-NEXT:    [[TMP8]] = add <16 x i32> [[TMP7]], [[TMP6]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEXT:    [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
+; CHECK-NEXT:    br i1 [[TMP15]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <16 x i32> [[TMP8]], i32 15
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <16 x i32> [[TMP8]], i32 15
+; CHECK-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %conv3
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+}

>From b2022deca1973d1b1e6f827080c253d1caf80566 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 18 Jul 2024 10:41:13 +0100
Subject: [PATCH 15/16] Add a printing test

---
 .../LoopVectorize/AArch64/vplan-printing.ll   | 94 +++++++++++++++++++
 1 file changed, 94 insertions(+)
 create mode 100644 llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll

diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
new file mode 100644
index 0000000000000..de788f59e5f32
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
@@ -0,0 +1,94 @@
+; REQUIRES: asserts
+
+; RUN: opt -passes=loop-vectorize -debug-only=loop-vectorize -force-vector-interleave=1 -disable-output %s 2>&1 | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+; Tests for printing VPlans that are enabled under AArch64
+
+define void @print_partial_reduction(ptr %a, ptr %b) {
+; CHECK-LABEL: Checking a loop in 'print_partial_reduction'
+; CHECK:      VPlan 'Initial VPlan for VF={2,4,8,16},UF>=1' {
+; CHECK-NEXT: Live-in vp<[[VFxUF:%.]]> = VF * UF
+; CHECK-NEXT: Live-in vp<[[VEC_TC:%.+]]> = vector-trip-count
+; CHECK-NEXT: Live-in ir<0> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT: vector.body:
+; CHECK-NEXT:   EMIT vp<[[CAN_IV:%.+]]> = CANONICAL-INDUCTION ir<0>, vp<[[CAN_IV_NEXT:%.+]]>
+; CHECK-NEXT:   WIDEN-REDUCTION-PHI ir<[[ACC:%.+]]> = phi ir<0>, ir<%add>
+; CHECK-NEXT:   vp<[[STEPS:%.+]]> = SCALAR-STEPS vp<[[CAN_IV]]>, ir<1>
+; CHECK-NEXT:   CLONE ir<%arrayidx> = getelementptr ir<%a>, vp<[[STEPS]]>
+; CHECK-NEXT:   vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:   WIDEN ir<%1> = load vp<%4>
+; CHECK-NEXT:   WIDEN-CAST ir<%conv> = zext ir<%1> to i32
+; CHECK-NEXT:   CLONE ir<%arrayidx2> = getelementptr ir<%b>, vp<[[STEPS]]>
+; CHECK-NEXT:   vp<%5> = vector-pointer ir<%arrayidx2>
+; CHECK-NEXT:   WIDEN ir<%2> = load vp<%5>
+; CHECK-NEXT:   WIDEN-CAST ir<%conv3> = zext ir<%2> to i32
+; CHECK-NEXT:   WIDEN ir<%mul> = mul ir<%conv3>, ir<%conv>
+; CHECK-NEXT:   WIDEN ir<%add> = add ir<%mul>, ir<[[ACC]]>
+; CHECK-NEXT:   EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
+; CHECK-NEXT:   EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VEC_TC]]>
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%8> = compute-reduction-result ir<[[ACC]]>, ir<%add>
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: Live-out i32 %add.lcssa = vp<%8>
+; CHECK-NEXT: }
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+}
+
+!llvm.dbg.cu = !{!0}
+!llvm.module.flags = !{!3, !4}
+
+declare float @foo(float) #0
+declare <2 x float> @vector_foo(<2 x float>, <2 x i1>)
+
+; We need a vector variant in order to allow for vectorization at present, but
+; we want to test scalarization of conditional calls. If we provide a variant
+; with a different number of lanes than the VF we force via
+; "-force-vector-width=4", then it should pass the legality checks but
+; scalarize. TODO: Remove the requirement to have a variant.
+attributes #0 = { readonly nounwind "vector-function-abi-variant"="_ZGV_LLVM_M2v_foo(vector_foo)" }
+
+!0 = distinct !DICompileUnit(language: DW_LANG_C99, file: !1, producer: "clang", isOptimized: true, runtimeVersion: 0, emissionKind: NoDebug, enums: !2)
+!1 = !DIFile(filename: "/tmp/s.c", directory: "/tmp")
+!2 = !{}
+!3 = !{i32 2, !"Debug Info Version", i32 3}
+!4 = !{i32 7, !"PIC Level", i32 2}
+!5 = distinct !DISubprogram(name: "f", scope: !1, file: !1, line: 4, type: !6, scopeLine: 4, flags: DIFlagPrototyped, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0, retainedNodes: !2)
+!6 = !DISubroutineType(types: !2)
+!7 = !DILocation(line: 5, column: 3, scope: !5)
+!8 = !DILocation(line: 5, column: 21, scope: !5)
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; CHECK: {{.*}}

>From c77e82d0f95257f85451b7a27ff8de80fa2203b4 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 18 Jul 2024 11:24:51 +0100
Subject: [PATCH 16/16] Add llvm_unreachable in partial reduction clone func

---
 llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 6 +++---
 llvm/lib/Transforms/Vectorize/VPlan.h                 | 6 ++----
 2 files changed, 5 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index d4784be4f0d25..96468065f6d7e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -7980,9 +7980,9 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
   case Intrinsic::experimental_vector_partial_reduce_add: {
 
     if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) {
-        visitTargetIntrinsic(I, Intrinsic);
-        return;
-      }
+      visitTargetIntrinsic(I, Intrinsic);
+      return;
+    }
 
     SDValue OpNode = getValue(I.getOperand(1));
     EVT ReducedTy = EVT::getEVT(I.getType());
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 40cebaf2e6af0..e62ce871b38dd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1983,10 +1983,8 @@ class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
         Opcode(I.getOpcode()), Scale(Scale) {}
   ~VPPartialReductionRecipe() override = default;
   VPPartialReductionRecipe *clone() override {
-    auto *R =
-        new VPPartialReductionRecipe(*getUnderlyingInstr(), operands(), Scale);
-    R->transferFlags(*this);
-    return R;
+    llvm_unreachable("Partial reductions with epilogue vectorization isn't supported yet.");
+    return nullptr;
   }
   VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
   /// Generate the reduction in the loop



More information about the llvm-commits mailing list