[llvm] [VPlan] Manage FindLastIV start value in ComputeFindLastIVResult (NFC).Vplan find last iv startop (PR #132690)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 24 01:02:00 PDT 2025


https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/132690

Keep the start value as operand of ComputeFindLastIVResult. A follow-up
patch will use this to make sure the start value is frozen if needed.

Depends on https://github.com/llvm/llvm-project/pull/132689 (included in PR)

>From 3e4268330f59a26ccb8eac864b39fadeb09a84b6 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sun, 23 Mar 2025 10:41:11 +0000
Subject: [PATCH 1/2] [VPlan] Add ComputeFindLastIVResult opcode (NFC).

This moves the logic for computing the FindLastIV reduction result to
its own opcode. A follow-up patch will update the new opcode to also
take the start value, to fix
https://github.com/llvm/llvm-project/issues/126836.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 17 ++++++---
 llvm/lib/Transforms/Vectorize/VPlan.h         |  1 +
 .../Transforms/Vectorize/VPlanAnalysis.cpp    |  1 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 36 ++++++++++++++-----
 llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp |  2 ++
 5 files changed, 44 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 92160a421e59c..1168211e3d87b 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7612,7 +7612,8 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
     BasicBlock *BypassBlock) {
   auto *EpiRedResult = dyn_cast<VPInstruction>(R);
   if (!EpiRedResult ||
-      EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult)
+      (EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
+       EpiRedResult->getOpcode() != VPInstruction::ComputeFindLastIVResult))
     return;
 
   auto *EpiRedHeaderPhi =
@@ -9817,8 +9818,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
           Builder.createSelect(Cond, OrigExitingVPV, PhiR, {}, "", FMFs);
       OrigExitingVPV->replaceUsesWithIf(NewExitingVPV, [](VPUser &U, unsigned) {
         return isa<VPInstruction>(&U) &&
-               cast<VPInstruction>(&U)->getOpcode() ==
-                   VPInstruction::ComputeReductionResult;
+               (cast<VPInstruction>(&U)->getOpcode() ==
+                    VPInstruction::ComputeReductionResult ||
+                cast<VPInstruction>(&U)->getOpcode() ==
+                    VPInstruction::ComputeFindLastIVResult);
       });
       if (CM.usePredicatedReductionSelect(
               PhiR->getRecurrenceDescriptor().getOpcode(), PhiTy))
@@ -9863,8 +9866,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
     // also modeled in VPlan.
     VPBuilder::InsertPointGuard Guard(Builder);
     Builder.setInsertPoint(MiddleVPBB, IP);
-    auto *FinalReductionResult = Builder.createNaryOp(
-        VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
+    auto *FinalReductionResult =
+        Builder.createNaryOp(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
+                                 RdxDesc.getRecurrenceKind())
+                                 ? VPInstruction::ComputeFindLastIVResult
+                                 : VPInstruction::ComputeReductionResult,
+                             {PhiR, NewExitingVPV}, ExitDL);
     // Update all users outside the vector region.
     OrigExitingVPV->replaceUsesWithIf(
         FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 3059b87ae63c8..64e7f2bddb668 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -866,6 +866,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
     BranchOnCount,
     BranchOnCond,
     Broadcast,
+    ComputeFindLastIVResult,
     ComputeReductionResult,
     // Takes the VPValue to extract from as first operand and the lane or part
     // to extract as second operand, counting from the end starting with 1 for
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 38bec733dbf73..d404ce46fae4a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -66,6 +66,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
                inferScalarType(R->getOperand(1)) &&
            "different types inferred for different operands");
     return IntegerType::get(Ctx, 1);
