[llvm] [VPlan] Model FOR extract of exit value in VPlan. (PR #93395)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 3 07:58:26 PDT 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/93395

>From c6067dce10c72373de8a07ac2eb92cb21c90aecb Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sat, 25 May 2024 22:45:52 +0100
Subject: [PATCH 1/5] [VPlan] Model FOR extract of exit value in VPlan.

This patch introduces a new ExtractRecurrenceResult VPInstruction opcode
to extract the value of a FOR for users outside the loop (i.e. in the
scalar loop's exits). This moves the first part of fixing first order
recurrences to VPlan, and removes some additional code to patch up
live-outs, which is now handled automatically.

The majority of test changes is due to changes in the order of which the
extracts are generated now. As we are now using VPTransformState to
generate the extracts, we may be able to re-use existing extracts in the
loop body in some cases. For scalable vectors, in some cases we now have
to compute the runtime VF twice, as each extract is now independent, but
those should be trivial to clean up for later passes (and in line with
other places in the code that also liberally re-compute runtime VFs).
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 38 -------------
 llvm/lib/Transforms/Vectorize/VPlan.h         |  8 ++-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 27 +++++++++
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  8 +++
 .../LoopVectorize/AArch64/induction-costs.ll  |  6 +-
 .../AArch64/loop-vectorization-factors.ll     |  2 +-
 .../first-order-recurrence-chains.ll          | 56 +++++++++----------
 .../LoopVectorize/first-order-recurrence.ll   | 24 ++++----
 .../scalable-first-order-recurrence.ll        | 12 ++--
 .../LoopVectorize/vplan-printing.ll           |  3 +-
 10 files changed, 93 insertions(+), 91 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 5939ce5b917aa..188bfc164f30a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -3536,44 +3536,6 @@ void InnerLoopVectorizer::fixFixedOrderRecurrence(
         Builder.CreateExtractElement(Incoming, LastIdx, "vector.recur.extract");
   }
 
-  auto RecurSplice = cast<VPInstruction>(*PhiR->user_begin());
-  assert(PhiR->getNumUsers() == 1 &&
-         RecurSplice->getOpcode() ==
-             VPInstruction::FirstOrderRecurrenceSplice &&
-         "recurrence phi must have a single user: FirstOrderRecurrenceSplice");
-  SmallVector<VPLiveOut *> LiveOuts;
-  for (VPUser *U : RecurSplice->users())
-    if (auto *LiveOut = dyn_cast<VPLiveOut>(U))
-      LiveOuts.push_back(LiveOut);
-
-  if (!LiveOuts.empty()) {
-    // Extract the second last element in the middle block if the
-    // Phi is used outside the loop. We need to extract the phi itself
-    // and not the last element (the phi update in the current iteration). This
-    // will be the value when jumping to the exit block from the
-    // LoopMiddleBlock, when the scalar loop is not run at all.
-    Value *ExtractForPhiUsedOutsideLoop = nullptr;
-    if (VF.isVector()) {
-      auto *Idx = Builder.CreateSub(RuntimeVF, ConstantInt::get(IdxTy, 2));
-      ExtractForPhiUsedOutsideLoop = Builder.CreateExtractElement(
-          Incoming, Idx, "vector.recur.extract.for.phi");
-    } else {
-      assert(UF > 1 && "VF and UF cannot both be 1");
-      // When loop is unrolled without vectorizing, initialize
-      // ExtractForPhiUsedOutsideLoop with the value just prior to unrolled
-      // value of `Incoming`. This is analogous to the vectorized case above:
-      // extracting the second last element when VF > 1.
-      ExtractForPhiUsedOutsideLoop = State.get(PreviousDef, UF - 2);
-    }
-
-    for (VPLiveOut *LiveOut : LiveOuts) {
-      assert(!Cost->requiresScalarEpilogue(VF.isVector()));
-      PHINode *LCSSAPhi = LiveOut->getPhi();
-      LCSSAPhi->addIncoming(ExtractForPhiUsedOutsideLoop, LoopMiddleBlock);
-      State.Plan->removeLiveOut(LCSSAPhi);
-    }
-  }
-
   // Fix the initial value of the original recurrence in the scalar loop.
   Builder.SetInsertPoint(LoopScalarPreHeader, LoopScalarPreHeader->begin());
   PHINode *Phi = cast<PHINode>(PhiR->getUnderlyingValue());
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index d6b63a5d43c46..659f28ec9d832 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -167,8 +167,8 @@ class VPLane {
 
   static VPLane getFirstLane() { return VPLane(0, VPLane::Kind::First); }
 
-  static VPLane getLastLaneForVF(const ElementCount &VF) {
-    unsigned LaneOffset = VF.getKnownMinValue() - 1;
+  static VPLane getLastLaneForVF(const ElementCount &VF, unsigned Offset = 1) {
+    unsigned LaneOffset = VF.getKnownMinValue() - Offset;
     Kind LaneKind;
     if (VF.isScalable())
       // In this case 'LaneOffset' refers to the offset from the start of the
@@ -1182,6 +1182,7 @@ class VPInstruction : public VPRecipeWithIRFlags {
     BranchOnCount,
     BranchOnCond,
     ComputeReductionResult,
+    ExtractRecurrenceResult,
     LogicalAnd, // Non-poison propagating logical And.
     // Add an offset in bytes (second operand) to a base pointer (first
     // operand). Only generates scalar values (either for the first lane only or
@@ -3657,7 +3658,8 @@ inline bool isUniformAfterVectorization(VPValue *VPV) {
   if (auto *GEP = dyn_cast<VPWidenGEPRecipe>(Def))
     return all_of(GEP->operands(), isUniformAfterVectorization);
   if (auto *VPI = dyn_cast<VPInstruction>(Def))
-    return VPI->getOpcode() == VPInstruction::ComputeReductionResult;
+    return VPI->getOpcode() == VPInstruction::ComputeReductionResult ||
+           VPI->getOpcode() == VPInstruction::ExtractRecurrenceResult;
   return false;
 }
 } // end namespace vputils
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 5eb99ffd1e10e..82e3bd530a8d6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -137,6 +137,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
     case VPInstruction::Not:
     case VPInstruction::CalculateTripCountMinusVF:
     case VPInstruction::CanonicalIVIncrementForPart:
+    case VPInstruction::ExtractRecurrenceResult:
     case VPInstruction::LogicalAnd:
     case VPInstruction::PtrAdd:
       return false;
@@ -300,6 +301,7 @@ bool VPInstruction::canGenerateScalarForFirstLane() const {
   case VPInstruction::CalculateTripCountMinusVF:
   case VPInstruction::CanonicalIVIncrementForPart:
   case VPInstruction::ComputeReductionResult:
+  case VPInstruction::ExtractRecurrenceResult:
   case VPInstruction::PtrAdd:
   case VPInstruction::ExplicitVectorLength:
     return true;
@@ -558,6 +560,27 @@ Value *VPInstruction::generatePerPart(VPTransformState &State, unsigned Part) {
 
     return ReducedPartRdx;
   }
+  case VPInstruction::ExtractRecurrenceResult: {
+    if (Part != 0)
+      return State.get(this, 0, /*IsScalar*/ true);
+
+    // Extract the second last element in the middle block for users outside the
+    // loop.
+    Value *Res;
+    if (State.VF.isVector()) {
+      Res = State.get(
+          getOperand(0),
+          VPIteration(State.UF - 1, VPLane::getLastLaneForVF(State.VF, 2)));
+    } else {
+      assert(State.UF > 1 && "VF and UF cannot both be 1");
+      // When loop is unrolled without vectorizing, retrieve the value just
+      // prior to the final unrolled value. This is analogous to the vectorized
+      // case above: extracting the second last element when VF > 1.
+      Res = State.get(getOperand(0), State.UF - 2);
+    }
+    Res->setName(Name);
+    return Res;
+  }
   case VPInstruction::LogicalAnd: {
     Value *A = State.get(getOperand(0), Part);
     Value *B = State.get(getOperand(1), Part);
@@ -598,6 +621,7 @@ void VPInstruction::execute(VPTransformState &State) {
   bool GeneratesPerFirstLaneOnly =
       canGenerateScalarForFirstLane() &&
       (vputils::onlyFirstLaneUsed(this) ||
+       getOpcode() == VPInstruction::ExtractRecurrenceResult ||
        getOpcode() == VPInstruction::ComputeReductionResult);
   bool GeneratesPerAllLanes = doesGeneratePerAllLanes();
   for (unsigned Part = 0; Part < State.UF; ++Part) {
@@ -692,6 +716,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
   case VPInstruction::BranchOnCount:
     O << "branch-on-count";
     break;
+  case VPInstruction::ExtractRecurrenceResult:
+    O << "extract-recurrence-result";
+    break;
   case VPInstruction::ComputeReductionResult:
     O << "compute-reduction-result";
     break;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 422579ea8b84f..f4f43ed15615e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -812,6 +812,8 @@ bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan,
     if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R))
       RecurrencePhis.push_back(FOR);
 
+  VPBuilder MiddleBuilder(
+      cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSingleSuccessor()));
   for (VPFirstOrderRecurrencePHIRecipe *FOR : RecurrencePhis) {
     SmallPtrSet<VPFirstOrderRecurrencePHIRecipe *, 4> SeenPhis;
     VPRecipeBase *Previous = FOR->getBackedgeValue()->getDefiningRecipe();
@@ -843,6 +845,12 @@ bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan,
     // Set the first operand of RecurSplice to FOR again, after replacing
     // all users.
     RecurSplice->setOperand(0, FOR);
+
+    auto *Result = cast<VPInstruction>(MiddleBuilder.createNaryOp(
+        VPInstruction::ExtractRecurrenceResult, {FOR->getBackedgeValue()}, {},
+        "vector.recur.extract.for.phi"));
+    RecurSplice->replaceUsesWithIf(
+        Result, [](VPUser &U, unsigned) { return isa<VPLiveOut>(&U); });
   }
   return true;
 }
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/induction-costs.ll b/llvm/test/Transforms/LoopVectorize/AArch64/induction-costs.ll
index 4d9c850abdf3d..162fb9c802dd2 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/induction-costs.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/induction-costs.ll
@@ -126,9 +126,9 @@ define i64 @pointer_induction_only(ptr %start, ptr %end) {
 ; CHECK-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
 ; CHECK:       middle.block:
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <2 x i64> [[TMP9]], i32 0
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP2]], [[N_VEC]]
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <2 x i64> [[TMP9]], i32 1
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <2 x i64> [[TMP9]], i32 0
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
 ; CHECK-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[VECTOR_RECUR_EXTRACT]], [[MIDDLE_BLOCK]] ]
@@ -191,8 +191,8 @@ define i64 @int_and_pointer_iv(ptr %start, i32 %N) {
 ; CHECK-NEXT:    [[TMP8:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
 ; CHECK-NEXT:    br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i64> [[TMP5]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i64> [[TMP5]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i64> [[TMP5]], i32 3
 ; CHECK-NEXT:    br i1 true, label [[EXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
 ; CHECK-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[VECTOR_RECUR_EXTRACT]], [[MIDDLE_BLOCK]] ]
@@ -332,9 +332,9 @@ define i64 @test_ptr_ivs_and_widened_ivs(ptr %src, i32 %N) {
 ; DEFAULT-NEXT:    [[TMP18:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; DEFAULT-NEXT:    br i1 [[TMP18]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
 ; DEFAULT:       middle.block:
+; DEFAULT-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i64> [[TMP15]], i32 2
 ; DEFAULT-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP2]], [[N_VEC]]
 ; DEFAULT-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i64> [[TMP15]], i32 3
-; DEFAULT-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i64> [[TMP15]], i32 2
 ; DEFAULT-NEXT:    br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
 ; DEFAULT:       scalar.ph:
 ; DEFAULT-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[VECTOR_RECUR_EXTRACT]], [[MIDDLE_BLOCK]] ]
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll b/llvm/test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll
index f2f58f89eb95d..718148a67fcc8 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll
@@ -870,9 +870,9 @@ define i8 @add_phifail2(ptr noalias nocapture readonly %p, ptr noalias nocapture
 ; CHECK-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP23:![0-9]+]]
 ; CHECK:       middle.block:
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <16 x i32> [[TMP6]], i32 14
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP2]], [[N_VEC]]
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <16 x i32> [[TMP6]], i32 15
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <16 x i32> [[TMP6]], i32 14
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
 ; CHECK-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[VECTOR_RECUR_EXTRACT]], [[MIDDLE_BLOCK]] ]
diff --git a/llvm/test/Transforms/LoopVectorize/first-order-recurrence-chains.ll b/llvm/test/Transforms/LoopVectorize/first-order-recurrence-chains.ll
index c663d2b15b587..a1173c6b46a2c 100644
--- a/llvm/test/Transforms/LoopVectorize/first-order-recurrence-chains.ll
+++ b/llvm/test/Transforms/LoopVectorize/first-order-recurrence-chains.ll
@@ -18,10 +18,10 @@ define i16 @test_chained_first_order_recurrences_1(ptr %ptr) {
 ; CHECK-NEXT:    [[TMP8:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
 ; CHECK-NEXT:    br i1 [[TMP8]], label %middle.block, label %vector.body
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT2:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI3:%.*]] = extractelement <4 x i16> [[TMP4]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT2:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
 ;
 entry:
   br label %loop
@@ -61,10 +61,10 @@ define i16 @test_chained_first_order_recurrences_2(ptr %ptr) {
 ; CHECK-NEXT:    [[TMP8:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
 ; CHECK-NEXT:    br i1 [[TMP8]], label %middle.block, label %vector.body, !llvm.loop [[LOOP4:![0-9]+]]
 ; CHECK:      middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i16> [[TMP4]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT2:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI3:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT2:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
 ;
 entry:
   br label %loop
@@ -107,12 +107,12 @@ define i16 @test_chained_first_order_recurrences_3(ptr %ptr) {
 ; CHECK-NEXT:    [[TMP10:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
 ; CHECK-NEXT:    br i1 [[TMP10]], label %middle.block, label %vector.body, !llvm.loop [[LOOP6:![0-9]+]]
 ; CHECK:      middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT3:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI4:%.*]] = extractelement <4 x i16> [[TMP4]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT7:%.*]] = extractelement <4 x i16> [[TMP5]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI8:%.*]] = extractelement <4 x i16> [[TMP5]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT3:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT7:%.*]] = extractelement <4 x i16> [[TMP5]], i32 3
 ;
 entry:
   br label %loop
@@ -219,12 +219,12 @@ define i16 @test_chained_first_order_recurrences_3_reordered_1(ptr %ptr) {
 ; CHECK-NEXT:    [[TMP10:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
 ; CHECK-NEXT:    br i1 [[TMP10]], label %middle.block, label %vector.body, !llvm.loop [[LOOP6:![0-9]+]]
 ; CHECK:      middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT7:%.*]] = extractelement <4 x i16> [[TMP5]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI8:%.*]] = extractelement <4 x i16> [[TMP5]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT3:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI4:%.*]] = extractelement <4 x i16> [[TMP4]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT7:%.*]] = extractelement <4 x i16> [[TMP5]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT3:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
 ;
 entry:
   br label %loop
@@ -270,12 +270,12 @@ define i16 @test_chained_first_order_recurrences_3_reordered_2(ptr %ptr) {
 ; CHECK-NEXT:    [[TMP10:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
 ; CHECK-NEXT:    br i1 [[TMP10]], label %middle.block, label %vector.body, !llvm.loop [[LOOP6:![0-9]+]]
 ; CHECK:      middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT3:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI4:%.*]] = extractelement <4 x i16> [[TMP4]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT7:%.*]] = extractelement <4 x i16> [[TMP5]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI8:%.*]] = extractelement <4 x i16> [[TMP5]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT3:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT7:%.*]] = extractelement <4 x i16> [[TMP5]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
 ;
 entry:
   br label %loop
@@ -321,12 +321,12 @@ define i16 @test_chained_first_order_recurrences_3_for2_no_other_uses(ptr %ptr)
 ; CHECK-NEXT:    [[TMP10:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
 ; CHECK-NEXT:    br i1 [[TMP10]], label %middle.block, label %vector.body, !llvm.loop [[LOOP6:![0-9]+]]
 ; CHECK:      middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT3:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI4:%.*]] = extractelement <4 x i16> [[TMP4]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT7:%.*]] = extractelement <4 x i16> [[TMP5]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI8:%.*]] = extractelement <4 x i16> [[TMP5]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT3:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT7:%.*]] = extractelement <4 x i16> [[TMP5]], i32 3
 ;
 entry:
   br label %loop
@@ -371,12 +371,12 @@ define i16 @test_chained_first_order_recurrences_3_for1_for2_no_other_uses(ptr %
 ; CHECK-NEXT:    [[TMP10:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
 ; CHECK-NEXT:    br i1 [[TMP10]], label %middle.block, label %vector.body, !llvm.loop [[LOOP6:![0-9]+]]
 ; CHECK:      middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT3:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI4:%.*]] = extractelement <4 x i16> [[TMP4]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT7:%.*]] = extractelement <4 x i16> [[TMP5]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI8:%.*]] = extractelement <4 x i16> [[TMP5]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT3:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT7:%.*]] = extractelement <4 x i16> [[TMP5]], i32 3
 ;
 entry:
   br label %loop
