[llvm] [VPlan] Add ComputeAnyOfResult VPInstruction (NFC) (PR #141932)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sun Jun 1 14:16:52 PDT 2025


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/141932

>From b709387276d5b994484ea9b5c4fe9343c3de859f Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 28 May 2025 13:59:24 +0100
Subject: [PATCH] [VPlan] Add start VPV to compute-reduction-result.

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 56 ++++++++++++++-----
 llvm/lib/Transforms/Vectorize/VPlan.h         |  1 +
 .../Transforms/Vectorize/VPlanAnalysis.cpp    |  2 +
 .../Transforms/Vectorize/VPlanPatternMatch.h  |  6 ++
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 30 +++++++---
 llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp |  4 +-
 6 files changed, 76 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e9ace195684b3..b8b0e13e1ae61 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7485,6 +7485,13 @@ static void addRuntimeUnrollDisableMetaData(Loop *L) {
   }
 }
 
+static Value *getStartValueFromReductionResult(VPInstruction *RdxResult) {
+  using namespace VPlanPatternMatch;
+  VPValue *StartVPV = RdxResult->getOperand(1);
+  match(StartVPV, m_Freeze(m_VPValue(StartVPV)));
+  return StartVPV->getLiveInIRValue();
+}
+
 // If \p R is a ComputeReductionResult when vectorizing the epilog loop,
 // fix the reduction's scalar PHI node by adding the incoming value from the
 // main vector loop.
@@ -7493,7 +7500,8 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
     BasicBlock *BypassBlock) {
   auto *EpiRedResult = dyn_cast<VPInstruction>(R);
   if (!EpiRedResult ||
-      (EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
+      (EpiRedResult->getOpcode() != VPInstruction::ComputeAnyOfResult &&
+       EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
        EpiRedResult->getOpcode() != VPInstruction::ComputeFindLastIVResult))
     return;
 
@@ -7505,15 +7513,19 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
       EpiRedHeaderPhi->getStartValue()->getUnderlyingValue();
   if (RecurrenceDescriptor::isAnyOfRecurrenceKind(
           RdxDesc.getRecurrenceKind())) {
+    Value *StartV = EpiRedResult->getOperand(1)->getLiveInIRValue();
+    (void)StartV;
     auto *Cmp = cast<ICmpInst>(MainResumeValue);
     assert(Cmp->getPredicate() == CmpInst::ICMP_NE &&
            "AnyOf expected to start with ICMP_NE");
-    assert(Cmp->getOperand(1) == RdxDesc.getRecurrenceStartValue() &&
+    assert(Cmp->getOperand(1) == StartV &&
            "AnyOf expected to start by comparing main resume value to original "
            "start value");
     MainResumeValue = Cmp->getOperand(0);
   } else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(
                  RdxDesc.getRecurrenceKind())) {
+    Value *StartV = getStartValueFromReductionResult(EpiRedResult);
+    (void)StartV;
     using namespace llvm::PatternMatch;
     Value *Cmp, *OrigResumeV, *CmpOp;
     bool IsExpectedPattern =
@@ -7522,10 +7534,7 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
                                         m_Value(OrigResumeV))) &&
         (match(Cmp, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(OrigResumeV),
                                    m_Value(CmpOp))) &&
-         (match(CmpOp,
-                m_Freeze(m_Specific(RdxDesc.getRecurrenceStartValue()))) ||
-          (CmpOp == RdxDesc.getRecurrenceStartValue() &&
-           isGuaranteedNotToBeUndefOrPoison(CmpOp))));
+         ((CmpOp == StartV && isGuaranteedNotToBeUndefOrPoison(CmpOp))));
     assert(IsExpectedPattern && "Unexpected reduction resume pattern");
     (void)IsExpectedPattern;
     MainResumeValue = OrigResumeV;
@@ -9460,7 +9469,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       OrigExitingVPV->replaceUsesWithIf(NewExitingVPV, [](VPUser &U, unsigned) {
         return isa<VPInstruction>(&U) &&
                (cast<VPInstruction>(&U)->getOpcode() ==
+                    VPInstruction::ComputeAnyOfResult ||
+                cast<VPInstruction>(&U)->getOpcode() ==
                     VPInstruction::ComputeReductionResult ||
+
                 cast<VPInstruction>(&U)->getOpcode() ==
                     VPInstruction::ComputeFindLastIVResult);
       });
@@ -9512,6 +9524,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       FinalReductionResult =
           Builder.createNaryOp(VPInstruction::ComputeFindLastIVResult,
                                {PhiR, Start, NewExitingVPV}, ExitDL);