+  case VPInstruction::ComputeFindLastIVResult:
   case VPInstruction::ComputeReductionResult: {
     auto *PhiR = cast<VPReductionPHIRecipe>(R->getOperand(0));
     auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index c7190b3187d94..2f1182399ee4a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -614,6 +614,27 @@ Value *VPInstruction::generate(VPTransformState &State) {
     return Builder.CreateVectorSplat(
         State.VF, State.get(getOperand(0), /*IsScalar*/ true), "broadcast");
   }
+  case VPInstruction::ComputeFindLastIVResult: {
+    // The recipe's operands are the reduction phi, followed by one operand for
+    // each part of the reduction.
+    unsigned UF = getNumOperands() - 1;
+    Value *ReducedPartRdx = State.get(getOperand(1));
+    for (unsigned Part = 1; Part < UF; ++Part) {
+      ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
+                                      State.get(getOperand(1 + Part)));
+    }
+
+    // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
+    // and will be removed by breaking up the recipe further.
+    auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
+    // Get its reduction variable descriptor.
+    const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
+    RecurKind RK = RdxDesc.getRecurrenceKind();
+
+    assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK));
+    assert(!PhiR->isInLoop());
+    return createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
+  }
   case VPInstruction::ComputeReductionResult: {
     // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
     // and will be removed by breaking up the recipe further.
@@ -623,6 +644,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
     const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
 
     RecurKind RK = RdxDesc.getRecurrenceKind();
+    assert(!RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
+           "should be handled by ComputeFindLastIVResult");
 
     Type *PhiTy = OrigPhi->getType();
     // The recipe's operands are the reduction phi, followed by one operand for
@@ -658,9 +681,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
         if (Op != Instruction::ICmp && Op != Instruction::FCmp)
           ReducedPartRdx = Builder.CreateBinOp(
               (Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx");
-        else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
-          ReducedPartRdx =
-              createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx, RdxPart);
         else
           ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
       }
@@ -669,8 +689,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
     // Create the reduction after the loop. Note that inloop reductions create
     // the target reduction in the loop using a Reduction recipe.
     if ((State.VF.isVector() ||
-         RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
-         RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
+         RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) &&
         !PhiR->isInLoop()) {
       // TODO: Support in-order reductions based on the recurrence descriptor.
       // All ops in the reduction inherit fast-math-flags from the recurrence
@@ -681,9 +700,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
       if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
         ReducedPartRdx =
             createAnyOfReduction(Builder, ReducedPartRdx, RdxDesc, OrigPhi);
-      else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
-        ReducedPartRdx =
-            createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
       else
         ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
 
@@ -829,6 +845,7 @@ bool VPInstruction::isVectorToScalar() const {
   return getOpcode() == VPInstruction::ExtractFromEnd ||
          getOpcode() == Instruction::ExtractElement ||
          getOpcode() == VPInstruction::FirstActiveLane ||
+         getOpcode() == VPInstruction::ComputeFindLastIVResult ||
          getOpcode() == VPInstruction::ComputeReductionResult ||
          getOpcode() == VPInstruction::AnyOf;
 }
@@ -1011,6 +1028,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
   case VPInstruction::ExtractFromEnd:
     O << "extract-from-end";
     break;
+  case VPInstruction::ComputeFindLastIVResult:
+    O << "compute-find-last-iv-result";
+    break;
   case VPInstruction::ComputeReductionResult:
     O << "compute-reduction-result";
     break;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
index a36c2aeb3da5c..ad957f33ee699 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
@@ -348,6 +348,8 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
     // the parts to compute the final reduction value.
     VPValue *Op1;
     if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
+                      m_VPValue(), m_VPValue(Op1))) ||
+        match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
                       m_VPValue(), m_VPValue(Op1)))) {
       addUniformForAllParts(cast<VPInstruction>(&R));
       for (unsigned Part = 1; Part != UF; ++Part)

>From c5511666e4a7591c2cce347adf94333ede7edf2b Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sun, 23 Mar 2025 10:43:13 +0000
Subject: [PATCH 2/2] [VPlan] Manage FindLastIV start value in
 ComputeFindLastIVResult (NFC).

Keep the start value as operand of ComputeFindLastIVResult. A follow-up
patch will use this to make sure the start value is frozen if needed.
---
 llvm/include/llvm/Transforms/Utils/LoopUtils.h  |  2 +-
 llvm/lib/Transforms/Utils/LoopUtils.cpp         |  4 ++--
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 17 +++++++++++------
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp |  1 +
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp  | 12 +++++++-----
 llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp   |  2 +-
 6 files changed, 23 insertions(+), 15 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 193f505fb03fe..416a0a70325d1 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -423,7 +423,7 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src,
 /// Create a reduction of the given vector \p Src for a reduction of the
 /// kind RecurKind::IFindLastIV or RecurKind::FFindLastIV. The reduction
 /// operation is described by \p Desc.
-Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
+Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src, Value *Start,
                                  const RecurrenceDescriptor &Desc);
 
 /// Create an ordered reduction intrinsic using the given recurrence
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 2e7685254f512..f57d95e7722dc 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1233,11 +1233,11 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
 }
 
 Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
+                                       Value *Start,
                                        const RecurrenceDescriptor &Desc) {
   assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
              Desc.getRecurrenceKind()) &&
          "Unexpected reduction kind");
