[llvm] [LoopVectorizer] Add support for partial reductions (PR #92418)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 24 06:11:12 PDT 2024


https://github.com/SamTebbs33 updated https://github.com/llvm/llvm-project/pull/92418

>From 16e4da01ebc3740f012fa127bd6fd438e734b8ee Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Fri, 17 May 2024 11:15:11 +0100
Subject: [PATCH 01/54] [NFC] Test pre-commit

---
 .../CodeGen/AArch64/partial-reduce-sdot.ll    | 99 +++++++++++++++++++
 1 file changed, 99 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll

diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
new file mode 100644
index 00000000000000..fc6e3239a1b43c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes=loop-vectorize -force-vector-interleave=1 -S < %s | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define void @dotp(ptr %a, ptr %b) #0 {
+; CHECK-LABEL: define void @dotp(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 16
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 0, [[TMP1]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 16
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 0, [[TMP3]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 0, [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 16
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 16 x i8>, ptr [[TMP17]], align 1
+; CHECK-NEXT:    [[TMP19:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD2]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr i8, ptr [[TMP21]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
+; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP29:%.*]] = mul <vscale x 16 x i32> [[TMP27]], [[TMP19]]
+; CHECK-NEXT:    [[TMP14]] = add <vscale x 16 x i32> [[TMP29]], [[VEC_PHI]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP32]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP33:%.*]] = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP14]])
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+; CHECK:       scalar.ph:
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.cond.cleanup.loopexit:
+; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[TMP20:%.*]] = lshr i32 [[ADD_LCSSA]], 0
+; CHECK-NEXT:    ret void
+; CHECK:       for.body:
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ], [ [[ADD]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP18:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
+; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP18]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP22:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP22]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[CONV3]], [[CONV]]
+; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
+; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
+; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], 0
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+
+; uselistorder directives
+  uselistorder i32 %add, { 1, 0 }
+}
+
+attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
+;.
+; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
+; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
+; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
+; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
+;.

>From faee9ed664bee0d166cdd5339b352cae49dd27be Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Fri, 17 May 2024 11:17:26 +0100
Subject: [PATCH 02/54] [LoopVectorizer] Add support for partial reductions

---
 llvm/include/llvm/IR/DerivedTypes.h           |  10 ++
 llvm/include/llvm/IR/Intrinsics.h             |   5 +-
 llvm/include/llvm/IR/Intrinsics.td            |   4 +
 llvm/lib/IR/Function.cpp                      |  16 +++
 .../Transforms/Vectorize/LoopVectorize.cpp    | 122 ++++++++++++++++++
 .../Transforms/Vectorize/VPRecipeBuilder.h    |   2 +
 llvm/lib/Transforms/Vectorize/VPlan.h         |  43 +++++-
 .../Transforms/Vectorize/VPlanAnalysis.cpp    |   6 +-
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h |   2 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  79 +++++++++++-
 llvm/lib/Transforms/Vectorize/VPlanValue.h    |   1 +
 .../CodeGen/AArch64/partial-reduce-sdot.ll    |   7 +-
 12 files changed, 285 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index 975c142f1a4572..eb98af66de4c54 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -513,6 +513,16 @@ class VectorType : public Type {
                            EltCnt.divideCoefficientBy(2));
   }
 
+  /// This static method returns a VectorType with quarter as many elements as the
+  /// input type and the same element type.
+  static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
+    auto EltCnt = VTy->getElementCount();
+    assert(EltCnt.isKnownEven() &&
+           "Cannot halve vector with odd number of elements.");
+    return VectorType::get(VTy->getElementType(),
+                           EltCnt.divideCoefficientBy(4));
+  }
+
   /// This static method returns a VectorType with twice as many elements as the
   /// input type and the same element type.
   static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index 4bd7fda77f3132..3038eb8dd6af6c 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -133,6 +133,7 @@ namespace Intrinsic {
       ExtendArgument,
       TruncArgument,
       HalfVecArgument,
+      QuarterVecArgument,
       SameVecWidthArgument,
       VecOfAnyPtrsToElt,
       VecElementArgument,
@@ -162,7 +163,7 @@ namespace Intrinsic {
 
     unsigned getArgumentNumber() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument || Kind == VecElementArgument ||
              Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
              Kind == VecOfBitcastsToInt);
@@ -170,7 +171,7 @@ namespace Intrinsic {
     }
     ArgKind getArgumentKind() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument ||
              Kind == VecElementArgument || Kind == Subdivide2Argument ||
              Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 0a74a217a5f010..d872cd613ba6b7 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -321,6 +321,7 @@ def IIT_I4 : IIT_Int<4, 58>;
 def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
 def IIT_V6 : IIT_Vec<6, 60>;
 def IIT_V10 : IIT_Vec<10, 61>;
+def IIT_QUARTER_VEC_ARG : IIT_Base<62>;
 }
 
 defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
@@ -457,6 +458,9 @@ class LLVMVectorElementType<int num> : LLVMMatchType<num, IIT_VEC_ELEMENT>;
 class LLVMHalfElementsVectorType<int num>
   : LLVMMatchType<num, IIT_HALF_VEC_ARG>;
 
+class LLVMQuarterElementsVectorType<int num>
+  : LLVMMatchType<num, IIT_QUARTER_VEC_ARG>;
+
 // Match the type of another intrinsic parameter that is expected to be a
 // vector type (i.e. <N x iM>) but with each element subdivided to
 // form a vector with more elements that are smaller than the original.
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 8767c2971f62c8..00bcfdf8c45241 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1317,6 +1317,12 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
                                              ArgInfo));
     return;
   }
+  case IIT_QUARTER_VEC_ARG: {
+    unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
+    OutputTable.push_back(IITDescriptor::get(IITDescriptor::QuarterVecArgument,
+                                             ArgInfo));
+    return;
+  }
   case IIT_SAME_VEC_WIDTH_ARG: {
     unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
     OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument,
@@ -1484,6 +1490,9 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
   case IITDescriptor::HalfVecArgument:
     return VectorType::getHalfElementsVectorType(cast<VectorType>(
                                                   Tys[D.getArgumentNumber()]));
+  case IITDescriptor::QuarterVecArgument:  {
+    return VectorType::getQuarterElementsVectorType(cast<VectorType>(Tys[D.getArgumentNumber()]));
+  }
   case IITDescriptor::SameVecWidthArgument: {
     Type *EltTy = DecodeFixedType(Infos, Tys, Context);
     Type *Ty = Tys[D.getArgumentNumber()];
@@ -1715,6 +1724,13 @@ static bool matchIntrinsicType(
       return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
              VectorType::getHalfElementsVectorType(
                      cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    case IITDescriptor::QuarterVecArgument: {
+    if (D.getArgumentNumber() >= ArgTys.size())
+        return IsDeferredCheck || DeferCheck(Ty);
+      return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
+             VectorType::getQuarterElementsVectorType(
+                     cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    }
     case IITDescriptor::SameVecWidthArgument: {
       if (D.getArgumentNumber() >= ArgTys.size()) {
         // Defer check and subsequent check for the vector element type.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 9685e7d124b7d1..62e6f1af684675 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2113,6 +2113,92 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
          Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
 }
 
+static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  Chain.push_back(Mul);
+  Chain.push_back(Ext0);
+  Chain.push_back(Ext1);
+  Chain.push_back(Instr->getOperand(1));
+}
+
+
+/// @param Instr The root instruction to scan
+static bool isInstrPartialReduction(Instruction *Instr) {
+  Value *ExpectedPhi;
+  Value *A, *B;
+  Value *InductionA, *InductionB;
+
+  using namespace llvm::PatternMatch;
+  auto Pattern = m_Add(
+    m_OneUse(m_Mul(
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(A),
+              m_Value(InductionA)))))),
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(B),
+              m_Value(InductionB))))))
+        )), m_Value(ExpectedPhi));
+
+  bool Matches = match(Instr, Pattern);
+
+  if(!Matches)
+    return false;
+
+  // Check that the two induction variable uses are to the same induction variable
+  if(InductionA != InductionB) {
+    LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  // Check that the extends extend to i32
+  if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the loads are loading i8
+  LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
+  LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
+  if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
+    LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
+    return false;
+  }
+
+  // Check that the add feeds into ExpectedPhi
+  PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
+  if(!PhiNode) {
+    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the first phi value is a zero initializer
+  ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
+  if(!ZeroInit || !ZeroInit->isZero()) {
+    LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the second phi value is the instruction we're looking at
+  Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
+  if(!MaybeAdd || MaybeAdd != Instr) {
+    LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  return true;
+}
+
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
 // vectorization. The loop needs to be annotated with #pragma omp simd
 // simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -4632,6 +4718,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
         return false;
   }
 
+  // Prevent epilogue vectorization if a partial reduction is involved
+  // TODO Is there a cleaner way to check this?
+  if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
+    return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
+  }))
+    return false;
+
   // Epilogue vectorization code has not been auditted to ensure it handles
   // non-latch exits properly.  It may be fine, but it needs auditted and
   // tested.
@@ -6891,6 +6984,17 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
     const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts();
     VecValuesToIgnore.insert(Casts.begin(), Casts.end());
   }
+
+  // Ignore any values that we know will be flattened
+  for(auto Reduction : this->Legal->getReductionVars()) {
+    auto &Recurrence = Reduction.second;
+    if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
+      SmallVector<Value*, 4> PartialReductionValues;
+      getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
+      ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+      VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+    }
+  }
 }
 
 void LoopVectorizationCostModel::collectInLoopReductions() {
@@ -8573,9 +8677,24 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
                                  *CI);
   }
 
+  if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
+    return PartialReduce;
+
   return tryToWiden(Instr, Operands, VPBB);
 }
 
+VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
+    VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
+
+  if(isInstrPartialReduction(Instr)) {
+    auto EC = ElementCount::getScalable(16);
+    if(std::find(Range.begin(), Range.end(), EC) == Range.end())
+      return nullptr;
+    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
+  }
+  return nullptr;
+}
+
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
                                                         ElementCount MaxVF) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -8982,6 +9101,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         VPBB->appendRecipe(Recipe);
     }
 
+    for(auto &Recipe : *VPBB)
+      Recipe.postInsertionOp();
+
     VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
     VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor());
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 25b8bf3e089e54..92724daaa26912 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -117,6 +117,8 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
+  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, Instruction* Instr, ArrayRef<VPValue*> Operands);
+
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
     assert(!Ingredient2Recipe.contains(I) &&
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index fda0a8907b4ab0..3f11a5281efe5d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -831,6 +831,8 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
   /// \returns an iterator pointing to the element after the erased one
   iplist<VPRecipeBase>::iterator eraseFromParent();
 
+  virtual void postInsertionOp() {}
+
   /// Method to support type inquiry through isa, cast, and dyn_cast.
   static inline bool classof(const VPDef *D) {
     // All VPDefs are also VPRecipeBases.
@@ -2138,14 +2140,19 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
   /// The phi is part of an ordered reduction. Requires IsInLoop to be true.
   bool IsOrdered;
 
+  /// The amount that the VF should be divided by during ::execute
+  unsigned VFScaleFactor = 1;
+
 public:
+
   /// Create a new VPReductionPHIRecipe for the reduction \p Phi described by \p
   /// RdxDesc.
   VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc,
                        VPValue &Start, bool IsInLoop = false,
-                       bool IsOrdered = false)
+                       bool IsOrdered = false, unsigned VFScaleFactor = 1)
       : VPHeaderPHIRecipe(VPDef::VPReductionPHISC, Phi, &Start),
-        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered) {
+        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered),
+        VFScaleFactor(VFScaleFactor) {
     assert((!IsOrdered || IsInLoop) && "IsOrdered requires IsInLoop");
   }
 
@@ -2154,7 +2161,7 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
   VPReductionPHIRecipe *clone() override {
     auto *R =
         new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc,
-                                 *getOperand(0), IsInLoop, IsOrdered);
+                                 *getOperand(0), IsInLoop, IsOrdered, VFScaleFactor);
     R->addOperand(getBackedgeValue());
     return R;
   }
@@ -2165,6 +2172,10 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
     return R->getVPDefID() == VPDef::VPReductionPHISC;
   }
 
+  void SetVFScaleFactor(unsigned ScaleFactor) {
+    VFScaleFactor = ScaleFactor;
+  }
+
   /// Generate the phi/select nodes.
   void execute(VPTransformState &State) override;
 
@@ -2185,6 +2196,32 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
   bool isInLoop() const { return IsInLoop; }
 };
 
+class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
+  unsigned Opcode;
+public:
+  template <typename IterT>
+  VPPartialReductionRecipe(Instruction &I,
+                           iterator_range<IterT> Operands) : VPRecipeWithIRFlags(
+    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode())
+  {}
+  ~VPPartialReductionRecipe() override = default;
+  VPPartialReductionRecipe *clone() override {
+    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands());
+    R->transferFlags(*this);
+    return R;
+  }
+  VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override;
+  void postInsertionOp() override;
+  unsigned getOpcode() { return Opcode; }
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+};
+
 /// A recipe for vectorizing a phi-node as a sequence of mask-based select
 /// instructions.
 class VPBlendRecipe : public VPSingleDefRecipe {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 277df0637372d8..8d6f51ca1b3b86 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -233,6 +233,10 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
   llvm_unreachable("Unhandled opcode");
 }
 
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
+  return R->getUnderlyingInstr()->getType();
+}
+
 Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
   if (Type *CachedTy = CachedTypes.lookup(V))
     return CachedTy;
@@ -266,7 +270,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
           })
           .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPWidenEVLRecipe,
                 VPReplicateRecipe, VPWidenCallRecipe, VPWidenMemoryRecipe,
-                VPWidenSelectRecipe>(
+                VPWidenSelectRecipe, VPPartialReductionRecipe>(
               [this](const auto *R) { return inferScalarTypeForRecipe(R); })
           .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
             // TODO: Use info from interleave group.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index cc21870bee2e3b..a34d9629eff9dd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -27,6 +27,7 @@ struct VPWidenSelectRecipe;
 class VPReplicateRecipe;
 class VPRecipeBase;
 class VPlan;
+class VPPartialReductionRecipe;
 class Type;
 
 /// An analysis for type-inference for VPValues.
@@ -53,6 +54,7 @@ class VPTypeAnalysis {
   Type *inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R);
   Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R);
   Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPPartialReductionRecipe *R);
 
 public:
   VPTypeAnalysis(Type *CanonicalIVTy)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 8f4b2951839118..0ce350694508bb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -318,6 +318,77 @@ InstructionCost VPRecipeBase::computeCost(ElementCount VF,
   return UI ? Ctx.getLegacyCost(UI, VF) : 0;
 }
 
+void VPPartialReductionRecipe::execute(VPTransformState &State) {
+  State.setDebugLocFrom(getDebugLoc());
+  auto &Builder = State.Builder;
+
+  switch(Opcode) {
+  case Instruction::Add: {
+
+    unsigned UF = getParent()->getPlan()->getUF();
+    for (unsigned Part = 0; Part < UF; ++Part) {
+      Value* Mul = nullptr;
+      Value* Phi = nullptr;
+      SmallVector<Value*, 2> Ops;
+      for (VPValue *VPOp : operands()) {
+        auto *Op = State.get(VPOp, Part);
+        Ops.push_back(Op);
+        if(isa<PHINode>(Op))
+          Phi = Op;
+        else
+          Mul = Op;
+      }
+
+      assert(Phi && Mul && "Phi and Mul must be set");
+      assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
+
+      ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
+      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), 4);
+
+      Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
+      switch(Opcode) {
+      case Instruction::Add:
+        PartialIntrinsic =
+            Intrinsic::experimental_vector_partial_reduce_add;
+        break;
+      default:
+        llvm_unreachable("Opcode not handled");
+      }
+
+      assert(PartialIntrinsic != Intrinsic::not_intrinsic);
+
+      Value *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, Mul, nullptr, Twine("partial.reduce"));
+      V = Builder.CreateNAryOp(Opcode, {V, Phi});
+      if (auto *VecOp = dyn_cast<Instruction>(V))
+        setFlags(VecOp);
+
+      // Use this vector value for all users of the original instruction.
+      State.set(this, V, Part);
+      State.addMetadata(V, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
+    }
+    break;
+  }
+  default:
+    LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : " << Instruction::getOpcodeName(Opcode));
+    llvm_unreachable("Unhandled instruction!");
+  }
+}
+
+void VPPartialReductionRecipe::postInsertionOp() {
+  cast<VPReductionPHIRecipe>(this->getOperand(1))->SetVFScaleFactor(4);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+  VPSlotTracker &SlotTracker) const {
+  O << Indent << "PARTIAL-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = " << Instruction::getOpcodeName(Opcode);
+  printFlags(O);
+  printOperands(O, SlotTracker);
+}
+#endif
+
 FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const {
   assert(OpType == OperationType::FPMathOp &&
          "recipe doesn't have fast math flags");
@@ -2960,6 +3031,8 @@ void VPFirstOrderRecurrencePHIRecipe::print(raw_ostream &O, const Twine &Indent,
 void VPReductionPHIRecipe::execute(VPTransformState &State) {
   auto &Builder = State.Builder;
 
+  auto VF = State.VF.divideCoefficientBy(VFScaleFactor);
+
   // Reductions do not have to start at zero. They can start with
   // any loop invariant values.
   VPValue *StartVPV = getStartValue();
@@ -2969,9 +3042,9 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
   // Phi nodes have cycles, so we need to vectorize them in two stages. This is
   // stage #1: We create a new vector PHI node with no incoming edges. We'll use
   // this value when we vectorize all of the instructions that use the PHI.
-  bool ScalarPHI = State.VF.isScalar() || IsInLoop;
+  bool ScalarPHI = VF.isScalar() || IsInLoop;
   Type *VecTy = ScalarPHI ? StartV->getType()
-                          : VectorType::get(StartV->getType(), State.VF);
+                          : VectorType::get(StartV->getType(), VF);
 
   BasicBlock *HeaderBB = State.CFG.PrevBB;
   assert(State.CurrentVectorLoop->getHeader() == HeaderBB &&
@@ -3005,7 +3078,7 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
         // Create start and identity vector values for the reduction in the
         // preheader.
         // TODO: Introduce recipes in VPlan preheader to create initial values.
-        Iden = Builder.CreateVectorSplat(State.VF, Iden);
+        Iden = Builder.CreateVectorSplat(VF, Iden);
         IRBuilderBase::InsertPointGuard IPBuilder(Builder);
         Builder.SetInsertPoint(VectorPH->getTerminator());
         Constant *Zero = Builder.getInt32(0);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index a23a59aa2f11c2..25c61ff775d9f6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -342,6 +342,7 @@ class VPDef {
     VPInterleaveSC,
     VPReductionEVLSC,
     VPReductionSC,
+    VPPartialReductionSC,
     VPReplicateSC,
     VPScalarCastSC,
     VPScalarIVStepsSC,
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
index fc6e3239a1b43c..1eafd505b199eb 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
@@ -22,7 +22,7 @@ define void @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP11]]
 ; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
@@ -33,12 +33,13 @@ define void @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
 ; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
 ; CHECK-NEXT:    [[TMP29:%.*]] = mul <vscale x 16 x i32> [[TMP27]], [[TMP19]]
-; CHECK-NEXT:    [[TMP14]] = add <vscale x 16 x i32> [[TMP29]], [[VEC_PHI]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP29]])
+; CHECK-NEXT:    [[TMP14]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
 ; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP32]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP33:%.*]] = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP14]])
+; CHECK-NEXT:    [[TMP33:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP14]])
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:

>From 52f0d21f8ad2f7b9afe7f20eaf4846d8e6c3679c Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Thu, 30 May 2024 15:04:55 +0100
Subject: [PATCH 03/54] [LoopVectorizer] Removed 4x restriction from partial
 reduction intrinsic

---
 llvm/include/llvm/IR/DerivedTypes.h              | 10 ----------
 llvm/include/llvm/IR/Intrinsics.h                |  5 ++---
 llvm/include/llvm/IR/Intrinsics.td               |  4 ----
 llvm/lib/IR/Function.cpp                         | 16 ----------------
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp  |  2 +-
 llvm/lib/Transforms/Vectorize/VPlan.h            |  7 ++++---
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp   |  5 +++--
 llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll |  2 +-
 8 files changed, 11 insertions(+), 40 deletions(-)

diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index eb98af66de4c54..975c142f1a4572 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -513,16 +513,6 @@ class VectorType : public Type {
                            EltCnt.divideCoefficientBy(2));
   }
 
-  /// This static method returns a VectorType with quarter as many elements as the
-  /// input type and the same element type.
-  static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
-    auto EltCnt = VTy->getElementCount();
-    assert(EltCnt.isKnownEven() &&
-           "Cannot halve vector with odd number of elements.");
-    return VectorType::get(VTy->getElementType(),
-                           EltCnt.divideCoefficientBy(4));
-  }
-
   /// This static method returns a VectorType with twice as many elements as the
   /// input type and the same element type.
   static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index 3038eb8dd6af6c..4bd7fda77f3132 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -133,7 +133,6 @@ namespace Intrinsic {
       ExtendArgument,
       TruncArgument,
       HalfVecArgument,
-      QuarterVecArgument,
       SameVecWidthArgument,
       VecOfAnyPtrsToElt,
       VecElementArgument,
@@ -163,7 +162,7 @@ namespace Intrinsic {
 
     unsigned getArgumentNumber() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument ||
              Kind == SameVecWidthArgument || Kind == VecElementArgument ||
              Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
              Kind == VecOfBitcastsToInt);
@@ -171,7 +170,7 @@ namespace Intrinsic {
     }
     ArgKind getArgumentKind() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument ||
              Kind == SameVecWidthArgument ||
              Kind == VecElementArgument || Kind == Subdivide2Argument ||
              Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index d872cd613ba6b7..0a74a217a5f010 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -321,7 +321,6 @@ def IIT_I4 : IIT_Int<4, 58>;
 def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
 def IIT_V6 : IIT_Vec<6, 60>;
 def IIT_V10 : IIT_Vec<10, 61>;
-def IIT_QUARTER_VEC_ARG : IIT_Base<62>;
 }
 
 defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
@@ -458,9 +457,6 @@ class LLVMVectorElementType<int num> : LLVMMatchType<num, IIT_VEC_ELEMENT>;
 class LLVMHalfElementsVectorType<int num>
   : LLVMMatchType<num, IIT_HALF_VEC_ARG>;
 
-class LLVMQuarterElementsVectorType<int num>
-  : LLVMMatchType<num, IIT_QUARTER_VEC_ARG>;
-
 // Match the type of another intrinsic parameter that is expected to be a
 // vector type (i.e. <N x iM>) but with each element subdivided to
 // form a vector with more elements that are smaller than the original.
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 00bcfdf8c45241..8767c2971f62c8 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1317,12 +1317,6 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
                                              ArgInfo));
     return;
   }
-  case IIT_QUARTER_VEC_ARG: {
-    unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
-    OutputTable.push_back(IITDescriptor::get(IITDescriptor::QuarterVecArgument,
-                                             ArgInfo));
-    return;
-  }
   case IIT_SAME_VEC_WIDTH_ARG: {
     unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
     OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument,
@@ -1490,9 +1484,6 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
   case IITDescriptor::HalfVecArgument:
     return VectorType::getHalfElementsVectorType(cast<VectorType>(
                                                   Tys[D.getArgumentNumber()]));
-  case IITDescriptor::QuarterVecArgument:  {
-    return VectorType::getQuarterElementsVectorType(cast<VectorType>(Tys[D.getArgumentNumber()]));
-  }
   case IITDescriptor::SameVecWidthArgument: {
     Type *EltTy = DecodeFixedType(Infos, Tys, Context);
     Type *Ty = Tys[D.getArgumentNumber()];
@@ -1724,13 +1715,6 @@ static bool matchIntrinsicType(
       return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
              VectorType::getHalfElementsVectorType(
                      cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
-    case IITDescriptor::QuarterVecArgument: {
-    if (D.getArgumentNumber() >= ArgTys.size())
-        return IsDeferredCheck || DeferCheck(Ty);
-      return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
-             VectorType::getQuarterElementsVectorType(
-                     cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
-    }
     case IITDescriptor::SameVecWidthArgument: {
       if (D.getArgumentNumber() >= ArgTys.size()) {
         // Defer check and subsequent check for the vector element type.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 62e6f1af684675..170c1f890b7ea1 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8690,7 +8690,7 @@ VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
     auto EC = ElementCount::getScalable(16);
     if(std::find(Range.begin(), Range.end(), EC) == Range.end())
       return nullptr;
-    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
+    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()), 4);
   }
   return nullptr;
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 3f11a5281efe5d..b22b08dc84bf09 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2198,15 +2198,16 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
 
 class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
   unsigned Opcode;
+  unsigned Scale;
 public:
   template <typename IterT>
   VPPartialReductionRecipe(Instruction &I,
-                           iterator_range<IterT> Operands) : VPRecipeWithIRFlags(
-    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode())
+                           iterator_range<IterT> Operands, unsigned Scale) : VPRecipeWithIRFlags(
+    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode()), Scale(Scale)
   {}
   ~VPPartialReductionRecipe() override = default;
   VPPartialReductionRecipe *clone() override {
-    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands());
+    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands(), Scale);
     R->transferFlags(*this);
     return R;
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 0ce350694508bb..0786d39aa7da65 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -343,7 +343,8 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
       assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
 
       ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
-      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), 4);
+      auto EC = FullTy->getElementCount();
+      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale).getKnownMinValue());
 
       Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
       switch(Opcode) {
@@ -375,7 +376,7 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
 }
 
 void VPPartialReductionRecipe::postInsertionOp() {
-  cast<VPReductionPHIRecipe>(this->getOperand(1))->SetVFScaleFactor(4);
+  cast<VPReductionPHIRecipe>(this->getOperand(1))->SetVFScaleFactor(Scale);
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
index 1eafd505b199eb..7883cfc05a13b3 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
@@ -33,7 +33,7 @@ define void @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
 ; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
 ; CHECK-NEXT:    [[TMP29:%.*]] = mul <vscale x 16 x i32> [[TMP27]], [[TMP19]]
-; CHECK-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP29]])
+; CHECK-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 16 x i32> [[TMP29]])
 ; CHECK-NEXT:    [[TMP14]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
 ; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]

>From 7a3144d37782951122bfe224375f88092f4437ca Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:12:04 +0100
Subject: [PATCH 04/54] Commit of test files

---
 .../CodeGen/AArch64/partial-reduce-sdot-ir.ll | 99 +++++++++++++++++++
 ...rtial-reduce-sdot.ll => partial-reduce.ll} |  0
 2 files changed, 99 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
 rename llvm/test/CodeGen/AArch64/{partial-reduce-sdot.ll => partial-reduce.ll} (100%)

diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
new file mode 100644
index 00000000000000..3519ba58b3df34
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes="default<O3>" -force-vector-interleave=1 -S < %s | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define void @dotp(ptr %out, ptr %a, ptr %b, i64 %wide.trip.count) #0 {
+; CHECK-LABEL: define void @dotp(
+; CHECK-SAME: ptr nocapture writeonly [[OUT:%.*]], ptr nocapture readonly [[A:%.*]], ptr nocapture readonly [[B:%.*]], i64 [[WIDE_TRIP_COUNT:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[WIDE_TRIP_COUNT]], 1
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i64 [[TMP1]], 4
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], [[TMP2]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[FOR_BODY_PREHEADER:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP4:%.*]] = shl i64 [[TMP3]], 4
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], [[TMP4]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP5:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP6:%.*]] = shl i64 [[TMP5]], 4
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 16 x i8>, ptr [[TMP7]], align 1
+; CHECK-NEXT:    [[TMP8:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 16 x i8>, ptr [[TMP9]], align 1
+; CHECK-NEXT:    [[TMP10:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD1]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = mul nuw nsw <vscale x 16 x i32> [[TMP10]], [[TMP8]]
+; CHECK-NEXT:    [[TMP12]] = add <vscale x 16 x i32> [[TMP11]], [[VEC_PHI]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP6]]
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP14:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP12]])
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[N_MOD_VF]], 0
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[FOR_BODY_PREHEADER]]
+; CHECK:       for.body.preheader:
+; CHECK-NEXT:    [[INDVARS_IV_PH:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[N_VEC]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[ACC_010_PH:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP14]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.cond.cleanup.loopexit:
+; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP14]], [[MIDDLE_BLOCK]] ], [ [[ADD:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[TMP15:%.*]] = trunc i32 [[ADD_LCSSA]] to i8
+; CHECK-NEXT:    store i8 [[TMP15]], ptr [[OUT]], align 1
+; CHECK-NEXT:    ret void
+; CHECK:       for.body:
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[INDVARS_IV_PH]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[ADD]], [[FOR_BODY]] ], [ [[ACC_010_PH]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP16:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
+; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP16]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP17:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP17]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw nsw i32 [[CONV3]], [[CONV]]
+; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
+; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
+; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV]], [[WIDE_TRIP_COUNT]]
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = trunc i32 %add to i8
+  store i8 %0, ptr %out, align 1
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv, %wide.trip.count
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+
+; uselistorder directives
+  uselistorder i32 %add, { 1, 0 }
+}
+
+attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
+;.
+; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
+; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
+; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
+; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
+;.
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll b/llvm/test/CodeGen/AArch64/partial-reduce.ll
similarity index 100%
rename from llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
rename to llvm/test/CodeGen/AArch64/partial-reduce.ll

>From d39d1d03e8d9e25bbe414950dc99fab8f83449aa Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:12:58 +0100
Subject: [PATCH 05/54] Add generic decomposition of partial reduction
 intrinsic

---
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 6067b3b29ea181..7741fc8ccec005 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -971,6 +971,11 @@ END_TWO_BYTE_PACK()
   inline const APInt &getAsAPIntVal() const;
 
   const SDValue &getOperand(unsigned Num) const {
+    if(Num >= NumOperands) {
+      dbgs() << Num << ">=" << NumOperands << "\n";
+      printr(dbgs());
+      dbgs() << "\n";
+    }
     assert(Num < NumOperands && "Invalid child # of SDNode!");
     return OperandList[Num];
   }

>From e8c9fefe92f05cf612bc41c8405a2faa8965abf4 Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:14:15 +0100
Subject: [PATCH 06/54] Add basic cost modeling for partial reductions

---
 llvm/include/llvm/Analysis/TargetTransformInfo.h   | 13 +++++++++++++
 .../llvm/Analysis/TargetTransformInfoImpl.h        |  7 +++++++
 llvm/include/llvm/CodeGen/BasicTTIImpl.h           |  7 +++++++
 llvm/lib/Analysis/TargetTransformInfo.cpp          |  6 ++++++
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp    | 14 +++++++++++++-
 5 files changed, 46 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index cd69a8a371b6e8..a30870b79e57bb 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1528,6 +1528,10 @@ class TargetTransformInfo {
       FastMathFlags FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
+  InstructionCost getPartialReductionCost(
+    unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
+    FastMathFlags FMF, TargetCostKind CostKind = TCK_RecipThroughput) const;
+
   /// \returns The cost of Intrinsic instructions. Analyses the real arguments.
   /// Three cases are handled: 1. scalar instruction 2. vector instruction
   /// 3. scalar instruction which is to be vectorized.
@@ -2106,6 +2110,9 @@ class TargetTransformInfo::Concept {
       unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
       FastMathFlags FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
+  virtual InstructionCost getPartialReductionCost(
+      unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
+      FastMathFlags FMF, TargetCostKind CostKind = TCK_RecipThroughput) = 0;
   virtual InstructionCost getMulAccReductionCost(
       bool IsUnsigned, Type *ResTy, VectorType *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
@@ -2794,6 +2801,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.getExtendedReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
                                          CostKind);
   }
+  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
+      VectorType *ResTy, VectorType *Ty, FastMathFlags FMF,
+      TargetCostKind CostKind = TCK_RecipThroughput) override {
+      return Impl.getPartialReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
+          CostKind);
+  }
   InstructionCost
   getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
                          TTI::TargetCostKind CostKind) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 79c8bafbc6c0df..26f478ba2581d4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -816,6 +816,13 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
+  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
+                                           VectorType *ResTy, VectorType *Ty,
+                                           FastMathFlags FMF,
+                                           TTI::TargetCostKind CostKind) const {
+    return InstructionCost::getMax();
+  }
+
   InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
                                          VectorType *Ty,
                                          TTI::TargetCostKind CostKind) const {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 7198e134a2d262..b5a8f461c16931 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -2732,6 +2732,13 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return RedCost + ExtCost;
   }
 
+  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
+                                          VectorType *ResTy, VectorType *Ty,
+                                          FastMathFlags FMD,
+                                          TTI::TargetCostKind CostKind) {
+    return InstructionCost::getMax();
+  }
+
   InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
                                          VectorType *Ty,
                                          TTI::TargetCostKind CostKind) {
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 67b626f300a101..21db0a3efa61d4 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1172,6 +1172,12 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
                                            CostKind);
 }
 
+InstructionCost TargetTransformInfo::getPartialReductionCost(
+  unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
+  FastMathFlags FMF, TargetCostKind CostKind) const {
+  return TTIImpl->getPartialReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF, CostKind);
+}
+
 InstructionCost TargetTransformInfo::getMulAccReductionCost(
     bool IsUnsigned, Type *ResTy, VectorType *Ty,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 170c1f890b7ea1..bea069340dab4f 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8690,7 +8690,19 @@ VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
     auto EC = ElementCount::getScalable(16);
     if(std::find(Range.begin(), Range.end(), EC) == Range.end())
       return nullptr;
-    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()), 4);
+
+    // Scale factor of 4 for sdot/udot.
+    unsigned Scale = 4;
+    VectorType* ResTy = ScalableVectorType::get(Instr->getType(), Scale);
+    VectorType* ValTy = ScalableVectorType::get(Instr->getType(), EC.getKnownMinValue());
+    using namespace llvm::PatternMatch;
+    bool IsUnsigned = match(Instr, m_Add(m_Mul(m_ZExt(m_Value()), m_ZExt(m_Value())), m_Value()));
+    auto RecipeCost = this->CM.TTI.getPartialReductionCost(Instr->getOpcode(), IsUnsigned, ResTy, ValTy, FastMathFlags::getFast());
+    // TODO replace with more informed cost check
+    if(RecipeCost == InstructionCost::getMax())
+      return nullptr;
+
+    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()), Scale);
   }
   return nullptr;
 }

>From 3b22b36924e41ada208a5b1b5b677ea21362ab65 Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:14:55 +0100
Subject: [PATCH 07/54] Add missing support for sign-extends

---
 .../lib/Transforms/Vectorize/LoopVectorize.cpp | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index bea069340dab4f..fbdf841331fbac 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2125,6 +2125,9 @@ static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*
 }
 
 
+/// Checks if the given instruction the root of a partial reduction chain
+///
+/// Note: This currently only supports udot/sdot chains
 /// @param Instr The root instruction to scan
 static bool isInstrPartialReduction(Instruction *Instr) {
   Value *ExpectedPhi;
@@ -2134,12 +2137,12 @@ static bool isInstrPartialReduction(Instruction *Instr) {
   using namespace llvm::PatternMatch;
   auto Pattern = m_Add(
     m_OneUse(m_Mul(
-      m_OneUse(m_ZExt(
+      m_OneUse(m_ZExtOrSExt(
         m_OneUse(m_Load(
           m_GEP(
               m_Value(A),
               m_Value(InductionA)))))),
-      m_OneUse(m_ZExt(
+      m_OneUse(m_ZExtOrSExt(
         m_OneUse(m_Load(
           m_GEP(
               m_Value(B),
@@ -2158,8 +2161,13 @@ static bool isInstrPartialReduction(Instruction *Instr) {
   }
 
   Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
-  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
-  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+  Instruction *Ext0 = cast<CastInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<CastInst>(Mul->getOperand(1));
+
+  if(Ext0->getOpcode() != Ext1->getOpcode()) {
+    LLVM_DEBUG(dbgs() << "Extends aren't of the same type, cannot create a partial reduction.\n");
+    return false;
+  }
 
   // Check that the extends extend to i32
   if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
@@ -8687,6 +8695,8 @@ VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
     VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
 
   if(isInstrPartialReduction(Instr)) {
+    // Restricting this case to 16x means that, using a scale of 4, we avoid
+    // trying to generate illegal types such as <vscale x 2 x i32>
     auto EC = ElementCount::getScalable(16);
     if(std::find(Range.begin(), Range.end(), EC) == Range.end())
       return nullptr;

>From b29b11279f22da61bf2afae2a61485ec210ea4d2 Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Wed, 5 Jun 2024 14:27:39 +0100
Subject: [PATCH 08/54] Remove debug statements

---
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 5 -----
 1 file changed, 5 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 7741fc8ccec005..6067b3b29ea181 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -971,11 +971,6 @@ END_TWO_BYTE_PACK()
   inline const APInt &getAsAPIntVal() const;
 
   const SDValue &getOperand(unsigned Num) const {
-    if(Num >= NumOperands) {
-      dbgs() << Num << ">=" << NumOperands << "\n";
-      printr(dbgs());
-      dbgs() << "\n";
-    }
     assert(Num < NumOperands && "Invalid child # of SDNode!");
     return OperandList[Num];
   }

>From c7eeb2a7d8ad54e3e7d8acb71ebe4a1c4a90b41a Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Fri, 14 Jun 2024 12:54:52 +0100
Subject: [PATCH 09/54] Redesign how the LoopVectorizer identifies partial
 reductions

---
 .../llvm/Analysis/TargetTransformInfo.h       |  27 +--
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  14 +-
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      |   7 -
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  13 +-
 .../Vectorize/LoopVectorizationPlanner.h      |  23 +++
 .../Transforms/Vectorize/LoopVectorize.cpp    | 177 +++++++++---------
 .../Transforms/Vectorize/VPRecipeBuilder.h    |   2 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |   2 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  10 +-
 .../CodeGen/AArch64/partial-reduce-sdot-ir.ll |  99 ----------
 llvm/test/CodeGen/AArch64/partial-reduce.ll   | 100 ----------
 11 files changed, 142 insertions(+), 332 deletions(-)
 delete mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
 delete mode 100644 llvm/test/CodeGen/AArch64/partial-reduce.ll

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index a30870b79e57bb..4efbd76689a69e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1248,6 +1248,8 @@ class TargetTransformInfo {
   /// \return if target want to issue a prefetch in address space \p AS.
   bool shouldPrefetchAddressSpace(unsigned AS) const;
 
+  bool isPartialReductionSupported(const Instruction* ReductionInstr, Type* InputType, unsigned ScaleFactor, bool IsInputASignExtended, bool IsInputBSignExtended, const Instruction* BinOp = nullptr) const;
+    
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism
   /// and the number of execution units in the CPU.
@@ -1528,10 +1530,6 @@ class TargetTransformInfo {
       FastMathFlags FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
-  InstructionCost getPartialReductionCost(
-    unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
-    FastMathFlags FMF, TargetCostKind CostKind = TCK_RecipThroughput) const;
-
   /// \returns The cost of Intrinsic instructions. Analyses the real arguments.
   /// Three cases are handled: 1. scalar instruction 2. vector instruction
   /// 3. scalar instruction which is to be vectorized.
@@ -2029,6 +2027,11 @@ class TargetTransformInfo::Concept {
   /// \return if target want to issue a prefetch in address space \p AS.
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
+  virtual bool isPartialReductionSupported(const Instruction* ReductionInstr,
+      Type* InputType, unsigned ScaleFactor,
+      bool IsInputASignExtended, bool IsInputBSignExtended,
+      const Instruction* BinOp = nullptr) const = 0;
+    
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
@@ -2110,9 +2113,6 @@ class TargetTransformInfo::Concept {
       unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
       FastMathFlags FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
-  virtual InstructionCost getPartialReductionCost(
-      unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
-      FastMathFlags FMF, TargetCostKind CostKind = TCK_RecipThroughput) = 0;
   virtual InstructionCost getMulAccReductionCost(
       bool IsUnsigned, Type *ResTy, VectorType *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
@@ -2666,6 +2666,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.shouldPrefetchAddressSpace(AS);
   }
 
+  bool isPartialReductionSupported(const Instruction* ReductionInstr, Type* InputType, unsigned ScaleFactor,
+                                              bool IsInputASignExtended, bool IsInputBSignExtended,
+                                              const Instruction* BinOp = nullptr) const override
+  {
+      return Impl.isPartialReductionSupported(ReductionInstr, InputType, ScaleFactor, IsInputASignExtended, IsInputBSignExtended, BinOp);
+  }
+
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
@@ -2801,12 +2808,6 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.getExtendedReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
                                          CostKind);
   }
-  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
-      VectorType *ResTy, VectorType *Ty, FastMathFlags FMF,
-      TargetCostKind CostKind = TCK_RecipThroughput) override {
-      return Impl.getPartialReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
-          CostKind);
-  }
   InstructionCost
   getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
                          TTI::TargetCostKind CostKind) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 26f478ba2581d4..283449a98d586e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -551,6 +551,13 @@ class TargetTransformInfoImplBase {
   bool enableWritePrefetching() const { return false; }
   bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
 
+  bool isPartialReductionSupported(const Instruction* ReductionInstr,
+      Type* InputType, unsigned ScaleFactor,
+      bool IsInputASignExtended, bool IsInputBSignExtended,
+      const Instruction* BinOp = nullptr) const {
+    return false;
+  }
+
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
 
   InstructionCost getArithmeticInstrCost(
@@ -816,13 +823,6 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
-  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
-                                           VectorType *ResTy, VectorType *Ty,
-                                           FastMathFlags FMF,
-                                           TTI::TargetCostKind CostKind) const {
-    return InstructionCost::getMax();
-  }
-
   InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
                                          VectorType *Ty,
                                          TTI::TargetCostKind CostKind) const {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index b5a8f461c16931..7198e134a2d262 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -2732,13 +2732,6 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return RedCost + ExtCost;
   }
 
-  InstructionCost getPartialReductionCost(unsigned Opcode, bool IsUnsigned,
-                                          VectorType *ResTy, VectorType *Ty,
-                                          FastMathFlags FMD,
-                                          TTI::TargetCostKind CostKind) {
-    return InstructionCost::getMax();
-  }
-
   InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
                                          VectorType *Ty,
                                          TTI::TargetCostKind CostKind) {
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 21db0a3efa61d4..4010d618c914fc 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -830,6 +830,13 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
   return TTIImpl->shouldPrefetchAddressSpace(AS);
 }
 
+bool TargetTransformInfo::isPartialReductionSupported(
+    const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+    bool IsInputASignExtended, bool IsInputBSignExtended,
+    const Instruction *BinOp) const {
+  return TTIImpl->isPartialReductionSupported(ReductionInstr, InputType, ScaleFactor, IsInputASignExtended, IsInputBSignExtended, BinOp);
+}
+
 unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
@@ -1172,12 +1179,6 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
                                            CostKind);
 }
 
-InstructionCost TargetTransformInfo::getPartialReductionCost(
-  unsigned Opcode, bool IsUnsigned, VectorType *ResTy, VectorType *Ty,
-  FastMathFlags FMF, TargetCostKind CostKind) const {
-  return TTIImpl->getPartialReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF, CostKind);
-}
-
 InstructionCost TargetTransformInfo::getMulAccReductionCost(
     bool IsUnsigned, Type *ResTy, VectorType *Ty,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 034fdf4233de37..58a9bf22904040 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -337,6 +337,19 @@ struct FixedScalableVFPair {
   bool hasVector() const { return FixedVF.isVector() || ScalableVF.isVector(); }
 };
 
+struct PartialReductionChain {
+  Instruction *Reduction;
+  Instruction *BinOp;
+  Instruction *ExtendA;
+  Instruction *ExtendB;
+  
+  Value *InputA;
+  Value *InputB;
+  Value *Accumulator;
+
+  unsigned ScaleFactor;
+};
+
 /// Planner drives the vectorization process after having passed
 /// Legality checks.
 class LoopVectorizationPlanner {
@@ -375,6 +388,8 @@ class LoopVectorizationPlanner {
   /// Profitable vector factors.
   SmallVector<VectorizationFactor, 8> ProfitableVFs;
 
+  SmallVector<PartialReductionChain> PartialReductionChains;
+
   /// A builder used to construct the current plan.
   VPBuilder Builder;
 
@@ -467,6 +482,10 @@ class LoopVectorizationPlanner {
   /// Emit remarks for recipes with invalid costs in the available VPlans.
   void emitInvalidCostRemarks(OptimizationRemarkEmitter *ORE);
 
+  SmallVector<PartialReductionChain> getPartialReductionChains() const {
+    return PartialReductionChains;
+  } 
+
 protected:
   /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive,
   /// according to the information gathered by Legal when it checked if it is
@@ -518,6 +537,10 @@ class LoopVectorizationPlanner {
   /// Determines if we have the infrastructure to vectorize the loop and its
   /// epilogue, assuming the main loop is vectorized by \p VF.
   bool isCandidateForEpilogueVectorization(const ElementCount VF) const;
+
+  bool getInstructionsPartialReduction(Instruction* I, PartialReductionChain &Chain) const;
+
+  
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index fbdf841331fbac..51a6e3215ed7c3 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1060,7 +1060,7 @@ class LoopVectorizationCostModel {
   calculateRegisterUsage(ArrayRef<ElementCount> VFs);
 
   /// Collect values we want to ignore in the cost model.
-  void collectValuesToIgnore();
+  void collectValuesToIgnore(LoopVectorizationPlanner *LVP);
 
   /// Collect all element types in the loop for which widening is needed.
   void collectElementTypesForWidening();
@@ -2113,73 +2113,59 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
          Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
 }
 
-static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
-  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
-  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
-  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+static PartialReductionChain getPartialReductionInstrChain(Instruction *Instr) {
+  Instruction *BinOp = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<Instruction>(BinOp->getOperand(0));
+  Instruction *Ext1 = cast<Instruction>(BinOp->getOperand(1));
 
-  Chain.push_back(Mul);
-  Chain.push_back(Ext0);
-  Chain.push_back(Ext1);
-  Chain.push_back(Instr->getOperand(1));
+  PartialReductionChain Chain;
+  Chain.Reduction = Instr;
+  Chain.BinOp = BinOp;
+  Chain.ExtendA = Ext0;
+  Chain.ExtendB = Ext1;
+  Chain.InputA = Ext0->getOperand(0);
+  Chain.InputB = Ext1->getOperand(0);
+  Chain.Accumulator = Instr->getOperand(1);
+
+  unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
+  unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
+  Chain.ScaleFactor = ResultSizeBits / InputSizeBits; 
+  return Chain;
 }
 
 
 /// Checks if the given instruction the root of a partial reduction chain
 ///
-/// Note: This currently only supports udot/sdot chains
 /// @param Instr The root instruction to scan
 static bool isInstrPartialReduction(Instruction *Instr) {
   Value *ExpectedPhi;
   Value *A, *B;
-  Value *InductionA, *InductionB;
 
   using namespace llvm::PatternMatch;
-  auto Pattern = m_Add(
-    m_OneUse(m_Mul(
-      m_OneUse(m_ZExtOrSExt(
-        m_OneUse(m_Load(
-          m_GEP(
-              m_Value(A),
-              m_Value(InductionA)))))),
-      m_OneUse(m_ZExtOrSExt(
-        m_OneUse(m_Load(
-          m_GEP(
-              m_Value(B),
-              m_Value(InductionB))))))
-        )), m_Value(ExpectedPhi));
+  auto Pattern = m_BinOp(
+      m_OneUse(m_BinOp(
+        m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
+        m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
+      m_Value(ExpectedPhi));
 
   bool Matches = match(Instr, Pattern);
 
   if(!Matches)
     return false;
 
-  // Check that the two induction variable uses are to the same induction variable
-  if(InductionA != InductionB) {
-    LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
-    return false;
-  }
-
-  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
-  Instruction *Ext0 = cast<CastInst>(Mul->getOperand(0));
-  Instruction *Ext1 = cast<CastInst>(Mul->getOperand(1));
-
-  if(Ext0->getOpcode() != Ext1->getOpcode()) {
-    LLVM_DEBUG(dbgs() << "Extends aren't of the same type, cannot create a partial reduction.\n");
+  // Check that the extends extend from the same type
+  if(A->getType() != B->getType()) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot create a partial reduction.\n");
     return false;
   }
 
-  // Check that the extends extend to i32
-  if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
-    LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
-    return false;
-  }
+  // A and B are one-use, so the first user of each should be the respective extend 
+  Instruction *Ext0 = cast<CastInst>(*A->user_begin());
+  Instruction *Ext1 = cast<CastInst>(*B->user_begin());
 
-  // Check that the loads are loading i8
-  LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
-  LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
-  if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
-    LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
+  // Check that the extends extend to the same type
+  if(Ext0->getType() != Ext1->getType()) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend to the same type, cannot create a partial reduction.\n");
     return false;
   }
 
@@ -2190,23 +2176,32 @@ static bool isInstrPartialReduction(Instruction *Instr) {
     return false;
   }
 
-  // Check that the first phi value is a zero initializer
-  ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
-  if(!ZeroInit || !ZeroInit->isZero()) {
-    LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
-    return false;
-  }
-
   // Check that the second phi value is the instruction we're looking at
   Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
   if(!MaybeAdd || MaybeAdd != Instr) {
-    LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
+    LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot create a partial reduction.\n");
     return false;
   }
 
   return true;
 }
 
+static bool isPartialReductionChainValid(PartialReductionChain &Chain, const TargetTransformInfo &TTI) {
+  if(Chain.Reduction->getOpcode() != Instruction::Add)
+    return false;
+
+  unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
+  unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
+
+  if(ResultSizeBits < InputSizeBits || (ResultSizeBits % InputSizeBits) != 0)
+    return false;
+  
+  bool IsASignExtended = isa<SExtInst>(Chain.ExtendA);
+  bool IsBSignExtended = isa<SExtInst>(Chain.ExtendB);
+
+  return TTI.isPartialReductionSupported(Chain.Reduction, Chain.InputA->getType(), Chain.ScaleFactor, IsASignExtended, IsBSignExtended, Chain.BinOp);
+}
+
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
 // vectorization. The loop needs to be annotated with #pragma omp simd
 // simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -4742,6 +4737,16 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
   return true;
 }
 
+bool LoopVectorizationPlanner::getInstructionsPartialReduction(Instruction *I, PartialReductionChain &Chain) const {
+  for(auto &C : PartialReductionChains) {
+    if(C.Reduction == I) {
+      Chain = C;
+      return true;
+    }
+  }
+  return false;
+}
+
 bool LoopVectorizationCostModel::isEpilogueVectorizationProfitable(
     const ElementCount VF) const {
   // FIXME: We need a much better cost-model to take different parameters such
@@ -6837,7 +6842,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I,
   } // end of switch.
 }
 
-void LoopVectorizationCostModel::collectValuesToIgnore() {
+void LoopVectorizationCostModel::collectValuesToIgnore(LoopVectorizationPlanner* LVP) {
   // Ignore ephemeral values.
   CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore);
 
@@ -6994,14 +6999,10 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
   }
 
   // Ignore any values that we know will be flattened
-  for(auto Reduction : this->Legal->getReductionVars()) {
-    auto &Recurrence = Reduction.second;
-    if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
-      SmallVector<Value*, 4> PartialReductionValues;
-      getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
-      ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
-      VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
-    }
+  for(auto Chain : LVP->getPartialReductionChains()) {
+    SmallVector<Value*> PartialReductionValues{Chain.Reduction, Chain.BinOp, Chain.ExtendA, Chain.ExtendB, Chain.Accumulator};
+    ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+    VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
   }
 }
 
@@ -7119,7 +7120,17 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
 
 void LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
-  CM.collectValuesToIgnore();
+
+  for(auto ReductionVar : Legal->getReductionVars()) {
+    auto *ReductionExitInstr = ReductionVar.second.getLoopExitInstr();
+    if(isInstrPartialReduction(ReductionExitInstr)) {
+      auto Chain = getPartialReductionInstrChain(ReductionExitInstr);
+      if(isPartialReductionChainValid(Chain, TTI)) 
+        PartialReductionChains.push_back(Chain);
+    }
+  }
+  
+  CM.collectValuesToIgnore(this);
   CM.collectElementTypesForWidening();
 
   FixedScalableVFPair MaxFactors = CM.computeMaxVF(UserVF, UserIC);
@@ -8685,36 +8696,12 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
                                  *CI);
   }
 
-  if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
-    return PartialReduce;
-
   return tryToWiden(Instr, Operands, VPBB);
 }
 
 VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
-    VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
-
-  if(isInstrPartialReduction(Instr)) {
-    // Restricting this case to 16x means that, using a scale of 4, we avoid
-    // trying to generate illegal types such as <vscale x 2 x i32>
-    auto EC = ElementCount::getScalable(16);
-    if(std::find(Range.begin(), Range.end(), EC) == Range.end())
-      return nullptr;
-
-    // Scale factor of 4 for sdot/udot.
-    unsigned Scale = 4;
-    VectorType* ResTy = ScalableVectorType::get(Instr->getType(), Scale);
-    VectorType* ValTy = ScalableVectorType::get(Instr->getType(), EC.getKnownMinValue());
-    using namespace llvm::PatternMatch;
-    bool IsUnsigned = match(Instr, m_Add(m_Mul(m_ZExt(m_Value()), m_ZExt(m_Value())), m_Value()));
-    auto RecipeCost = this->CM.TTI.getPartialReductionCost(Instr->getOpcode(), IsUnsigned, ResTy, ValTy, FastMathFlags::getFast());
-    // TODO replace with more informed cost check
-    if(RecipeCost == InstructionCost::getMax())
-      return nullptr;
-
-    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()), Scale);
-  }
-  return nullptr;
+    VFRange &Range, PartialReductionChain &Chain, ArrayRef<VPValue *> Operands) {
+  return new VPPartialReductionRecipe(*Chain.Reduction, make_range(Operands.begin(), Operands.end()), Chain.ScaleFactor);
 }
 
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
@@ -9100,8 +9087,14 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
           Legal->isInvariantAddressOfReduction(SI->getPointerOperand()))
         continue;
 
-      VPRecipeBase *Recipe =
-          RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB);
+      VPRecipeBase *Recipe = nullptr;
+
+      PartialReductionChain Chain;
+      if(getInstructionsPartialReduction(Instr, Chain)) 
+        Recipe = RecipeBuilder.tryToCreatePartialReduction(Range, Chain, Operands);
+      
+      if (!Recipe)
+        Recipe = RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB);
       if (!Recipe)
         Recipe = RecipeBuilder.handleReplication(Instr, Range);
 
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 92724daaa26912..eecb5ff3b49646 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -117,7 +117,7 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
-  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, Instruction* Instr, ArrayRef<VPValue*> Operands);
+  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, PartialReductionChain &Chain, ArrayRef<VPValue*> Operands);
 
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index b22b08dc84bf09..2b468455ea0a25 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1021,7 +1021,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     FastMathFlagsTy(const FastMathFlags &FMF);
   };
 
+public:
   OperationType OpType;
+private:
 
   union {
     CmpInst::Predicate CmpPredicate;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 0786d39aa7da65..1ba44e738abeb0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -340,11 +340,10 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
       }
 
       assert(Phi && Mul && "Phi and Mul must be set");
-      assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
 
-      ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
+      VectorType *FullTy = cast<VectorType>(Ops[0]->getType());
       auto EC = FullTy->getElementCount();
