[llvm] [LoopVectorizer] Add support for partial reductions (PR #92418)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 16 09:00:13 PDT 2024


https://github.com/NickGuy-Arm created https://github.com/llvm/llvm-project/pull/92418

This patch adds to the loop vectorizer support for partial reductions; that is a reduction from a wider vector to a narrower vector.

>From 1640de07bf8facb2b5284d88481a96cd026e3613 Mon Sep 17 00:00:00 2001
From: Nicholas Guy <nicholas.guy at arm.com>
Date: Fri, 3 May 2024 16:53:22 +0100
Subject: [PATCH] [LoopVectorizer] Add support for partial reductions

---
 llvm/include/llvm/IR/DerivedTypes.h           |  10 ++
 llvm/include/llvm/IR/Intrinsics.h             |   5 +-
 llvm/include/llvm/IR/Intrinsics.td            |  10 ++
 llvm/lib/IR/Function.cpp                      |  16 +++
 .../Transforms/Vectorize/LoopVectorize.cpp    | 122 ++++++++++++++++++
 .../Transforms/Vectorize/VPRecipeBuilder.h    |   2 +
 llvm/lib/Transforms/Vectorize/VPlan.h         |  43 +++++-
 .../Transforms/Vectorize/VPlanAnalysis.cpp    |   6 +-
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h |   2 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  80 +++++++++++-
 llvm/lib/Transforms/Vectorize/VPlanValue.h    |   1 +
 .../CodeGen/AArch64/partial-reduce-sdot.ll    | 100 ++++++++++++++
 12 files changed, 387 insertions(+), 10 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll

diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index 443fb7de3b821..866a01c9afebd 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -512,6 +512,16 @@ class VectorType : public Type {
                            EltCnt.divideCoefficientBy(2));
   }
 
+  /// This static method returns a VectorType with quarter as many elements as the
+  /// input type and the same element type.
+  static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
+    auto EltCnt = VTy->getElementCount();
+    assert(EltCnt.isKnownEven() &&
+           "Cannot halve vector with odd number of elements.");
+    return VectorType::get(VTy->getElementType(),
+                           EltCnt.divideCoefficientBy(4));
+  }
+
   /// This static method returns a VectorType with twice as many elements as the
   /// input type and the same element type.
   static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index 340c1c326d066..e03e7e0bf50de 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -131,6 +131,7 @@ namespace Intrinsic {
       ExtendArgument,
       TruncArgument,
       HalfVecArgument,
+      QuarterVecArgument,
       SameVecWidthArgument,
       VecOfAnyPtrsToElt,
       VecElementArgument,
@@ -160,7 +161,7 @@ namespace Intrinsic {
 
     unsigned getArgumentNumber() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument || Kind == VecElementArgument ||
              Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
              Kind == VecOfBitcastsToInt);
@@ -168,7 +169,7 @@ namespace Intrinsic {
     }
     ArgKind getArgumentKind() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument ||
              Kind == VecElementArgument || Kind == Subdivide2Argument ||
              Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 1d20f7e1b1985..dad177e595341 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -321,6 +321,7 @@ def IIT_I4 : IIT_Int<4, 58>;
 def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
 def IIT_V6 : IIT_Vec<6, 60>;
 def IIT_V10 : IIT_Vec<10, 61>;
+def IIT_QUARTER_VEC_ARG : IIT_Base<62>;
 }
 
 defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
@@ -457,6 +458,9 @@ class LLVMVectorElementType<int num> : LLVMMatchType<num, IIT_VEC_ELEMENT>;
 class LLVMHalfElementsVectorType<int num>
   : LLVMMatchType<num, IIT_HALF_VEC_ARG>;
 
+class LLVMQuarterElementsVectorType<int num>
+  : LLVMMatchType<num, IIT_QUARTER_VEC_ARG>;
+
 // Match the type of another intrinsic parameter that is expected to be a
 // vector type (i.e. <N x iM>) but with each element subdivided to
 // form a vector with more elements that are smaller than the original.