@@ -420,10 +420,10 @@ define double @test_chained_first_order_recurrence_sink_users_1(ptr %ptr) {
 ; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], 996
 ; CHECK-NEXT:    br i1 [[TMP9]], label %middle.block, label %vector.body, !llvm.loop [[LOOP10:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x double> [[WIDE_LOAD]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x double> [[WIDE_LOAD]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT2:%.*]] = extractelement <4 x double> [[TMP4]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI3:%.*]] = extractelement <4 x double> [[TMP4]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x double> [[WIDE_LOAD]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT2:%.*]] = extractelement <4 x double> [[TMP4]], i32 3
 ;
 entry:
   br label %loop
@@ -488,8 +488,8 @@ define i64 @test_first_order_recurrences_and_induction(ptr %ptr) {
 ; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
 ; CHECK-NEXT:    br i1 [[TMP5]], label %middle.block, label %vector.body
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i64> [[VEC_IND]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i64> [[VEC_IND]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i64> [[VEC_IND]], i32 3
 ; CHECK-NEXT:    br i1 true
 
 entry:
@@ -528,8 +528,8 @@ define i64 @test_first_order_recurrences_and_induction2(ptr %ptr) {
 ; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
 ; CHECK-NEXT:    br i1 [[TMP5]], label %middle.block, label %vector.body
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i64> [[VEC_IND]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i64> [[VEC_IND]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i64> [[VEC_IND]], i32 3
 ; CHECK-NEXT:    br i1 true
 ;
 entry:
@@ -569,8 +569,8 @@ define ptr @test_first_order_recurrences_and_pointer_induction1(ptr %ptr) {
 ; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
 ; CHECK-NEXT:    br i1 [[TMP5]], label %middle.block, label %vector.body
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x ptr> [[TMP0]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x ptr> [[TMP0]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x ptr> [[TMP0]], i32 3
 ; CHECK-NEXT:    br i1 true
 ;
 entry:
@@ -613,8 +613,8 @@ define ptr @test_first_order_recurrences_and_pointer_induction2(ptr %ptr) {
 ; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
 ; CHECK-NEXT:    br i1 [[TMP5]], label %middle.block, label %vector.body
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x ptr> [[TMP0]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x ptr> [[TMP0]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x ptr> [[TMP0]], i32 3
 ; CHECK-NEXT:    br i1 true
 ;
 entry:
@@ -660,12 +660,12 @@ define double @test_resinking_required(ptr %p, ptr noalias %a, ptr noalias %b) {
 ; CHECK-NEXT:    [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
 ; CHECK-NEXT:    br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP28:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x double> [[BROADCAST_SPLAT]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x double> [[BROADCAST_SPLAT]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT5:%.*]] = extractelement <4 x double> [[BROADCAST_SPLAT4]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI6:%.*]] = extractelement <4 x double> [[BROADCAST_SPLAT4]], i32 2
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT9:%.*]] = extractelement <4 x double> [[TMP4]], i32 3
 ; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI10:%.*]] = extractelement <4 x double> [[TMP4]], i32 2
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x double> [[BROADCAST_SPLAT]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT5:%.*]] = extractelement <4 x double> [[BROADCAST_SPLAT4]], i32 3
+; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT9:%.*]] = extractelement <4 x double> [[TMP4]], i32 3
 ; CHECK-NEXT:    br i1 true, label %End, label %scalar.ph
 ;
 Entry:
diff --git a/llvm/test/Transforms/LoopVectorize/first-order-recurrence.ll b/llvm/test/Transforms/LoopVectorize/first-order-recurrence.ll
index 0a37e5ea0ca00..a588d71a2c76e 100644
--- a/llvm/test/Transforms/LoopVectorize/first-order-recurrence.ll
+++ b/llvm/test/Transforms/LoopVectorize/first-order-recurrence.ll
@@ -918,9 +918,9 @@ define i32 @PR27246() {
 ; UNROLL-NO-IC-NEXT:    [[TMP2:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; UNROLL-NO-IC-NEXT:    br i1 [[TMP2]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
 ; UNROLL-NO-IC:       middle.block:
+; UNROLL-NO-IC-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i32> [[STEP_ADD]], i32 2
 ; UNROLL-NO-IC-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[I_016]], [[N_VEC]]
 ; UNROLL-NO-IC-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[STEP_ADD]], i32 3
-; UNROLL-NO-IC-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i32> [[STEP_ADD]], i32 2
 ; UNROLL-NO-IC-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP3]], label [[SCALAR_PH]]
 ; UNROLL-NO-IC:       scalar.ph:
 ; UNROLL-NO-IC-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i32 [ [[E_015]], [[FOR_COND1_PREHEADER]] ], [ [[VECTOR_RECUR_EXTRACT]], [[MIDDLE_BLOCK]] ]
@@ -1012,9 +1012,9 @@ define i32 @PR27246() {
 ; SINK-AFTER-NEXT:    [[TMP1:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; SINK-AFTER-NEXT:    br i1 [[TMP1]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
 ; SINK-AFTER:       middle.block:
+; SINK-AFTER-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i32> [[VEC_IND]], i32 2
 ; SINK-AFTER-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[I_016]], [[N_VEC]]
 ; SINK-AFTER-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[VEC_IND]], i32 3
-; SINK-AFTER-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i32> [[VEC_IND]], i32 2
 ; SINK-AFTER-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP3]], label [[SCALAR_PH]]
 ; SINK-AFTER:       scalar.ph:
 ; SINK-AFTER-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i32 [ [[E_015]], [[FOR_COND1_PREHEADER]] ], [ [[VECTOR_RECUR_EXTRACT]], [[MIDDLE_BLOCK]] ]
@@ -1114,12 +1114,12 @@ define i32 @PR30183(i32 %pre_load, ptr %a, ptr %b, i64 %n) {
 ; UNROLL-NO-IC-NEXT:    [[TMP34:%.*]] = insertelement <4 x i32> [[TMP33]], i32 [[TMP30]], i32 3
 ; UNROLL-NO-IC-NEXT:    [[TMP35:%.*]] = load i32, ptr [[TMP23]], align 4
 ; UNROLL-NO-IC-NEXT:    [[TMP36:%.*]] = load i32, ptr [[TMP24]], align 4
-; UNROLL-NO-IC-NEXT:    [[TMP37:%.*]] = load i32, ptr [[TMP25]], align 4
+; UNROLL-NO-IC-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = load i32, ptr [[TMP25]], align 4
 ; UNROLL-NO-IC-NEXT:    [[TMP38:%.*]] = load i32, ptr [[TMP26]], align 4
 ; UNROLL-NO-IC-NEXT:    [[TMP39:%.*]] = insertelement <4 x i32> poison, i32 [[TMP35]], i32 0
 ; UNROLL-NO-IC-NEXT:    [[TMP40:%.*]] = insertelement <4 x i32> [[TMP39]], i32 [[TMP36]], i32 1
-; UNROLL-NO-IC-NEXT:    [[TMP41:%.*]] = insertelement <4 x i32> [[TMP40]], i32 [[TMP37]], i32 2
-; UNROLL-NO-IC-NEXT:    [[TMP42]] = insertelement <4 x i32> [[TMP41]], i32 [[TMP38]], i32 3
+; UNROLL-NO-IC-NEXT:    [[TMP41:%.*]] = insertelement <4 x i32> [[TMP40]], i32 [[VECTOR_RECUR_EXTRACT_FOR_PHI]], i32 2
+; UNROLL-NO-IC-NEXT:    [[TMP42:%.*]] = insertelement <4 x i32> [[TMP41]], i32 [[TMP38]], i32 3
 ; UNROLL-NO-IC-NEXT:    [[TMP43:%.*]] = shufflevector <4 x i32> [[VECTOR_RECUR]], <4 x i32> [[TMP34]], <4 x i32> <i32 3, i32 4, i32 5, i32 6>
 ; UNROLL-NO-IC-NEXT:    [[TMP44:%.*]] = shufflevector <4 x i32> [[TMP34]], <4 x i32> [[TMP42]], <4 x i32> <i32 3, i32 4, i32 5, i32 6>
 ; UNROLL-NO-IC-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
@@ -1128,7 +1128,6 @@ define i32 @PR30183(i32 %pre_load, ptr %a, ptr %b, i64 %n) {
 ; UNROLL-NO-IC:       middle.block:
 ; UNROLL-NO-IC-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP2]], [[N_VEC]]
 ; UNROLL-NO-IC-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[TMP42]], i32 3
-; UNROLL-NO-IC-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i32> [[TMP42]], i32 2
 ; UNROLL-NO-IC-NEXT:    br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]
 ; UNROLL-NO-IC:       scalar.ph:
 ; UNROLL-NO-IC-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i32 [ [[PRE_LOAD]], [[ENTRY:%.*]] ], [ [[VECTOR_RECUR_EXTRACT]], [[MIDDLE_BLOCK]] ]
@@ -1223,11 +1222,11 @@ define i32 @PR30183(i32 %pre_load, ptr %a, ptr %b, i64 %n) {
 ; SINK-AFTER-NEXT:    [[TMP14:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP10]]
 ; SINK-AFTER-NEXT:    [[TMP15:%.*]] = load i32, ptr [[TMP11]], align 4
 ; SINK-AFTER-NEXT:    [[TMP16:%.*]] = load i32, ptr [[TMP12]], align 4
-; SINK-AFTER-NEXT:    [[TMP17:%.*]] = load i32, ptr [[TMP13]], align 4
+; SINK-AFTER-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = load i32, ptr [[TMP13]], align 4
 ; SINK-AFTER-NEXT:    [[TMP18:%.*]] = load i32, ptr [[TMP14]], align 4
 ; SINK-AFTER-NEXT:    [[TMP19:%.*]] = insertelement <4 x i32> poison, i32 [[TMP15]], i32 0
 ; SINK-AFTER-NEXT:    [[TMP20:%.*]] = insertelement <4 x i32> [[TMP19]], i32 [[TMP16]], i32 1
-; SINK-AFTER-NEXT:    [[TMP21:%.*]] = insertelement <4 x i32> [[TMP20]], i32 [[TMP17]], i32 2
+; SINK-AFTER-NEXT:    [[TMP21:%.*]] = insertelement <4 x i32> [[TMP20]], i32 [[VECTOR_RECUR_EXTRACT_FOR_PHI]], i32 2
 ; SINK-AFTER-NEXT:    [[TMP22]] = insertelement <4 x i32> [[TMP21]], i32 [[TMP18]], i32 3
 ; SINK-AFTER-NEXT:    [[TMP23:%.*]] = shufflevector <4 x i32> [[VECTOR_RECUR]], <4 x i32> [[TMP22]], <4 x i32> <i32 3, i32 4, i32 5, i32 6>
 ; SINK-AFTER-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
@@ -1236,7 +1235,6 @@ define i32 @PR30183(i32 %pre_load, ptr %a, ptr %b, i64 %n) {
 ; SINK-AFTER:       middle.block:
 ; SINK-AFTER-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP2]], [[N_VEC]]
 ; SINK-AFTER-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[TMP22]], i32 3
-; SINK-AFTER-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i32> [[TMP22]], i32 2
 ; SINK-AFTER-NEXT:    br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]
 ; SINK-AFTER:       scalar.ph:
 ; SINK-AFTER-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i32 [ [[PRE_LOAD]], [[ENTRY:%.*]] ], [ [[VECTOR_RECUR_EXTRACT]], [[MIDDLE_BLOCK]] ]
@@ -1403,8 +1401,8 @@ define i32 @extract_second_last_iteration(ptr %cval, i32 %x)  {
 ; UNROLL-NO-IC-NEXT:    [[TMP4:%.*]] = icmp eq i32 [[INDEX_NEXT]], 96
 ; UNROLL-NO-IC-NEXT:    br i1 [[TMP4]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP14:![0-9]+]]
 ; UNROLL-NO-IC:       middle.block:
-; UNROLL-NO-IC-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[TMP1]], i32 3
 ; UNROLL-NO-IC-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i32> [[TMP1]], i32 2
+; UNROLL-NO-IC-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[TMP1]], i32 3
 ; UNROLL-NO-IC-NEXT:    br i1 true, label [[FOR_END:%.*]], label [[SCALAR_PH]]
 ; UNROLL-NO-IC:       scalar.ph:
 ; UNROLL-NO-IC-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[VECTOR_RECUR_EXTRACT]], [[MIDDLE_BLOCK]] ]
