[llvm] [VPlan] Truncate/Extend ComputeReductionResult at construction (NFC). (PR #141860)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sun Jun 1 14:16:03 PDT 2025


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/141860

>From 6eb0f459eb1d1fe7328d49b6f479fffb7b813d1b Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sun, 1 Jun 2025 22:07:35 +0100
Subject: [PATCH] [VPlan] Truncate/Extend ComputeReductionResult at
 construction (NFC).

Instead of looking up the narrower reduction type via getRecurrenceType
we can generate the needed extend directly at constructiond re-use the
truncated value from the loop.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 54 +++++++++++--------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 23 ++------
 .../LoopVectorize/X86/cost-model.ll           |  3 +-
 .../epilog-vectorization-reductions.ll        |  6 +--
 .../LoopVectorize/reduction-small-size.ll     |  9 ++--
 .../scalable-reduction-inloop.ll              |  4 +-
 6 files changed, 43 insertions(+), 56 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e9ace195684b3..e803edf18a926 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7536,6 +7536,13 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
   // created a bc.merge.rdx Phi after the main vector body. Ensure that we carry
   // over the incoming values correctly.
   using namespace VPlanPatternMatch;
+  if (EpiRedResult->getNumUsers() == 1 &&
+      isa<VPInstructionWithType>(*EpiRedResult->user_begin())) {
+    EpiRedResult = cast<VPInstructionWithType>(*EpiRedResult->user_begin());
+    assert((EpiRedResult->getOpcode() == Instruction::SExt ||
+            EpiRedResult->getOpcode() == Instruction::ZExt) &&
+           "can only have SExt/ZExt users");
+  }
   assert(count_if(EpiRedResult->users(), IsaPred<VPPhi>) == 1 &&
          "ResumePhi must have a single user");
   auto *EpiResumePhiVPI =
@@ -9468,28 +9475,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         PhiR->setOperand(1, NewExitingVPV);
     }
 
-    // If the vector reduction can be performed in a smaller type, we truncate
-    // then extend the loop exit value to enable InstCombine to evaluate the
-    // entire expression in the smaller type.
-    if (MinVF.isVector() && PhiTy != RdxDesc.getRecurrenceType() &&
-        !RecurrenceDescriptor::isAnyOfRecurrenceKind(
-            RdxDesc.getRecurrenceKind())) {
-      assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!");
-      Type *RdxTy = RdxDesc.getRecurrenceType();
-      auto *Trunc =
-          new VPWidenCastRecipe(Instruction::Trunc, NewExitingVPV, RdxTy);
-      auto *Extnd =
-          RdxDesc.isSigned()
-              ? new VPWidenCastRecipe(Instruction::SExt, Trunc, PhiTy)
-              : new VPWidenCastRecipe(Instruction::ZExt, Trunc, PhiTy);
-
-      Trunc->insertAfter(NewExitingVPV->getDefiningRecipe());
-      Extnd->insertAfter(Trunc);
-      if (PhiR->getOperand(1) == NewExitingVPV)
-        PhiR->setOperand(1, Extnd->getVPSingleValue());
-      NewExitingVPV = Extnd;
-    }
-
     // We want code in the middle block to appear to execute on the location of
     // the scalar loop's latch terminator because: (a) it is all compiler
     // generated, (b) these instructions are always executed after evaluating
@@ -9521,6 +9506,31 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
           Builder.createNaryOp(VPInstruction::ComputeReductionResult,
                                {PhiR, NewExitingVPV}, Flags, ExitDL);
     }