@@ -2605,6 +2609,12 @@ def int_experimental_vector_deinterleave2 : DefaultAttrsIntrinsic<[LLVMHalfEleme
                                                                   [llvm_anyvector_ty],
                                                                   [IntrNoMem]>;
 
+//===-------------- Intrinsics to perform partial reduction ---------------===//
+
+def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMQuarterElementsVectorType<0>],
+                                                                       [llvm_anyvector_ty],
+                                                                       [IntrNoMem]>;
+
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
 
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index e66fe73425e86..e9eebd5e35300 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1240,6 +1240,12 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
                                              ArgInfo));
     return;
   }
+  case IIT_QUARTER_VEC_ARG: {
+    unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
+    OutputTable.push_back(IITDescriptor::get(IITDescriptor::QuarterVecArgument,
+                                             ArgInfo));
+    return;
+  }
   case IIT_SAME_VEC_WIDTH_ARG: {
     unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
     OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument,
@@ -1404,6 +1410,9 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
   case IITDescriptor::HalfVecArgument:
     return VectorType::getHalfElementsVectorType(cast<VectorType>(
                                                   Tys[D.getArgumentNumber()]));
+  case IITDescriptor::QuarterVecArgument:  {
+    return VectorType::getQuarterElementsVectorType(cast<VectorType>(Tys[D.getArgumentNumber()]));
+  }
   case IITDescriptor::SameVecWidthArgument: {
     Type *EltTy = DecodeFixedType(Infos, Tys, Context);
     Type *Ty = Tys[D.getArgumentNumber()];
@@ -1619,6 +1628,13 @@ static bool matchIntrinsicType(
       return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
              VectorType::getHalfElementsVectorType(
                      cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    case IITDescriptor::QuarterVecArgument: {
+    if (D.getArgumentNumber() >= ArgTys.size())
+        return IsDeferredCheck || DeferCheck(Ty);
+      return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
+             VectorType::getQuarterElementsVectorType(
+                     cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    }
     case IITDescriptor::SameVecWidthArgument: {
       if (D.getArgumentNumber() >= ArgTys.size()) {
         // Defer check and subsequent check for the vector element type.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 33c4decd58a6c..1f37df061bbf7 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2203,6 +2203,92 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
          Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
 }
 
+static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  Chain.push_back(Mul);
+  Chain.push_back(Ext0);
+  Chain.push_back(Ext1);
+  Chain.push_back(Instr->getOperand(1));
+}
+
+
+/// @param Instr The root instruction to scan
+static bool isInstrPartialReduction(Instruction *Instr) {
+  Value *ExpectedPhi;
+  Value *A, *B;
+  Value *InductionA, *InductionB;
+
+  using namespace llvm::PatternMatch;
+  auto Pattern = m_Add(
+    m_OneUse(m_Mul(
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(A),
+              m_Value(InductionA)))))),
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(B),
+              m_Value(InductionB))))))
+        )), m_Value(ExpectedPhi));
+
+  bool Matches = match(Instr, Pattern);
+
+  if(!Matches)
+    return false;
+
+  // Check that the two induction variable uses are to the same induction variable
+  if(InductionA != InductionB) {
+    LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  // Check that the extends extend to i32
+  if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the loads are loading i8
+  LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
+  LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
+  if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
+    LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
+    return false;
+  }
+
+  // Check that the add feeds into ExpectedPhi
+  PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
+  if(!PhiNode) {
+    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the first phi value is a zero initializer
+  ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
+  if(!ZeroInit || !ZeroInit->isZero()) {
+    LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the second phi value is the instruction we're looking at
+  Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
+  if(!MaybeAdd || MaybeAdd != Instr) {
+    LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  return true;
+}
+
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
 // vectorization. The loop needs to be annotated with #pragma omp simd
 // simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -5084,6 +5170,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
         return false;
   }
 
+  // Prevent epilogue vectorization if a partial reduction is involved
+  // TODO Is there a cleaner way to check this?
+  if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
+    return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
+  }))
+    return false;
+
   // Epilogue vectorization code has not been auditted to ensure it handles
   // non-latch exits properly.  It may be fine, but it needs auditted and
   // tested.
@@ -7182,6 +7275,17 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
     const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts();
     VecValuesToIgnore.insert(Casts.begin(), Casts.end());
   }