@@ -1473,8 +1471,8 @@ define i32 @extract_second_last_iteration(ptr %cval, i32 %x)  {
 ; SINK-AFTER-NEXT:    [[TMP2:%.*]] = icmp eq i32 [[INDEX_NEXT]], 96
 ; SINK-AFTER-NEXT:    br i1 [[TMP2]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP14:![0-9]+]]
 ; SINK-AFTER:       middle.block:
-; SINK-AFTER-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 3
 ; SINK-AFTER-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i32> [[TMP0]], i32 2
+; SINK-AFTER-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 3
 ; SINK-AFTER-NEXT:    br i1 true, label [[FOR_END:%.*]], label [[SCALAR_PH]]
 ; SINK-AFTER:       scalar.ph:
 ; SINK-AFTER-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[VECTOR_RECUR_EXTRACT]], [[MIDDLE_BLOCK]] ]
@@ -3511,8 +3509,8 @@ define i32 @sink_after_dead_inst(ptr %A.ptr) {
 ; UNROLL-NO-IC-NEXT:    [[TMP14:%.*]] = icmp eq i32 [[INDEX_NEXT]], 16
 ; UNROLL-NO-IC-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP33:![0-9]+]]
 ; UNROLL-NO-IC:       middle.block:
-; UNROLL-NO-IC-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[TMP7]], i32 3
 ; UNROLL-NO-IC-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i32> [[TMP7]], i32 2