-      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale).getKnownMinValue());
+      Type *RetTy = VectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale));
 
       Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
       switch(Opcode) {
@@ -358,10 +357,7 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
 
       assert(PartialIntrinsic != Intrinsic::not_intrinsic);
 
-      Value *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, Mul, nullptr, Twine("partial.reduce"));
-      V = Builder.CreateNAryOp(Opcode, {V, Phi});
-      if (auto *VecOp = dyn_cast<Instruction>(V))
-        setFlags(VecOp);
+      CallInst *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, {Phi, Mul}, nullptr, Twine("partial.reduce"));
 
       // Use this vector value for all users of the original instruction.
       State.set(this, V, Part);
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
deleted file mode 100644
index 3519ba58b3df34..00000000000000
--- a/llvm/test/CodeGen/AArch64/partial-reduce-sdot-ir.ll
+++ /dev/null
@@ -1,99 +0,0 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt -passes="default<O3>" -force-vector-interleave=1 -S < %s | FileCheck %s
-
-target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
-target triple = "aarch64-none-unknown-elf"
-
-define void @dotp(ptr %out, ptr %a, ptr %b, i64 %wide.trip.count) #0 {
-; CHECK-LABEL: define void @dotp(
-; CHECK-SAME: ptr nocapture writeonly [[OUT:%.*]], ptr nocapture readonly [[A:%.*]], ptr nocapture readonly [[B:%.*]], i64 [[WIDE_TRIP_COUNT:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[WIDE_TRIP_COUNT]], 1
-; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP2:%.*]] = shl i64 [[TMP1]], 4
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], [[TMP2]]
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[FOR_BODY_PREHEADER:%.*]], label [[VECTOR_PH:%.*]]
-; CHECK:       vector.ph:
-; CHECK-NEXT:    [[TMP3:%.*]] = tail call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP4:%.*]] = shl i64 [[TMP3]], 4
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], [[TMP4]]
-; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
-; CHECK-NEXT:    [[TMP5:%.*]] = tail call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP6:%.*]] = shl i64 [[TMP5]], 4
-; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
-; CHECK:       vector.body:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 16 x i8>, ptr [[TMP7]], align 1
-; CHECK-NEXT:    [[TMP8:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD]] to <vscale x 16 x i32>
-; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 16 x i8>, ptr [[TMP9]], align 1
-; CHECK-NEXT:    [[TMP10:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD1]] to <vscale x 16 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul nuw nsw <vscale x 16 x i32> [[TMP10]], [[TMP8]]
-; CHECK-NEXT:    [[TMP12]] = add <vscale x 16 x i32> [[TMP11]], [[VEC_PHI]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP6]]
-; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
-; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP14:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP12]])
-; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[N_MOD_VF]], 0
-; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[FOR_BODY_PREHEADER]]
-; CHECK:       for.body.preheader:
-; CHECK-NEXT:    [[INDVARS_IV_PH:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[N_VEC]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    [[ACC_010_PH:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP14]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
-; CHECK:       for.cond.cleanup.loopexit:
-; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP14]], [[MIDDLE_BLOCK]] ], [ [[ADD:%.*]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[TMP15:%.*]] = trunc i32 [[ADD_LCSSA]] to i8
-; CHECK-NEXT:    store i8 [[TMP15]], ptr [[OUT]], align 1
-; CHECK-NEXT:    ret void
-; CHECK:       for.body:
-; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[INDVARS_IV_PH]], [[FOR_BODY_PREHEADER]] ]
-; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[ADD]], [[FOR_BODY]] ], [ [[ACC_010_PH]], [[FOR_BODY_PREHEADER]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP16:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
-; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP16]] to i32
-; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP17:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
-; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP17]] to i32
-; CHECK-NEXT:    [[MUL:%.*]] = mul nuw nsw i32 [[CONV3]], [[CONV]]
-; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
-; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
-; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV]], [[WIDE_TRIP_COUNT]]
-; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
-;
-entry:
-  br label %for.body
-
-for.cond.cleanup.loopexit:                        ; preds = %for.body
-  %0 = trunc i32 %add to i8
-  store i8 %0, ptr %out, align 1
-  ret void
-
-for.body:                                         ; preds = %for.body, %entry
-  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
-  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
-  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
-  %1 = load i8, ptr %arrayidx, align 1
-  %conv = zext i8 %1 to i32
-  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
-  %2 = load i8, ptr %arrayidx2, align 1
-  %conv3 = zext i8 %2 to i32
-  %mul = mul i32 %conv3, %conv
-  %add = add i32 %mul, %acc.010
-  %indvars.iv.next = add i64 %indvars.iv, 1
-  %exitcond.not = icmp eq i64 %indvars.iv, %wide.trip.count
-  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
-
-; uselistorder directives
-  uselistorder i32 %add, { 1, 0 }
-}
-
-attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
-;.
-; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
-; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
-; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
-; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
-;.
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce.ll b/llvm/test/CodeGen/AArch64/partial-reduce.ll
deleted file mode 100644
index 7883cfc05a13b3..00000000000000
--- a/llvm/test/CodeGen/AArch64/partial-reduce.ll
+++ /dev/null
@@ -1,100 +0,0 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt -passes=loop-vectorize -force-vector-interleave=1 -S < %s | FileCheck %s
-
-target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
-target triple = "aarch64-none-unknown-elf"
-
-define void @dotp(ptr %a, ptr %b) #0 {
-; CHECK-LABEL: define void @dotp(
-; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 16
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 0, [[TMP1]]
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
-; CHECK:       vector.ph:
-; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 16
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 0, [[TMP3]]
-; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 0, [[N_MOD_VF]]
-; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 16
-; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
-; CHECK:       vector.body:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 0
-; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP11]]
-; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 16 x i8>, ptr [[TMP17]], align 1
-; CHECK-NEXT:    [[TMP19:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD2]] to <vscale x 16 x i32>
-; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP11]]
-; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr i8, ptr [[TMP21]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
-; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
-; CHECK-NEXT:    [[TMP29:%.*]] = mul <vscale x 16 x i32> [[TMP27]], [[TMP19]]
-; CHECK-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 16 x i32> [[TMP29]])
-; CHECK-NEXT:    [[TMP14]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE]], [[VEC_PHI]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
-; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP32]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
-; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP33:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP14]])
-; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
-; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
-; CHECK:       scalar.ph:
-; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
-; CHECK:       for.cond.cleanup.loopexit:
-; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    [[TMP20:%.*]] = lshr i32 [[ADD_LCSSA]], 0
-; CHECK-NEXT:    ret void
-; CHECK:       for.body:
-; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ], [ [[ADD]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP18:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
-; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP18]] to i32
-; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP22:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
-; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP22]] to i32
-; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[CONV3]], [[CONV]]
-; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
-; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
-; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], 0
-; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
-;
-entry:
-  br label %for.body
-
-for.cond.cleanup.loopexit:                        ; preds = %for.body
-  %0 = lshr i32 %add, 0
-  ret void
-
-for.body:                                         ; preds = %for.body, %entry
-  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
-  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
-  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
-  %1 = load i8, ptr %arrayidx, align 1
-  %conv = zext i8 %1 to i32
-  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
-  %2 = load i8, ptr %arrayidx2, align 1
-  %conv3 = zext i8 %2 to i32
-  %mul = mul i32 %conv3, %conv
-  %add = add i32 %mul, %acc.010
-  %indvars.iv.next = add i64 %indvars.iv, 1
-  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
-  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
-
-; uselistorder directives
-  uselistorder i32 %add, { 1, 0 }
-}
-
-attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
-;.
-; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
-; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
-; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
-; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
-;.

>From f7338843eaf16d4664e933f2c0d2bb99fca7ca75 Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Thu, 4 Jul 2024 16:07:43 +0100
Subject: [PATCH 10/54] Format

---
 .../llvm/Analysis/TargetTransformInfo.h       |  28 +++--
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   9 +-
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   4 +-
 .../Vectorize/LoopVectorizationPlanner.h      |   9 +-
 .../Transforms/Vectorize/LoopVectorize.cpp    | 110 +++++++++++-------
 .../Transforms/Vectorize/VPRecipeBuilder.h    |   4 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |  25 ++--
 .../Transforms/Vectorize/VPlanAnalysis.cpp    |   3 +-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  30 ++---
 9 files changed, 126 insertions(+), 96 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 4efbd76689a69e..6e31686483cc11 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1248,8 +1248,12 @@ class TargetTransformInfo {
   /// \return if target want to issue a prefetch in address space \p AS.
   bool shouldPrefetchAddressSpace(unsigned AS) const;
 
-  bool isPartialReductionSupported(const Instruction* ReductionInstr, Type* InputType, unsigned ScaleFactor, bool IsInputASignExtended, bool IsInputBSignExtended, const Instruction* BinOp = nullptr) const;
-    
+  bool isPartialReductionSupported(const Instruction *ReductionInstr,
+                                   Type *InputType, unsigned ScaleFactor,
+                                   bool IsInputASignExtended,
+                                   bool IsInputBSignExtended,
+                                   const Instruction *BinOp = nullptr) const;
+
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism
   /// and the number of execution units in the CPU.
@@ -2027,11 +2031,11 @@ class TargetTransformInfo::Concept {
   /// \return if target want to issue a prefetch in address space \p AS.
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
-  virtual bool isPartialReductionSupported(const Instruction* ReductionInstr,
-      Type* InputType, unsigned ScaleFactor,
+  virtual bool isPartialReductionSupported(
+      const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
       bool IsInputASignExtended, bool IsInputBSignExtended,
-      const Instruction* BinOp = nullptr) const = 0;
-    
+      const Instruction *BinOp = nullptr) const = 0;
+
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
@@ -2666,11 +2670,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.shouldPrefetchAddressSpace(AS);
   }
 
-  bool isPartialReductionSupported(const Instruction* ReductionInstr, Type* InputType, unsigned ScaleFactor,
-                                              bool IsInputASignExtended, bool IsInputBSignExtended,
-                                              const Instruction* BinOp = nullptr) const override
-  {
-      return Impl.isPartialReductionSupported(ReductionInstr, InputType, ScaleFactor, IsInputASignExtended, IsInputBSignExtended, BinOp);
+  bool isPartialReductionSupported(
+      const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+      bool IsInputASignExtended, bool IsInputBSignExtended,
+      const Instruction *BinOp = nullptr) const override {
+    return Impl.isPartialReductionSupported(ReductionInstr, InputType,
+                                            ScaleFactor, IsInputASignExtended,
+                                            IsInputBSignExtended, BinOp);
   }
 
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 283449a98d586e..ad72bcddd242e6 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -551,10 +551,11 @@ class TargetTransformInfoImplBase {
   bool enableWritePrefetching() const { return false; }
   bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
 
-  bool isPartialReductionSupported(const Instruction* ReductionInstr,
-      Type* InputType, unsigned ScaleFactor,
-      bool IsInputASignExtended, bool IsInputBSignExtended,
-      const Instruction* BinOp = nullptr) const {
+  bool isPartialReductionSupported(const Instruction *ReductionInstr,
+                                   Type *InputType, unsigned ScaleFactor,
+                                   bool IsInputASignExtended,
+                                   bool IsInputBSignExtended,
+                                   const Instruction *BinOp = nullptr) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 4010d618c914fc..964259eccb7807 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -834,7 +834,9 @@ bool TargetTransformInfo::isPartialReductionSupported(
     const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
     bool IsInputASignExtended, bool IsInputBSignExtended,
     const Instruction *BinOp) const {
-  return TTIImpl->isPartialReductionSupported(ReductionInstr, InputType, ScaleFactor, IsInputASignExtended, IsInputBSignExtended, BinOp);
+  return TTIImpl->isPartialReductionSupported(ReductionInstr, InputType,
+                                              ScaleFactor, IsInputASignExtended,
+                                              IsInputBSignExtended, BinOp);
 }
 
 unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 58a9bf22904040..1237cf385a3ebe 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -342,7 +342,7 @@ struct PartialReductionChain {
   Instruction *BinOp;
   Instruction *ExtendA;
   Instruction *ExtendB;
-  
+
   Value *InputA;
   Value *InputB;
   Value *Accumulator;
@@ -484,7 +484,7 @@ class LoopVectorizationPlanner {
 
   SmallVector<PartialReductionChain> getPartialReductionChains() const {
     return PartialReductionChains;
-  } 
+  }
 
 protected:
   /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive,
@@ -538,9 +538,8 @@ class LoopVectorizationPlanner {
   /// epilogue, assuming the main loop is vectorized by \p VF.
   bool isCandidateForEpilogueVectorization(const ElementCount VF) const;
 
-  bool getInstructionsPartialReduction(Instruction* I, PartialReductionChain &Chain) const;
-
-  
+  bool getInstructionsPartialReduction(Instruction *I,
+                                       PartialReductionChain &Chain) const;
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 51a6e3215ed7c3..a604a6d641146a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2129,11 +2129,10 @@ static PartialReductionChain getPartialReductionInstrChain(Instruction *Instr) {
 
   unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
   unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
-  Chain.ScaleFactor = ResultSizeBits / InputSizeBits; 
+  Chain.ScaleFactor = ResultSizeBits / InputSizeBits;
   return Chain;
 }
 
-
 /// Checks if the given instruction the root of a partial reduction chain
 ///
 /// @param Instr The root instruction to scan
@@ -2142,64 +2141,71 @@ static bool isInstrPartialReduction(Instruction *Instr) {
   Value *A, *B;
 
   using namespace llvm::PatternMatch;
-  auto Pattern = m_BinOp(
-      m_OneUse(m_BinOp(
-        m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
-        m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
-      m_Value(ExpectedPhi));
+  auto Pattern =
+      m_BinOp(m_OneUse(m_BinOp(m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
+                               m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
+              m_Value(ExpectedPhi));
 
   bool Matches = match(Instr, Pattern);
 
-  if(!Matches)
+  if (!Matches)
     return false;
 
   // Check that the extends extend from the same type
-  if(A->getType() != B->getType()) {
-    LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot create a partial reduction.\n");
+  if (A->getType() != B->getType()) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot "
+                         "create a partial reduction.\n");
     return false;
   }
 
-  // A and B are one-use, so the first user of each should be the respective extend 
+  // A and B are one-use, so the first user of each should be the respective
+  // extend
   Instruction *Ext0 = cast<CastInst>(*A->user_begin());
   Instruction *Ext1 = cast<CastInst>(*B->user_begin());
 
   // Check that the extends extend to the same type
-  if(Ext0->getType() != Ext1->getType()) {
-    LLVM_DEBUG(dbgs() << "Extends don't extend to the same type, cannot create a partial reduction.\n");
+  if (Ext0->getType() != Ext1->getType()) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend to the same type, cannot create "
+                         "a partial reduction.\n");
     return false;
   }
 
   // Check that the add feeds into ExpectedPhi
   PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
-  if(!PhiNode) {
-    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
+  if (!PhiNode) {
+    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a "
+                         "partial reduction.\n");
     return false;
   }
 
   // Check that the second phi value is the instruction we're looking at
   Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
-  if(!MaybeAdd || MaybeAdd != Instr) {
-    LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot create a partial reduction.\n");
+  if (!MaybeAdd || MaybeAdd != Instr) {
+    LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot "
+                         "create a partial reduction.\n");
     return false;
   }
 
   return true;
 }
 
-static bool isPartialReductionChainValid(PartialReductionChain &Chain, const TargetTransformInfo &TTI) {
-  if(Chain.Reduction->getOpcode() != Instruction::Add)
+static bool isPartialReductionChainValid(PartialReductionChain &Chain,
+                                         const TargetTransformInfo &TTI) {
+  if (Chain.Reduction->getOpcode() != Instruction::Add)
     return false;
 
   unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
   unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
 
-  if(ResultSizeBits < InputSizeBits || (ResultSizeBits % InputSizeBits) != 0)
+  if (ResultSizeBits < InputSizeBits || (ResultSizeBits % InputSizeBits) != 0)
     return false;
-  
+
   bool IsASignExtended = isa<SExtInst>(Chain.ExtendA);
   bool IsBSignExtended = isa<SExtInst>(Chain.ExtendB);
 
-  return TTI.isPartialReductionSupported(Chain.Reduction, Chain.InputA->getType(), Chain.ScaleFactor, IsASignExtended, IsBSignExtended, Chain.BinOp);
+  return TTI.isPartialReductionSupported(
+      Chain.Reduction, Chain.InputA->getType(), Chain.ScaleFactor,
+      IsASignExtended, IsBSignExtended, Chain.BinOp);
 }
 
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
@@ -4723,9 +4729,11 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
 
   // Prevent epilogue vectorization if a partial reduction is involved
   // TODO Is there a cleaner way to check this?
-  if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
-    return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
-  }))
+  if (any_of(Legal->getReductionVars(),
+             [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
+               return isInstrPartialReduction(
+                   Reduction.second.getLoopExitInstr());
+             }))
     return false;
 
   // Epilogue vectorization code has not been auditted to ensure it handles
@@ -4737,9 +4745,10 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
   return true;
 }
 
-bool LoopVectorizationPlanner::getInstructionsPartialReduction(Instruction *I, PartialReductionChain &Chain) const {
-  for(auto &C : PartialReductionChains) {
-    if(C.Reduction == I) {
+bool LoopVectorizationPlanner::getInstructionsPartialReduction(
+    Instruction *I, PartialReductionChain &Chain) const {
+  for (auto &C : PartialReductionChains) {
+    if (C.Reduction == I) {
       Chain = C;
       return true;
     }
@@ -6842,7 +6851,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I,
   } // end of switch.
 }
 
-void LoopVectorizationCostModel::collectValuesToIgnore(LoopVectorizationPlanner* LVP) {
+void LoopVectorizationCostModel::collectValuesToIgnore(
+    LoopVectorizationPlanner *LVP) {
   // Ignore ephemeral values.
   CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore);
 
@@ -6999,10 +7009,14 @@ void LoopVectorizationCostModel::collectValuesToIgnore(LoopVectorizationPlanner*
   }
 
   // Ignore any values that we know will be flattened
-  for(auto Chain : LVP->getPartialReductionChains()) {
-    SmallVector<Value*> PartialReductionValues{Chain.Reduction, Chain.BinOp, Chain.ExtendA, Chain.ExtendB, Chain.Accumulator};
-    ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
-    VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+  for (auto Chain : LVP->getPartialReductionChains()) {
+    SmallVector<Value *> PartialReductionValues{Chain.Reduction, Chain.BinOp,
+                                                Chain.ExtendA, Chain.ExtendB,
+                                                Chain.Accumulator};
+    ValuesToIgnore.insert(PartialReductionValues.begin(),
+                          PartialReductionValues.end());
+    VecValuesToIgnore.insert(PartialReductionValues.begin(),
+                             PartialReductionValues.end());
   }
 }
 
@@ -7121,15 +7135,15 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
 void LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
 
-  for(auto ReductionVar : Legal->getReductionVars()) {
+  for (auto ReductionVar : Legal->getReductionVars()) {
     auto *ReductionExitInstr = ReductionVar.second.getLoopExitInstr();
-    if(isInstrPartialReduction(ReductionExitInstr)) {
+    if (isInstrPartialReduction(ReductionExitInstr)) {
       auto Chain = getPartialReductionInstrChain(ReductionExitInstr);
-      if(isPartialReductionChainValid(Chain, TTI)) 
+      if (isPartialReductionChainValid(Chain, TTI))
         PartialReductionChains.push_back(Chain);
     }
   }
-  
+
   CM.collectValuesToIgnore(this);
   CM.collectElementTypesForWidening();
 
@@ -8699,9 +8713,13 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
   return tryToWiden(Instr, Operands, VPBB);
 }
 
-VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
-    VFRange &Range, PartialReductionChain &Chain, ArrayRef<VPValue *> Operands) {
-  return new VPPartialReductionRecipe(*Chain.Reduction, make_range(Operands.begin(), Operands.end()), Chain.ScaleFactor);
+VPRecipeBase *
+VPRecipeBuilder::tryToCreatePartialReduction(VFRange &Range,
+                                             PartialReductionChain &Chain,
+                                             ArrayRef<VPValue *> Operands) {
+  return new VPPartialReductionRecipe(
+      *Chain.Reduction, make_range(Operands.begin(), Operands.end()),
+      Chain.ScaleFactor);
 }
 
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
@@ -9090,11 +9108,13 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
       VPRecipeBase *Recipe = nullptr;
 
       PartialReductionChain Chain;
-      if(getInstructionsPartialReduction(Instr, Chain)) 
-        Recipe = RecipeBuilder.tryToCreatePartialReduction(Range, Chain, Operands);
-      
+      if (getInstructionsPartialReduction(Instr, Chain))
+        Recipe =
+            RecipeBuilder.tryToCreatePartialReduction(Range, Chain, Operands);
+
       if (!Recipe)
-        Recipe = RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB);
+        Recipe =
+            RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB);
       if (!Recipe)
         Recipe = RecipeBuilder.handleReplication(Instr, Range);
 
@@ -9116,7 +9136,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         VPBB->appendRecipe(Recipe);
     }
 
-    for(auto &Recipe : *VPBB)
+    for (auto &Recipe : *VPBB)
       Recipe.postInsertionOp();
 
     VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index eecb5ff3b49646..5c15e6f16bd082 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -117,7 +117,9 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
-  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, PartialReductionChain &Chain, ArrayRef<VPValue*> Operands);
+  VPRecipeBase *tryToCreatePartialReduction(VFRange &Range,
+                                            PartialReductionChain &Chain,
+                                            ArrayRef<VPValue *> Operands);
 
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 2b468455ea0a25..a6821b9d014613 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1021,9 +1021,7 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     FastMathFlagsTy(const FastMathFlags &FMF);
   };
 
-public:
   OperationType OpType;
-private:
 
   union {
     CmpInst::Predicate CmpPredicate;
@@ -2146,7 +2144,6 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
   unsigned VFScaleFactor = 1;
 
 public:
-
   /// Create a new VPReductionPHIRecipe for the reduction \p Phi described by \p
   /// RdxDesc.
   VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc,
@@ -2161,9 +2158,9 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
   ~VPReductionPHIRecipe() override = default;
 
   VPReductionPHIRecipe *clone() override {
-    auto *R =
-        new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc,
-                                 *getOperand(0), IsInLoop, IsOrdered, VFScaleFactor);
+    auto *R = new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()),
+                                       RdxDesc, *getOperand(0), IsInLoop,
+                                       IsOrdered, VFScaleFactor);
     R->addOperand(getBackedgeValue());
     return R;
   }
@@ -2174,9 +2171,7 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
     return R->getVPDefID() == VPDef::VPReductionPHISC;
   }
 
-  void SetVFScaleFactor(unsigned ScaleFactor) {
-    VFScaleFactor = ScaleFactor;
-  }
+  void SetVFScaleFactor(unsigned ScaleFactor) { VFScaleFactor = ScaleFactor; }
 
   /// Generate the phi/select nodes.
   void execute(VPTransformState &State) override;
@@ -2201,15 +2196,17 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
 class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
   unsigned Opcode;
   unsigned Scale;
+
 public:
   template <typename IterT>
-  VPPartialReductionRecipe(Instruction &I,
-                           iterator_range<IterT> Operands, unsigned Scale) : VPRecipeWithIRFlags(
-    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode()), Scale(Scale)
-  {}
+  VPPartialReductionRecipe(Instruction &I, iterator_range<IterT> Operands,
+                           unsigned Scale)
+      : VPRecipeWithIRFlags(VPDef::VPPartialReductionSC, Operands, I),
+        Opcode(I.getOpcode()), Scale(Scale) {}
   ~VPPartialReductionRecipe() override = default;
   VPPartialReductionRecipe *clone() override {
-    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands(), Scale);
+    auto *R =
+        new VPPartialReductionRecipe(*getUnderlyingInstr(), operands(), Scale);
     R->transferFlags(*this);
     return R;
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 8d6f51ca1b3b86..523bbb4ad44e85 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -233,7 +233,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
   llvm_unreachable("Unhandled opcode");
 }
 
-Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
+Type *
+VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
   return R->getUnderlyingInstr()->getType();
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 1ba44e738abeb0..452290ebabb047 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -322,18 +322,18 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
   State.setDebugLocFrom(getDebugLoc());
   auto &Builder = State.Builder;
 
-  switch(Opcode) {
+  switch (Opcode) {
   case Instruction::Add: {
 
     unsigned UF = getParent()->getPlan()->getUF();
     for (unsigned Part = 0; Part < UF; ++Part) {
-      Value* Mul = nullptr;
-      Value* Phi = nullptr;
-      SmallVector<Value*, 2> Ops;
+      Value *Mul = nullptr;
+      Value *Phi = nullptr;
+      SmallVector<Value *, 2> Ops;
       for (VPValue *VPOp : operands()) {
         auto *Op = State.get(VPOp, Part);
         Ops.push_back(Op);
-        if(isa<PHINode>(Op))
+        if (isa<PHINode>(Op))
           Phi = Op;
         else
           Mul = Op;
@@ -343,13 +343,13 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
 
       VectorType *FullTy = cast<VectorType>(Ops[0]->getType());
       auto EC = FullTy->getElementCount();
-      Type *RetTy = VectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale));
+      Type *RetTy = VectorType::get(FullTy->getScalarType(),
+                                    EC.divideCoefficientBy(Scale));
 
       Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
-      switch(Opcode) {
+      switch (Opcode) {
       case Instruction::Add:
-        PartialIntrinsic =
-            Intrinsic::experimental_vector_partial_reduce_add;
+        PartialIntrinsic = Intrinsic::experimental_vector_partial_reduce_add;
         break;
       default:
         llvm_unreachable("Opcode not handled");
@@ -357,7 +357,8 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
 
       assert(PartialIntrinsic != Intrinsic::not_intrinsic);
 
-      CallInst *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, {Phi, Mul}, nullptr, Twine("partial.reduce"));
+      CallInst *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, {Phi, Mul},
+                                            nullptr, Twine("partial.reduce"));
 
       // Use this vector value for all users of the original instruction.
       State.set(this, V, Part);
@@ -366,7 +367,8 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
     break;
   }
   default:
-    LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : " << Instruction::getOpcodeName(Opcode));
+    LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : "
+                      << Instruction::getOpcodeName(Opcode));
     llvm_unreachable("Unhandled instruction!");
   }
 }