-  Value *StartVal = Desc.getRecurrenceStartValue();
   Value *Sentinel = Desc.getSentinelValue();
   Value *MaxRdx = Src->getType()->isVectorTy()
                       ? Builder.CreateIntMaxReduce(Src, true)
@@ -1246,7 +1246,7 @@ Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
   // reduction is sentinel value.
   Value *Cmp =
       Builder.CreateCmp(CmpInst::ICMP_NE, MaxRdx, Sentinel, "rdx.select.cmp");
-  return Builder.CreateSelect(Cmp, MaxRdx, StartVal, "rdx.select");
+  return Builder.CreateSelect(Cmp, MaxRdx, Start, "rdx.select");
 }
 
 Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1168211e3d87b..b47b444e5cfbc 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9864,14 +9864,19 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
     // bc.merge.rdx phi nodes, hence it needs to be created unconditionally here
     // even for in-loop reductions, until the reduction resume value handling is
     // also modeled in VPlan.
+    VPInstruction *FinalReductionResult;
     VPBuilder::InsertPointGuard Guard(Builder);
     Builder.setInsertPoint(MiddleVPBB, IP);
-    auto *FinalReductionResult =
-        Builder.createNaryOp(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
-                                 RdxDesc.getRecurrenceKind())
-                                 ? VPInstruction::ComputeFindLastIVResult
-                                 : VPInstruction::ComputeReductionResult,
-                             {PhiR, NewExitingVPV}, ExitDL);
+    if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(
+            RdxDesc.getRecurrenceKind())) {
+      VPValue *Start = PhiR->getStartValue();
+      FinalReductionResult =
+          Builder.createNaryOp(VPInstruction::ComputeFindLastIVResult,
+                               {PhiR, Start, NewExitingVPV}, ExitDL);
+    } else {
+      FinalReductionResult = Builder.createNaryOp(
+          VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
+    }
     // Update all users outside the vector region.
     OrigExitingVPV->replaceUsesWithIf(
         FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index d404ce46fae4a..24a166bd336d1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -51,6 +51,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
 
   switch (Opcode) {
   case Instruction::ExtractElement:
+  case Instruction::Freeze:
     return inferScalarType(R->getOperand(0));
   case Instruction::Select: {
     Type *ResTy = inferScalarType(R->getOperand(1));
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 2f1182399ee4a..02ff3c5dff239 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -617,11 +617,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
   case VPInstruction::ComputeFindLastIVResult: {
     // The recipe's operands are the reduction phi, followed by one operand for
     // each part of the reduction.
-    unsigned UF = getNumOperands() - 1;
-    Value *ReducedPartRdx = State.get(getOperand(1));
+    unsigned UF = getNumOperands() - 2;
+    Value *ReducedPartRdx = State.get(getOperand(2));
     for (unsigned Part = 1; Part < UF; ++Part) {
       ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
-                                      State.get(getOperand(1 + Part)));
+                                      State.get(getOperand(2 + Part)));
     }
 
     // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
@@ -633,7 +633,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
 
     assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK));
     assert(!PhiR->isInLoop());
-    return createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
+    return createFindLastIVReduction(Builder, ReducedPartRdx,
+                                     State.get(getOperand(1), true), RdxDesc);
   }
   case VPInstruction::ComputeReductionResult: {
     // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
@@ -950,6 +951,8 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
     return true;
   case VPInstruction::PtrAdd:
     return Op == getOperand(0) || vputils::onlyFirstLaneUsed(this);
+  case VPInstruction::ComputeFindLastIVResult:
+    return Op == getOperand(1);
   };
   llvm_unreachable("switch should return");
 }
@@ -1591,7 +1594,6 @@ void VPWidenRecipe::execute(VPTransformState &State) {
   }
   case Instruction::Freeze: {
     Value *Op = State.get(getOperand(0));
-
     Value *Freeze = Builder.CreateFreeze(Op);
     State.set(this, Freeze);
     break;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
index ad957f33ee699..a513a255344cc 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
@@ -350,7 +350,7 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
     if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
                       m_VPValue(), m_VPValue(Op1))) ||
         match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
-                      m_VPValue(), m_VPValue(Op1)))) {
+                      m_VPValue(), m_VPValue(), m_VPValue(Op1)))) {
       addUniformForAllParts(cast<VPInstruction>(&R));
       for (unsigned Part = 1; Part != UF; ++Part)
         R.addOperand(getValueForPart(Op1, Part));



More information about the llvm-commits mailing list