+    } else if (RecurrenceDescriptor::isAnyOfRecurrenceKind(
+                   RdxDesc.getRecurrenceKind())) {
+      VPValue *Start = PhiR->getStartValue();
+      FinalReductionResult =
+          Builder.createNaryOp(VPInstruction::ComputeAnyOfResult,
+                               {PhiR, Start, NewExitingVPV}, ExitDL);
     } else {
       VPIRFlags Flags = RecurrenceDescriptor::isFloatingPointRecurrenceKind(
                             RdxDesc.getRecurrenceKind())
@@ -10040,23 +10058,36 @@ preparePlanForEpilogueVectorLoop(VPlan &Plan, Loop *L,
     Value *ResumeV = nullptr;
     // TODO: Move setting of resume values to prepareToExecute.
     if (auto *ReductionPhi = dyn_cast<VPReductionPHIRecipe>(&R)) {
+      auto *RdxResult =
+          cast<VPInstruction>(*find_if(ReductionPhi->users(), [](VPUser *U) {
+            auto *VPI = dyn_cast<VPInstruction>(U);
+            return VPI &&
+                   (VPI->getOpcode() == VPInstruction::ComputeReductionResult ||
+                    VPI->getOpcode() == VPInstruction::ComputeFindLastIVResult);
+          }));
       ResumeV = cast<PHINode>(ReductionPhi->getUnderlyingInstr())
                     ->getIncomingValueForBlock(L->getLoopPreheader());
       const RecurrenceDescriptor &RdxDesc =
           ReductionPhi->getRecurrenceDescriptor();
       RecurKind RK = RdxDesc.getRecurrenceKind();
       if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) {
+        Value *StartV = RdxResult->getOperand(1)->getLiveInIRValue();
+        assert(RdxDesc.getRecurrenceStartValue() == StartV &&
+               "start value from ComputeAnyOfResult must match");
+
         // VPReductionPHIRecipes for AnyOf reductions expect a boolean as
         // start value; compare the final value from the main vector loop
         // to the start value.
         BasicBlock *PBB = cast<Instruction>(ResumeV)->getParent();
         IRBuilder<> Builder(PBB, PBB->getFirstNonPHIIt());
-        ResumeV =
-            Builder.CreateICmpNE(ResumeV, RdxDesc.getRecurrenceStartValue());
+        ResumeV = Builder.CreateICmpNE(ResumeV, StartV);
       } else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) {
-        ToFrozen[RdxDesc.getRecurrenceStartValue()] =
-            cast<PHINode>(ResumeV)->getIncomingValueForBlock(
-                EPI.MainLoopIterationCountCheck);
+        Value *StartV = getStartValueFromReductionResult(RdxResult);
+        assert(RdxDesc.getRecurrenceStartValue() == StartV &&
+               "start value from ComputeFindLastIVResult must match");
+
+        ToFrozen[StartV] = cast<PHINode>(ResumeV)->getIncomingValueForBlock(
+            EPI.MainLoopIterationCountCheck);
 
         // VPReductionPHIRecipe for FindLastIV reductions requires an adjustment
         // to the resume value. The resume value is adjusted to the sentinel
@@ -10066,8 +10097,7 @@ preparePlanForEpilogueVectorLoop(VPlan &Plan, Loop *L,
         // variable.
         BasicBlock *ResumeBB = cast<Instruction>(ResumeV)->getParent();
         IRBuilder<> Builder(ResumeBB, ResumeBB->getFirstNonPHIIt());
-        Value *Cmp = Builder.CreateICmpEQ(
-            ResumeV, ToFrozen[RdxDesc.getRecurrenceStartValue()]);
+        Value *Cmp = Builder.CreateICmpEQ(ResumeV, ToFrozen[StartV]);
         ResumeV =
             Builder.CreateSelect(Cmp, RdxDesc.getSentinelValue(), ResumeV);
       }
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 44f0b6d964a6e..273df55188c16 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -907,6 +907,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
     BranchOnCount,
     BranchOnCond,
     Broadcast,
+    ComputeAnyOfResult,
     ComputeFindLastIVResult,
     ComputeReductionResult,
     // Extracts the last lane from its operand if it is a vector, or the last
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 926490bfad7d0..37adf03328fc8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -87,6 +87,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
                inferScalarType(R->getOperand(1)) &&
            "different types inferred for different operands");
     return IntegerType::get(Ctx, 1);
+  case VPInstruction::ComputeAnyOfResult:
+    return inferScalarType(R->getOperand(1));
   case VPInstruction::ComputeFindLastIVResult:
   case VPInstruction::ComputeReductionResult: {
     auto *PhiR = cast<VPReductionPHIRecipe>(R->getOperand(0));
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index f2a7f16e19a79..dfd9fc3d4d719 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -318,6 +318,12 @@ m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
       {Op0, Op1, Op2});
 }
 