@@ -377,7 +379,7 @@ void VPPartialReductionRecipe::postInsertionOp() {
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
-  VPSlotTracker &SlotTracker) const {
+                                     VPSlotTracker &SlotTracker) const {
   O << Indent << "PARTIAL-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = " << Instruction::getOpcodeName(Opcode);
@@ -3040,8 +3042,8 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
   // stage #1: We create a new vector PHI node with no incoming edges. We'll use
   // this value when we vectorize all of the instructions that use the PHI.
   bool ScalarPHI = VF.isScalar() || IsInLoop;
-  Type *VecTy = ScalarPHI ? StartV->getType()
-                          : VectorType::get(StartV->getType(), VF);
+  Type *VecTy =
+      ScalarPHI ? StartV->getType() : VectorType::get(StartV->getType(), VF);
 
   BasicBlock *HeaderBB = State.CFG.PrevBB;
   assert(State.CurrentVectorLoop->getHeader() == HeaderBB &&

>From 5fd6a334e7a66ff7718913294572285f3d76a51d Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Fri, 5 Jul 2024 13:36:04 +0100
Subject: [PATCH 11/54] Add TLI hook for delegating intrinsic lowering to the
 target

---
 llvm/include/llvm/CodeGen/TargetLowering.h | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 3842af56e6b3d7..98c74412875356 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -473,6 +473,12 @@ class TargetLoweringBase {
     return true;
   }
 
+  /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
+  /// should be expanded using generic code in SelectionDAGBuilder.
+  virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
+    return true;
+  }
+
   /// Return true if the @llvm.experimental.cttz.elts intrinsic should be
   /// expanded using generic code in SelectionDAGBuilder.
   virtual bool shouldExpandCttzElements(EVT VT) const { return true; }

>From 4df58e6ebebf7790cb8d3a4a85c7bdcdf7e573b0 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 17 Jul 2024 11:15:45 +0100
Subject: [PATCH 12/54] Add to VPSingleDefRecipe::classof

---
 llvm/lib/Transforms/Vectorize/VPlan.h | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index a6821b9d014613..876cf131dd1317 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -939,6 +939,7 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPWidenPointerInductionSC:
     case VPRecipeBase::VPReductionPHISC:
     case VPRecipeBase::VPScalarCastSC:
+    case VPRecipeBase::VPPartialReductionSC:
       return true;
     case VPRecipeBase::VPBranchOnMaskSC:
     case VPRecipeBase::VPInterleaveSC:

>From ebb33191c8092b6adf2b21594cfe5c2217adfa3f Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 17 Jul 2024 16:40:50 +0100
Subject: [PATCH 13/54] Move test and add negative tests

---
 .../partial-reduce-dot-product.ll             | 272 ++++++++++++++++++
 1 file changed, 272 insertions(+)
 create mode 100644 llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll

diff --git a/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll b/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll
new file mode 100644
index 00000000000000..949dfb5f8844b1
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll
@@ -0,0 +1,272 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes=loop-vectorize -force-vector-interleave=1 -S < %s | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define void @dotp(ptr %a, ptr %b) {
+; CHECK-LABEL: define void @dotp(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP7]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP10]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT:    [[TMP15:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP3]]
+; CHECK-NEXT:    [[TMP14]] = add <16 x i32> [[TMP16]], [[VEC_PHI]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
+; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP17:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP14]])
+; CHECK-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+}
+
+define void @dotp_different_types(ptr %a, ptr %b) {
+; CHECK-LABEL: define void @dotp_different_types(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP69:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP1:%.*]] = add i64 [[INDEX]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = add i64 [[INDEX]], 2
+; CHECK-NEXT:    [[TMP3:%.*]] = add i64 [[INDEX]], 3
+; CHECK-NEXT:    [[TMP4:%.*]] = add i64 [[INDEX]], 4
+; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[INDEX]], 5
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 6
+; CHECK-NEXT:    [[TMP7:%.*]] = add i64 [[INDEX]], 7
+; CHECK-NEXT:    [[TMP8:%.*]] = add i64 [[INDEX]], 8
+; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 9
+; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[INDEX]], 10
+; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 11
+; CHECK-NEXT:    [[TMP12:%.*]] = add i64 [[INDEX]], 12
+; CHECK-NEXT:    [[TMP13:%.*]] = add i64 [[INDEX]], 13
+; CHECK-NEXT:    [[TMP14:%.*]] = add i64 [[INDEX]], 14
+; CHECK-NEXT:    [[TMP15:%.*]] = add i64 [[INDEX]], 15
+; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP0]]
+; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP16]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP17]], align 1
+; CHECK-NEXT:    [[TMP18:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP19:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
+; CHECK-NEXT:    [[TMP20:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP1]]
+; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP2]]
+; CHECK-NEXT:    [[TMP22:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP3]]
+; CHECK-NEXT:    [[TMP23:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP4]]
+; CHECK-NEXT:    [[TMP24:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP5]]
+; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP26:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP7]]
+; CHECK-NEXT:    [[TMP27:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP8]]
+; CHECK-NEXT:    [[TMP28:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP29:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP10]]
+; CHECK-NEXT:    [[TMP30:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP31:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP12]]
+; CHECK-NEXT:    [[TMP32:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP13]]
+; CHECK-NEXT:    [[TMP33:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP14]]
+; CHECK-NEXT:    [[TMP34:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP15]]
+; CHECK-NEXT:    [[TMP35:%.*]] = load i16, ptr [[TMP19]], align 2
+; CHECK-NEXT:    [[TMP36:%.*]] = load i16, ptr [[TMP20]], align 2
+; CHECK-NEXT:    [[TMP37:%.*]] = load i16, ptr [[TMP21]], align 2
+; CHECK-NEXT:    [[TMP38:%.*]] = load i16, ptr [[TMP22]], align 2
+; CHECK-NEXT:    [[TMP39:%.*]] = load i16, ptr [[TMP23]], align 2
+; CHECK-NEXT:    [[TMP40:%.*]] = load i16, ptr [[TMP24]], align 2
+; CHECK-NEXT:    [[TMP41:%.*]] = load i16, ptr [[TMP25]], align 2
+; CHECK-NEXT:    [[TMP42:%.*]] = load i16, ptr [[TMP26]], align 2
+; CHECK-NEXT:    [[TMP43:%.*]] = load i16, ptr [[TMP27]], align 2
+; CHECK-NEXT:    [[TMP44:%.*]] = load i16, ptr [[TMP28]], align 2
+; CHECK-NEXT:    [[TMP45:%.*]] = load i16, ptr [[TMP29]], align 2
+; CHECK-NEXT:    [[TMP46:%.*]] = load i16, ptr [[TMP30]], align 2
+; CHECK-NEXT:    [[TMP47:%.*]] = load i16, ptr [[TMP31]], align 2
+; CHECK-NEXT:    [[TMP48:%.*]] = load i16, ptr [[TMP32]], align 2
+; CHECK-NEXT:    [[TMP49:%.*]] = load i16, ptr [[TMP33]], align 2
+; CHECK-NEXT:    [[TMP50:%.*]] = load i16, ptr [[TMP34]], align 2
+; CHECK-NEXT:    [[TMP51:%.*]] = insertelement <16 x i16> poison, i16 [[TMP35]], i32 0
+; CHECK-NEXT:    [[TMP52:%.*]] = insertelement <16 x i16> [[TMP51]], i16 [[TMP36]], i32 1
+; CHECK-NEXT:    [[TMP53:%.*]] = insertelement <16 x i16> [[TMP52]], i16 [[TMP37]], i32 2
+; CHECK-NEXT:    [[TMP54:%.*]] = insertelement <16 x i16> [[TMP53]], i16 [[TMP38]], i32 3
+; CHECK-NEXT:    [[TMP55:%.*]] = insertelement <16 x i16> [[TMP54]], i16 [[TMP39]], i32 4
+; CHECK-NEXT:    [[TMP56:%.*]] = insertelement <16 x i16> [[TMP55]], i16 [[TMP40]], i32 5
+; CHECK-NEXT:    [[TMP57:%.*]] = insertelement <16 x i16> [[TMP56]], i16 [[TMP41]], i32 6
+; CHECK-NEXT:    [[TMP58:%.*]] = insertelement <16 x i16> [[TMP57]], i16 [[TMP42]], i32 7
+; CHECK-NEXT:    [[TMP59:%.*]] = insertelement <16 x i16> [[TMP58]], i16 [[TMP43]], i32 8
+; CHECK-NEXT:    [[TMP60:%.*]] = insertelement <16 x i16> [[TMP59]], i16 [[TMP44]], i32 9
+; CHECK-NEXT:    [[TMP61:%.*]] = insertelement <16 x i16> [[TMP60]], i16 [[TMP45]], i32 10
+; CHECK-NEXT:    [[TMP62:%.*]] = insertelement <16 x i16> [[TMP61]], i16 [[TMP46]], i32 11
+; CHECK-NEXT:    [[TMP63:%.*]] = insertelement <16 x i16> [[TMP62]], i16 [[TMP47]], i32 12
+; CHECK-NEXT:    [[TMP64:%.*]] = insertelement <16 x i16> [[TMP63]], i16 [[TMP48]], i32 13
+; CHECK-NEXT:    [[TMP65:%.*]] = insertelement <16 x i16> [[TMP64]], i16 [[TMP49]], i32 14
+; CHECK-NEXT:    [[TMP66:%.*]] = insertelement <16 x i16> [[TMP65]], i16 [[TMP50]], i32 15
+; CHECK-NEXT:    [[TMP67:%.*]] = zext <16 x i16> [[TMP66]] to <16 x i32>
+; CHECK-NEXT:    [[TMP68:%.*]] = mul <16 x i32> [[TMP67]], [[TMP18]]
+; CHECK-NEXT:    [[TMP69]] = add <16 x i32> [[TMP68]], [[VEC_PHI]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEXT:    [[TMP70:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
+; CHECK-NEXT:    br i1 [[TMP70]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP71:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP69]])
+; CHECK-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i16, ptr %arrayidx2, align 2
+  %conv3 = zext i16 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+}
+
+define void @dotp_not_loop_carried(ptr %a, ptr %b) {
+; CHECK-LABEL: define void @dotp_not_loop_carried(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VECTOR_RECUR:%.*]] = phi <16 x i32> [ <i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 0>, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP10]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP14:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP14]], align 1
+; CHECK-NEXT:    [[TMP6:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP7]] = mul <16 x i32> [[TMP6]], [[TMP3]]
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <16 x i32> [[VECTOR_RECUR]], <16 x i32> [[TMP7]], <16 x i32> <i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30>
+; CHECK-NEXT:    [[TMP15:%.*]] = add <16 x i32> [[TMP7]], [[TMP8]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEXT:    [[TMP16:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
+; CHECK-NEXT:    br i1 [[TMP16]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <16 x i32> [[TMP15]], i32 15
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <16 x i32> [[TMP7]], i32 15
+; CHECK-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %mul, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+}
+
+define void @dotp_not_phi(ptr %a, ptr %b) {
+; CHECK-LABEL: define void @dotp_not_phi(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VECTOR_RECUR:%.*]] = phi <16 x i32> [ <i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 0>, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP10]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP14:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP14]], align 1
+; CHECK-NEXT:    [[TMP6:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul <16 x i32> [[TMP6]], [[TMP3]]
+; CHECK-NEXT:    [[TMP8]] = add <16 x i32> [[TMP7]], [[TMP6]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEXT:    [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
+; CHECK-NEXT:    br i1 [[TMP15]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <16 x i32> [[TMP8]], i32 15
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <16 x i32> [[TMP8]], i32 15
+; CHECK-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %conv3
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+}

>From 44f504208d8cd048897601c810b6afe547b2ab91 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 18 Jul 2024 10:41:13 +0100
Subject: [PATCH 14/54] Add a printing test

---
 .../LoopVectorize/AArch64/vplan-printing.ll   | 102 ++++++++++++++++++
 1 file changed, 102 insertions(+)
 create mode 100644 llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll

diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
new file mode 100644
index 00000000000000..f5d178cf085ece
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
@@ -0,0 +1,102 @@
+; REQUIRES: asserts
+
+; RUN: opt -passes=loop-vectorize -debug-only=loop-vectorize -force-vector-interleave=1 -disable-output %s 2>&1 | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+; Tests for printing VPlans that are enabled under AArch64
+
+define void @print_partial_reduction(ptr %a, ptr %b) {
+; CHECK-LABEL: Checking a loop in 'print_partial_reduction'
+; CHECK:      VPlan 'Initial VPlan for VF={2,4,8,16},UF>=1' {
+; CHECK-NEXT: Live-in vp<[[VFxUF:%.]]> = VF * UF
+; CHECK-NEXT: Live-in vp<[[VEC_TC:%.+]]> = vector-trip-count
+; CHECK-NEXT: Live-in ir<0> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT: vector.body:
+; CHECK-NEXT:   EMIT vp<[[CAN_IV:%.+]]> = CANONICAL-INDUCTION ir<0>, vp<[[CAN_IV_NEXT:%.+]]>
+; CHECK-NEXT:   WIDEN-REDUCTION-PHI ir<[[ACC:%.+]]> = phi ir<0>, ir<%add>
+; CHECK-NEXT:   vp<[[STEPS:%.+]]> = SCALAR-STEPS vp<[[CAN_IV]]>, ir<1>
+; CHECK-NEXT:   CLONE ir<%arrayidx> = getelementptr ir<%a>, vp<[[STEPS]]>
+; CHECK-NEXT:   vp<%4> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:   WIDEN ir<%1> = load vp<%4>
+; CHECK-NEXT:   WIDEN-CAST ir<%conv> = zext ir<%1> to i32
+; CHECK-NEXT:   CLONE ir<%arrayidx2> = getelementptr ir<%b>, vp<[[STEPS]]>
+; CHECK-NEXT:   vp<%5> = vector-pointer ir<%arrayidx2>
+; CHECK-NEXT:   WIDEN ir<%2> = load vp<%5>
+; CHECK-NEXT:   WIDEN-CAST ir<%conv3> = zext ir<%2> to i32
+; CHECK-NEXT:   WIDEN ir<%mul> = mul ir<%conv3>, ir<%conv>
+; CHECK-NEXT:   WIDEN ir<%add> = add ir<%mul>, ir<[[ACC]]>
+; CHECK-NEXT:   EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
+; CHECK-NEXT:   EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VEC_TC]]>
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%8> = compute-reduction-result ir<[[ACC]]>, ir<%add>
+; CHECK-NEXT:   EMIT vp<%9> = icmp eq ir<0>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%9>
+; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: Live-out i32 %add.lcssa = vp<%8>
+; CHECK-NEXT: }
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+}
+
+!llvm.dbg.cu = !{!0}
+!llvm.module.flags = !{!3, !4}
+
+declare float @foo(float) #0
+declare <2 x float> @vector_foo(<2 x float>, <2 x i1>)
+
+; We need a vector variant in order to allow for vectorization at present, but
+; we want to test scalarization of conditional calls. If we provide a variant
+; with a different number of lanes than the VF we force via
+; "-force-vector-width=4", then it should pass the legality checks but
+; scalarize. TODO: Remove the requirement to have a variant.
+attributes #0 = { readonly nounwind "vector-function-abi-variant"="_ZGV_LLVM_M2v_foo(vector_foo)" }
+
+!0 = distinct !DICompileUnit(language: DW_LANG_C99, file: !1, producer: "clang", isOptimized: true, runtimeVersion: 0, emissionKind: NoDebug, enums: !2)
+!1 = !DIFile(filename: "/tmp/s.c", directory: "/tmp")
+!2 = !{}
+!3 = !{i32 2, !"Debug Info Version", i32 3}
+!4 = !{i32 7, !"PIC Level", i32 2}
+!5 = distinct !DISubprogram(name: "f", scope: !1, file: !1, line: 4, type: !6, scopeLine: 4, flags: DIFlagPrototyped, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0, retainedNodes: !2)
+!6 = !DISubroutineType(types: !2)
+!7 = !DILocation(line: 5, column: 3, scope: !5)
+!8 = !DILocation(line: 5, column: 21, scope: !5)
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; CHECK: {{.*}}

>From 42e16fa7ce5093639a9b2051708ac0ac80773db0 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 18 Jul 2024 11:24:51 +0100
Subject: [PATCH 15/54] Add llvm_unreachable in partial reduction clone func

---
 llvm/lib/Transforms/Vectorize/VPlan.h | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 876cf131dd1317..c2675477933271 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2206,10 +2206,8 @@ class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
         Opcode(I.getOpcode()), Scale(Scale) {}
   ~VPPartialReductionRecipe() override = default;
   VPPartialReductionRecipe *clone() override {
-    auto *R =
-        new VPPartialReductionRecipe(*getUnderlyingInstr(), operands(), Scale);
-    R->transferFlags(*this);
-    return R;
+    llvm_unreachable("Partial reductions with epilogue vectorization isn't supported yet.");
+    return nullptr;
   }
   VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
   /// Generate the reduction in the loop

>From 4e33507d84ca95bc13df4865a252ab4f28469b3a Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 09:50:32 +0100
Subject: [PATCH 16/54] Remove target hook

---
 llvm/include/llvm/CodeGen/TargetLowering.h | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 98c74412875356..3842af56e6b3d7 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -473,12 +473,6 @@ class TargetLoweringBase {
     return true;
   }
 
-  /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
-  /// should be expanded using generic code in SelectionDAGBuilder.
-  virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
-    return true;
-  }
-
   /// Return true if the @llvm.experimental.cttz.elts intrinsic should be
   /// expanded using generic code in SelectionDAGBuilder.
   virtual bool shouldExpandCttzElements(EVT VT) const { return true; }

>From 1b0b39d78dc00c1ca087f42c12f94bd232576e73 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 19 Jul 2024 16:30:00 +0100
Subject: [PATCH 17/54] Add doc for postInsertionOp

---
 llvm/lib/Transforms/Vectorize/VPlan.h | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c2675477933271..e74440f277b601 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -831,6 +831,8 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
   /// \returns an iterator pointing to the element after the erased one
   iplist<VPRecipeBase>::iterator eraseFromParent();
 
+  /// Run any required modifications to the recipe after it has been inserted
+  /// into the plan.
   virtual void postInsertionOp() {}
 
   /// Method to support type inquiry through isa, cast, and dyn_cast.

>From be6dd675a1ad59cd19eb701d25fc1eea279c37ae Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 19 Jul 2024 17:00:12 +0100
Subject: [PATCH 18/54] Add doc for PartialreductionChain

---
 .../Transforms/Vectorize/LoopVectorizationPlanner.h  | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 1237cf385a3ebe..36438553ddcf99 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -337,16 +337,28 @@ struct FixedScalableVFPair {
   bool hasVector() const { return FixedVF.isVector() || ScalableVF.isVector(); }
 };
 
+/// A chain of instructions that form a partial reduction.
+/// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
+/// accumulator)
 struct PartialReductionChain {
+  /// The top-level binary operation that forms the reduction to a scalar after
+  /// the loop body
   Instruction *Reduction;
+  /// The inner binary operation that forms the reduction to a vector value
+  /// within the loop body
   Instruction *BinOp;
+  /// The extension of each of the inner binary operation's operands
   Instruction *ExtendA;
   Instruction *ExtendB;
 
+  /// The inner binary operation's operands
   Value *InputA;
   Value *InputB;
+  /// The accumulator that is reduced to a scalar after the loop body
   Value *Accumulator;
 
+  /// The scaling factor between the size of the reduction type and the
+  /// (possibly extended) inputs
   unsigned ScaleFactor;
 };
 

>From beea00ef7f8ab3d594c8c7b5ede1296be654c640 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 19 Jul 2024 17:00:27 +0100
Subject: [PATCH 19/54] Add doc for VPPartialReductionRecipe

---
 llvm/lib/Transforms/Vectorize/VPlan.h | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e74440f277b601..373bc858b8b824 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2196,6 +2196,10 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
   bool isInLoop() const { return IsInLoop; }
 };
 
+/// A recipe for forming partial reductions. In the loop, an accumulator and
+/// vector operand are added together and passed to the next iteration as the
+/// next accumulator. After the loop body, the accumulator is reduced to a
+/// scalar value.
 class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
   unsigned Opcode;
   unsigned Scale;

>From 5e256098815538780b7d92f7ce4e193c9f2677e9 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 19 Jul 2024 17:12:03 +0100
Subject: [PATCH 20/54] Add doc for VFScaleFactor

---
 llvm/lib/Transforms/Vectorize/VPlan.h | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 373bc858b8b824..632dfe8f08a64c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2143,7 +2143,10 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
   /// The phi is part of an ordered reduction. Requires IsInLoop to be true.
   bool IsOrdered;
 
-  /// The amount that the VF should be divided by during ::execute
+  /// The scaling difference between the size of the output of the entire
+  /// reduction and the size of the inputs When expanding the reduction PHI, the
+  /// plan's VF element count is divided by this factor to form the reduction
+  /// phi's VF.
   unsigned VFScaleFactor = 1;
 
 public:

>From bafacfbb729df053127bc3af0fff7a1dae7581b0 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 19 Jul 2024 17:15:29 +0100
Subject: [PATCH 21/54] Format

---
 llvm/lib/Transforms/Vectorize/VPlan.h | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 632dfe8f08a64c..cb0b7c822d6ff6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2215,7 +2215,8 @@ class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
         Opcode(I.getOpcode()), Scale(Scale) {}
   ~VPPartialReductionRecipe() override = default;
   VPPartialReductionRecipe *clone() override {
-    llvm_unreachable("Partial reductions with epilogue vectorization isn't supported yet.");
+    llvm_unreachable(
+        "Partial reductions with epilogue vectorization isn't supported yet.");
     return nullptr;
   }
   VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)

>From a6d6f0b9cb4179f13b872a38a5f58f2e7925a9cb Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 3 Sep 2024 15:32:15 +0100
Subject: [PATCH 22/54] Rebase

---
 .../Transforms/LoopVectorize/AArch64/vplan-printing.ll     | 7 ++++---
 .../Transforms/LoopVectorize/partial-reduce-dot-product.ll | 2 +-
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
index f5d178cf085ece..96d51a55b37098 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
@@ -40,8 +40,9 @@ define void @print_partial_reduction(ptr %a, ptr %b) {
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
 ; CHECK-NEXT:   EMIT vp<%8> = compute-reduction-result ir<[[ACC]]>, ir<%add>
-; CHECK-NEXT:   EMIT vp<%9> = icmp eq ir<0>, vp<%1>
-; CHECK-NEXT:   EMIT branch-on-cond vp<%9>
+; CHECK-NEXT:   EMIT vp<%9> = extract-from-end vp<%8>, ir<1>
+; CHECK-NEXT:   EMIT vp<%10> = icmp eq ir<0>, vp<%1>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%10>
 ; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
@@ -50,7 +51,7 @@ define void @print_partial_reduction(ptr %a, ptr %b) {
 ; CHECK-NEXT: scalar.ph:
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
-; CHECK-NEXT: Live-out i32 %add.lcssa = vp<%8>
+; CHECK-NEXT: Live-out i32 %add.lcssa = vp<%9>
 ; CHECK-NEXT: }
 ;
 entry:
diff --git a/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll b/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll
index 949dfb5f8844b1..d194893eefaf03 100644
--- a/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll
+++ b/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll
@@ -192,8 +192,8 @@ define void @dotp_not_loop_carried(ptr %a, ptr %b) {
 ; CHECK-NEXT:    [[TMP16:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
 ; CHECK-NEXT:    br i1 [[TMP16]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <16 x i32> [[TMP15]], i32 15
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <16 x i32> [[TMP7]], i32 15
+; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <16 x i32> [[TMP15]], i32 15
 ; CHECK-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ;
 entry:

>From a57ac1477e4aaff346d62807e2e63048be61d497 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 4 Sep 2024 16:50:57 +0100
Subject: [PATCH 23/54] Enable partial reductions for AArch64

---
 .../AArch64/AArch64TargetTransformInfo.h      |  9 ++++
 .../partial-reduce-dot-product.ll             | 41 ++++++++++++-------
 2 files changed, 35 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 22bba21eedcc5a..38855fcef6b78e 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -340,6 +340,15 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return BaseT::isLegalNTLoad(DataType, Alignment);
   }
 
+  bool isPartialReductionSupported(const Instruction *ReductionInstr,
+                                   Type *InputType, unsigned ScaleFactor,
+                                   bool IsInputASignExtended,
+                                   bool IsInputBSignExtended,
+                                   const Instruction *BinOp) const {
+    return ScaleFactor == 4 && (ST->isSVEorStreamingSVEAvailable() ||
+                                (ST->isNeonAvailable() && ST->hasDotProd()));
+  }
+
   bool enableOrderedReductions() const { return true; }
 
   InstructionCost getInterleavedMemoryOpCost(
diff --git a/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll b/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll
index d194893eefaf03..76bec75a8e3580 100644
--- a/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll
+++ b/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll
@@ -4,33 +4,43 @@
 target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
 target triple = "aarch64-none-unknown-elf"
 
-define void @dotp(ptr %a, ptr %b) {
+define void @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-LABEL: define void @dotp(
-; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) {
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 16
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 0, [[TMP1]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 16
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 0, [[TMP3]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 0, [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 16
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP14:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP6]]
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP7]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 16 x i8>, ptr [[TMP8]], align 1
+; CHECK-NEXT:    [[TMP9:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD]] to <vscale x 16 x i32>
 ; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP6]]
 ; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP10]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP11]], align 1
-; CHECK-NEXT:    [[TMP15:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
-; CHECK-NEXT:    [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP3]]
-; CHECK-NEXT:    [[TMP14]] = add <16 x i32> [[TMP16]], [[VEC_PHI]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
-; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
-; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 16 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT:    [[TMP15:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD1]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP16:%.*]] = mul <vscale x 16 x i32> [[TMP15]], [[TMP9]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI]], <vscale x 16 x i32> [[TMP16]])
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP17:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP14]])
-; CHECK-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+; CHECK-NEXT:    [[TMP17:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]])
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ;
 entry:
   br label %for.body
@@ -270,3 +280,4 @@ for.body:                                         ; preds = %for.body, %entry
   %exitcond.not = icmp eq i64 %indvars.iv.next, 0
   br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
 }
+attributes #0 = { nofree norecurse nosync nounwind memory(argmem: readwrite) uwtable vscale_range(1,16) "target-features"="+sve" }

>From ed7d0ee789592d89d08d516a7a054ea465272106 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 4 Sep 2024 16:51:35 +0100
Subject: [PATCH 24/54] Move test to AArch64 subdir

---
 .../LoopVectorize/{ => AArch64}/partial-reduce-dot-product.ll     | 0
 1 file changed, 0 insertions(+), 0 deletions(-)
 rename llvm/test/Transforms/LoopVectorize/{ => AArch64}/partial-reduce-dot-product.ll (100%)

diff --git a/llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll
similarity index 100%
rename from llvm/test/Transforms/LoopVectorize/partial-reduce-dot-product.ll
rename to llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll

>From 3e36818aa65b1feeaa9efdb9411fd850170cc4c2 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 10 Sep 2024 10:58:23 +0100
Subject: [PATCH 25/54] Add PartialReductionExtendKind

---
 llvm/include/llvm/Analysis/TargetTransformInfo.h | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 6e31686483cc11..78559161625078 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -24,6 +24,7 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/IR/FMF.h"
 #include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/AtomicOrdering.h"
@@ -91,6 +92,8 @@ struct MemIntrinsicInfo {
   }
 };
 
+enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };
+
 /// Attributes of a target dependent hardware loop.
 struct HardwareLoopInfo {
   HardwareLoopInfo() = delete;
@@ -210,6 +213,15 @@ typedef TargetTransformInfo TTI;
 /// for IR-level transformations.
 class TargetTransformInfo {
 public:
+  static PartialReductionExtendKind
+  getPartialReductionExtendKind(Instruction *I) {
+    if (isa<SExtInst>(I))
+      return PR_SignExtend;
+    if (isa<ZExtInst>(I))
+      return PR_ZeroExtend;
+    return PR_None;
+  }
+
   /// Construct a TTI object using a type implementing the \c Concept
   /// API below.
   ///

>From 8fd2748fd32e0173ebc45a4ded002c904ffc6c6b Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 10 Sep 2024 11:00:44 +0100
Subject: [PATCH 26/54] isPartialReductionSupported -> getPartialReductionCost

---
 .../llvm/Analysis/TargetTransformInfo.h       |  33 +--
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  12 +-
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  13 +-
 .../AArch64/AArch64TargetTransformInfo.h      |  42 +++-
 .../Vectorize/LoopVectorizationPlanner.h      |   9 -
 .../Transforms/Vectorize/LoopVectorize.cpp    | 228 ++++++++----------
 6 files changed, 170 insertions(+), 167 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 78559161625078..50004a23059845 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1260,11 +1260,11 @@ class TargetTransformInfo {
   /// \return if target want to issue a prefetch in address space \p AS.
   bool shouldPrefetchAddressSpace(unsigned AS) const;
 
-  bool isPartialReductionSupported(const Instruction *ReductionInstr,
-                                   Type *InputType, unsigned ScaleFactor,
-                                   bool IsInputASignExtended,
-                                   bool IsInputBSignExtended,
-                                   const Instruction *BinOp = nullptr) const;
+  InstructionCost
+  getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
+                          ElementCount VF, PartialReductionExtendKind OpAExtend,
+                          PartialReductionExtendKind OpBExtend,
+                          std::optional<unsigned> BinOp = std::nullopt) const;
 
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism
@@ -2043,10 +2043,11 @@ class TargetTransformInfo::Concept {
   /// \return if target want to issue a prefetch in address space \p AS.
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
-  virtual bool isPartialReductionSupported(
-      const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
-      bool IsInputASignExtended, bool IsInputBSignExtended,
-      const Instruction *BinOp = nullptr) const = 0;
+  virtual InstructionCost
+  getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
+                          ElementCount VF, PartialReductionExtendKind OpAExtend,
+                          PartialReductionExtendKind OpBExtend,
+                          std::optional<unsigned> BinOp) const = 0;
 
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
   virtual InstructionCost getArithmeticInstrCost(
@@ -2682,13 +2683,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.shouldPrefetchAddressSpace(AS);
   }
 
-  bool isPartialReductionSupported(
-      const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
-      bool IsInputASignExtended, bool IsInputBSignExtended,
-      const Instruction *BinOp = nullptr) const override {
-    return Impl.isPartialReductionSupported(ReductionInstr, InputType,
-                                            ScaleFactor, IsInputASignExtended,
-                                            IsInputBSignExtended, BinOp);
+  InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
+      PartialReductionExtendKind OpAExtend,
+      PartialReductionExtendKind OpBExtend,
+      std::optional<unsigned> BinOp = std::nullopt) const override {
+    return Impl.getPartialReductionCost(Opcode, InputType, AccumType, VF,
+                                        OpAExtend, OpBExtend, BinOp);
   }
 
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index ad72bcddd242e6..02a499e538fcf0 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -551,12 +551,12 @@ class TargetTransformInfoImplBase {
   bool enableWritePrefetching() const { return false; }
   bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
 
-  bool isPartialReductionSupported(const Instruction *ReductionInstr,
-                                   Type *InputType, unsigned ScaleFactor,
-                                   bool IsInputASignExtended,
-                                   bool IsInputBSignExtended,
-                                   const Instruction *BinOp = nullptr) const {
-    return false;
+  InstructionCost
+  getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
+                          ElementCount VF, PartialReductionExtendKind OpAExtend,
+                          PartialReductionExtendKind OpBExtend,
+                          std::optional<unsigned> BinOp = std::nullopt) const {
+    return InstructionCost::getInvalid();
   }
 
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 964259eccb7807..1e4a75eedb6a69 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -830,13 +830,12 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
   return TTIImpl->shouldPrefetchAddressSpace(AS);
 }
 
-bool TargetTransformInfo::isPartialReductionSupported(
-    const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
-    bool IsInputASignExtended, bool IsInputBSignExtended,
-    const Instruction *BinOp) const {
-  return TTIImpl->isPartialReductionSupported(ReductionInstr, InputType,
-                                              ScaleFactor, IsInputASignExtended,
-                                              IsInputBSignExtended, BinOp);
+InstructionCost TargetTransformInfo::getPartialReductionCost(
+    unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
+    PartialReductionExtendKind OpAExtend, PartialReductionExtendKind OpBExtend,
+    std::optional<unsigned> BinOp) const {
+  return TTIImpl->getPartialReductionCost(Opcode, InputType, AccumType, VF,
+                                          OpAExtend, OpBExtend, BinOp);
 }
 
 unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 38855fcef6b78e..3f63a2255d9e50 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -340,13 +340,41 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return BaseT::isLegalNTLoad(DataType, Alignment);
   }
 
-  bool isPartialReductionSupported(const Instruction *ReductionInstr,
-                                   Type *InputType, unsigned ScaleFactor,
-                                   bool IsInputASignExtended,
-                                   bool IsInputBSignExtended,
-                                   const Instruction *BinOp) const {
-    return ScaleFactor == 4 && (ST->isSVEorStreamingSVEAvailable() ||
-                                (ST->isNeonAvailable() && ST->hasDotProd()));
+  InstructionCost getPartialReductionCost(unsigned Opcode, Type *InputType,
+                                          Type *AccumType, ElementCount VF,
+                                          PartialReductionExtendKind OpAExtend,
+                                          PartialReductionExtendKind OpBExtend,
+                                          std::optional<unsigned> BinOp) const {
+    InstructionCost Cost = InstructionCost::getInvalid();
+
+    if (Opcode != Instruction::Add)
+      return Cost;
+
+    EVT InputEVT = EVT::getEVT(InputType);
+    EVT AccumEVT = EVT::getEVT(AccumType);
+
+    if (AccumEVT.isScalableVector() && !ST->isSVEorStreamingSVEAvailable())
+      return Cost;
+    if (!AccumEVT.isScalableVector() && !ST->isNeonAvailable() &&
+        !ST->hasDotProd())
+      return Cost;
+
+    if (InputEVT == MVT::i8) {
+      if (AccumEVT != MVT::i32)
+        return Cost;
+    } else if (InputEVT == MVT::i16) {
+      if (AccumEVT != MVT::i64)
+        return Cost;
+    } else
+      return Cost;
+
+    if (OpAExtend == PR_None || OpBExtend == PR_None)
+      return Cost;
+
+    if (!BinOp || (*BinOp) != Instruction::Mul)
+      return Cost;
+
+    return InstructionCost::getMin();
   }
 
   bool enableOrderedReductions() const { return true; }
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 36438553ddcf99..70850e62feb7bf 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -400,8 +400,6 @@ class LoopVectorizationPlanner {
   /// Profitable vector factors.
   SmallVector<VectorizationFactor, 8> ProfitableVFs;
 
-  SmallVector<PartialReductionChain> PartialReductionChains;
-
   /// A builder used to construct the current plan.
   VPBuilder Builder;
 
@@ -494,10 +492,6 @@ class LoopVectorizationPlanner {
   /// Emit remarks for recipes with invalid costs in the available VPlans.
   void emitInvalidCostRemarks(OptimizationRemarkEmitter *ORE);
 
-  SmallVector<PartialReductionChain> getPartialReductionChains() const {
-    return PartialReductionChains;
-  }
-
 protected:
   /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive,
   /// according to the information gathered by Legal when it checked if it is
@@ -549,9 +543,6 @@ class LoopVectorizationPlanner {
   /// Determines if we have the infrastructure to vectorize the loop and its
   /// epilogue, assuming the main loop is vectorized by \p VF.
   bool isCandidateForEpilogueVectorization(const ElementCount VF) const;
-
-  bool getInstructionsPartialReduction(Instruction *I,
-                                       PartialReductionChain &Chain) const;
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index a604a6d641146a..f98ee42bffe29f 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1060,7 +1060,7 @@ class LoopVectorizationCostModel {
   calculateRegisterUsage(ArrayRef<ElementCount> VFs);
 
   /// Collect values we want to ignore in the cost model.
-  void collectValuesToIgnore(LoopVectorizationPlanner *LVP);
+  void collectValuesToIgnore();
 
   /// Collect all element types in the loop for which widening is needed.
   void collectElementTypesForWidening();
@@ -1545,7 +1545,105 @@ class LoopVectorizationCostModel {
   getReductionPatternCost(Instruction *I, ElementCount VF, Type *VectorTy,
                           TTI::TargetCostKind CostKind) const;
 
+  using PartialReductionList = DenseMap<Instruction *, PartialReductionChain>;
+
+  PartialReductionList getPartialReductionChains() {
+    return PartialReductionChains;
+  }
+
+  bool getInstructionsPartialReduction(Instruction *I,
+                                       PartialReductionChain &Chain) const {
+    auto PairIt = PartialReductionChains.find(I);
+    if (PairIt == PartialReductionChains.end())
+      return false;
+    Chain = PairIt->second;
+    return true;
+  }
+
+  void addPartialReductionIfSupported(Instruction *Instr, ElementCount VF) {
+    Value *ExpectedPhi;
+    Value *A, *B;
+
+    using namespace llvm::PatternMatch;
+    auto Pattern =
+        m_BinOp(m_OneUse(m_BinOp(m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
+                                 m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
+                m_Value(ExpectedPhi));
+
+    bool Matches = match(Instr, Pattern);
+
+    if (!Matches)
+      return;
+
+    // Check that the extends extend from the same type
+    if (A->getType() != B->getType()) {
+      LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot "
+                           "create a partial reduction.\n");
+      return;
+    }
+
+    // A and B are one-use, so the first user of each should be the respective
+    // extend
+    Instruction *Ext0 = cast<CastInst>(*A->user_begin());
+    Instruction *Ext1 = cast<CastInst>(*B->user_begin());
+
+    // Check that the extends extend to the same type
+    if (Ext0->getType() != Ext1->getType()) {
+      LLVM_DEBUG(
+          dbgs() << "Extends don't extend to the same type, cannot create "
+                    "a partial reduction.\n");
+      return;
+    }
+
+    // Check that the add feeds into ExpectedPhi
+    PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
+    if (!PhiNode) {
+      LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a "
+                           "partial reduction.\n");
+      return;
+    }
+
+    // Check that the second phi value is the instruction we're looking at
+    Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
+    if (!MaybeAdd || MaybeAdd != Instr) {
+      LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot "
+                           "create a partial reduction.\n");
+      return;
+    }
+
+    Instruction *BinOp = cast<Instruction>(Instr->getOperand(0));
+    Value *InputA = Ext0->getOperand(0);
+    Value *InputB = Ext1->getOperand(0);
+    PartialReductionExtendKind OpAExtend =
+        TargetTransformInfo::getPartialReductionExtendKind(Ext0);
+    PartialReductionExtendKind OpBExtend =
+        TargetTransformInfo::getPartialReductionExtendKind(Ext1);
+    InstructionCost Cost = TTI.getPartialReductionCost(
+        Instr->getOpcode(), InputA->getType(), ExpectedPhi->getType(), VF,
+        OpAExtend, OpBExtend,
+        BinOp ? std::make_optional(BinOp->getOpcode()) : std::nullopt);
+    if (Cost == InstructionCost::getInvalid())
+      return;
+
+    PartialReductionChain Chain;
+    Chain.Reduction = Instr;
+    Chain.BinOp = BinOp;
+    Chain.ExtendA = Ext0;
+    Chain.ExtendB = Ext1;
+    Chain.InputA = InputA;
+    Chain.InputB = InputB;
+    Chain.Accumulator = ExpectedPhi;
+
+    unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
+    unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
+    Chain.ScaleFactor = ResultSizeBits / InputSizeBits;
+
+    PartialReductionChains[Instr] = Chain;
+  }
+
 private:
+  PartialReductionList PartialReductionChains;
+
   unsigned NumPredStores = 0;
 
   /// \return An upper bound for the vectorization factors for both
@@ -2113,101 +2211,6 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
          Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
 }
 
-static PartialReductionChain getPartialReductionInstrChain(Instruction *Instr) {
-  Instruction *BinOp = cast<Instruction>(Instr->getOperand(0));
-  Instruction *Ext0 = cast<Instruction>(BinOp->getOperand(0));
-  Instruction *Ext1 = cast<Instruction>(BinOp->getOperand(1));
-
-  PartialReductionChain Chain;
-  Chain.Reduction = Instr;
-  Chain.BinOp = BinOp;
-  Chain.ExtendA = Ext0;
-  Chain.ExtendB = Ext1;
-  Chain.InputA = Ext0->getOperand(0);
-  Chain.InputB = Ext1->getOperand(0);
-  Chain.Accumulator = Instr->getOperand(1);
-
-  unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
-  unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
-  Chain.ScaleFactor = ResultSizeBits / InputSizeBits;
-  return Chain;
-}
-
-/// Checks if the given instruction the root of a partial reduction chain
-///
-/// @param Instr The root instruction to scan
-static bool isInstrPartialReduction(Instruction *Instr) {
-  Value *ExpectedPhi;
-  Value *A, *B;
-
-  using namespace llvm::PatternMatch;
-  auto Pattern =
-      m_BinOp(m_OneUse(m_BinOp(m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
-                               m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
-              m_Value(ExpectedPhi));
-
-  bool Matches = match(Instr, Pattern);
-
-  if (!Matches)
-    return false;
-
-  // Check that the extends extend from the same type
-  if (A->getType() != B->getType()) {
-    LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot "
-                         "create a partial reduction.\n");
-    return false;
-  }
-
-  // A and B are one-use, so the first user of each should be the respective
-  // extend
-  Instruction *Ext0 = cast<CastInst>(*A->user_begin());
-  Instruction *Ext1 = cast<CastInst>(*B->user_begin());
-
-  // Check that the extends extend to the same type
-  if (Ext0->getType() != Ext1->getType()) {
-    LLVM_DEBUG(dbgs() << "Extends don't extend to the same type, cannot create "
-                         "a partial reduction.\n");
-    return false;
-  }
-
-  // Check that the add feeds into ExpectedPhi
-  PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
-  if (!PhiNode) {
-    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a "
-                         "partial reduction.\n");
-    return false;
-  }
-
-  // Check that the second phi value is the instruction we're looking at
-  Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
-  if (!MaybeAdd || MaybeAdd != Instr) {
-    LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot "
-                         "create a partial reduction.\n");
-    return false;
-  }
-
-  return true;
-}
-
-static bool isPartialReductionChainValid(PartialReductionChain &Chain,
-                                         const TargetTransformInfo &TTI) {
-  if (Chain.Reduction->getOpcode() != Instruction::Add)
-    return false;
-
-  unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
-  unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
-
-  if (ResultSizeBits < InputSizeBits || (ResultSizeBits % InputSizeBits) != 0)
-    return false;
-
-  bool IsASignExtended = isa<SExtInst>(Chain.ExtendA);
-  bool IsBSignExtended = isa<SExtInst>(Chain.ExtendB);
-
-  return TTI.isPartialReductionSupported(
-      Chain.Reduction, Chain.InputA->getType(), Chain.ScaleFactor,
-      IsASignExtended, IsBSignExtended, Chain.BinOp);
-}
-
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
 // vectorization. The loop needs to be annotated with #pragma omp simd
 // simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -4729,11 +4732,7 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
 
   // Prevent epilogue vectorization if a partial reduction is involved
   // TODO Is there a cleaner way to check this?
-  if (any_of(Legal->getReductionVars(),
-             [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
-               return isInstrPartialReduction(
-                   Reduction.second.getLoopExitInstr());
-             }))
+  if (CM.getPartialReductionChains().size() > 0)
     return false;
 
   // Epilogue vectorization code has not been auditted to ensure it handles
@@ -4745,17 +4744,6 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
   return true;
 }
 
-bool LoopVectorizationPlanner::getInstructionsPartialReduction(
-    Instruction *I, PartialReductionChain &Chain) const {
-  for (auto &C : PartialReductionChains) {
-    if (C.Reduction == I) {
-      Chain = C;
-      return true;
-    }
-  }
-  return false;
-}
-
 bool LoopVectorizationCostModel::isEpilogueVectorizationProfitable(
     const ElementCount VF) const {
   // FIXME: We need a much better cost-model to take different parameters such
@@ -6851,8 +6839,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I,
   } // end of switch.
 }
 
-void LoopVectorizationCostModel::collectValuesToIgnore(
-    LoopVectorizationPlanner *LVP) {
+void LoopVectorizationCostModel::collectValuesToIgnore() {
   // Ignore ephemeral values.
   CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore);
 
@@ -7009,7 +6996,8 @@ void LoopVectorizationCostModel::collectValuesToIgnore(
   }
 
   // Ignore any values that we know will be flattened
-  for (auto Chain : LVP->getPartialReductionChains()) {
+  for (auto It : getPartialReductionChains()) {
+    PartialReductionChain Chain = It.second;
     SmallVector<Value *> PartialReductionValues{Chain.Reduction, Chain.BinOp,
                                                 Chain.ExtendA, Chain.ExtendB,
                                                 Chain.Accumulator};
@@ -7137,14 +7125,10 @@ void LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
 
   for (auto ReductionVar : Legal->getReductionVars()) {
     auto *ReductionExitInstr = ReductionVar.second.getLoopExitInstr();
-    if (isInstrPartialReduction(ReductionExitInstr)) {
-      auto Chain = getPartialReductionInstrChain(ReductionExitInstr);
-      if (isPartialReductionChainValid(Chain, TTI))
-        PartialReductionChains.push_back(Chain);
-    }
+    CM.addPartialReductionIfSupported(ReductionExitInstr, UserVF);
   }
 
-  CM.collectValuesToIgnore(this);
+  CM.collectValuesToIgnore();
   CM.collectElementTypesForWidening();
 
   FixedScalableVFPair MaxFactors = CM.computeMaxVF(UserVF, UserIC);
@@ -9108,7 +9092,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
       VPRecipeBase *Recipe = nullptr;
 
       PartialReductionChain Chain;
-      if (getInstructionsPartialReduction(Instr, Chain))
+      if (CM.getInstructionsPartialReduction(Instr, Chain))
         Recipe =
             RecipeBuilder.tryToCreatePartialReduction(Range, Chain, Operands);
 

>From a043d80f9341dca1ca1dd1ce3d261057e36ccb98 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 10 Sep 2024 13:38:14 +0100
Subject: [PATCH 27/54] Remove postInsertionOp

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 22 ++++++++++++++-----
 .../Transforms/Vectorize/VPRecipeBuilder.h    |  2 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |  7 ------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  4 ----
 .../LoopVectorize/AArch64/vplan-printing.ll   |  2 +-
 5 files changed, 18 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f98ee42bffe29f..c54484a1527b2e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8647,9 +8647,22 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
           Legal->getReductionVars().find(Phi)->second;
       assert(RdxDesc.getRecurrenceStartValue() ==
              Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
+
+      // If the PHI is used by a partial reduction, set the scale factor
+      unsigned ScaleFactor = 1;
+      for (auto *User : Phi->users()) {
+        if (auto *I = dyn_cast<Instruction>(User)) {
+            PartialReductionChain Chain;
+            if (CM.getInstructionsPartialReduction(I, Chain)) {
+                ScaleFactor = Chain.ScaleFactor;
+                break;
+            }
+        }
+      }
       PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
                                            CM.isInLoopReduction(Phi),
-                                           CM.useOrderedReductions(RdxDesc));
+                                           CM.useOrderedReductions(RdxDesc),
+                                           ScaleFactor);
     } else {
       // TODO: Currently fixed-order recurrences are modeled as chains of
       // first-order recurrences. If there are no users of the intermediate
@@ -8698,7 +8711,7 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
 }
 
 VPRecipeBase *
-VPRecipeBuilder::tryToCreatePartialReduction(VFRange &Range,
+VPRecipeBuilder::tryToCreatePartialReduction(
                                              PartialReductionChain &Chain,
                                              ArrayRef<VPValue *> Operands) {
   return new VPPartialReductionRecipe(
@@ -9094,7 +9107,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
       PartialReductionChain Chain;
       if (CM.getInstructionsPartialReduction(Instr, Chain))
         Recipe =
-            RecipeBuilder.tryToCreatePartialReduction(Range, Chain, Operands);
+            RecipeBuilder.tryToCreatePartialReduction(Chain, Operands);
 
       if (!Recipe)
         Recipe =
@@ -9120,9 +9133,6 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         VPBB->appendRecipe(Recipe);
     }
 
-    for (auto &Recipe : *VPBB)
-      Recipe.postInsertionOp();
-
     VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
     VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor());
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 5c15e6f16bd082..01ecac21c7f1ee 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -117,7 +117,7 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
-  VPRecipeBase *tryToCreatePartialReduction(VFRange &Range,
+  VPRecipeBase *tryToCreatePartialReduction(
                                             PartialReductionChain &Chain,
                                             ArrayRef<VPValue *> Operands);
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index cb0b7c822d6ff6..06249faf035ad5 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -831,10 +831,6 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
   /// \returns an iterator pointing to the element after the erased one
   iplist<VPRecipeBase>::iterator eraseFromParent();
 
-  /// Run any required modifications to the recipe after it has been inserted
-  /// into the plan.
-  virtual void postInsertionOp() {}
-
   /// Method to support type inquiry through isa, cast, and dyn_cast.
   static inline bool classof(const VPDef *D) {
     // All VPDefs are also VPRecipeBases.
@@ -2177,8 +2173,6 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
     return R->getVPDefID() == VPDef::VPReductionPHISC;
   }
 
-  void SetVFScaleFactor(unsigned ScaleFactor) { VFScaleFactor = ScaleFactor; }
-
   /// Generate the phi/select nodes.
   void execute(VPTransformState &State) override;
 
@@ -2222,7 +2216,6 @@ class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
   VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
   /// Generate the reduction in the loop
   void execute(VPTransformState &State) override;
-  void postInsertionOp() override;
   unsigned getOpcode() { return Opcode; }
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   /// Print the recipe.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 452290ebabb047..87d6293af586c9 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -373,10 +373,6 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
   }
 }
 
-void VPPartialReductionRecipe::postInsertionOp() {
-  cast<VPReductionPHIRecipe>(this->getOperand(1))->SetVFScaleFactor(Scale);
-}
-
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                                      VPSlotTracker &SlotTracker) const {
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
index 96d51a55b37098..1bf03963ad40e7 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
@@ -31,7 +31,7 @@ define void @print_partial_reduction(ptr %a, ptr %b) {
 ; CHECK-NEXT:   WIDEN ir<%2> = load vp<%5>
 ; CHECK-NEXT:   WIDEN-CAST ir<%conv3> = zext ir<%2> to i32
 ; CHECK-NEXT:   WIDEN ir<%mul> = mul ir<%conv3>, ir<%conv>
-; CHECK-NEXT:   WIDEN ir<%add> = add ir<%mul>, ir<[[ACC]]>
+; CHECK-NEXT:   PARTIAL-REDUCE ir<%add> = add ir<%mul>, ir<%acc.010>
 ; CHECK-NEXT:   EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
 ; CHECK-NEXT:   EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VEC_TC]]>
 ; CHECK-NEXT: No successors

>From 9d7d04735ec948720d8fa35a8db7703f2effe6cf Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 10 Sep 2024 14:41:01 +0100
Subject: [PATCH 28/54] Clean up

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 23 +++----
 .../Transforms/Vectorize/VPRecipeBuilder.h    |  3 +-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 68 +++++++------------
 3 files changed, 36 insertions(+), 58 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index c54484a1527b2e..9f6e4b4fc4b35d 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8652,17 +8652,16 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
       unsigned ScaleFactor = 1;
       for (auto *User : Phi->users()) {
         if (auto *I = dyn_cast<Instruction>(User)) {
-            PartialReductionChain Chain;
-            if (CM.getInstructionsPartialReduction(I, Chain)) {
-                ScaleFactor = Chain.ScaleFactor;
-                break;
-            }
+          PartialReductionChain Chain;
+          if (CM.getInstructionsPartialReduction(I, Chain)) {
+            ScaleFactor = Chain.ScaleFactor;
+            break;
+          }
         }
       }
-      PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
-                                           CM.isInLoopReduction(Phi),
-                                           CM.useOrderedReductions(RdxDesc),
-                                           ScaleFactor);
+      PhiRecipe = new VPReductionPHIRecipe(
+          Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
+          CM.useOrderedReductions(RdxDesc), ScaleFactor);
     } else {
       // TODO: Currently fixed-order recurrences are modeled as chains of
       // first-order recurrences. If there are no users of the intermediate
@@ -8711,8 +8710,7 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
 }
 
 VPRecipeBase *
-VPRecipeBuilder::tryToCreatePartialReduction(
-                                             PartialReductionChain &Chain,
+VPRecipeBuilder::tryToCreatePartialReduction(PartialReductionChain &Chain,
                                              ArrayRef<VPValue *> Operands) {
   return new VPPartialReductionRecipe(
       *Chain.Reduction, make_range(Operands.begin(), Operands.end()),
@@ -9106,8 +9104,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
 
       PartialReductionChain Chain;
       if (CM.getInstructionsPartialReduction(Instr, Chain))
-        Recipe =
-            RecipeBuilder.tryToCreatePartialReduction(Chain, Operands);
+        Recipe = RecipeBuilder.tryToCreatePartialReduction(Chain, Operands);
 
       if (!Recipe)
         Recipe =
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 01ecac21c7f1ee..9ebc55f65ea50a 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -117,8 +117,7 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
-  VPRecipeBase *tryToCreatePartialReduction(
-                                            PartialReductionChain &Chain,
+  VPRecipeBase *tryToCreatePartialReduction(PartialReductionChain &Chain,
                                             ArrayRef<VPValue *> Operands);
 
   /// Set the recipe created for given ingredient.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 87d6293af586c9..854023aa3f5661 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -322,54 +322,36 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
   State.setDebugLocFrom(getDebugLoc());
   auto &Builder = State.Builder;
 
-  switch (Opcode) {
-  case Instruction::Add: {
-
-    unsigned UF = getParent()->getPlan()->getUF();
-    for (unsigned Part = 0; Part < UF; ++Part) {
-      Value *Mul = nullptr;
-      Value *Phi = nullptr;
-      SmallVector<Value *, 2> Ops;
-      for (VPValue *VPOp : operands()) {
-        auto *Op = State.get(VPOp, Part);
-        Ops.push_back(Op);
-        if (isa<PHINode>(Op))
-          Phi = Op;
-        else
-          Mul = Op;
-      }
+  assert(Opcode == Instruction::Add && "Unhandled partial reduction opcode");
 
-      assert(Phi && Mul && "Phi and Mul must be set");
-
-      VectorType *FullTy = cast<VectorType>(Ops[0]->getType());
-      auto EC = FullTy->getElementCount();
-      Type *RetTy = VectorType::get(FullTy->getScalarType(),
-                                    EC.divideCoefficientBy(Scale));
+  unsigned UF = getParent()->getPlan()->getUF();
+  for (unsigned Part = 0; Part < UF; ++Part) {
+    Value *Mul = nullptr;
+    Value *Phi = nullptr;
+    SmallVector<Value *, 2> Ops;
+    for (VPValue *VPOp : operands()) {
+      auto *Op = State.get(VPOp, Part);
+      Ops.push_back(Op);
+      if (isa<PHINode>(Op))
+        Phi = Op;
+      else
+        Mul = Op;
+    }
 
-      Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
-      switch (Opcode) {
-      case Instruction::Add:
-        PartialIntrinsic = Intrinsic::experimental_vector_partial_reduce_add;
-        break;
-      default:
-        llvm_unreachable("Opcode not handled");
-      }
+    assert(Phi && Mul && "Phi and Mul must be set");
 
-      assert(PartialIntrinsic != Intrinsic::not_intrinsic);
+    VectorType *FullTy = cast<VectorType>(Ops[0]->getType());
+    ElementCount EC = FullTy->getElementCount();
+    Type *RetTy =
+        VectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale));
 
-      CallInst *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, {Phi, Mul},
-                                            nullptr, Twine("partial.reduce"));
+    CallInst *V = Builder.CreateIntrinsic(
+        RetTy, Intrinsic::experimental_vector_partial_reduce_add, {Phi, Mul},
+        nullptr, Twine("partial.reduce"));
 
-      // Use this vector value for all users of the original instruction.
-      State.set(this, V, Part);
-      State.addMetadata(V, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
-    }
-    break;
-  }
-  default:
-    LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : "
-                      << Instruction::getOpcodeName(Opcode));
-    llvm_unreachable("Unhandled instruction!");
+    // Use this vector value for all users of the original instruction.
+    State.set(this, V, Part);
+    State.addMetadata(V, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
   }
 }
 

>From 3614a268eb681585fa7ef0400dbe375dd4f3737e Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 10 Sep 2024 14:48:28 +0100
Subject: [PATCH 29/54] Clean up test

---
 .../LoopVectorize/AArch64/vplan-printing.ll   | 26 -------------------
 1 file changed, 26 deletions(-)

diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
index 1bf03963ad40e7..b1832ab159c44a 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
@@ -1,5 +1,4 @@
 ; REQUIRES: asserts
-
 ; RUN: opt -passes=loop-vectorize -debug-only=loop-vectorize -force-vector-interleave=1 -disable-output %s 2>&1 | FileCheck %s
 
 target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
@@ -76,28 +75,3 @@ for.body:                                         ; preds = %for.body, %entry
   %exitcond.not = icmp eq i64 %indvars.iv.next, 0
   br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
 }
-
-!llvm.dbg.cu = !{!0}
-!llvm.module.flags = !{!3, !4}
-
-declare float @foo(float) #0
-declare <2 x float> @vector_foo(<2 x float>, <2 x i1>)
-
-; We need a vector variant in order to allow for vectorization at present, but
-; we want to test scalarization of conditional calls. If we provide a variant
-; with a different number of lanes than the VF we force via
-; "-force-vector-width=4", then it should pass the legality checks but
-; scalarize. TODO: Remove the requirement to have a variant.
-attributes #0 = { readonly nounwind "vector-function-abi-variant"="_ZGV_LLVM_M2v_foo(vector_foo)" }
-
-!0 = distinct !DICompileUnit(language: DW_LANG_C99, file: !1, producer: "clang", isOptimized: true, runtimeVersion: 0, emissionKind: NoDebug, enums: !2)
-!1 = !DIFile(filename: "/tmp/s.c", directory: "/tmp")
-!2 = !{}
-!3 = !{i32 2, !"Debug Info Version", i32 3}
-!4 = !{i32 7, !"PIC Level", i32 2}
-!5 = distinct !DISubprogram(name: "f", scope: !1, file: !1, line: 4, type: !6, scopeLine: 4, flags: DIFlagPrototyped, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0, retainedNodes: !2)
-!6 = !DISubroutineType(types: !2)
-!7 = !DILocation(line: 5, column: 3, scope: !5)
-!8 = !DILocation(line: 5, column: 21, scope: !5)
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; CHECK: {{.*}}

>From 0c710b1af26db7cadaf2bb6965a210a2dc60a148 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Fri, 13 Sep 2024 14:04:13 +0100
Subject: [PATCH 30/54] Move PartialReductionExtendKind

---
 llvm/include/llvm/Analysis/TargetTransformInfo.h     | 4 ++--
 llvm/include/llvm/Analysis/TargetTransformInfoImpl.h | 4 ++--
 llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h | 6 +++---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp      | 4 ++--
 4 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 50004a23059845..c788e8eb40febe 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -92,8 +92,6 @@ struct MemIntrinsicInfo {
   }
 };
 
-enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };
-
 /// Attributes of a target dependent hardware loop.
 struct HardwareLoopInfo {
   HardwareLoopInfo() = delete;
@@ -213,6 +211,8 @@ typedef TargetTransformInfo TTI;
 /// for IR-level transformations.
 class TargetTransformInfo {
 public:
+  enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };
+
   static PartialReductionExtendKind
   getPartialReductionExtendKind(Instruction *I) {
     if (isa<SExtInst>(I))
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 02a499e538fcf0..e24aa729d1c1bd 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -553,8 +553,8 @@ class TargetTransformInfoImplBase {
 
   InstructionCost
   getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
-                          ElementCount VF, PartialReductionExtendKind OpAExtend,
-                          PartialReductionExtendKind OpBExtend,
+                          ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+                          TTI::PartialReductionExtendKind OpBExtend,
                           std::optional<unsigned> BinOp = std::nullopt) const {
     return InstructionCost::getInvalid();
   }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 3f63a2255d9e50..989ee72d250376 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -342,8 +342,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   InstructionCost getPartialReductionCost(unsigned Opcode, Type *InputType,
                                           Type *AccumType, ElementCount VF,
-                                          PartialReductionExtendKind OpAExtend,
-                                          PartialReductionExtendKind OpBExtend,
+                                          TTI::PartialReductionExtendKind OpAExtend,
+                                          TTI::PartialReductionExtendKind OpBExtend,
                                           std::optional<unsigned> BinOp) const {
     InstructionCost Cost = InstructionCost::getInvalid();
 
@@ -368,7 +368,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     } else
       return Cost;
 
-    if (OpAExtend == PR_None || OpBExtend == PR_None)
+    if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None)
       return Cost;
 
     if (!BinOp || (*BinOp) != Instruction::Mul)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 9f6e4b4fc4b35d..f98e06c1510a37 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1614,9 +1614,9 @@ class LoopVectorizationCostModel {
     Instruction *BinOp = cast<Instruction>(Instr->getOperand(0));
     Value *InputA = Ext0->getOperand(0);
     Value *InputB = Ext1->getOperand(0);
-    PartialReductionExtendKind OpAExtend =
+    TTI::PartialReductionExtendKind OpAExtend =
         TargetTransformInfo::getPartialReductionExtendKind(Ext0);
-    PartialReductionExtendKind OpBExtend =
+    TTI::PartialReductionExtendKind OpBExtend =
         TargetTransformInfo::getPartialReductionExtendKind(Ext1);
     InstructionCost Cost = TTI.getPartialReductionCost(
         Instr->getOpcode(), InputA->getType(), ExpectedPhi->getType(), VF,

>From 95a72e8ca25645178fad86f8ddaeafd92d60cc0c Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Fri, 13 Sep 2024 14:04:25 +0100
Subject: [PATCH 31/54] Cost -> Invalid

---
 .../AArch64/AArch64TargetTransformInfo.h       | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 989ee72d250376..257ace26400c5c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -345,34 +345,34 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
                                           TTI::PartialReductionExtendKind OpAExtend,
                                           TTI::PartialReductionExtendKind OpBExtend,
                                           std::optional<unsigned> BinOp) const {
-    InstructionCost Cost = InstructionCost::getInvalid();
+    InstructionCost Invalid = InstructionCost::getInvalid();
 
     if (Opcode != Instruction::Add)
-      return Cost;
+      return Invalid;
 
     EVT InputEVT = EVT::getEVT(InputType);
     EVT AccumEVT = EVT::getEVT(AccumType);
 
     if (AccumEVT.isScalableVector() && !ST->isSVEorStreamingSVEAvailable())
-      return Cost;
+      return Invalid;
     if (!AccumEVT.isScalableVector() && !ST->isNeonAvailable() &&
         !ST->hasDotProd())
-      return Cost;
+      return Invalid;
 
     if (InputEVT == MVT::i8) {
       if (AccumEVT != MVT::i32)
-        return Cost;
+        return Invalid;
     } else if (InputEVT == MVT::i16) {
       if (AccumEVT != MVT::i64)
-        return Cost;
+        return Invalid;
     } else
-      return Cost;
+      return Invalid;
 
     if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None)
-      return Cost;
+      return Invalid;
 
     if (!BinOp || (*BinOp) != Instruction::Mul)
-      return Cost;
+      return Invalid;
 
     return InstructionCost::getMin();
   }

>From 597f3039ad37ff3976adf568ab0119a982f8cbdf Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Fri, 13 Sep 2024 14:05:16 +0100
Subject: [PATCH 32/54] Use isFixedLengthVector

---
 llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 257ace26400c5c..e3fa007791301a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -355,7 +355,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
     if (AccumEVT.isScalableVector() && !ST->isSVEorStreamingSVEAvailable())
       return Invalid;
-    if (!AccumEVT.isScalableVector() && !ST->isNeonAvailable() &&
+    if (AccumEVT.isFixedLengthVector() && !ST->isNeonAvailable() &&
         !ST->hasDotProd())
       return Invalid;
 

>From 2bae011482cd541317b232a1f6863966802e6d44 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Fri, 13 Sep 2024 14:41:54 +0100
Subject: [PATCH 33/54] Move PartialReductionChain

---
 .../Vectorize/LoopVectorizationPlanner.h      | 25 ------------
 .../Transforms/Vectorize/LoopVectorize.cpp    | 38 ++++++++++++++++---
 .../Transforms/Vectorize/VPRecipeBuilder.h    |  3 +-
 3 files changed, 34 insertions(+), 32 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 70850e62feb7bf..034fdf4233de37 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -337,31 +337,6 @@ struct FixedScalableVFPair {
   bool hasVector() const { return FixedVF.isVector() || ScalableVF.isVector(); }
 };
 
-/// A chain of instructions that form a partial reduction.
-/// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
-/// accumulator)
-struct PartialReductionChain {
-  /// The top-level binary operation that forms the reduction to a scalar after
-  /// the loop body
-  Instruction *Reduction;
-  /// The inner binary operation that forms the reduction to a vector value
-  /// within the loop body
-  Instruction *BinOp;
-  /// The extension of each of the inner binary operation's operands
-  Instruction *ExtendA;
-  Instruction *ExtendB;
-
-  /// The inner binary operation's operands
-  Value *InputA;
-  Value *InputB;
-  /// The accumulator that is reduced to a scalar after the loop body
-  Value *Accumulator;
-
-  /// The scaling factor between the size of the reduction type and the
-  /// (possibly extended) inputs
-  unsigned ScaleFactor;
-};
-
 /// Planner drives the vectorization process after having passed
 /// Legality checks.
 class LoopVectorizationPlanner {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f98e06c1510a37..12ca8c68dedba6 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1545,6 +1545,31 @@ class LoopVectorizationCostModel {
   getReductionPatternCost(Instruction *I, ElementCount VF, Type *VectorTy,
                           TTI::TargetCostKind CostKind) const;
 
+  /// A chain of instructions that form a partial reduction.
+  /// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
+  /// accumulator)
+  struct PartialReductionChain {
+    /// The top-level binary operation that forms the reduction to a scalar
+    /// after the loop body
+    Instruction *Reduction;
+    /// The inner binary operation that forms the reduction to a vector value
+    /// within the loop body
+    Instruction *BinOp;
+    /// The extension of each of the inner binary operation's operands
+    Instruction *ExtendA;
+    Instruction *ExtendB;
+
+    /// The inner binary operation's operands
+    Value *InputA;
+    Value *InputB;
+    /// The accumulator that is reduced to a scalar after the loop body
+    Value *Accumulator;
+
+    /// The scaling factor between the size of the reduction type and the
+    /// (possibly extended) inputs
+    unsigned ScaleFactor;
+  };
+
   using PartialReductionList = DenseMap<Instruction *, PartialReductionChain>;
 
   PartialReductionList getPartialReductionChains() {
@@ -8652,7 +8677,7 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
       unsigned ScaleFactor = 1;
       for (auto *User : Phi->users()) {
         if (auto *I = dyn_cast<Instruction>(User)) {
-          PartialReductionChain Chain;
+          LoopVectorizationCostModel::PartialReductionChain Chain;
           if (CM.getInstructionsPartialReduction(I, Chain)) {
             ScaleFactor = Chain.ScaleFactor;
             break;
@@ -8710,11 +8735,11 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
 }
 
 VPRecipeBase *
-VPRecipeBuilder::tryToCreatePartialReduction(PartialReductionChain &Chain,
+VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
+                                             unsigned ScaleFactor,
                                              ArrayRef<VPValue *> Operands) {
   return new VPPartialReductionRecipe(
-      *Chain.Reduction, make_range(Operands.begin(), Operands.end()),
-      Chain.ScaleFactor);
+      *Reduction, make_range(Operands.begin(), Operands.end()), ScaleFactor);
 }
 
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
@@ -9102,9 +9127,10 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
 
       VPRecipeBase *Recipe = nullptr;
 
-      PartialReductionChain Chain;
+      LoopVectorizationCostModel::PartialReductionChain Chain;
       if (CM.getInstructionsPartialReduction(Instr, Chain))
-        Recipe = RecipeBuilder.tryToCreatePartialReduction(Chain, Operands);
+        Recipe = RecipeBuilder.tryToCreatePartialReduction(
+            Chain.Reduction, Chain.ScaleFactor, Operands);
 
       if (!Recipe)
         Recipe =
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 9ebc55f65ea50a..5c4192147ac20f 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -117,7 +117,8 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
-  VPRecipeBase *tryToCreatePartialReduction(PartialReductionChain &Chain,
+  VPRecipeBase *tryToCreatePartialReduction(Instruction *Reduction,
+                                            unsigned ScaleFactor,
                                             ArrayRef<VPValue *> Operands);
 
   /// Set the recipe created for given ingredient.

>From 449ad7fa2d93451d63482e4fb609ae958ed6a9d5 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Fri, 13 Sep 2024 15:11:48 +0100
Subject: [PATCH 34/54] Remove InputA and InputB

---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 8 +-------
 1 file changed, 1 insertion(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 12ca8c68dedba6..1c37c43c1afb90 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1559,9 +1559,6 @@ class LoopVectorizationCostModel {
     Instruction *ExtendA;
     Instruction *ExtendB;
 
-    /// The inner binary operation's operands
-    Value *InputA;
-    Value *InputB;
     /// The accumulator that is reduced to a scalar after the loop body
     Value *Accumulator;
 
@@ -1638,7 +1635,6 @@ class LoopVectorizationCostModel {
 
     Instruction *BinOp = cast<Instruction>(Instr->getOperand(0));
     Value *InputA = Ext0->getOperand(0);
-    Value *InputB = Ext1->getOperand(0);
     TTI::PartialReductionExtendKind OpAExtend =
         TargetTransformInfo::getPartialReductionExtendKind(Ext0);
     TTI::PartialReductionExtendKind OpBExtend =
@@ -1655,11 +1651,9 @@ class LoopVectorizationCostModel {
     Chain.BinOp = BinOp;
     Chain.ExtendA = Ext0;
     Chain.ExtendB = Ext1;
-    Chain.InputA = InputA;
-    Chain.InputB = InputB;
     Chain.Accumulator = ExpectedPhi;
 
-    unsigned InputSizeBits = Chain.InputA->getType()->getScalarSizeInBits();
+    unsigned InputSizeBits = InputA->getType()->getScalarSizeInBits();
     unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
     Chain.ScaleFactor = ResultSizeBits / InputSizeBits;
 

>From dafe017572f3f2e644c1c19c435997c614141a58 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Fri, 13 Sep 2024 15:26:47 +0100
Subject: [PATCH 35/54] Return optional from getInstructionsPartialReduction

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 19 ++++++++-----------
 1 file changed, 8 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1c37c43c1afb90..c73f34ffd0d140 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1573,13 +1573,12 @@ class LoopVectorizationCostModel {
     return PartialReductionChains;
   }
 
-  bool getInstructionsPartialReduction(Instruction *I,
-                                       PartialReductionChain &Chain) const {
+  std::optional<PartialReductionChain> getInstructionsPartialReduction(Instruction *I
+ ) const {
     auto PairIt = PartialReductionChains.find(I);
     if (PairIt == PartialReductionChains.end())
-      return false;
-    Chain = PairIt->second;
-    return true;
+      return std::nullopt;
+    return PairIt->second;
   }
 
   void addPartialReductionIfSupported(Instruction *Instr, ElementCount VF) {
@@ -8671,9 +8670,8 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
       unsigned ScaleFactor = 1;
       for (auto *User : Phi->users()) {
         if (auto *I = dyn_cast<Instruction>(User)) {
-          LoopVectorizationCostModel::PartialReductionChain Chain;
-          if (CM.getInstructionsPartialReduction(I, Chain)) {
-            ScaleFactor = Chain.ScaleFactor;
+          if (auto Chain = CM.getInstructionsPartialReduction(I)) {
+            ScaleFactor = Chain->ScaleFactor;
             break;
           }
         }
@@ -9121,10 +9119,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
 
       VPRecipeBase *Recipe = nullptr;
 
-      LoopVectorizationCostModel::PartialReductionChain Chain;
-      if (CM.getInstructionsPartialReduction(Instr, Chain))
+      if (auto Chain = CM.getInstructionsPartialReduction(Instr))
         Recipe = RecipeBuilder.tryToCreatePartialReduction(
-            Chain.Reduction, Chain.ScaleFactor, Operands);
+            Chain->Reduction, Chain->ScaleFactor, Operands);
 
       if (!Recipe)
         Recipe =

>From 16d3db6bfef7898fb1fccb081d932f0401a58e7b Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 09:54:40 +0100
Subject: [PATCH 36/54] Remove Pattern and Matches

---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index c73f34ffd0d140..3de71c081f97b7 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1586,14 +1586,12 @@ class LoopVectorizationCostModel {
     Value *A, *B;
 
     using namespace llvm::PatternMatch;
-    auto Pattern =
-        m_BinOp(m_OneUse(m_BinOp(m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
-                                 m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
-                m_Value(ExpectedPhi));
 
-    bool Matches = match(Instr, Pattern);
-
-    if (!Matches)
+    if (!match(Instr,
+               m_BinOp(m_OneUse(m_BinOp(
+                           m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
+                           m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
+                       m_Value(ExpectedPhi))))
       return;
 
     // Check that the extends extend from the same type

>From 78cac58dc781207c4460e96bb01d6cb0c60a1cb8 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 09:58:11 +0100
Subject: [PATCH 37/54] Use getIncomingValueForBlock

---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 3de71c081f97b7..7db0e8f3b25a81 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1623,7 +1623,8 @@ class LoopVectorizationCostModel {
     }
 
     // Check that the second phi value is the instruction we're looking at
-    Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
+    Instruction *MaybeAdd = dyn_cast<Instruction>(
+        PhiNode->getIncomingValueForBlock(Instr->getParent()));
     if (!MaybeAdd || MaybeAdd != Instr) {
       LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot "
                            "create a partial reduction.\n");

>From d6a3ff7b7ea2a59de468fd4470b8ec72ff2115ce Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 10:26:06 +0100
Subject: [PATCH 38/54] Check for commutativity in pattern

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 22 ++++++++++++++-----
 1 file changed, 16 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 7db0e8f3b25a81..a6a84f98211698 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1587,11 +1587,21 @@ class LoopVectorizationCostModel {
 
     using namespace llvm::PatternMatch;
 
-    if (!match(Instr,
-               m_BinOp(m_OneUse(m_BinOp(
-                           m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
-                           m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
-                       m_Value(ExpectedPhi))))
+    unsigned BinOpIdx = 0;
+
+    // The binary operator can be commutative
+    if (match(Instr, m_BinOp(m_OneUse(m_BinOp(
+                                 m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
+                                 m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
+                             m_Value(ExpectedPhi))))
+      BinOpIdx = 0;
+    else if (match(Instr,
+                   m_BinOp(m_Value(ExpectedPhi),
+                           m_OneUse(m_BinOp(
+                               m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
+                               m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))))))
+      BinOpIdx = 1;
+    else
       return;
 
     // Check that the extends extend from the same type
@@ -1631,7 +1641,7 @@ class LoopVectorizationCostModel {
       return;
     }
 
-    Instruction *BinOp = cast<Instruction>(Instr->getOperand(0));
+    Instruction *BinOp = cast<Instruction>(Instr->getOperand(BinOpIdx));
     Value *InputA = Ext0->getOperand(0);
     TTI::PartialReductionExtendKind OpAExtend =
         TargetTransformInfo::getPartialReductionExtendKind(Ext0);

>From 50c32f49c443d484e67599ed9a85f639877561dd Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 10:26:27 +0100
Subject: [PATCH 39/54] InputA -> A

---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index a6a84f98211698..9e4f5749c61a85 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1642,14 +1642,13 @@ class LoopVectorizationCostModel {
     }
 
     Instruction *BinOp = cast<Instruction>(Instr->getOperand(BinOpIdx));
-    Value *InputA = Ext0->getOperand(0);
     TTI::PartialReductionExtendKind OpAExtend =
         TargetTransformInfo::getPartialReductionExtendKind(Ext0);
     TTI::PartialReductionExtendKind OpBExtend =
         TargetTransformInfo::getPartialReductionExtendKind(Ext1);
     InstructionCost Cost = TTI.getPartialReductionCost(
-        Instr->getOpcode(), InputA->getType(), ExpectedPhi->getType(), VF,
-        OpAExtend, OpBExtend,
+        Instr->getOpcode(), A->getType(), ExpectedPhi->getType(), VF, OpAExtend,
+        OpBExtend,
         BinOp ? std::make_optional(BinOp->getOpcode()) : std::nullopt);
     if (Cost == InstructionCost::getInvalid())
       return;
@@ -1661,7 +1660,7 @@ class LoopVectorizationCostModel {
     Chain.ExtendB = Ext1;
     Chain.Accumulator = ExpectedPhi;
 
-    unsigned InputSizeBits = InputA->getType()->getScalarSizeInBits();
+    unsigned InputSizeBits = A->getType()->getScalarSizeInBits();
     unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();
     Chain.ScaleFactor = ResultSizeBits / InputSizeBits;
 

>From 96d326afe619c0166e4300886aaa6fdd8642e4b5 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 10:26:43 +0100
Subject: [PATCH 40/54] Remove tertiary on BinOp

---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 9e4f5749c61a85..b7d1db93e25a95 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1648,8 +1648,7 @@ class LoopVectorizationCostModel {
         TargetTransformInfo::getPartialReductionExtendKind(Ext1);
     InstructionCost Cost = TTI.getPartialReductionCost(
         Instr->getOpcode(), A->getType(), ExpectedPhi->getType(), VF, OpAExtend,
-        OpBExtend,
-        BinOp ? std::make_optional(BinOp->getOpcode()) : std::nullopt);
+        OpBExtend, std::make_optional(BinOp->getOpcode()));
     if (Cost == InstructionCost::getInvalid())
       return;
 

>From 510de44c25b43c8e200db1d25677099a97ccf0e1 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 10:26:57 +0100
Subject: [PATCH 41/54] if -> else if for recipe check

---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index b7d1db93e25a95..666f176461368a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9129,8 +9129,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
       if (auto Chain = CM.getInstructionsPartialReduction(Instr))
         Recipe = RecipeBuilder.tryToCreatePartialReduction(
             Chain->Reduction, Chain->ScaleFactor, Operands);
-
-      if (!Recipe)
+      else if (!Recipe)
         Recipe =
             RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB);
       if (!Recipe)

>From aa8f4c016bbf32a08e45a75181b821ed4ff1dc42 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 15:03:01 +0100
Subject: [PATCH 42/54] Remove Scale from recipe and pre-set order of recipe
 operands

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 12 +++++++-
 llvm/lib/Transforms/Vectorize/VPlan.h         |  5 ++--
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 29 +++++++------------
 3 files changed, 23 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 666f176461368a..653f4b926edb0e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8737,8 +8737,18 @@ VPRecipeBase *
 VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
                                              unsigned ScaleFactor,
                                              ArrayRef<VPValue *> Operands) {
+  assert(Operands.size() == 2 &&
+         "Unexpected number of operands for partial reduction");
+
+  VPValue *BinOp = Operands[0];
+  VPValue *Phi = Operands[1];
+  VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
+  if (isa<VPReductionPHIRecipe>(BinOpRecipe))
+    std::swap(BinOp, Phi);
+
+  SmallVector<VPValue *, 2> OrderedOperands = {BinOp, Phi};
   return new VPPartialReductionRecipe(
-      *Reduction, make_range(Operands.begin(), Operands.end()), ScaleFactor);
+      *Reduction, make_range(OrderedOperands.begin(), OrderedOperands.end()));
 }
 
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 06249faf035ad5..3ca8f7d8d305a8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2203,10 +2203,9 @@ class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
 
 public:
   template <typename IterT>
-  VPPartialReductionRecipe(Instruction &I, iterator_range<IterT> Operands,
-                           unsigned Scale)
+  VPPartialReductionRecipe(Instruction &I, iterator_range<IterT> Operands)
       : VPRecipeWithIRFlags(VPDef::VPPartialReductionSC, Operands, I),
-        Opcode(I.getOpcode()), Scale(Scale) {}
+        Opcode(I.getOpcode()) {}
   ~VPPartialReductionRecipe() override = default;
   VPPartialReductionRecipe *clone() override {
     llvm_unreachable(
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 854023aa3f5661..fe0c0c240e260e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -324,30 +324,21 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
 
   assert(Opcode == Instruction::Add && "Unhandled partial reduction opcode");
 
+  SmallVector<VPValue *, 2> Operands;
+  for (auto *Op : operands())
+    Operands.push_back(Op);
+
   unsigned UF = getParent()->getPlan()->getUF();
   for (unsigned Part = 0; Part < UF; ++Part) {
-    Value *Mul = nullptr;
-    Value *Phi = nullptr;
-    SmallVector<Value *, 2> Ops;
-    for (VPValue *VPOp : operands()) {
-      auto *Op = State.get(VPOp, Part);
-      Ops.push_back(Op);
-      if (isa<PHINode>(Op))
-        Phi = Op;
-      else
-        Mul = Op;
-    }
-
-    assert(Phi && Mul && "Phi and Mul must be set");
+    Value *BinOpVal = State.get(Operands[0], Part);
+    Value *PhiVal = State.get(Operands[1], Part);
+    assert(PhiVal && BinOpVal && "Phi and Mul must be set");
 
-    VectorType *FullTy = cast<VectorType>(Ops[0]->getType());
-    ElementCount EC = FullTy->getElementCount();
-    Type *RetTy =
-        VectorType::get(FullTy->getScalarType(), EC.divideCoefficientBy(Scale));
+    Type *RetTy = PhiVal->getType();
 
     CallInst *V = Builder.CreateIntrinsic(
-        RetTy, Intrinsic::experimental_vector_partial_reduce_add, {Phi, Mul},
-        nullptr, Twine("partial.reduce"));
+        RetTy, Intrinsic::experimental_vector_partial_reduce_add,
+        {PhiVal, BinOpVal}, nullptr, Twine("partial.reduce"));
 
     // Use this vector value for all users of the original instruction.
     State.set(this, V, Part);

>From dca93ce28f88242f3fa2e6b48855b153ef6745d5 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 15:19:25 +0100
Subject: [PATCH 43/54] Improve test IR naming

---
 .../AArch64/partial-reduce-dot-product.ll     | 104 +++++++++---------
 1 file changed, 52 insertions(+), 52 deletions(-)

diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll
index 76bec75a8e3580..a332c720eaacf9 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll
@@ -4,8 +4,8 @@
 target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
 target triple = "aarch64-none-unknown-elf"
 
-define void @dotp(ptr %a, ptr %b) #0 {
-; CHECK-LABEL: define void @dotp(
+define i32 @dotp(ptr %a, ptr %b) #0 {
+; CHECK-LABEL: define i32 @dotp(
 ; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
@@ -46,22 +46,22 @@ entry:
   br label %for.body
 
 for.cond.cleanup.loopexit:                        ; preds = %for.body
-  %0 = lshr i32 %add, 0
-  ret void
+  %result = lshr i32 %add, 0
+  ret i32 %result
 
 for.body:                                         ; preds = %for.body, %entry
-  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
-  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
-  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
-  %1 = load i8, ptr %arrayidx, align 1
-  %conv = zext i8 %1 to i32
-  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
-  %2 = load i8, ptr %arrayidx2, align 1
-  %conv3 = zext i8 %2 to i32
-  %mul = mul i32 %conv3, %conv
-  %add = add i32 %mul, %acc.010
-  %indvars.iv.next = add i64 %indvars.iv, 1
-  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %gep.a = getelementptr i8, ptr %a, i64 %iv
+  %load.a = load i8, ptr %gep.a, align 1
+  %ext.a = zext i8 %load.a to i32
+  %gep.b = getelementptr i8, ptr %b, i64 %iv
+  %load.b = load i8, ptr %gep.b, align 1
+  %ext.b = zext i8 %load.b to i32
+  %mul = mul i32 %ext.b, %ext.a
+  %add = add i32 %mul, %accum
+  %iv.next = add i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, 0
   br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
 }
 
@@ -161,18 +161,18 @@ for.cond.cleanup.loopexit:                        ; preds = %for.body
   ret void
 
 for.body:                                         ; preds = %for.body, %entry
-  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
-  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
-  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
-  %1 = load i8, ptr %arrayidx, align 1
-  %conv = zext i8 %1 to i32
-  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
-  %2 = load i16, ptr %arrayidx2, align 2
-  %conv3 = zext i16 %2 to i32
-  %mul = mul i32 %conv3, %conv
-  %add = add i32 %mul, %acc.010
-  %indvars.iv.next = add i64 %indvars.iv, 1
-  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %gep.a = getelementptr i8, ptr %a, i64 %iv
+  %load.a = load i8, ptr %gep.a, align 1
+  %ext.a = zext i8 %load.a to i32
+  %gep.b = getelementptr i8, ptr %b, i64 %iv
+  %load.b = load i16, ptr %gep.b, align 2
+  %ext.b = zext i16 %load.b to i32
+  %mul = mul i32 %ext.b, %ext.a
+  %add = add i32 %mul, %accum
+  %iv.next = add i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, 0
   br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
 }
 
@@ -214,18 +214,18 @@ for.cond.cleanup.loopexit:                        ; preds = %for.body
   ret void
 
 for.body:                                         ; preds = %for.body, %entry
-  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
-  %acc.010 = phi i32 [ 0, %entry ], [ %mul, %for.body ]
-  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
-  %1 = load i8, ptr %arrayidx, align 1
-  %conv = zext i8 %1 to i32
-  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
-  %2 = load i8, ptr %arrayidx2, align 1
-  %conv3 = zext i8 %2 to i32
-  %mul = mul i32 %conv3, %conv
-  %add = add i32 %mul, %acc.010
-  %indvars.iv.next = add i64 %indvars.iv, 1
-  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %accum = phi i32 [ 0, %entry ], [ %mul, %for.body ]
+  %gep.a = getelementptr i8, ptr %a, i64 %iv
+  %load.a = load i8, ptr %gep.a, align 1
+  %ext.a = zext i8 %load.a to i32
+  %gep.b = getelementptr i8, ptr %b, i64 %iv
+  %load.b = load i8, ptr %gep.b, align 1
+  %ext.b = zext i8 %load.b to i32
+  %mul = mul i32 %ext.b, %ext.a
+  %add = add i32 %mul, %accum
+  %iv.next = add i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, 0
   br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
 }
 
@@ -266,18 +266,18 @@ for.cond.cleanup.loopexit:                        ; preds = %for.body
   ret void
 
 for.body:                                         ; preds = %for.body, %entry
-  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
-  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
-  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
-  %1 = load i8, ptr %arrayidx, align 1
-  %conv = zext i8 %1 to i32
-  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
-  %2 = load i8, ptr %arrayidx2, align 1
-  %conv3 = zext i8 %2 to i32
-  %mul = mul i32 %conv3, %conv
-  %add = add i32 %mul, %conv3
-  %indvars.iv.next = add i64 %indvars.iv, 1
-  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %gep.a = getelementptr i8, ptr %a, i64 %iv
+  %load.a = load i8, ptr %gep.a, align 1
+  %ext.a = zext i8 %load.a to i32
+  %gep.b = getelementptr i8, ptr %b, i64 %iv
+  %load.b = load i8, ptr %gep.b, align 1
+  %ext.b = zext i8 %load.b to i32
+  %mul = mul i32 %ext.b, %ext.a
+  %add = add i32 %mul, %ext.b
+  %iv.next = add i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, 0
   br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
 }
 attributes #0 = { nofree norecurse nosync nounwind memory(argmem: readwrite) uwtable vscale_range(1,16) "target-features"="+sve" }

>From 3b45be254d3359f1d88ea28f59bd05796bd938e8 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 15:21:03 +0100
Subject: [PATCH 44/54] Add missing attributes to test

---
 .../AArch64/partial-reduce-dot-product.ll     | 110 ++++++++++++------
 1 file changed, 75 insertions(+), 35 deletions(-)

diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll
index a332c720eaacf9..ea46c0446a7505 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll
@@ -65,9 +65,9 @@ for.body:                                         ; preds = %for.body, %entry
   br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
 }
 
-define void @dotp_different_types(ptr %a, ptr %b) {
+define void @dotp_different_types(ptr %a, ptr %b) #0 {
 ; CHECK-LABEL: define void @dotp_different_types(
-; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) {
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
@@ -176,35 +176,55 @@ for.body:                                         ; preds = %for.body, %entry
   br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
 }
 
-define void @dotp_not_loop_carried(ptr %a, ptr %b) {
+define void @dotp_not_loop_carried(ptr %a, ptr %b) #0 {
 ; CHECK-LABEL: define void @dotp_not_loop_carried(
-; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) {
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 8
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 0, [[TMP1]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 8
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 0, [[TMP3]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 0, [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 8
+; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vscale.i32()
+; CHECK-NEXT:    [[TMP7:%.*]] = mul i32 [[TMP6]], 8
+; CHECK-NEXT:    [[TMP8:%.*]] = sub i32 [[TMP7]], 1
+; CHECK-NEXT:    [[VECTOR_RECUR_INIT:%.*]] = insertelement <vscale x 8 x i32> poison, i32 0, i32 [[TMP8]]
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VECTOR_RECUR:%.*]] = phi <16 x i32> [ <i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 0>, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VECTOR_RECUR:%.*]] = phi <vscale x 8 x i32> [ [[VECTOR_RECUR_INIT]], [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP9]]
 ; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP10]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP11]], align 1
-; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 8 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT:    [[TMP12:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD]] to <vscale x 8 x i32>
 ; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP9]]
 ; CHECK-NEXT:    [[TMP14:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP14]], align 1
-; CHECK-NEXT:    [[TMP6:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
-; CHECK-NEXT:    [[TMP7]] = mul <16 x i32> [[TMP6]], [[TMP3]]
-; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <16 x i32> [[VECTOR_RECUR]], <16 x i32> [[TMP7]], <16 x i32> <i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30>
-; CHECK-NEXT:    [[TMP15:%.*]] = add <16 x i32> [[TMP7]], [[TMP8]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
-; CHECK-NEXT:    [[TMP16:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
-; CHECK-NEXT:    br i1 [[TMP16]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 8 x i8>, ptr [[TMP14]], align 1
+; CHECK-NEXT:    [[TMP15:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD1]] to <vscale x 8 x i32>
+; CHECK-NEXT:    [[TMP16]] = mul <vscale x 8 x i32> [[TMP15]], [[TMP12]]
+; CHECK-NEXT:    [[TMP17:%.*]] = call <vscale x 8 x i32> @llvm.vector.splice.nxv8i32(<vscale x 8 x i32> [[VECTOR_RECUR]], <vscale x 8 x i32> [[TMP16]], i32 -1)
+; CHECK-NEXT:    [[TMP18:%.*]] = add <vscale x 8 x i32> [[TMP16]], [[TMP17]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT:    [[TMP19:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <16 x i32> [[TMP7]], i32 15
-; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <16 x i32> [[TMP15]], i32 15
-; CHECK-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+; CHECK-NEXT:    [[TMP20:%.*]] = call i32 @llvm.vscale.i32()
+; CHECK-NEXT:    [[TMP21:%.*]] = mul i32 [[TMP20]], 8
+; CHECK-NEXT:    [[TMP22:%.*]] = sub i32 [[TMP21]], 1
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <vscale x 8 x i32> [[TMP16]], i32 [[TMP22]]
+; CHECK-NEXT:    [[TMP23:%.*]] = call i32 @llvm.vscale.i32()
+; CHECK-NEXT:    [[TMP24:%.*]] = mul i32 [[TMP23]], 8
+; CHECK-NEXT:    [[TMP25:%.*]] = sub i32 [[TMP24]], 1
+; CHECK-NEXT:    [[TMP26:%.*]] = extractelement <vscale x 8 x i32> [[TMP18]], i32 [[TMP25]]
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ;
 entry:
   br label %for.body
@@ -229,34 +249,54 @@ for.body:                                         ; preds = %for.body, %entry
   br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
 }
 
-define void @dotp_not_phi(ptr %a, ptr %b) {
+define void @dotp_not_phi(ptr %a, ptr %b) #0 {
 ; CHECK-LABEL: define void @dotp_not_phi(
-; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) {
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 8
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 0, [[TMP1]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 8
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 0, [[TMP3]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 0, [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 8
+; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vscale.i32()
+; CHECK-NEXT:    [[TMP7:%.*]] = mul i32 [[TMP6]], 8
+; CHECK-NEXT:    [[TMP8:%.*]] = sub i32 [[TMP7]], 1
+; CHECK-NEXT:    [[VECTOR_RECUR_INIT:%.*]] = insertelement <vscale x 8 x i32> poison, i32 0, i32 [[TMP8]]
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VECTOR_RECUR:%.*]] = phi <16 x i32> [ <i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 0>, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VECTOR_RECUR:%.*]] = phi <vscale x 8 x i32> [ [[VECTOR_RECUR_INIT]], [[VECTOR_PH]] ], [ [[TMP17:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP9]]
 ; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP10]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP11]], align 1
-; CHECK-NEXT:    [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 8 x i8>, ptr [[TMP11]], align 1
+; CHECK-NEXT:    [[TMP12:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD]] to <vscale x 8 x i32>
 ; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP9]]
 ; CHECK-NEXT:    [[TMP14:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP14]], align 1
-; CHECK-NEXT:    [[TMP6:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = mul <16 x i32> [[TMP6]], [[TMP3]]
-; CHECK-NEXT:    [[TMP8]] = add <16 x i32> [[TMP7]], [[TMP6]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
-; CHECK-NEXT:    [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
-; CHECK-NEXT:    br i1 [[TMP15]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 8 x i8>, ptr [[TMP14]], align 1
+; CHECK-NEXT:    [[TMP15:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD1]] to <vscale x 8 x i32>
+; CHECK-NEXT:    [[TMP16:%.*]] = mul <vscale x 8 x i32> [[TMP15]], [[TMP12]]
+; CHECK-NEXT:    [[TMP17]] = add <vscale x 8 x i32> [[TMP16]], [[TMP15]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT:    [[TMP18:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP18]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <16 x i32> [[TMP8]], i32 15
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <16 x i32> [[TMP8]], i32 15
-; CHECK-NEXT:    br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+; CHECK-NEXT:    [[TMP19:%.*]] = call i32 @llvm.vscale.i32()
+; CHECK-NEXT:    [[TMP20:%.*]] = mul i32 [[TMP19]], 8
+; CHECK-NEXT:    [[TMP21:%.*]] = sub i32 [[TMP20]], 1
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <vscale x 8 x i32> [[TMP17]], i32 [[TMP21]]
+; CHECK-NEXT:    [[TMP22:%.*]] = call i32 @llvm.vscale.i32()
+; CHECK-NEXT:    [[TMP23:%.*]] = mul i32 [[TMP22]], 8
+; CHECK-NEXT:    [[TMP24:%.*]] = sub i32 [[TMP23]], 1
+; CHECK-NEXT:    [[TMP25:%.*]] = extractelement <vscale x 8 x i32> [[TMP17]], i32 [[TMP24]]
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ;
 entry:
   br label %for.body

>From 68c193a76ce01b96fe438357395a772d0b6ec4a0 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 15:54:58 +0100
Subject: [PATCH 45/54] Remove one-use restriction

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  39 +++-
 .../AArch64/partial-reduce-dot-product.ll     | 193 ++++++++++++++++++
 2 files changed, 228 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 653f4b926edb0e..2a0dfb14d62eab 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1591,15 +1591,15 @@ class LoopVectorizationCostModel {
 
     // The binary operator can be commutative
     if (match(Instr, m_BinOp(m_OneUse(m_BinOp(
-                                 m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
-                                 m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
+                                 m_ZExtOrSExt(m_Value(A)),
+                                 m_ZExtOrSExt(m_Value(B)))),
                              m_Value(ExpectedPhi))))
       BinOpIdx = 0;
     else if (match(Instr,
                    m_BinOp(m_Value(ExpectedPhi),
                            m_OneUse(m_BinOp(
-                               m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
-                               m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))))))
+                               m_ZExtOrSExt(m_Value(A)),
+                               m_ZExtOrSExt(m_Value(B)))))))
       BinOpIdx = 1;
     else
       return;
@@ -7153,6 +7153,37 @@ void LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
     CM.addPartialReductionIfSupported(ReductionExitInstr, UserVF);
   }
 
+  // Wider-than-legal vector types (coming from extends in partial reductions) should only be used by partial reductions so that they are lowered properly
+
+  // Build up a set of partial reduction bin ops for efficient use checking
+  SmallSet<Instruction *, 4> PartialReductionBinOps;
+  for (auto It : CM.getPartialReductionChains()) {
+    if (It.second.BinOp) PartialReductionBinOps.insert(It.second.BinOp);
+  }
+
+  auto ExtendIsOnlyUsedByPartialReductions = [PartialReductionBinOps](Instruction *Extend) {
+    for (auto *Use : Extend->users()) {
+      Instruction *UseInstr = dyn_cast<Instruction>(Use);
+      if (!PartialReductionBinOps.contains(UseInstr))
+        return false;
+    }
+    return true;
+  };
+
+  // Check if each use of a chain's two extends is a partial reduction
+  SmallVector<Instruction *, 2> ChainsToRemove;
+  for (auto It : CM.getPartialReductionChains()) {
+      LoopVectorizationCostModel::PartialReductionChain Chain = It.second;
+    if (!ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA))
+      ChainsToRemove.push_back(Chain.Reduction);
+    else if (!ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
+      ChainsToRemove.push_back(Chain.Reduction);
+  }
+
+  // Remove those that have non-partial reduction users
+  for (auto *It : ChainsToRemove)
+    CM.getPartialReductionChains().erase(It);
+
   CM.collectValuesToIgnore();
   CM.collectElementTypesForWidening();
 
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll
index ea46c0446a7505..ae3856e61fe54a 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll
@@ -320,4 +320,197 @@ for.body:                                         ; preds = %for.body, %entry
   %exitcond.not = icmp eq i64 %iv.next, 0
   br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
 }
+
+define void @dotp_unrolled(i32 %num_out, i32 %num_in, ptr %w, ptr %scales, ptr %u, ptr %v) #0 {
+entry:
+  %cmp154 = icmp sgt i32 %num_out, 3
+  br i1 %cmp154, label %for.body.lr.ph, label %for.end98
+
+for.body.lr.ph:                                   ; preds = %entry
+  %div = sdiv i32 %num_out, 4
+  %mul = shl nsw i32 %div, 2
+  %cmp11145 = icmp sgt i32 %num_in, 0
+  %idxprom44 = sext i32 %num_in to i64
+  %0 = zext nneg i32 %mul to i64
+  br i1 %cmp11145, label %for.body.us.preheader, label %for.body.preheader
+
+for.body.preheader:                               ; preds = %for.body.lr.ph
+  br label %for.end98
+
+for.body.us.preheader:                            ; preds = %for.body.lr.ph
+  %wide.trip.count = zext nneg i32 %num_in to i64
+  br label %for.body.us
+
+for.body.us:                                      ; preds = %for.body.us.preheader, %for.cond10.for.cond.cleanup_crit_edge.us
+  %indvars.iv164 = phi i64 [ 0, %for.body.us.preheader ], [ %indvars.iv.next165, %for.cond10.for.cond.cleanup_crit_edge.us ]
+  %arrayidx.us = getelementptr inbounds ptr, ptr %w, i64 %indvars.iv164
+  %1 = load ptr, ptr %arrayidx.us, align 8
+  %2 = or disjoint i64 %indvars.iv164, 1
+  %arrayidx3.us = getelementptr inbounds ptr, ptr %w, i64 %2
+  %3 = load ptr, ptr %arrayidx3.us, align 8
+  %4 = or disjoint i64 %indvars.iv164, 2
+  %arrayidx6.us = getelementptr inbounds ptr, ptr %w, i64 %4
+  %5 = load ptr, ptr %arrayidx6.us, align 8
+  %6 = or disjoint i64 %indvars.iv164, 3
+  %arrayidx9.us = getelementptr inbounds ptr, ptr %w, i64 %6
+  %7 = load ptr, ptr %arrayidx9.us, align 8
+  %8 = call i64 @llvm.vscale.i64()
+  %9 = mul i64 %8, 16
+  %min.iters.check = icmp ult i64 %wide.trip.count, %9
+  br i1 %min.iters.check, label %scalar.ph, label %vector.ph
+
+vector.ph:                                        ; preds = %for.body.us
+  %10 = call i64 @llvm.vscale.i64()
+  %11 = mul i64 %10, 16
+  %n.mod.vf = urem i64 %wide.trip.count, %11
+  %n.vec = sub i64 %wide.trip.count, %n.mod.vf
+  %12 = call i64 @llvm.vscale.i64()
+  %13 = mul i64 %12, 16
+  br label %vector.body
+
+vector.body:                                      ; preds = %vector.body, %vector.ph
+  %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
+  %vec.phi = phi <vscale x 4 x i32> [ zeroinitializer, %vector.ph ], [ %partial.reduce181, %vector.body ]
+  %vec.phi172 = phi <vscale x 4 x i32> [ zeroinitializer, %vector.ph ], [ %partial.reduce179, %vector.body ]
+  %vec.phi173 = phi <vscale x 4 x i32> [ zeroinitializer, %vector.ph ], [ %partial.reduce177, %vector.body ]
+  %vec.phi174 = phi <vscale x 4 x i32> [ zeroinitializer, %vector.ph ], [ %partial.reduce, %vector.body ]
+  %14 = add i64 %index, 0
+  %15 = getelementptr inbounds i8, ptr %1, i64 %14
+  %16 = getelementptr inbounds i8, ptr %15, i32 0
+  %wide.load = load <vscale x 16 x i8>, ptr %16, align 1
+  %17 = sext <vscale x 16 x i8> %wide.load to <vscale x 16 x i32>
+  %18 = getelementptr inbounds i8, ptr %u, i64 %14
+  %19 = getelementptr inbounds i8, ptr %18, i32 0
+  %wide.load175 = load <vscale x 16 x i8>, ptr %19, align 1
+  %20 = sext <vscale x 16 x i8> %wide.load175 to <vscale x 16 x i32>
+  %21 = mul nsw <vscale x 16 x i32> %20, %17
+  %partial.reduce = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %vec.phi174, <vscale x 16 x i32> %21)
+  %22 = getelementptr inbounds i8, ptr %3, i64 %14
+  %23 = getelementptr inbounds i8, ptr %22, i32 0
+  %wide.load176 = load <vscale x 16 x i8>, ptr %23, align 1
+  %24 = sext <vscale x 16 x i8> %wide.load176 to <vscale x 16 x i32>
+  %25 = mul nsw <vscale x 16 x i32> %24, %20
+  %partial.reduce177 = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %vec.phi173, <vscale x 16 x i32> %25)
+  %26 = getelementptr inbounds i8, ptr %5, i64 %14
+  %27 = getelementptr inbounds i8, ptr %26, i32 0
+  %wide.load178 = load <vscale x 16 x i8>, ptr %27, align 1
+  %28 = sext <vscale x 16 x i8> %wide.load178 to <vscale x 16 x i32>
+  %29 = mul nsw <vscale x 16 x i32> %28, %20
+  %partial.reduce179 = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %vec.phi172, <vscale x 16 x i32> %29)
+  %30 = getelementptr inbounds i8, ptr %7, i64 %14
+  %31 = getelementptr inbounds i8, ptr %30, i32 0
+  %wide.load180 = load <vscale x 16 x i8>, ptr %31, align 1
+  %32 = sext <vscale x 16 x i8> %wide.load180 to <vscale x 16 x i32>
+  %33 = mul nsw <vscale x 16 x i32> %32, %20
+  %partial.reduce181 = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %vec.phi, <vscale x 16 x i32> %33)
+  %index.next = add nuw i64 %index, %13
+  %34 = icmp eq i64 %index.next, %n.vec
+  br i1 %34, label %middle.block, label %vector.body
+
+middle.block:                                     ; preds = %vector.body
+  %35 = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %partial.reduce181)
+  %36 = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %partial.reduce179)
+  %37 = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %partial.reduce177)
+  %38 = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %partial.reduce)
+  %cmp.n = icmp eq i64 %wide.trip.count, %n.vec
+  br i1 %cmp.n, label %for.cond10.for.cond.cleanup_crit_edge.us, label %scalar.ph
+
+scalar.ph:                                        ; preds = %middle.block, %for.body.us
+  %bc.resume.val = phi i64 [ %n.vec, %middle.block ], [ 0, %for.body.us ]
+  %bc.merge.rdx = phi i32 [ %35, %middle.block ], [ 0, %for.body.us ]
+  %bc.merge.rdx182 = phi i32 [ %36, %middle.block ], [ 0, %for.body.us ]
+  %bc.merge.rdx183 = phi i32 [ %37, %middle.block ], [ 0, %for.body.us ]
+  %bc.merge.rdx184 = phi i32 [ %38, %middle.block ], [ 0, %for.body.us ]
+  br label %for.body12.us
+
+for.body12.us:                                    ; preds = %scalar.ph, %for.body12.us
+  %indvars.iv161 = phi i64 [ %bc.resume.val, %scalar.ph ], [ %indvars.iv.next162, %for.body12.us ]
+  %total3.0149.us = phi i32 [ %bc.merge.rdx, %scalar.ph ], [ %add43.us, %for.body12.us ]
+  %total2.0148.us = phi i32 [ %bc.merge.rdx182, %scalar.ph ], [ %add35.us, %for.body12.us ]
+  %total1.0147.us = phi i32 [ %bc.merge.rdx183, %scalar.ph ], [ %add27.us, %for.body12.us ]
+  %total0.0146.us = phi i32 [ %bc.merge.rdx184, %scalar.ph ], [ %add19.us, %for.body12.us ]
+  %arrayidx14.us = getelementptr inbounds i8, ptr %1, i64 %indvars.iv161
+  %39 = load i8, ptr %arrayidx14.us, align 1
+  %conv.us = sext i8 %39 to i32
+  %arrayidx16.us = getelementptr inbounds i8, ptr %u, i64 %indvars.iv161
+  %40 = load i8, ptr %arrayidx16.us, align 1
+  %conv17.us = sext i8 %40 to i32
+  %mul18.us = mul nsw i32 %conv17.us, %conv.us
+  %add19.us = add nsw i32 %mul18.us, %total0.0146.us
+  %arrayidx21.us = getelementptr inbounds i8, ptr %3, i64 %indvars.iv161
+  %41 = load i8, ptr %arrayidx21.us, align 1
+  %conv22.us = sext i8 %41 to i32
+  %mul26.us = mul nsw i32 %conv22.us, %conv17.us
+  %add27.us = add nsw i32 %mul26.us, %total1.0147.us
+  %arrayidx29.us = getelementptr inbounds i8, ptr %5, i64 %indvars.iv161
+  %42 = load i8, ptr %arrayidx29.us, align 1
+  %conv30.us = sext i8 %42 to i32
+  %mul34.us = mul nsw i32 %conv30.us, %conv17.us
+  %add35.us = add nsw i32 %mul34.us, %total2.0148.us
+  %arrayidx37.us = getelementptr inbounds i8, ptr %7, i64 %indvars.iv161
+  %43 = load i8, ptr %arrayidx37.us, align 1
+  %conv38.us = sext i8 %43 to i32
+  %mul42.us = mul nsw i32 %conv38.us, %conv17.us
+  %add43.us = add nsw i32 %mul42.us, %total3.0149.us
+  %indvars.iv.next162 = add nuw nsw i64 %indvars.iv161, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next162, %wide.trip.count
+  br i1 %exitcond.not, label %for.cond10.for.cond.cleanup_crit_edge.us, label %for.body12.us
+
+for.cond10.for.cond.cleanup_crit_edge.us:         ; preds = %middle.block, %for.body12.us
+  %add19.us.lcssa = phi i32 [ %add19.us, %for.body12.us ], [ %38, %middle.block ]
+  %add27.us.lcssa = phi i32 [ %add27.us, %for.body12.us ], [ %37, %middle.block ]
+  %add35.us.lcssa = phi i32 [ %add35.us, %for.body12.us ], [ %36, %middle.block ]
+  %add43.us.lcssa = phi i32 [ %add43.us, %for.body12.us ], [ %35, %middle.block ]
+  %arrayidx45.us = getelementptr inbounds i8, ptr %1, i64 %idxprom44
+  %44 = load i8, ptr %arrayidx45.us, align 1
+  %conv46.us = sext i8 %44 to i32
+  %mul47.us = mul nsw i32 %conv46.us, 127
+  %add48.us = add nsw i32 %mul47.us, %add19.us.lcssa
+  %conv49.us = sitofp i32 %add48.us to float
+  %arrayidx52.us = getelementptr inbounds float, ptr %scales, i64 %indvars.iv164
+  %45 = load float, ptr %arrayidx52.us, align 4
+  %mul53.us = fmul float %45, %conv49.us
+  %arrayidx56.us = getelementptr inbounds float, ptr %v, i64 %indvars.iv164
+  store float %mul53.us, ptr %arrayidx56.us, align 4
+  %arrayidx58.us = getelementptr inbounds i8, ptr %3, i64 %idxprom44
+  %46 = load i8, ptr %arrayidx58.us, align 1
+  %conv59.us = sext i8 %46 to i32
+  %mul60.us = mul nsw i32 %conv59.us, 127
+  %add61.us = add nsw i32 %mul60.us, %add27.us.lcssa
+  %conv62.us = sitofp i32 %add61.us to float
+  %arrayidx65.us = getelementptr inbounds float, ptr %scales, i64 %2
+  %47 = load float, ptr %arrayidx65.us, align 4
+  %mul66.us = fmul float %47, %conv62.us
+  %arrayidx69.us = getelementptr inbounds float, ptr %v, i64 %2
+  store float %mul66.us, ptr %arrayidx69.us, align 4
+  %arrayidx71.us = getelementptr inbounds i8, ptr %5, i64 %idxprom44
+  %48 = load i8, ptr %arrayidx71.us, align 1
+  %conv72.us = sext i8 %48 to i32
+  %mul73.us = mul nsw i32 %conv72.us, 127
+  %add74.us = add nsw i32 %mul73.us, %add35.us.lcssa
+  %conv75.us = sitofp i32 %add74.us to float
+  %arrayidx78.us = getelementptr inbounds float, ptr %scales, i64 %4
+  %49 = load float, ptr %arrayidx78.us, align 4
+  %mul79.us = fmul float %49, %conv75.us
+  %arrayidx82.us = getelementptr inbounds float, ptr %v, i64 %4
+  store float %mul79.us, ptr %arrayidx82.us, align 4
+  %arrayidx84.us = getelementptr inbounds i8, ptr %7, i64 %idxprom44
+  %50 = load i8, ptr %arrayidx84.us, align 1
+  %conv85.us = sext i8 %50 to i32
+  %mul86.us = mul nsw i32 %conv85.us, 127
+  %add87.us = add nsw i32 %mul86.us, %add43.us.lcssa
+  %conv88.us = sitofp i32 %add87.us to float
+  %arrayidx91.us = getelementptr inbounds float, ptr %scales, i64 %6
+  %51 = load float, ptr %arrayidx91.us, align 4
+  %mul92.us = fmul float %51, %conv88.us
+  %arrayidx95.us = getelementptr inbounds float, ptr %v, i64 %6
+  store float %mul92.us, ptr %arrayidx95.us, align 4
+  %indvars.iv.next165 = add nuw nsw i64 %indvars.iv164, 4
+  %cmp.us = icmp ult i64 %indvars.iv.next165, %0
+  br i1 %cmp.us, label %for.body.us, label %for.end98
+
+for.end98:                                        ; preds = %for.end98.loopexit171, %for.end98.loopexit, %entry
+  ret void
+}
+
 attributes #0 = { nofree norecurse nosync nounwind memory(argmem: readwrite) uwtable vscale_range(1,16) "target-features"="+sve" }

>From 42490f893554e97edbd3aa7360610fc08be7a6cf Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 16:13:28 +0100
Subject: [PATCH 46/54] Add scale factor helper function

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 22 +++++++++++--------
 1 file changed, 13 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 2a0dfb14d62eab..8ad322e5367537 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8679,6 +8679,18 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
   return Recipe;
 }
 
+unsigned getScaleFactorForReductionPhi(PHINode *Phi,
+                                       LoopVectorizationCostModel &CM) {
+  for (auto *User : Phi->users()) {
+    if (auto *I = dyn_cast<Instruction>(User)) {
+      if (auto Chain = CM.getInstructionsPartialReduction(I)) {
+        return Chain->ScaleFactor;
+      }
+    }
+  }
+  return 1;
+}
+
 VPRecipeBase *
 VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
                                         ArrayRef<VPValue *> Operands,
@@ -8705,15 +8717,7 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
              Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
 
       // If the PHI is used by a partial reduction, set the scale factor
-      unsigned ScaleFactor = 1;
-      for (auto *User : Phi->users()) {
-        if (auto *I = dyn_cast<Instruction>(User)) {
-          if (auto Chain = CM.getInstructionsPartialReduction(I)) {
-            ScaleFactor = Chain->ScaleFactor;
-            break;
-          }
-        }
-      }
+      unsigned ScaleFactor = getScaleFactorForReductionPhi(Phi, CM);
       PhiRecipe = new VPReductionPHIRecipe(
           Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
           CM.useOrderedReductions(RdxDesc), ScaleFactor);

>From 0d265aa4bb9ba983c4d2c3c6147d16043f79899f Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 19 Sep 2024 14:29:34 +0100
Subject: [PATCH 47/54] Manually match nodes

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 83 ++++++++++---------
 1 file changed, 43 insertions(+), 40 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 8ad322e5367537..be8746d28a0edd 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1582,28 +1582,45 @@ class LoopVectorizationCostModel {
   }
 
   void addPartialReductionIfSupported(Instruction *Instr, ElementCount VF) {
-    Value *ExpectedPhi;
-    Value *A, *B;
-
     using namespace llvm::PatternMatch;
 
-    unsigned BinOpIdx = 0;
-
-    // The binary operator can be commutative
-    if (match(Instr, m_BinOp(m_OneUse(m_BinOp(
-                                 m_ZExtOrSExt(m_Value(A)),
-                                 m_ZExtOrSExt(m_Value(B)))),
-                             m_Value(ExpectedPhi))))
-      BinOpIdx = 0;
-    else if (match(Instr,
-                   m_BinOp(m_Value(ExpectedPhi),
-                           m_OneUse(m_BinOp(
-                               m_ZExtOrSExt(m_Value(A)),
-                               m_ZExtOrSExt(m_Value(B)))))))
-      BinOpIdx = 1;
-    else
+    // Try to commutatively match:
+    // bin_op (one_use bin_op (z_or_sext, z_or_sext), phi)
+
+    auto *Root = dyn_cast<BinaryOperator>(Instr);
+    if (!Root) return;
+
+    auto *BinOp = dyn_cast<BinaryOperator>(Root->getOperand(0));
+    auto *Phi = dyn_cast<PHINode>(Root->getOperand(1));
+    if (!BinOp) {
+        BinOp = dyn_cast<BinaryOperator>(Root->getOperand(1));
+        Phi = dyn_cast<PHINode>(Root->getOperand(0));
+    }
+    if (!BinOp || !BinOp->hasOneUse()) {
+      LLVM_DEBUG(dbgs() << "Root was not a one-use binary operator, cannot create a "
+                           "partial reduction.\n");
       return;
+    }
+    if (!Phi) {
+      LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a "
+                           "partial reduction.\n");
+      return;
+    }
 
+    auto IsSextOrZext = [](Instruction *I) {
+        return I && (I->getOpcode() == Instruction::ZExt || I->getOpcode() == Instruction::SExt);
+    };
+
+    auto *ExtA = dyn_cast<Instruction>(BinOp->getOperand(0));
+    auto *ExtB = dyn_cast<Instruction>(BinOp->getOperand(1));
+    if (!IsSextOrZext(ExtA) || !IsSextOrZext(ExtB)) {
+      LLVM_DEBUG(dbgs() << "Expected extends were not extends, cannot create a "
+                           "partial reduction.\n");
+        return;
+    }
+
+    Value *A = ExtA->getOperand(0);
+    Value *B = ExtB->getOperand(0);
     // Check that the extends extend from the same type
     if (A->getType() != B->getType()) {
       LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot "
@@ -1611,43 +1628,29 @@ class LoopVectorizationCostModel {
       return;
     }
 
-    // A and B are one-use, so the first user of each should be the respective
-    // extend
-    Instruction *Ext0 = cast<CastInst>(*A->user_begin());
-    Instruction *Ext1 = cast<CastInst>(*B->user_begin());
-
     // Check that the extends extend to the same type
-    if (Ext0->getType() != Ext1->getType()) {
+    if (ExtA->getType() != ExtB->getType()) {
       LLVM_DEBUG(
           dbgs() << "Extends don't extend to the same type, cannot create "
                     "a partial reduction.\n");
       return;
     }
 
-    // Check that the add feeds into ExpectedPhi
-    PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
-    if (!PhiNode) {
-      LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a "
-                           "partial reduction.\n");
-      return;
-    }
-
     // Check that the second phi value is the instruction we're looking at
     Instruction *MaybeAdd = dyn_cast<Instruction>(
-        PhiNode->getIncomingValueForBlock(Instr->getParent()));
+        Phi->getIncomingValueForBlock(Instr->getParent()));
     if (!MaybeAdd || MaybeAdd != Instr) {
       LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot "
                            "create a partial reduction.\n");
       return;
     }
 
-    Instruction *BinOp = cast<Instruction>(Instr->getOperand(BinOpIdx));
     TTI::PartialReductionExtendKind OpAExtend =
-        TargetTransformInfo::getPartialReductionExtendKind(Ext0);
+        TargetTransformInfo::getPartialReductionExtendKind(ExtA);
     TTI::PartialReductionExtendKind OpBExtend =
-        TargetTransformInfo::getPartialReductionExtendKind(Ext1);
+        TargetTransformInfo::getPartialReductionExtendKind(ExtB);
     InstructionCost Cost = TTI.getPartialReductionCost(
-        Instr->getOpcode(), A->getType(), ExpectedPhi->getType(), VF, OpAExtend,
+        Instr->getOpcode(), A->getType(), Phi->getType(), VF, OpAExtend,
         OpBExtend, std::make_optional(BinOp->getOpcode()));
     if (Cost == InstructionCost::getInvalid())
       return;
@@ -1655,9 +1658,9 @@ class LoopVectorizationCostModel {
     PartialReductionChain Chain;
     Chain.Reduction = Instr;
     Chain.BinOp = BinOp;
-    Chain.ExtendA = Ext0;
-    Chain.ExtendB = Ext1;
-    Chain.Accumulator = ExpectedPhi;
+    Chain.ExtendA = ExtA;
+    Chain.ExtendB = ExtB;
+    Chain.Accumulator = Phi;
 
     unsigned InputSizeBits = A->getType()->getScalarSizeInBits();
     unsigned ResultSizeBits = Chain.Reduction->getType()->getScalarSizeInBits();

>From 91f05483c759803ff01cd7f49a22d64e970a45f0 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 19 Sep 2024 14:36:44 +0100
Subject: [PATCH 48/54] Use loop latch rather than parent

---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index be8746d28a0edd..937e6330cc07d8 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1638,7 +1638,7 @@ class LoopVectorizationCostModel {
 
     // Check that the second phi value is the instruction we're looking at
     Instruction *MaybeAdd = dyn_cast<Instruction>(
-        Phi->getIncomingValueForBlock(Instr->getParent()));
+        Phi->getIncomingValueForBlock(TheLoop->getLoopLatch()));
     if (!MaybeAdd || MaybeAdd != Instr) {
       LLVM_DEBUG(dbgs() << "Second PHI value is not the root binop, cannot "
                            "create a partial reduction.\n");

>From da54287c220b2acc536618cd5ebb707ea65d76f4 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 19 Sep 2024 16:18:02 +0100
Subject: [PATCH 49/54] Rebase

---
 .../AArch64/partial-reduce-dot-product.ll     | 96 ++++++++++++++++++-
 1 file changed, 94 insertions(+), 2 deletions(-)

diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll
index ae3856e61fe54a..f6af7b943d5077 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll
@@ -218,11 +218,11 @@ define void @dotp_not_loop_carried(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    [[TMP20:%.*]] = call i32 @llvm.vscale.i32()
 ; CHECK-NEXT:    [[TMP21:%.*]] = mul i32 [[TMP20]], 8
 ; CHECK-NEXT:    [[TMP22:%.*]] = sub i32 [[TMP21]], 1
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <vscale x 8 x i32> [[TMP16]], i32 [[TMP22]]
+; CHECK-NEXT:    [[TMP26:%.*]] = extractelement <vscale x 8 x i32> [[TMP18]], i32 [[TMP22]]
 ; CHECK-NEXT:    [[TMP23:%.*]] = call i32 @llvm.vscale.i32()
 ; CHECK-NEXT:    [[TMP24:%.*]] = mul i32 [[TMP23]], 8
 ; CHECK-NEXT:    [[TMP25:%.*]] = sub i32 [[TMP24]], 1
-; CHECK-NEXT:    [[TMP26:%.*]] = extractelement <vscale x 8 x i32> [[TMP18]], i32 [[TMP25]]
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <vscale x 8 x i32> [[TMP16]], i32 [[TMP25]]
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ;
@@ -322,6 +322,98 @@ for.body:                                         ; preds = %for.body, %entry
 }
 
 define void @dotp_unrolled(i32 %num_out, i32 %num_in, ptr %w, ptr %scales, ptr %u, ptr %v) #0 {
+; CHECK-LABEL: define void @dotp_unrolled(
+; CHECK-SAME: i32 [[NUM_OUT:%.*]], i32 [[NUM_IN:%.*]], ptr [[W:%.*]], ptr [[SCALES:%.*]], ptr [[U:%.*]], ptr [[V:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP154:%.*]] = icmp sgt i32 [[NUM_OUT]], 3
+; CHECK-NEXT:    br i1 [[CMP154]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END98:%.*]]
+; CHECK:       for.body.lr.ph:
+; CHECK-NEXT:    [[DIV:%.*]] = sdiv i32 [[NUM_OUT]], 4
+; CHECK-NEXT:    [[MUL:%.*]] = shl nsw i32 [[DIV]], 2
+; CHECK-NEXT:    [[CMP11145:%.*]] = icmp sgt i32 [[NUM_IN]], 0
+; CHECK-NEXT:    [[IDXPROM44:%.*]] = sext i32 [[NUM_IN]] to i64
+; CHECK-NEXT:    [[TMP0:%.*]] = zext nneg i32 [[MUL]] to i64
+; CHECK-NEXT:    br i1 [[CMP11145]], label [[FOR_BODY_US_PREHEADER:%.*]], label [[FOR_BODY_PREHEADER:%.*]]
+; CHECK:       for.body.preheader:
+; CHECK-NEXT:    br label [[FOR_END98]]
+; CHECK:       for.body.us.preheader:
+; CHECK-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[NUM_IN]] to i64
+; CHECK-NEXT:    br label [[FOR_BODY_US:%.*]]
+; CHECK:       for.body.us:
+; CHECK-NEXT:    [[INDVARS_IV164:%.*]] = phi i64 [ 0, [[FOR_BODY_US_PREHEADER]] ], [ [[INDVARS_IV_NEXT165:%.*]], [[FOR_COND10_FOR_COND_CLEANUP_CRIT_EDGE_US:%.*]] ]
+; CHECK-NEXT:    [[ARRAYIDX_US:%.*]] = getelementptr inbounds ptr, ptr [[W]], i64 [[INDVARS_IV164]]
+; CHECK-NEXT:    [[TMP1:%.*]] = load ptr, ptr [[ARRAYIDX_US]], align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = or disjoint i64 [[INDVARS_IV164]], 1
+; CHECK-NEXT:    [[ARRAYIDX3_US:%.*]] = getelementptr inbounds ptr, ptr [[W]], i64 [[TMP2]]
+; CHECK-NEXT:    [[TMP3:%.*]] = load ptr, ptr [[ARRAYIDX3_US]], align 8
+; CHECK-NEXT:    [[TMP4:%.*]] = or disjoint i64 [[INDVARS_IV164]], 2
+; CHECK-NEXT:    [[ARRAYIDX6_US:%.*]] = getelementptr inbounds ptr, ptr [[W]], i64 [[TMP4]]
+; CHECK-NEXT:    [[TMP5:%.*]] = load ptr, ptr [[ARRAYIDX6_US]], align 8
+; CHECK-NEXT:    [[TMP6:%.*]] = or disjoint i64 [[INDVARS_IV164]], 3
+; CHECK-NEXT:    [[ARRAYIDX9_US:%.*]] = getelementptr inbounds ptr, ptr [[W]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP7:%.*]] = load ptr, ptr [[ARRAYIDX9_US]], align 8
+; CHECK-NEXT:    [[TMP8:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP9:%.*]] = mul i64 [[TMP8]], 16
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP9]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[TMP10:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP11:%.*]] = mul i64 [[TMP10]], 16
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP11]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP12:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP13:%.*]] = mul i64 [[TMP12]], 16
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE181:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI172:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE179:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI173:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE177:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI174:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP14:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i8, ptr [[TMP1]], i64 [[TMP14]]
+; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr inbounds i8, ptr [[TMP15]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 16 x i8>, ptr [[TMP16]], align 1
+; CHECK-NEXT:    [[TMP17:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i8, ptr [[U]], i64 [[TMP14]]
+; CHECK-NEXT:    [[TMP19:%.*]] = getelementptr inbounds i8, ptr [[TMP18]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD175:%.*]] = load <vscale x 16 x i8>, ptr [[TMP19]], align 1
+; CHECK-NEXT:    [[TMP20:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD175]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP21:%.*]] = mul nsw <vscale x 16 x i32> [[TMP20]], [[TMP17]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI174]], <vscale x 16 x i32> [[TMP21]])
+; CHECK-NEXT:    [[TMP22:%.*]] = getelementptr inbounds i8, ptr [[TMP3]], i64 [[TMP14]]
+; CHECK-NEXT:    [[TMP23:%.*]] = getelementptr inbounds i8, ptr [[TMP22]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD176:%.*]] = load <vscale x 16 x i8>, ptr [[TMP23]], align 1
+; CHECK-NEXT:    [[TMP24:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD176]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP25:%.*]] = mul nsw <vscale x 16 x i32> [[TMP24]], [[TMP20]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE177]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI173]], <vscale x 16 x i32> [[TMP25]])
+; CHECK-NEXT:    [[TMP26:%.*]] = getelementptr inbounds i8, ptr [[TMP5]], i64 [[TMP14]]
+; CHECK-NEXT:    [[TMP27:%.*]] = getelementptr inbounds i8, ptr [[TMP26]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD178:%.*]] = load <vscale x 16 x i8>, ptr [[TMP27]], align 1
+; CHECK-NEXT:    [[TMP28:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD178]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP29:%.*]] = mul nsw <vscale x 16 x i32> [[TMP28]], [[TMP20]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE179]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI172]], <vscale x 16 x i32> [[TMP29]])
+; CHECK-NEXT:    [[TMP30:%.*]] = getelementptr inbounds i8, ptr [[TMP7]], i64 [[TMP14]]
+; CHECK-NEXT:    [[TMP31:%.*]] = getelementptr inbounds i8, ptr [[TMP30]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD180:%.*]] = load <vscale x 16 x i8>, ptr [[TMP31]], align 1
+; CHECK-NEXT:    [[TMP32:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD180]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP33:%.*]] = mul nsw <vscale x 16 x i32> [[TMP32]], [[TMP20]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE181]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI]], <vscale x 16 x i32> [[TMP33]])
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP13]]
+; CHECK-NEXT:    [[TMP34:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP34]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[PARTIAL_REDUCE_LCSSA:%.*]] = phi <vscale x 4 x i32> [ [[PARTIAL_REDUCE]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[PARTIAL_REDUCE177_LCSSA:%.*]] = phi <vscale x 4 x i32> [ [[PARTIAL_REDUCE177]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[PARTIAL_REDUCE179_LCSSA:%.*]] = phi <vscale x 4 x i32> [ [[PARTIAL_REDUCE179]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[PARTIAL_REDUCE181_LCSSA:%.*]] = phi <vscale x 4 x i32> [ [[PARTIAL_REDUCE181]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP35:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE181_LCSSA]])
+; CHECK-NEXT:    [[TMP36:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE179_LCSSA]])
+; CHECK-NEXT:    [[TMP37:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE177_LCSSA]])
+; CHECK-NEXT:    [[TMP38:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE_LCSSA]])
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND10_FOR_COND_CLEANUP_CRIT_EDGE_US]], label [[SCALAR_PH]]
+;
 entry:
   %cmp154 = icmp sgt i32 %num_out, 3
   br i1 %cmp154, label %for.body.lr.ph, label %for.end98

