[llvm] [VPlan] Manage FindLastIV start value in ComputeFindLastIVResult (NFC) (PR #132690)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 26 07:38:18 PDT 2025


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/132690

>From f26703c9bc0e62f646bf3d5bd53d539ae7741916 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sun, 23 Mar 2025 10:43:13 +0000
Subject: [PATCH 1/2] [VPlan] Manage FindLastIV start value in
 ComputeFindLastIVResult (NFC).

Keep the start value as operand of ComputeFindLastIVResult. A follow-up
patch will use this to make sure the start value is frozen if needed.
---
 llvm/include/llvm/Transforms/Utils/LoopUtils.h  |  2 +-
 llvm/lib/Transforms/Utils/LoopUtils.cpp         |  4 ++--
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 17 +++++++++++------
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp |  1 +
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp  | 12 +++++++-----
 llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp   |  2 +-
 6 files changed, 23 insertions(+), 15 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 193f505fb03fe..416a0a70325d1 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -423,7 +423,7 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src,
 /// Create a reduction of the given vector \p Src for a reduction of the
 /// kind RecurKind::IFindLastIV or RecurKind::FFindLastIV. The reduction
 /// operation is described by \p Desc.
-Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
+Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src, Value *Start,
                                  const RecurrenceDescriptor &Desc);
 
 /// Create an ordered reduction intrinsic using the given recurrence
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 2e7685254f512..f57d95e7722dc 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1233,11 +1233,11 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
 }
 
 Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
+                                       Value *Start,
                                        const RecurrenceDescriptor &Desc) {
   assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
              Desc.getRecurrenceKind()) &&
          "Unexpected reduction kind");
-  Value *StartVal = Desc.getRecurrenceStartValue();
   Value *Sentinel = Desc.getSentinelValue();
   Value *MaxRdx = Src->getType()->isVectorTy()
                       ? Builder.CreateIntMaxReduce(Src, true)
@@ -1246,7 +1246,7 @@ Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
   // reduction is sentinel value.
   Value *Cmp =
       Builder.CreateCmp(CmpInst::ICMP_NE, MaxRdx, Sentinel, "rdx.select.cmp");
-  return Builder.CreateSelect(Cmp, MaxRdx, StartVal, "rdx.select");
+  return Builder.CreateSelect(Cmp, MaxRdx, Start, "rdx.select");
 }
 
 Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 5244a5e7b1c41..fc3e7a4e3d10e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9866,14 +9866,19 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
     // bc.merge.rdx phi nodes, hence it needs to be created unconditionally here
     // even for in-loop reductions, until the reduction resume value handling is
     // also modeled in VPlan.
+    VPInstruction *FinalReductionResult;
     VPBuilder::InsertPointGuard Guard(Builder);
     Builder.setInsertPoint(MiddleVPBB, IP);
-    auto *FinalReductionResult =
-        Builder.createNaryOp(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
-                                 RdxDesc.getRecurrenceKind())
-                                 ? VPInstruction::ComputeFindLastIVResult
-                                 : VPInstruction::ComputeReductionResult,
-                             {PhiR, NewExitingVPV}, ExitDL);
+    if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(
+            RdxDesc.getRecurrenceKind())) {
+      VPValue *Start = PhiR->getStartValue();
+      FinalReductionResult =
+          Builder.createNaryOp(VPInstruction::ComputeFindLastIVResult,
+                               {PhiR, Start, NewExitingVPV}, ExitDL);
+    } else {
+      FinalReductionResult = Builder.createNaryOp(
+          VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
+    }
     // Update all users outside the vector region.
     OrigExitingVPV->replaceUsesWithIf(
         FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index d404ce46fae4a..24a166bd336d1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -51,6 +51,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
 
   switch (Opcode) {
   case Instruction::ExtractElement:
+  case Instruction::Freeze:
     return inferScalarType(R->getOperand(0));
   case Instruction::Select: {
     Type *ResTy = inferScalarType(R->getOperand(1));
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index efa238228f6c3..d92417c163a49 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -627,14 +627,15 @@ Value *VPInstruction::generate(VPTransformState &State) {
 
     // The recipe's operands are the reduction phi, followed by one operand for
     // each part of the reduction.
-    unsigned UF = getNumOperands() - 1;
-    Value *ReducedPartRdx = State.get(getOperand(1));
+    unsigned UF = getNumOperands() - 2;
+    Value *ReducedPartRdx = State.get(getOperand(2));
     for (unsigned Part = 1; Part < UF; ++Part) {
       ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
-                                      State.get(getOperand(1 + Part)));
+                                      State.get(getOperand(2 + Part)));
     }
 
-    return createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
+    return createFindLastIVReduction(Builder, ReducedPartRdx,
+                                     State.get(getOperand(1), true), RdxDesc);
   }
   case VPInstruction::ComputeReductionResult: {
     // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
@@ -951,6 +952,8 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
     return true;
   case VPInstruction::PtrAdd:
     return Op == getOperand(0) || vputils::onlyFirstLaneUsed(this);
+  case VPInstruction::ComputeFindLastIVResult:
+    return Op == getOperand(1);
   };
   llvm_unreachable("switch should return");
 }
@@ -1592,7 +1595,6 @@ void VPWidenRecipe::execute(VPTransformState &State) {
   }
   case Instruction::Freeze: {
     Value *Op = State.get(getOperand(0));
-
     Value *Freeze = Builder.CreateFreeze(Op);
     State.set(this, Freeze);
     break;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
index ad957f33ee699..a513a255344cc 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
@@ -350,7 +350,7 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
     if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
                       m_VPValue(), m_VPValue(Op1))) ||
         match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