+
+  // Ignore any values that we know will be flattened
+  for(auto Reduction : this->Legal->getReductionVars()) {
+    auto &Recurrence = Reduction.second;
+    if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
+      SmallVector<Value*, 4> PartialReductionValues;
+      getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
+      ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+      VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+    }
+  }
 }
 
 void LoopVectorizationCostModel::collectInLoopReductions() {
@@ -8536,9 +8640,24 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
                                  *CI);
   }
 
+  if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
+    return PartialReduce;
+
   return tryToWiden(Instr, Operands, VPBB);
 }
 
+VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
+    VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
+
+  if(isInstrPartialReduction(Instr)) {
+    auto EC = ElementCount::getScalable(16);
+    if(std::find(Range.begin(), Range.end(), EC) == Range.end())
+      return nullptr;
+    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
+  }
+  return nullptr;
+}
+
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
                                                         ElementCount MaxVF) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -8746,6 +8865,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         VPBB->appendRecipe(Recipe);
     }
 
+    for(auto &Recipe : *VPBB)
+      Recipe.postInsertionOp();
+
     VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
     VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor());
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index b4c7ab02f928f..c439f221709e1 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -116,6 +116,8 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
+  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, Instruction* Instr, ArrayRef<VPValue*> Operands);
+
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
     assert(!Ingredient2Recipe.contains(I) &&
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c74329a0bcc4a..5a572ecb798d6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -767,6 +767,8 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
   /// \returns an iterator pointing to the element after the erased one
   iplist<VPRecipeBase>::iterator eraseFromParent();
 
+  virtual void postInsertionOp() {}
+
   /// Method to support type inquiry through isa, cast, and dyn_cast.
   static inline bool classof(const VPDef *D) {
     // All VPDefs are also VPRecipeBases.
@@ -1881,14 +1883,19 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   /// The phi is part of an ordered reduction. Requires IsInLoop to be true.
   bool IsOrdered;
 
+  /// The amount that the VF should be divided by during ::execute
+  unsigned VFScaleFactor = 1;
+
 public:
+
   /// Create a new VPReductionPHIRecipe for the reduction \p Phi described by \p
   /// RdxDesc.
   VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc,
                        VPValue &Start, bool IsInLoop = false,
-                       bool IsOrdered = false)
+                       bool IsOrdered = false, unsigned VFScaleFactor = 1)
       : VPHeaderPHIRecipe(VPDef::VPReductionPHISC, Phi, &Start),
-        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered) {
+        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered),
+        VFScaleFactor(VFScaleFactor) {
     assert((!IsOrdered || IsInLoop) && "IsOrdered requires IsInLoop");
   }
 
@@ -1897,7 +1904,7 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   VPReductionPHIRecipe *clone() override {
     auto *R =
         new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc,
-                                 *getOperand(0), IsInLoop, IsOrdered);
+                                 *getOperand(0), IsInLoop, IsOrdered, VFScaleFactor);
     R->addOperand(getBackedgeValue());
     return R;
   }
@@ -1908,6 +1915,10 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
     return R->getVPDefID() == VPDef::VPReductionPHISC;
   }
 
+  void SetVFScaleFactor(unsigned ScaleFactor) {
+    VFScaleFactor = ScaleFactor;
+  }
+
   /// Generate the phi/select nodes.
   void execute(VPTransformState &State) override;
 
@@ -1928,6 +1939,32 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   bool isInLoop() const { return IsInLoop; }
 };
 
+class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
+  unsigned Opcode;
+public:
+  template <typename IterT>
+  VPPartialReductionRecipe(Instruction &I,
+                           iterator_range<IterT> Operands) : VPRecipeWithIRFlags(
+    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode())
+  {}
+  ~VPPartialReductionRecipe() override = default;
+  VPPartialReductionRecipe *clone() override {
+    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands());
+    R->transferFlags(*this);
+    return R;
+  }
+  VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override;
+  void postInsertionOp() override;
+  unsigned getOpcode() { return Opcode; }
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+};
+
 /// A recipe for vectorizing a phi-node as a sequence of mask-based select
 /// instructions.
 class VPBlendRecipe : public VPSingleDefRecipe {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 5f93339083f0c..8a75668886599 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -208,6 +208,10 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
   llvm_unreachable("Unhandled opcode");
 }
 
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
+  return R->getUnderlyingInstr()->getType();
+}
+
 Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
   if (Type *CachedTy = CachedTypes.lookup(V))
     return CachedTy;