+    // If the vector reduction can be performed in a smaller type, we truncate
+    // then extend the loop exit value to enable InstCombine to evaluate the
+    // entire expression in the smaller type.
+    if (MinVF.isVector() && PhiTy != RdxDesc.getRecurrenceType() &&
+        !RecurrenceDescriptor::isAnyOfRecurrenceKind(
+            RdxDesc.getRecurrenceKind())) {
+      assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!");
+      Type *RdxTy = RdxDesc.getRecurrenceType();
+      auto *Trunc =
+          new VPWidenCastRecipe(Instruction::Trunc, NewExitingVPV, RdxTy);
+      Instruction::CastOps ExtendOpc =
+          RdxDesc.isSigned() ? Instruction::SExt : Instruction::ZExt;
+      auto *Extnd = new VPWidenCastRecipe(ExtendOpc, Trunc, PhiTy);
+      Trunc->insertAfter(NewExitingVPV->getDefiningRecipe());
+      Extnd->insertAfter(Trunc);
+      if (PhiR->getOperand(1) == NewExitingVPV)
+        PhiR->setOperand(1, Extnd->getVPSingleValue());
+
+      // Update ComputeReductionResult with the truncated exiting value and
+      // extend its result.
+      FinalReductionResult->setOperand(1, Trunc);
+      FinalReductionResult =
+          Builder.createScalarCast(ExtendOpc, FinalReductionResult, PhiTy, {});
+    }
+
     // Update all users outside the vector region.
     OrigExitingVPV->replaceUsesWithIf(
         FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index a4831ea7c11f7..672c1f2d9524c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -633,7 +633,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
     // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
     // and will be removed by breaking up the recipe further.
     auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
-    auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
     // Get its reduction variable descriptor.
     const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
 
@@ -641,7 +640,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
     assert(!RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
            "should be handled by ComputeFindLastIVResult");
 
-    Type *PhiTy = OrigPhi->getType();
     // The recipe's operands are the reduction phi, followed by one operand for
     // each part of the reduction.
     unsigned UF = getNumOperands() - 1;
@@ -653,15 +651,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
     if (hasFastMathFlags())
       Builder.setFastMathFlags(getFastMathFlags());
 
-    // If the vector reduction can be performed in a smaller type, we truncate
-    // then extend the loop exit value to enable InstCombine to evaluate the
-    // entire expression in the smaller type.
-    // TODO: Handle this in truncateToMinBW.
-    if (State.VF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) {
-      Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), State.VF);
-      for (unsigned Part = 0; Part < UF; ++Part)
-        RdxParts[Part] = Builder.CreateTrunc(RdxParts[Part], RdxVecTy);
-    }
     // Reduce all of the unrolled parts into a single vector.
     Value *ReducedPartRdx = RdxParts[0];
     if (PhiR->isOrdered()) {
@@ -687,19 +676,14 @@ Value *VPInstruction::generate(VPTransformState &State) {
       // TODO: Support in-order reductions based on the recurrence descriptor.
       // All ops in the reduction inherit fast-math-flags from the recurrence
       // descriptor.
-      if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
+      if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) {
+        auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
         ReducedPartRdx =
             createAnyOfReduction(Builder, ReducedPartRdx,
                                  RdxDesc.getRecurrenceStartValue(), OrigPhi);
-      else
+      } else
         ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
 
-      // If the reduction can be performed in a smaller type, we need to extend
-      // the reduction to the wider type before we branch to the original loop.
-      if (PhiTy != RdxDesc.getRecurrenceType())
-        ReducedPartRdx = RdxDesc.isSigned()
-                             ? Builder.CreateSExt(ReducedPartRdx, PhiTy)
-                             : Builder.CreateZExt(ReducedPartRdx, PhiTy);
     }
 
     return ReducedPartRdx;
@@ -1040,6 +1024,7 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
 void VPInstructionWithType::execute(VPTransformState &State) {
   State.setDebugLocFrom(getDebugLoc());
   switch (getOpcode()) {
+  case Instruction::SExt:
   case Instruction::ZExt:
   case Instruction::Trunc: {
     Value *Op = State.get(getOperand(0), VPLane(0));
diff --git a/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll b/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll
index 7c42c3d9cd52e..a6ac9c2886a92 100644
--- a/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll
@@ -1167,8 +1167,7 @@ define i32 @narrowed_reduction(ptr %a, i1 %cmp) #0 {
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 16
 ; CHECK-NEXT:    br i1 true, label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP28:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP10:%.*]] = trunc <16 x i32> [[TMP7]] to <16 x i1>
-; CHECK-NEXT:    [[TMP20:%.*]] = call i1 @llvm.vector.reduce.or.v16i1(<16 x i1> [[TMP10]])
+; CHECK-NEXT:    [[TMP20:%.*]] = call i1 @llvm.vector.reduce.or.v16i1(<16 x i1> [[TMP5]])
 ; CHECK-NEXT:    [[TMP21:%.*]] = zext i1 [[TMP20]] to i32
 ; CHECK-NEXT:    br i1 true, label [[EXIT:%.*]], label [[VEC_EPILOG_PH]]
 ; CHECK:       scalar.ph:
diff --git a/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll b/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll
index 0a2bb8d5682f2..c101d6a19aa2e 100644
--- a/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll
@@ -208,8 +208,7 @@ define i16 @reduction_or_trunc(ptr noalias nocapture %ptr) {
 ; CHECK-NEXT:    [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], 256
 ; CHECK-NEXT:    br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP9:%.*]] = trunc <4 x i32> [[TMP7]] to <4 x i16>