+; UNROLL-NO-IC-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[TMP7]], i32 3
 ; UNROLL-NO-IC-NEXT:    br i1 true, label [[FOR_END:%.*]], label [[SCALAR_PH]]
 ; UNROLL-NO-IC:       scalar.ph:
 ; UNROLL-NO-IC-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[VECTOR_RECUR_EXTRACT]], [[MIDDLE_BLOCK]] ]
@@ -3607,8 +3605,8 @@ define i32 @sink_after_dead_inst(ptr %A.ptr) {
 ; SINK-AFTER-NEXT:    [[TMP7:%.*]] = icmp eq i32 [[INDEX_NEXT]], 16
 ; SINK-AFTER-NEXT:    br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP33:![0-9]+]]
 ; SINK-AFTER:       middle.block:
-; SINK-AFTER-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[TMP3]], i32 3
 ; SINK-AFTER-NEXT:    [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i32> [[TMP3]], i32 2
+; SINK-AFTER-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[TMP3]], i32 3
 ; SINK-AFTER-NEXT:    br i1 true, label [[FOR_END:%.*]], label [[SCALAR_PH]]
 ; SINK-AFTER:       scalar.ph:
 ; SINK-AFTER-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[VECTOR_RECUR_EXTRACT]], [[MIDDLE_BLOCK]] ]