+template <typename Op0_t>
+inline UnaryVPInstruction_match<Op0_t, Instruction::Freeze>
+m_Freeze(const Op0_t &Op0) {
+  return m_VPInstruction<Instruction::Freeze>(Op0);
+}
+
 template <typename Op0_t>
 inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not>
 m_Not(const Op0_t &Op0) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index a4831ea7c11f7..2aa5dd1b48c00 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -604,6 +604,20 @@ Value *VPInstruction::generate(VPTransformState &State) {
     return Builder.CreateVectorSplat(
         State.VF, State.get(getOperand(0), /*IsScalar*/ true), "broadcast");
   }
+  case VPInstruction::ComputeAnyOfResult: {
+    // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
+    // and will be removed by breaking up the recipe further.
+    auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
+    auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
+    Value *ReducedPartRdx = State.get(getOperand(2));
+    for (unsigned Idx = 3; Idx < getNumOperands(); ++Idx)
+      ReducedPartRdx = Builder.CreateBinOp(
+          (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(
+              RecurKind::AnyOf),
+          State.get(getOperand(Idx)), ReducedPartRdx, "bin.rdx");
+    return createAnyOfReduction(Builder, ReducedPartRdx,
+                                State.get(getOperand(1), VPLane(0)), OrigPhi);
+  }
   case VPInstruction::ComputeFindLastIVResult: {
     // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
     // and will be removed by breaking up the recipe further.
@@ -681,18 +695,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
 
     // Create the reduction after the loop. Note that inloop reductions create
     // the target reduction in the loop using a Reduction recipe.
-    if ((State.VF.isVector() ||
-         RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) &&
-        !PhiR->isInLoop()) {
+    if (State.VF.isVector() && !PhiR->isInLoop()) {
       // TODO: Support in-order reductions based on the recurrence descriptor.
       // All ops in the reduction inherit fast-math-flags from the recurrence
       // descriptor.
-      if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
-        ReducedPartRdx =
-            createAnyOfReduction(Builder, ReducedPartRdx,
-                                 RdxDesc.getRecurrenceStartValue(), OrigPhi);
-      else
-        ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
+      ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
 
       // If the reduction can be performed in a smaller type, we need to extend
       // the reduction to the wider type before we branch to the original loop.
@@ -830,6 +837,7 @@ bool VPInstruction::isVectorToScalar() const {
          getOpcode() == VPInstruction::ExtractPenultimateElement ||
          getOpcode() == Instruction::ExtractElement ||
          getOpcode() == VPInstruction::FirstActiveLane ||
+         getOpcode() == VPInstruction::ComputeAnyOfResult ||
          getOpcode() == VPInstruction::ComputeFindLastIVResult ||
          getOpcode() == VPInstruction::ComputeReductionResult ||
          getOpcode() == VPInstruction::AnyOf;
@@ -925,6 +933,7 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
     return true;
   case VPInstruction::PtrAdd:
     return Op == getOperand(0) || vputils::onlyFirstLaneUsed(this);
+  case VPInstruction::ComputeAnyOfResult:
   case VPInstruction::ComputeFindLastIVResult:
     return Op == getOperand(1);
   };
@@ -1005,6 +1014,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
   case VPInstruction::ExtractPenultimateElement:
     O << "extract-penultimate-element";
     break;
+  case VPInstruction::ComputeAnyOfResult:
+    O << "compute-anyof-result";
+    break;
   case VPInstruction::ComputeFindLastIVResult:
     O << "compute-find-last-iv-result";
     break;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
index e1fb3d476c58d..335301a927ceb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
@@ -327,7 +327,9 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
     // Add all VPValues for all parts to ComputeReductionResult which combines
     // the parts to compute the final reduction value.
     VPValue *Op1;
-    if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
+    if (match(&R, m_VPInstruction<VPInstruction::ComputeAnyOfResult>(
+                      m_VPValue(), m_VPValue(), m_VPValue(Op1))) ||
+        match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
                       m_VPValue(), m_VPValue(Op1))) ||
         match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
                       m_VPValue(), m_VPValue(), m_VPValue(Op1)))) {



More information about the llvm-commits mailing list