-; CHECK-NEXT:    [[TMP10:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP9]])
+; CHECK-NEXT:    [[TMP10:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP6]])
 ; CHECK-NEXT:    [[TMP11:%.*]] = zext i16 [[TMP10]] to i32
 ; CHECK-NEXT:    br i1 true, label [[FOR_END:%.*]], label [[VEC_EPILOG_ITER_CHECK:%.*]]
 ; CHECK:       vec.epilog.iter.check:
@@ -234,8 +233,7 @@ define i16 @reduction_or_trunc(ptr noalias nocapture %ptr) {
 ; CHECK-NEXT:    [[TMP21:%.*]] = icmp eq i32 [[INDEX_NEXT4]], 256
 ; CHECK-NEXT:    br i1 [[TMP21]], label [[VEC_EPILOG_MIDDLE_BLOCK:%.*]], label [[VEC_EPILOG_VECTOR_BODY]], !llvm.loop [[LOOP9:![0-9]+]]
 ; CHECK:       vec.epilog.middle.block:
-; CHECK-NEXT:    [[TMP22:%.*]] = trunc <4 x i32> [[TMP20]] to <4 x i16>
-; CHECK-NEXT:    [[TMP23:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP22]])
+; CHECK-NEXT:    [[TMP23:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP19]])
 ; CHECK-NEXT:    [[TMP24:%.*]] = zext i16 [[TMP23]] to i32
 ; CHECK-NEXT:    br i1 true, label [[FOR_END]], label [[VEC_EPILOG_SCALAR_PH]]
 ; CHECK:       vec.epilog.scalar.ph:
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll b/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll
index 796c1d116aa19..13cc1b657d231 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll
@@ -25,8 +25,7 @@ define i8 @PR34687(i1 %c, i32 %x, i32 %n) {
 ; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP5]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP6:%.*]] = trunc <4 x i32> [[TMP4]] to <4 x i8>
-; CHECK-NEXT:    [[TMP7:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP6]])
+; CHECK-NEXT:    [[TMP7:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP3]])
 ; CHECK-NEXT:    [[TMP8:%.*]] = zext i8 [[TMP7]] to i32
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]
@@ -104,8 +103,7 @@ define i8 @PR34687_no_undef(i1 %c, i32 %x, i32 %n) {
 ; CHECK-NEXT:    [[TMP7:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP8:%.*]] = trunc <4 x i32> [[TMP6]] to <4 x i8>
-; CHECK-NEXT:    [[TMP9:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP8]])
+; CHECK-NEXT:    [[TMP9:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP5]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = zext i8 [[TMP9]] to i32
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]
@@ -183,8 +181,7 @@ define i32 @PR35734(i32 %x, i32 %y) {
 ; CHECK-NEXT:    [[TMP7:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP8:%.*]] = trunc <4 x i32> [[TMP6]] to <4 x i1>
-; CHECK-NEXT:    [[TMP9:%.*]] = call i1 @llvm.vector.reduce.add.v4i1(<4 x i1> [[TMP8]])
+; CHECK-NEXT:    [[TMP9:%.*]] = call i1 @llvm.vector.reduce.add.v4i1(<4 x i1> [[TMP5]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = sext i1 [[TMP9]] to i32
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[TMP1]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]
diff --git a/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll
index 223acfa2e3a25..6e251cb7b0ad3 100644
--- a/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll
@@ -28,9 +28,7 @@ define i8 @reduction_add_trunc(ptr noalias nocapture %A) {
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], [[TMP31]]
 ; CHECK-NEXT:    [[TMP32:%.*]] = icmp eq i32 [[INDEX_NEXT]], {{%.*}}
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP37:%.*]] = trunc <vscale x 8 x i32> [[TMP34]] to <vscale x 8 x i8>
-; CHECK-NEXT:    [[TMP38:%.*]] = trunc <vscale x 8 x i32> [[TMP36]] to <vscale x 8 x i8>
-; CHECK-NEXT:    [[BIN_RDX:%.*]] = add <vscale x 8 x i8> [[TMP38]], [[TMP37]]
+; CHECK-NEXT:    [[BIN_RDX:%.*]] = add <vscale x 8 x i8> [[TMP35]], [[TMP33]]
 ; CHECK-NEXT:    [[TMP39:%.*]] = call i8 @llvm.vector.reduce.add.nxv8i8(<vscale x 8 x i8> [[BIN_RDX]])
 ; CHECK-NEXT:    [[TMP40:%.*]] = zext i8 [[TMP39]] to i32
 ; CHECK-NEXT:    %cmp.n = icmp eq i32 256, %n.vec



More information about the llvm-commits mailing list