diff --git a/llvm/test/Transforms/LoopVectorize/scalable-first-order-recurrence.ll b/llvm/test/Transforms/LoopVectorize/scalable-first-order-recurrence.ll
index d64755999635c..b959894b671e1 100644
--- a/llvm/test/Transforms/LoopVectorize/scalable-first-order-recurrence.ll
+++ b/llvm/test/Transforms/LoopVectorize/scalable-first-order-recurrence.ll
@@ -25,10 +25,12 @@ define i32 @recurrence_1(ptr nocapture readonly %a, ptr nocapture %b, i32 %n) {
 ; CHECK-VF4UF1: middle.block:
 ; CHECK-VF4UF1: %[[VSCALE2:.*]] = call i32 @llvm.vscale.i32()
 ; CHECK-VF4UF1: %[[MUL2:.*]] = mul i32 %[[VSCALE2]], 4
-; CHECK-VF4UF1: %[[SUB2:.*]] = sub i32 %[[MUL2]], 1
-; CHECK-VF4UF1: %[[VEC_RECUR_EXT:.*]] = extractelement <vscale x 4 x i32> %[[LOAD]], i32 %[[SUB2]]
 ; CHECK-VF4UF1: %[[SUB3:.*]] = sub i32 %[[MUL2]], 2
 ; CHECK-VF4UF1: %[[VEC_RECUR_FOR_PHI:.*]] =  extractelement <vscale x 4 x i32> %[[LOAD]], i32 %[[SUB3]]
+; CHECK-VF4UF1: %[[VSCALE3:.*]] = call i32 @llvm.vscale.i32()
+; CHECK-VF4UF1: %[[MUL3:.*]] = mul i32 %[[VSCALE3]], 4
+; CHECK-VF4UF1: %[[SUB3:.*]] = sub i32 %[[MUL3]], 1
+; CHECK-VF4UF1: %[[VEC_RECUR_EXT:.*]] = extractelement <vscale x 4 x i32> %[[LOAD]], i32 %[[SUB3]]
 entry:
   br label %for.preheader
 
@@ -211,10 +213,12 @@ define i32 @extract_second_last_iteration(ptr %cval, i32 %x)  {
 ; CHECK-VF4UF2: middle.block
 ; CHECK-VF4UF2: %[[VSCALE2:.*]] = call i32 @llvm.vscale.i32()
 ; CHECK-VF4UF2: %[[MUL2:.*]] = mul i32 %[[VSCALE2]], 4
-; CHECK-VF4UF2: %[[SUB2:.*]] = sub i32 %[[MUL2]], 1
-; CHECK-VF4UF2: %vector.recur.extract = extractelement <vscale x 4 x i32> %[[ADD2]], i32 %[[SUB2]]
 ; CHECK-VF4UF2: %[[SUB3:.*]] = sub i32 %[[MUL2]], 2
 ; CHECK-VF4UF2: %vector.recur.extract.for.phi = extractelement <vscale x 4 x i32> %[[ADD2]], i32 %[[SUB3]]
+; CHECK-VF4UF2: %[[VSCALE3:.*]] = call i32 @llvm.vscale.i32()
+; CHECK-VF4UF2: %[[MUL3:.*]] = mul i32 %[[VSCALE3]], 4
+; CHECK-VF4UF2: %[[SUB2:.*]] = sub i32 %[[MUL3]], 1
+; CHECK-VF4UF2: %vector.recur.extract = extractelement <vscale x 4 x i32> %[[ADD2]], i32 %[[SUB2]]
 entry:
   br label %for.body
 
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index dd7735584737b..4e6e5f890a8b7 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -911,9 +911,10 @@ define i16 @print_first_order_recurrence_and_result(ptr %ptr) {
 ; CHECK-NEXT: Successor(s): middle.block
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<[[FOR_RESULT:%.+]]> = extract-recurrence-result ir<%for.1.next>
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
-; CHECK-NEXT:   Live-out i16 %for.1.lcssa = vp<[[FOR1_SPLICE]]>
+; CHECK-NEXT:   Live-out i16 %for.1.lcssa = vp<[[FOR_RESULT]]>
 ; CHECK-NEXT: }
 ;
 entry:

>From 254f9e639ccfbf17263ec41151d00e7cb92fbe61 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sat, 1 Jun 2024 01:05:51 -0700
Subject: [PATCH 2/5] !fixup address comments

---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 10 ++++++---
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 22 ++++++++++---------
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 18 +++++++++------
 .../LoopVectorize/vplan-printing.ll           |  2 +-
 4 files changed, 31 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 659f28ec9d832..efcc1c99ea2ff 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -167,7 +167,7 @@ class VPLane {
 
   static VPLane getFirstLane() { return VPLane(0, VPLane::Kind::First); }
 
-  static VPLane getLastLaneForVF(const ElementCount &VF, unsigned Offset = 1) {
+  static VPLane getLaneFromEnd(const ElementCount &VF, unsigned Offset) {
     unsigned LaneOffset = VF.getKnownMinValue() - Offset;
     Kind LaneKind;
     if (VF.isScalable())
@@ -179,6 +179,10 @@ class VPLane {
     return VPLane(LaneOffset, LaneKind);
   }
 
+  static VPLane getLastLaneForVF(const ElementCount &VF) {
+    return getLaneFromEnd(VF, 1);
+  }
+
   /// Returns a compile-time known value for the lane index and asserts if the
   /// lane can only be calculated at runtime.
   unsigned getKnownLane() const {
@@ -1182,7 +1186,7 @@ class VPInstruction : public VPRecipeWithIRFlags {
     BranchOnCount,
     BranchOnCond,
     ComputeReductionResult,
-    ExtractRecurrenceResult,
+    ExtractFromEnd,
     LogicalAnd, // Non-poison propagating logical And.
     // Add an offset in bytes (second operand) to a base pointer (first
     // operand). Only generates scalar values (either for the first lane only or
@@ -3659,7 +3663,7 @@ inline bool isUniformAfterVectorization(VPValue *VPV) {
     return all_of(GEP->operands(), isUniformAfterVectorization);
   if (auto *VPI = dyn_cast<VPInstruction>(Def))
     return VPI->getOpcode() == VPInstruction::ComputeReductionResult ||
-           VPI->getOpcode() == VPInstruction::ExtractRecurrenceResult;
+           VPI->getOpcode() == VPInstruction::ExtractFromEnd;
   return false;
 }
 } // end namespace vputils
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 82e3bd530a8d6..b010268eddfc7 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -137,7 +137,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
     case VPInstruction::Not:
     case VPInstruction::CalculateTripCountMinusVF:
     case VPInstruction::CanonicalIVIncrementForPart:
-    case VPInstruction::ExtractRecurrenceResult:
+    case VPInstruction::ExtractFromEnd:
     case VPInstruction::LogicalAnd:
     case VPInstruction::PtrAdd:
       return false;
@@ -301,7 +301,7 @@ bool VPInstruction::canGenerateScalarForFirstLane() const {
   case VPInstruction::CalculateTripCountMinusVF:
   case VPInstruction::CanonicalIVIncrementForPart:
   case VPInstruction::ComputeReductionResult:
-  case VPInstruction::ExtractRecurrenceResult:
+  case VPInstruction::ExtractFromEnd:
   case VPInstruction::PtrAdd:
   case VPInstruction::ExplicitVectorLength:
     return true;
@@ -560,23 +560,25 @@ Value *VPInstruction::generatePerPart(VPTransformState &State, unsigned Part) {
 
     return ReducedPartRdx;
   }
-  case VPInstruction::ExtractRecurrenceResult: {
+  case VPInstruction::ExtractFromEnd: {
     if (Part != 0)
       return State.get(this, 0, /*IsScalar*/ true);
 
-    // Extract the second last element in the middle block for users outside the
-    // loop.
+    auto *CI = cast<ConstantInt>(getOperand(1)->getLiveInIRValue());
+    unsigned Offset = CI->getZExtValue();
+
+    // Extract lane VF - Offset in the from the operand.
     Value *Res;
     if (State.VF.isVector()) {
       Res = State.get(
           getOperand(0),
-          VPIteration(State.UF - 1, VPLane::getLastLaneForVF(State.VF, 2)));
+          VPIteration(State.UF - 1, VPLane::getLaneFromEnd(State.VF, Offset)));
     } else {
       assert(State.UF > 1 && "VF and UF cannot both be 1");
       // When loop is unrolled without vectorizing, retrieve the value just
       // prior to the final unrolled value. This is analogous to the vectorized
       // case above: extracting the second last element when VF > 1.
-      Res = State.get(getOperand(0), State.UF - 2);
+      Res = State.get(getOperand(0), State.UF - Offset);
     }
     Res->setName(Name);
     return Res;
@@ -621,7 +623,7 @@ void VPInstruction::execute(VPTransformState &State) {
   bool GeneratesPerFirstLaneOnly =
       canGenerateScalarForFirstLane() &&
       (vputils::onlyFirstLaneUsed(this) ||
-       getOpcode() == VPInstruction::ExtractRecurrenceResult ||
+       getOpcode() == VPInstruction::ExtractFromEnd ||
        getOpcode() == VPInstruction::ComputeReductionResult);
   bool GeneratesPerAllLanes = doesGeneratePerAllLanes();
   for (unsigned Part = 0; Part < State.UF; ++Part) {
@@ -716,8 +718,8 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
   case VPInstruction::BranchOnCount:
     O << "branch-on-count";
     break;
-  case VPInstruction::ExtractRecurrenceResult:
-    O << "extract-recurrence-result";
+  case VPInstruction::ExtractFromEnd:
+    O << "extract-from-end";
     break;
   case VPInstruction::ComputeReductionResult:
     O << "compute-reduction-result";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index f4f43ed15615e..af3cf0ad7af04 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -802,7 +802,7 @@ sinkRecurrenceUsersAfterPrevious(VPFirstOrderRecurrencePHIRecipe *FOR,
 }
 
 bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan,
-                                                  VPBuilder &Builder) {
+                                                  VPBuilder &LoopBuilder) {
   VPDominatorTree VPDT;
   VPDT.recalculate(Plan);
 
@@ -833,22 +833,26 @@ bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan,
     // fixed-order recurrence.
     VPBasicBlock *InsertBlock = Previous->getParent();
     if (isa<VPHeaderPHIRecipe>(Previous))
-      Builder.setInsertPoint(InsertBlock, InsertBlock->getFirstNonPhi());
+      LoopBuilder.setInsertPoint(InsertBlock, InsertBlock->getFirstNonPhi());
     else
-      Builder.setInsertPoint(InsertBlock, std::next(Previous->getIterator()));
+      LoopBuilder.setInsertPoint(InsertBlock,
+                                 std::next(Previous->getIterator()));
 
     auto *RecurSplice = cast<VPInstruction>(
-        Builder.createNaryOp(VPInstruction::FirstOrderRecurrenceSplice,
-                             {FOR, FOR->getBackedgeValue()}));
+        LoopBuilder.createNaryOp(VPInstruction::FirstOrderRecurrenceSplice,
+                                 {FOR, FOR->getBackedgeValue()}));
 
     FOR->replaceAllUsesWith(RecurSplice);
     // Set the first operand of RecurSplice to FOR again, after replacing
     // all users.
     RecurSplice->setOperand(0, FOR);
 
+    Type *IntTy = Plan.getCanonicalIV()->getScalarType();
     auto *Result = cast<VPInstruction>(MiddleBuilder.createNaryOp(
-        VPInstruction::ExtractRecurrenceResult, {FOR->getBackedgeValue()}, {},
-        "vector.recur.extract.for.phi"));
+        VPInstruction::ExtractFromEnd,
+        {FOR->getBackedgeValue(),
+         Plan.getOrAddLiveIn(ConstantInt::get(IntTy, 2))},
+        {}, "vector.recur.extract.for.phi"));
     RecurSplice->replaceUsesWithIf(
         Result, [](VPUser &U, unsigned) { return isa<VPLiveOut>(&U); });
   }
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 4e6e5f890a8b7..4a71e18ea3778 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -911,7 +911,7 @@ define i16 @print_first_order_recurrence_and_result(ptr %ptr) {
 ; CHECK-NEXT: Successor(s): middle.block
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
-; CHECK-NEXT:   EMIT vp<[[FOR_RESULT:%.+]]> = extract-recurrence-result ir<%for.1.next>
+; CHECK-NEXT:   EMIT vp<[[FOR_RESULT:%.+]]> = extract-from-end ir<%for.1.next>, ir<2>
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT:   Live-out i16 %for.1.lcssa = vp<[[FOR_RESULT]]>

>From f226ee442df2598e3356d23b4ecdd8280fd261ae Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 3 Jun 2024 10:11:30 +0100
Subject: [PATCH 3/5] !fixup address lates comments, thanks!

---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 10 ++++++++++
 .../Transforms/Vectorize/VPlanAnalysis.cpp    |  6 ++++++
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 20 +++++++++----------
 3 files changed, 26 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index efcc1c99ea2ff..69e5de1407926 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -168,6 +168,8 @@ class VPLane {
   static VPLane getFirstLane() { return VPLane(0, VPLane::Kind::First); }
 
   static VPLane getLaneFromEnd(const ElementCount &VF, unsigned Offset) {
+    assert(Offset <= VF.getKnownMinValue() &&
+           "trying to extract with invalid offset");
     unsigned LaneOffset = VF.getKnownMinValue() - Offset;
     Kind LaneKind;
     if (VF.isScalable())
@@ -1186,6 +1188,10 @@ class VPInstruction : public VPRecipeWithIRFlags {
     BranchOnCount,
     BranchOnCond,
     ComputeReductionResult,
+    // Takes the VPValue to extract from as first operand and the lane to
+    // extract from as second operand. The second operand must be a constant and
+    // <= VF when extracting from a vector or <= UF when extracting from a
+    // scalar.
     ExtractFromEnd,
     LogicalAnd, // Non-poison propagating logical And.
     // Add an offset in bytes (second operand) to a base pointer (first
@@ -1224,6 +1230,10 @@ class VPInstruction : public VPRecipeWithIRFlags {
   /// value for lane \p Lane.
   Value *generatePerLane(VPTransformState &State, const VPIteration &Lane);
 
+  /// Returns true if this VPInstruction converts a vector value to a scalar,
+  /// e.g. by performing a reduction or extracting a lane.
+  bool isVectorToScalar() const;
+
 #if !defined(NDEBUG)
   /// Return true if the VPInstruction is a floating point math operation, i.e.
   /// has fast-math flags.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index efe8c21874a3a..6accebdd8f9d3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -45,6 +45,12 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
     CachedTypes[OtherV] = ResTy;
     return ResTy;
   }
+  case VPInstruction::ExtractFromEnd: {
+    Type *BaseTy = inferScalarType(R->getOperand(0));
+    if (auto *VecTy = dyn_cast<VectorType>(BaseTy))
+      return VecTy->getElementType();
+    return BaseTy;
+  }
   case VPInstruction::Not: {
     Type *ResTy = inferScalarType(R->getOperand(0));
     assert(IntegerType::get(Ctx, 1) == ResTy &&
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index b010268eddfc7..956f16563c40b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -294,14 +294,13 @@ bool VPInstruction::doesGeneratePerAllLanes() const {
 bool VPInstruction::canGenerateScalarForFirstLane() const {
   if (Instruction::isBinaryOp(getOpcode()))
     return true;
-
+  if (isVectorToScalar())
+    return true;
   switch (Opcode) {
   case VPInstruction::BranchOnCond:
   case VPInstruction::BranchOnCount:
   case VPInstruction::CalculateTripCountMinusVF:
   case VPInstruction::CanonicalIVIncrementForPart:
-  case VPInstruction::ComputeReductionResult:
-  case VPInstruction::ExtractFromEnd:
   case VPInstruction::PtrAdd:
   case VPInstruction::ExplicitVectorLength:
     return true;
@@ -567,17 +566,15 @@ Value *VPInstruction::generatePerPart(VPTransformState &State, unsigned Part) {
     auto *CI = cast<ConstantInt>(getOperand(1)->getLiveInIRValue());
     unsigned Offset = CI->getZExtValue();
 
-    // Extract lane VF - Offset in the from the operand.
     Value *Res;
     if (State.VF.isVector()) {
+      // Extract lane VF - Offset from the operand.
       Res = State.get(
           getOperand(0),
           VPIteration(State.UF - 1, VPLane::getLaneFromEnd(State.VF, Offset)));
     } else {
       assert(State.UF > 1 && "VF and UF cannot both be 1");
-      // When loop is unrolled without vectorizing, retrieve the value just
-      // prior to the final unrolled value. This is analogous to the vectorized
-      // case above: extracting the second last element when VF > 1.
+      // When loop is unrolled without vectorizing, retrieve UF - Offset.
       Res = State.get(getOperand(0), State.UF - Offset);
     }
     Res->setName(Name);
@@ -600,6 +597,11 @@ Value *VPInstruction::generatePerPart(VPTransformState &State, unsigned Part) {
   }
 }
 
+bool VPInstruction::isVectorToScalar() const {
+  return getOpcode() == VPInstruction::ExtractFromEnd ||
+         getOpcode() == VPInstruction::ComputeReductionResult;
+}
+
 #if !defined(NDEBUG)
 bool VPInstruction::isFPMathOp() const {
   // Inspired by FPMathOperator::classof. Notable differences are that we don't
@@ -622,9 +624,7 @@ void VPInstruction::execute(VPTransformState &State) {
   State.setDebugLocFrom(getDebugLoc());
   bool GeneratesPerFirstLaneOnly =
       canGenerateScalarForFirstLane() &&
-      (vputils::onlyFirstLaneUsed(this) ||
-       getOpcode() == VPInstruction::ExtractFromEnd ||
-       getOpcode() == VPInstruction::ComputeReductionResult);
+      (vputils::onlyFirstLaneUsed(this) || isVectorToScalar());
   bool GeneratesPerAllLanes = doesGeneratePerAllLanes();
   for (unsigned Part = 0; Part < State.UF; ++Part) {
     if (GeneratesPerAllLanes) {

>From 54505a83d809dad898dafae5144d1af8794d6aa0 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 3 Jun 2024 10:18:00 +0100
Subject: [PATCH 4/5] !fixup add assert to ExtractFromEnd

---
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 956f16563c40b..8ef2e4b5ee09d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -568,12 +568,14 @@ Value *VPInstruction::generatePerPart(VPTransformState &State, unsigned Part) {
 
     Value *Res;
     if (State.VF.isVector()) {
+      assert(Offset <= State.VF.getKnownMinValue() &&
+             "invalid offset to extract from");
       // Extract lane VF - Offset from the operand.
       Res = State.get(
           getOperand(0),
           VPIteration(State.UF - 1, VPLane::getLaneFromEnd(State.VF, Offset)));
     } else {
-      assert(State.UF > 1 && "VF and UF cannot both be 1");
+      assert(Offset <= State.UF && "invalid offset to extract from");
       // When loop is unrolled without vectorizing, retrieve UF - Offset.
       Res = State.get(getOperand(0), State.UF - Offset);
     }

>From fb5997b7d7c5c905723c2941e4d1b57b5a9d0275 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 3 Jun 2024 15:57:03 +0100
Subject: [PATCH 5/5] !fixup address latest comments, thanks!

---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 20 +++++++++----------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  2 +-
 2 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 69e5de1407926..bd500728883b3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -168,7 +168,7 @@ class VPLane {
   static VPLane getFirstLane() { return VPLane(0, VPLane::Kind::First); }
 
   static VPLane getLaneFromEnd(const ElementCount &VF, unsigned Offset) {
-    assert(Offset <= VF.getKnownMinValue() &&
+    assert(Offset > 0 && Offset <= VF.getKnownMinValue() &&
            "trying to extract with invalid offset");
     unsigned LaneOffset = VF.getKnownMinValue() - Offset;
     Kind LaneKind;
@@ -1188,9 +1188,10 @@ class VPInstruction : public VPRecipeWithIRFlags {
     BranchOnCount,
     BranchOnCond,
     ComputeReductionResult,
-    // Takes the VPValue to extract from as first operand and the lane to
-    // extract from as second operand. The second operand must be a constant and
-    // <= VF when extracting from a vector or <= UF when extracting from a
+    // Takes the VPValue to extract from as first operand and the lane or part
+    // to extract as second operand, counting from the end starting with 1 for
+    // last. The second operand must be a positive constant and <= VF when
+    // extracting from a vector or <= UF when extracting from an unrolled
     // scalar.
     ExtractFromEnd,
     LogicalAnd, // Non-poison propagating logical And.
@@ -1230,10 +1231,6 @@ class VPInstruction : public VPRecipeWithIRFlags {
   /// value for lane \p Lane.
   Value *generatePerLane(VPTransformState &State, const VPIteration &Lane);
 
-  /// Returns true if this VPInstruction converts a vector value to a scalar,
-  /// e.g. by performing a reduction or extracting a lane.
-  bool isVectorToScalar() const;
-
 #if !defined(NDEBUG)
   /// Return true if the VPInstruction is a floating point math operation, i.e.
   /// has fast-math flags.
@@ -1342,6 +1339,10 @@ class VPInstruction : public VPRecipeWithIRFlags {
     };
     llvm_unreachable("switch should return");
   }
+
+  /// Returns true if this VPInstruction produces a scalar value from a vector,
+  /// e.g. by performing a reduction or extracting a lane.
+  bool isVectorToScalar() const;
 };
 
 /// VPWidenRecipe is a recipe for producing a copy of vector type its
@@ -3672,8 +3673,7 @@ inline bool isUniformAfterVectorization(VPValue *VPV) {
   if (auto *GEP = dyn_cast<VPWidenGEPRecipe>(Def))
     return all_of(GEP->operands(), isUniformAfterVectorization);
   if (auto *VPI = dyn_cast<VPInstruction>(Def))
-    return VPI->getOpcode() == VPInstruction::ComputeReductionResult ||
-           VPI->getOpcode() == VPInstruction::ExtractFromEnd;
+    return VPI->isVectorToScalar();
   return false;
 }
 } // end namespace vputils
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 8ef2e4b5ee09d..cb707d7c0e24f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -565,7 +565,7 @@ Value *VPInstruction::generatePerPart(VPTransformState &State, unsigned Part) {
 
     auto *CI = cast<ConstantInt>(getOperand(1)->getLiveInIRValue());
     unsigned Offset = CI->getZExtValue();
-
+    assert(Offset > 0 && "Offset from end must be positive");
     Value *Res;
     if (State.VF.isVector()) {
       assert(Offset <= State.VF.getKnownMinValue() &&



More information about the llvm-commits mailing list