@@ -238,7 +242,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
             return inferScalarType(R->getOperand(0));
           })
           .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
-                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe>(
+                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe, VPPartialReductionRecipe>(
               [this](const auto *R) { return inferScalarTypeForRecipe(R); })
           .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
             // TODO: Use info from interleave group.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 7d310b1b31b6f..3bd8d24542199 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -23,6 +23,7 @@ class VPWidenIntOrFpInductionRecipe;
 class VPWidenMemoryRecipe;
 struct VPWidenSelectRecipe;
 class VPReplicateRecipe;
+class VPPartialReductionRecipe;
 class Type;
 
 /// An analysis for type-inference for VPValues.
@@ -49,6 +50,7 @@ class VPTypeAnalysis {
   Type *inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R);
   Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R);
   Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPPartialReductionRecipe *R);
 
 public:
   VPTypeAnalysis(Type *CanonicalIVTy, LLVMContext &Ctx)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 9ec422ec002c8..9aff5dd0a7771 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -245,6 +245,76 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB,
   insertBefore(BB, I);
 }
 
+void VPPartialReductionRecipe::execute(VPTransformState &State) {
+  State.setDebugLocFrom(getDebugLoc());
+  auto &Builder = State.Builder;
+
+  switch(Opcode) {
+  case Instruction::Add: {
+
+    for (unsigned Part = 0; Part < State.UF; ++Part) {
+      Value* Mul = nullptr;
+      Value* Phi = nullptr;
+      SmallVector<Value*, 2> Ops;
+      for (VPValue *VPOp : operands()) {
+        auto *Op = State.get(VPOp, Part);
+        Ops.push_back(Op);
+        if(isa<PHINode>(Op))
+          Phi = Op;
+        else
+          Mul = Op;
+      }
+
+      assert(Phi && Mul && "Phi and Mul must be set");
+      assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
+
+      ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
+      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), 4);
+
+      Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
+      switch(Opcode) {
+      case Instruction::Add:
+        PartialIntrinsic =
+            Intrinsic::experimental_vector_partial_reduce_add;
+        break;
+      default:
+        llvm_unreachable("Opcode not handled");
+      }
+
+      assert(PartialIntrinsic != Intrinsic::not_intrinsic);
+
+      Value *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, Mul, nullptr, Twine("partial.reduce"));
+      V = Builder.CreateNAryOp(Opcode, {V, Phi});
+      if (auto *VecOp = dyn_cast<Instruction>(V))
+        setFlags(VecOp);
+
+      // Use this vector value for all users of the original instruction.
+      State.set(this, V, Part);
+      State.addMetadata(V, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
+    }
+    break;
+  }
+  default:
+    LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : " << Instruction::getOpcodeName(Opcode));
+    llvm_unreachable("Unhandled instruction!");
+  }
+}
+
+void VPPartialReductionRecipe::postInsertionOp() {
+  cast<VPReductionPHIRecipe>(this->getOperand(1)).SetVFScaleFactor(4);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+  VPSlotTracker &SlotTracker) const {
+  O << Indent << "PARTIAL-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = " << Instruction::getOpcodeName(Opcode);
+  printFlags(O);
+  printOperands(O, SlotTracker);
+}
+#endif
+
 FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const {
   assert(OpType == OperationType::FPMathOp &&
          "recipe doesn't have fast math flags");
@@ -1953,6 +2023,8 @@ void VPFirstOrderRecurrencePHIRecipe::print(raw_ostream &O, const Twine &Indent,
 void VPReductionPHIRecipe::execute(VPTransformState &State) {
   auto &Builder = State.Builder;
 
+  auto VF = State.VF.divideCoefficientBy(VFScaleFactor);
+
   // Reductions do not have to start at zero. They can start with
   // any loop invariant values.
   VPValue *StartVPV = getStartValue();
@@ -1962,9 +2034,9 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
   // Phi nodes have cycles, so we need to vectorize them in two stages. This is
   // stage #1: We create a new vector PHI node with no incoming edges. We'll use
   // this value when we vectorize all of the instructions that use the PHI.
-  bool ScalarPHI = State.VF.isScalar() || IsInLoop;
+  bool ScalarPHI = VF.isScalar() || IsInLoop;
   Type *VecTy = ScalarPHI ? StartV->getType()
-                          : VectorType::get(StartV->getType(), State.VF);
+                          : VectorType::get(StartV->getType(), VF);
 
   BasicBlock *HeaderBB = State.CFG.PrevBB;
   assert(State.CurrentVectorLoop->getHeader() == HeaderBB &&
@@ -1989,14 +2061,14 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
       IRBuilderBase::InsertPointGuard IPBuilder(Builder);
       Builder.SetInsertPoint(VectorPH->getTerminator());
       StartV = Iden =
-          Builder.CreateVectorSplat(State.VF, StartV, "minmax.ident");
+          Builder.CreateVectorSplat(VF, StartV, "minmax.ident");
     }
   } else {
     Iden = RdxDesc.getRecurrenceIdentity(RK, VecTy->getScalarType(),
                                          RdxDesc.getFastMathFlags());
 
     if (!ScalarPHI) {
-      Iden = Builder.CreateVectorSplat(State.VF, Iden);
+      Iden = Builder.CreateVectorSplat(VF, Iden);
       IRBuilderBase::InsertPointGuard IPBuilder(Builder);
       Builder.SetInsertPoint(VectorPH->getTerminator());
       Constant *Zero = Builder.getInt32(0);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 96d04271850f7..c971c9a48364e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -348,6 +348,7 @@ class VPDef {
     VPInstructionSC,
     VPInterleaveSC,
     VPReductionSC,
+    VPPartialReductionSC,
     VPReplicateSC,
     VPScalarCastSC,
     VPScalarIVStepsSC,
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
new file mode 100644
index 0000000000000..2931c43261b28
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll
@@ -0,0 +1,100 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes=loop-vectorize -force-vector-interleave=1 -S < %s | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define void @dotp(ptr %a, ptr %b) #0 {
+; CHECK-LABEL: define void @dotp(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 16
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 0, [[TMP1]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 16
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 0, [[TMP3]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 0, [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 16
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI1:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP31:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 16 x i8>, ptr [[TMP17]], align 1
+; CHECK-NEXT:    [[TMP19:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD2]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr i8, ptr [[TMP21]], i32 0
+; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
+; CHECK-NEXT:    [[TMP27:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD4]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP29:%.*]] = mul <vscale x 16 x i32> [[TMP27]], [[TMP19]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE5:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP29]])
+; CHECK-NEXT:    [[TMP31]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE5]], [[VEC_PHI1]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP32]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP33:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP31]])
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 0, [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
+; CHECK:       scalar.ph:
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.cond.cleanup.loopexit:
+; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i32 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[TMP20:%.*]] = lshr i32 [[ADD_LCSSA]], 0
+; CHECK-NEXT:    ret void
+; CHECK:       for.body:
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ], [ [[ADD]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP18:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
+; CHECK-NEXT:    [[CONV:%.*]] = zext i8 [[TMP18]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP22:%.*]] = load i8, ptr [[ARRAYIDX2]], align 1
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i8 [[TMP22]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[CONV3]], [[CONV]]
+; CHECK-NEXT:    [[ADD]] = add i32 [[MUL]], [[ACC_010]]
+; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add i64 [[INDVARS_IV]], 1
+; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], 0
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %0 = lshr i32 %add, 0
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.010 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr i8, ptr %a, i64 %indvars.iv
+  %1 = load i8, ptr %arrayidx, align 1
+  %conv = zext i8 %1 to i32
+  %arrayidx2 = getelementptr i8, ptr %b, i64 %indvars.iv
+  %2 = load i8, ptr %arrayidx2, align 1
+  %conv3 = zext i8 %2 to i32
+  %mul = mul i32 %conv3, %conv
+  %add = add i32 %mul, %acc.010
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+
+; uselistorder directives
+  uselistorder i32 %add, { 1, 0 }
+}
+
+attributes #0 = { "target-features"="+fp-armv8,+fullfp16,+neon,+sve,+sve2,+v8a" }
+;.
+; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
+; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
+; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
+; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
+;.



More information about the llvm-commits mailing list