>From 57d3dab06032f8df44142718a99fce1719f74323 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 20 Sep 2024 14:23:58 +0100
Subject: [PATCH 50/54] Rebase aarch64 vplan printing test

---
 llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
index b1832ab159c44a..7fcb33b8584f33 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
@@ -45,12 +45,12 @@ define void @print_partial_reduction(ptr %a, ptr %b) {
 ; CHECK-NEXT: Successor(s): ir-bb<for.cond.cleanup.loopexit>, scalar.ph
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.cond.cleanup.loopexit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i32 [ %add, %for.body ] (extra operand: vp<%9>)
+; CHECK-NEXT:   IR   %0 = lshr i32 %add.lcssa, 0
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: scalar.ph:
 ; CHECK-NEXT: No successors
-; CHECK-EMPTY:
-; CHECK-NEXT: Live-out i32 %add.lcssa = vp<%9>
 ; CHECK-NEXT: }
 ;
 entry:

>From 375c4cdf02574dc634b7d9b310efeee744cf4f9d Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 20 Sep 2024 16:29:15 +0100
Subject: [PATCH 51/54] Format

---
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  3 +-
 llvm/include/llvm/IR/Intrinsics.h             |  6 +--
 llvm/lib/IR/Function.cpp                      |  6 +--
 .../AArch64/AArch64TargetTransformInfo.h      | 11 +++--
 .../Transforms/Vectorize/LoopVectorize.cpp    | 46 +++++++++++--------
 5 files changed, 40 insertions(+), 32 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index e24aa729d1c1bd..3b0bd3c72eba1a 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -553,7 +553,8 @@ class TargetTransformInfoImplBase {
 
   InstructionCost
   getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
-                          ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+                          ElementCount VF,
+                          TTI::PartialReductionExtendKind OpAExtend,
                           TTI::PartialReductionExtendKind OpBExtend,
                           std::optional<unsigned> BinOp = std::nullopt) const {
     return InstructionCost::getInvalid();
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index 4bd7fda77f3132..cd76059ab65abe 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -171,9 +171,9 @@ namespace Intrinsic {
     ArgKind getArgumentKind() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
              Kind == TruncArgument || Kind == HalfVecArgument ||
-             Kind == SameVecWidthArgument ||
-             Kind == VecElementArgument || Kind == Subdivide2Argument ||
-             Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
+             Kind == SameVecWidthArgument || Kind == VecElementArgument ||
+             Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
+             Kind == VecOfBitcastsToInt);
       return (ArgKind)(Argument_Info & 7);
     }
 
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 8767c2971f62c8..d30cb86525ea7f 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1482,8 +1482,8 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
     return VectorType::getSubdividedVectorType(VTy, SubDivs);
   }
   case IITDescriptor::HalfVecArgument:
-    return VectorType::getHalfElementsVectorType(cast<VectorType>(
-                                                  Tys[D.getArgumentNumber()]));
+    return VectorType::getHalfElementsVectorType(
+        cast<VectorType>(Tys[D.getArgumentNumber()]));
   case IITDescriptor::SameVecWidthArgument: {
     Type *EltTy = DecodeFixedType(Infos, Tys, Context);
     Type *Ty = Tys[D.getArgumentNumber()];
@@ -1714,7 +1714,7 @@ static bool matchIntrinsicType(
         return IsDeferredCheck || DeferCheck(Ty);
       return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
              VectorType::getHalfElementsVectorType(
-                     cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+                 cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
     case IITDescriptor::SameVecWidthArgument: {
       if (D.getArgumentNumber() >= ArgTys.size()) {
         // Defer check and subsequent check for the vector element type.
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index e3fa007791301a..bd9e4c11b9523c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -340,11 +340,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return BaseT::isLegalNTLoad(DataType, Alignment);
   }
 
-  InstructionCost getPartialReductionCost(unsigned Opcode, Type *InputType,
-                                          Type *AccumType, ElementCount VF,
-                                          TTI::PartialReductionExtendKind OpAExtend,
-                                          TTI::PartialReductionExtendKind OpBExtend,
-                                          std::optional<unsigned> BinOp) const {
+  InstructionCost
+  getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
+                          ElementCount VF,
+                          TTI::PartialReductionExtendKind OpAExtend,
+                          TTI::PartialReductionExtendKind OpBExtend,
+                          std::optional<unsigned> BinOp) const {
     InstructionCost Invalid = InstructionCost::getInvalid();
 
     if (Opcode != Instruction::Add)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 937e6330cc07d8..218020839b0787 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1573,8 +1573,8 @@ class LoopVectorizationCostModel {
     return PartialReductionChains;
   }
 
-  std::optional<PartialReductionChain> getInstructionsPartialReduction(Instruction *I
- ) const {
+  std::optional<PartialReductionChain>
+  getInstructionsPartialReduction(Instruction *I) const {
     auto PairIt = PartialReductionChains.find(I);
     if (PairIt == PartialReductionChains.end())
       return std::nullopt;
@@ -1588,17 +1588,19 @@ class LoopVectorizationCostModel {
     // bin_op (one_use bin_op (z_or_sext, z_or_sext), phi)
 
     auto *Root = dyn_cast<BinaryOperator>(Instr);
-    if (!Root) return;
+    if (!Root)
+      return;
 
     auto *BinOp = dyn_cast<BinaryOperator>(Root->getOperand(0));
     auto *Phi = dyn_cast<PHINode>(Root->getOperand(1));
     if (!BinOp) {
-        BinOp = dyn_cast<BinaryOperator>(Root->getOperand(1));
-        Phi = dyn_cast<PHINode>(Root->getOperand(0));
+      BinOp = dyn_cast<BinaryOperator>(Root->getOperand(1));
+      Phi = dyn_cast<PHINode>(Root->getOperand(0));
     }
     if (!BinOp || !BinOp->hasOneUse()) {
-      LLVM_DEBUG(dbgs() << "Root was not a one-use binary operator, cannot create a "
-                           "partial reduction.\n");
+      LLVM_DEBUG(
+          dbgs() << "Root was not a one-use binary operator, cannot create a "
+                    "partial reduction.\n");
       return;
     }
     if (!Phi) {
@@ -1608,7 +1610,8 @@ class LoopVectorizationCostModel {
     }
 
     auto IsSextOrZext = [](Instruction *I) {
-        return I && (I->getOpcode() == Instruction::ZExt || I->getOpcode() == Instruction::SExt);
+      return I && (I->getOpcode() == Instruction::ZExt ||
+                   I->getOpcode() == Instruction::SExt);
     };
 
     auto *ExtA = dyn_cast<Instruction>(BinOp->getOperand(0));
@@ -1616,7 +1619,7 @@ class LoopVectorizationCostModel {
     if (!IsSextOrZext(ExtA) || !IsSextOrZext(ExtB)) {
       LLVM_DEBUG(dbgs() << "Expected extends were not extends, cannot create a "
                            "partial reduction.\n");
-        return;
+      return;
     }
 
     Value *A = ExtA->getOperand(0);
@@ -7156,27 +7159,30 @@ void LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
     CM.addPartialReductionIfSupported(ReductionExitInstr, UserVF);
   }
 
-  // Wider-than-legal vector types (coming from extends in partial reductions) should only be used by partial reductions so that they are lowered properly
+  // Wider-than-legal vector types (coming from extends in partial reductions)
+  // should only be used by partial reductions so that they are lowered properly
 
   // Build up a set of partial reduction bin ops for efficient use checking
   SmallSet<Instruction *, 4> PartialReductionBinOps;
   for (auto It : CM.getPartialReductionChains()) {
-    if (It.second.BinOp) PartialReductionBinOps.insert(It.second.BinOp);
+    if (It.second.BinOp)
+      PartialReductionBinOps.insert(It.second.BinOp);
   }
 
-  auto ExtendIsOnlyUsedByPartialReductions = [PartialReductionBinOps](Instruction *Extend) {
-    for (auto *Use : Extend->users()) {
-      Instruction *UseInstr = dyn_cast<Instruction>(Use);
-      if (!PartialReductionBinOps.contains(UseInstr))
-        return false;
-    }
-    return true;
-  };
+  auto ExtendIsOnlyUsedByPartialReductions =
+      [PartialReductionBinOps](Instruction *Extend) {
+        for (auto *Use : Extend->users()) {
+          Instruction *UseInstr = dyn_cast<Instruction>(Use);
+          if (!PartialReductionBinOps.contains(UseInstr))
+            return false;
+        }
+        return true;
+      };
 
   // Check if each use of a chain's two extends is a partial reduction
   SmallVector<Instruction *, 2> ChainsToRemove;
   for (auto It : CM.getPartialReductionChains()) {
-      LoopVectorizationCostModel::PartialReductionChain Chain = It.second;
+    LoopVectorizationCostModel::PartialReductionChain Chain = It.second;
     if (!ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA))
       ChainsToRemove.push_back(Chain.Reduction);
     else if (!ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))

>From 9ca05776c45234a070c93f519fcb20f8802cd490 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 23 Sep 2024 14:22:42 +0100
Subject: [PATCH 52/54] Remove formatting changes from Function.cpp

---
 llvm/include/llvm/IR/Intrinsics.h | 6 +++---
 llvm/lib/IR/Function.cpp          | 6 +++---
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index cd76059ab65abe..4bd7fda77f3132 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -171,9 +171,9 @@ namespace Intrinsic {
     ArgKind getArgumentKind() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
              Kind == TruncArgument || Kind == HalfVecArgument ||
-             Kind == SameVecWidthArgument || Kind == VecElementArgument ||
-             Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
-             Kind == VecOfBitcastsToInt);
+             Kind == SameVecWidthArgument ||
+             Kind == VecElementArgument || Kind == Subdivide2Argument ||
+             Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
       return (ArgKind)(Argument_Info & 7);
     }
 
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index d30cb86525ea7f..8767c2971f62c8 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1482,8 +1482,8 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
     return VectorType::getSubdividedVectorType(VTy, SubDivs);
   }
   case IITDescriptor::HalfVecArgument:
-    return VectorType::getHalfElementsVectorType(
-        cast<VectorType>(Tys[D.getArgumentNumber()]));
+    return VectorType::getHalfElementsVectorType(cast<VectorType>(
+                                                  Tys[D.getArgumentNumber()]));
   case IITDescriptor::SameVecWidthArgument: {
     Type *EltTy = DecodeFixedType(Infos, Tys, Context);
     Type *Ty = Tys[D.getArgumentNumber()];
@@ -1714,7 +1714,7 @@ static bool matchIntrinsicType(
         return IsDeferredCheck || DeferCheck(Ty);
       return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
              VectorType::getHalfElementsVectorType(
-                 cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+                     cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
     case IITDescriptor::SameVecWidthArgument: {
       if (D.getArgumentNumber() >= ArgTys.size()) {
         // Defer check and subsequent check for the vector element type.

>From 342419050c8f962318c3049cbfff46ac6701f737 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 23 Sep 2024 14:28:13 +0100
Subject: [PATCH 53/54] Fix comment

---
 llvm/lib/Transforms/Vectorize/VPlan.h | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 3ca8f7d8d305a8..082f041a3a3984 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2140,9 +2140,10 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
   bool IsOrdered;
 
   /// The scaling difference between the size of the output of the entire
-  /// reduction and the size of the inputs When expanding the reduction PHI, the
-  /// plan's VF element count is divided by this factor to form the reduction
-  /// phi's VF.
+  /// reduction and the size of the input.
+
+  /// When expanding the reduction PHI, the plan's VF element count is divided
+  /// by this factor to form the reduction phi's VF.
   unsigned VFScaleFactor = 1;
 
 public:

>From 05c6caca9eba511d3b9e59d1913c41e38665de34 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 24 Sep 2024 14:09:52 +0100
Subject: [PATCH 54/54] Check VF for scalable-ness in getPartialReductionCost

---
 llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index bd9e4c11b9523c..c6e65bf2fd6337 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -354,10 +354,9 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     EVT InputEVT = EVT::getEVT(InputType);
     EVT AccumEVT = EVT::getEVT(AccumType);
 
-    if (AccumEVT.isScalableVector() && !ST->isSVEorStreamingSVEAvailable())
+    if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
       return Invalid;
-    if (AccumEVT.isFixedLengthVector() && !ST->isNeonAvailable() &&
-        !ST->hasDotProd())
+    if (VF.isFixed() && !ST->isNeonAvailable() && !ST->hasDotProd())
       return Invalid;
 
     if (InputEVT == MVT::i8) {



More information about the llvm-commits mailing list