-                      m_VPValue(), m_VPValue(Op1)))) {
+                      m_VPValue(), m_VPValue(), m_VPValue(Op1)))) {
       addUniformForAllParts(cast<VPInstruction>(&R));
       for (unsigned Part = 1; Part != UF; ++Part)
         R.addOperand(getValueForPart(Op1, Part));

>From da91908d5ff531e1ac6a75ee56b9183717c3eda6 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sun, 23 Mar 2025 11:08:52 +0000
Subject: [PATCH 2/2] !fixup add missing ternary VPInstruction matcher, update
 test.

---
 .../Transforms/Vectorize/VPlanPatternMatch.h    | 17 +++++++++++++++++
 .../LoopVectorize/vplan-printing-reductions.ll  |  2 +-
 2 files changed, 18 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 8c11d93734667..3a727866a2875 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -216,6 +216,16 @@ using BinaryVPInstruction_match =
     BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false,
                        VPInstruction>;
 
+template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode,
+          bool Commutative, typename... RecipeTys>
+using TernaryRecipe_match = Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>,
+                                         Opcode, Commutative, RecipeTys...>;
+
+template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode>
+using TernaryVPInstruction_match =
+    TernaryRecipe_match<Op0_t, Op1_t, Op2_t, Opcode, /*Commutative*/ false,
+                        VPInstruction>;
+
 template <typename Op0_t, typename Op1_t, unsigned Opcode,
           bool Commutative = false>
 using AllBinaryRecipe_match =
@@ -234,6 +244,13 @@ m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) {
   return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
 }
 
+template <unsigned Opcode, typename Op0_t, typename Op1_t, typename Op2_t>
+inline TernaryVPInstruction_match<Op0_t, Op1_t, Op2_t, Opcode>
+m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
+  return TernaryVPInstruction_match<Op0_t, Op1_t, Op2_t, Opcode>(
+      {Op0, Op1, Op2});
+}
+
 template <typename Op0_t>
 inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not>
 m_Not(const Op0_t &Op0) {
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
index b357be63a49cd..11b4efb08bb2e 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
@@ -234,7 +234,7 @@ define i64 @find_last_iv(ptr %a, i64 %n, i64 %start) {
 ; CHECK-NEXT: Successor(s): middle.block
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
-; CHECK-NEXT:   EMIT vp<[[RDX_RES:%.+]]> = compute-find-last-iv-result ir<%rdx>, ir<%cond>
+; CHECK-NEXT:   EMIT vp<[[RDX_RES:%.+]]> = compute-find-last-iv-result ir<%rdx>, ir<%start>, ir<%cond>
 ; CHECK-NEXT:   EMIT vp<[[EXT:%.+]]> = extract-from-end vp<[[RDX_RES]]>, ir<1>
 ; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<{{.+}}>
 ; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>



More information about the llvm-commits mailing list