[llvm] [VPlan] Truncate/Extend ComputeReductionResult at construction (NFC). (PR #141860)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 9 14:16:41 PDT 2025


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/141860

>From 76eb1d84e92863c563489f1d4ed84fcd5f55228e Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sun, 1 Jun 2025 22:07:35 +0100
Subject: [PATCH] [VPlan] Truncate/Extend ComputeReductionResult at
 construction (NFC).

Instead of looking up the narrower reduction type via getRecurrenceType
we can generate the needed extend directly at constructiond re-use the
truncated value from the loop.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 52 +++++++++++--------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 18 +------
 .../LoopVectorize/X86/cost-model.ll           |  3 +-
 .../epilog-vectorization-reductions.ll        |  6 +--
 .../LoopVectorize/reduction-small-size.ll     |  9 ++--
 .../scalable-reduction-inloop.ll              |  4 +-
 6 files changed, 37 insertions(+), 55 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 333e50ee98418..ff1f94d84402e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7226,7 +7226,10 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
   // Get the VPInstruction computing the reduction result in the middle block.
   // The first operand may not be from the middle block if it is not connected
   // to the scalar preheader. In that case, there's nothing to fix.
-  auto *EpiRedResult = dyn_cast<VPInstruction>(EpiResumePhiR->getOperand(0));
+  VPValue *Incoming = EpiResumePhiR->getOperand(0);
+  match(Incoming, VPlanPatternMatch::m_ZExtOrSExt(
+                      VPlanPatternMatch::m_VPValue(Incoming)));
+  auto *EpiRedResult = dyn_cast<VPInstruction>(Incoming);
   if (!EpiRedResult ||
       (EpiRedResult->getOpcode() != VPInstruction::ComputeAnyOfResult &&
        EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
@@ -9201,28 +9204,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         PhiR->setOperand(1, NewExitingVPV);
     }
 
-    // If the vector reduction can be performed in a smaller type, we truncate
-    // then extend the loop exit value to enable InstCombine to evaluate the
-    // entire expression in the smaller type.
-    if (MinVF.isVector() && PhiTy != RdxDesc.getRecurrenceType() &&
-        !RecurrenceDescriptor::isAnyOfRecurrenceKind(
-            RdxDesc.getRecurrenceKind())) {
-      assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!");
-      Type *RdxTy = RdxDesc.getRecurrenceType();
-      auto *Trunc =
-          new VPWidenCastRecipe(Instruction::Trunc, NewExitingVPV, RdxTy);
-      auto *Extnd =
-          RdxDesc.isSigned()
-              ? new VPWidenCastRecipe(Instruction::SExt, Trunc, PhiTy)
-              : new VPWidenCastRecipe(Instruction::ZExt, Trunc, PhiTy);
-
-      Trunc->insertAfter(NewExitingVPV->getDefiningRecipe());
-      Extnd->insertAfter(Trunc);
-      if (PhiR->getOperand(1) == NewExitingVPV)
-        PhiR->setOperand(1, Extnd->getVPSingleValue());
-      NewExitingVPV = Extnd;
-    }
-
     // We want code in the middle block to appear to execute on the location of
     // the scalar loop's latch terminator because: (a) it is all compiler
     // generated, (b) these instructions are always executed after evaluating
@@ -9260,6 +9241,31 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
           Builder.createNaryOp(VPInstruction::ComputeReductionResult,
                                {PhiR, NewExitingVPV}, Flags, ExitDL);
     }
+    // If the vector reduction can be performed in a smaller type, we truncate
+    // then extend the loop exit value to enable InstCombine to evaluate the
+    // entire expression in the smaller type.
+    if (MinVF.isVector() && PhiTy != RdxDesc.getRecurrenceType() &&
+        !RecurrenceDescriptor::isAnyOfRecurrenceKind(
+            RdxDesc.getRecurrenceKind())) {
+      assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!");
+      Type *RdxTy = RdxDesc.getRecurrenceType();
+      auto *Trunc =
+          new VPWidenCastRecipe(Instruction::Trunc, NewExitingVPV, RdxTy);
+      Instruction::CastOps ExtendOpc =
+          RdxDesc.isSigned() ? Instruction::SExt : Instruction::ZExt;
+      auto *Extnd = new VPWidenCastRecipe(ExtendOpc, Trunc, PhiTy);
+      Trunc->insertAfter(NewExitingVPV->getDefiningRecipe());
+      Extnd->insertAfter(Trunc);
+      if (PhiR->getOperand(1) == NewExitingVPV)
+        PhiR->setOperand(1, Extnd->getVPSingleValue());
+
+      // Update ComputeReductionResult with the truncated exiting value and
+      // extend its result.
+      FinalReductionResult->setOperand(1, Trunc);
+      FinalReductionResult =
+          Builder.createScalarCast(ExtendOpc, FinalReductionResult, PhiTy, {});
+    }
+
     // Update all users outside the vector region.
     OrigExitingVPV->replaceUsesWithIf(
         FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 62b99d98a2b5e..0f72350b973d5 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -668,7 +668,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
     assert(!RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
            "should be handled by ComputeFindLastIVResult");
 
-    Type *ResultTy = State.TypeAnalysis.inferScalarType(this);
     // The recipe's operands are the reduction phi, followed by one operand for
     // each part of the reduction.
     unsigned UF = getNumOperands() - 1;
@@ -680,15 +679,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
     if (hasFastMathFlags())
       Builder.setFastMathFlags(getFastMathFlags());
 
-    // If the vector reduction can be performed in a smaller type, we truncate
-    // then extend the loop exit value to enable InstCombine to evaluate the
-    // entire expression in the smaller type.
-    // TODO: Handle this in truncateToMinBW.
-    if (State.VF.isVector() && ResultTy != RdxDesc.getRecurrenceType()) {
-      Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), State.VF);
-      for (unsigned Part = 0; Part < UF; ++Part)
-        RdxParts[Part] = Builder.CreateTrunc(RdxParts[Part], RdxVecTy);
-    }
     // Reduce all of the unrolled parts into a single vector.
     Value *ReducedPartRdx = RdxParts[0];
     if (PhiR->isOrdered()) {
@@ -713,13 +703,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
       // All ops in the reduction inherit fast-math-flags from the recurrence
       // descriptor.
       ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
-
-      // If the reduction can be performed in a smaller type, we need to extend
-      // the reduction to the wider type before we branch to the original loop.
-      if (ResultTy != RdxDesc.getRecurrenceType())
-        ReducedPartRdx = RdxDesc.isSigned()
-                             ? Builder.CreateSExt(ReducedPartRdx, ResultTy)
-                             : Builder.CreateZExt(ReducedPartRdx, ResultTy);
     }
 
     return ReducedPartRdx;
@@ -1070,6 +1053,7 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
 void VPInstructionWithType::execute(VPTransformState &State) {
   State.setDebugLocFrom(getDebugLoc());
   switch (getOpcode()) {
+  case Instruction::SExt:
   case Instruction::ZExt:
   case Instruction::Trunc: {
     Value *Op = State.get(getOperand(0), VPLane(0));
diff --git a/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll b/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll
index 2c6fe4f5c808e..147e949808b54 100644
--- a/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll
@@ -1167,8 +1167,7 @@ define i32 @narrowed_reduction(ptr %a, i1 %cmp) #0 {
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 16
 ; CHECK-NEXT:    br i1 true, label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP28:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP10:%.*]] = trunc <16 x i32> [[TMP7]] to <16 x i1>
-; CHECK-NEXT:    [[TMP20:%.*]] = call i1 @llvm.vector.reduce.or.v16i1(<16 x i1> [[TMP10]])
+; CHECK-NEXT:    [[TMP20:%.*]] = call i1 @llvm.vector.reduce.or.v16i1(<16 x i1> [[TMP5]])
 ; CHECK-NEXT:    [[TMP21:%.*]] = zext i1 [[TMP20]] to i32
 ; CHECK-NEXT:    br i1 true, label [[EXIT:%.*]], label [[VEC_EPILOG_PH]]
 ; CHECK:       scalar.ph:
diff --git a/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll b/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll
index 0a2bb8d5682f2..c101d6a19aa2e 100644
--- a/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll
@@ -208,8 +208,7 @@ define i16 @reduction_or_trunc(ptr noalias nocapture %ptr) {
 ; CHECK-NEXT:    [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], 256
 ; CHECK-NEXT:    br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP9:%.*]] = trunc <4 x i32> [[TMP7]] to <4 x i16>
-; CHECK-NEXT:    [[TMP10:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP9]])
+; CHECK-NEXT:    [[TMP10:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP6]])
 ; CHECK-NEXT:    [[TMP11:%.*]] = zext i16 [[TMP10]] to i32
 ; CHECK-NEXT:    br i1 true, label [[FOR_END:%.*]], label [[VEC_EPILOG_ITER_CHECK:%.*]]
 ; CHECK:       vec.epilog.iter.check:
@@ -234,8 +233,7 @@ define i16 @reduction_or_trunc(ptr noalias nocapture %ptr) {
 ; CHECK-NEXT:    [[TMP21:%.*]] = icmp eq i32 [[INDEX_NEXT4]], 256
 ; CHECK-NEXT:    br i1 [[TMP21]], label [[VEC_EPILOG_MIDDLE_BLOCK:%.*]], label [[VEC_EPILOG_VECTOR_BODY]], !llvm.loop [[LOOP9:![0-9]+]]
 ; CHECK:       vec.epilog.middle.block:
-; CHECK-NEXT:    [[TMP22:%.*]] = trunc <4 x i32> [[TMP20]] to <4 x i16>
-; CHECK-NEXT:    [[TMP23:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP22]])
+; CHECK-NEXT:    [[TMP23:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP19]])
 ; CHECK-NEXT:    [[TMP24:%.*]] = zext i16 [[TMP23]] to i32
 ; CHECK-NEXT:    br i1 true, label [[FOR_END]], label [[VEC_EPILOG_SCALAR_PH]]
 ; CHECK:       vec.epilog.scalar.ph:
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll b/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll
index 796c1d116aa19..13cc1b657d231 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll
@@ -25,8 +25,7 @@ define i8 @PR34687(i1 %c, i32 %x, i32 %n) {
 ; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP5]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP6:%.*]] = trunc <4 x i32> [[TMP4]] to <4 x i8>
-; CHECK-NEXT:    [[TMP7:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP6]])
+; CHECK-NEXT:    [[TMP7:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP3]])
 ; CHECK-NEXT:    [[TMP8:%.*]] = zext i8 [[TMP7]] to i32
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]
@@ -104,8 +103,7 @@ define i8 @PR34687_no_undef(i1 %c, i32 %x, i32 %n) {
 ; CHECK-NEXT:    [[TMP7:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP8:%.*]] = trunc <4 x i32> [[TMP6]] to <4 x i8>
-; CHECK-NEXT:    [[TMP9:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP8]])
+; CHECK-NEXT:    [[TMP9:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP5]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = zext i8 [[TMP9]] to i32
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]
@@ -183,8 +181,7 @@ define i32 @PR35734(i32 %x, i32 %y) {
 ; CHECK-NEXT:    [[TMP7:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP8:%.*]] = trunc <4 x i32> [[TMP6]] to <4 x i1>
-; CHECK-NEXT:    [[TMP9:%.*]] = call i1 @llvm.vector.reduce.add.v4i1(<4 x i1> [[TMP8]])
+; CHECK-NEXT:    [[TMP9:%.*]] = call i1 @llvm.vector.reduce.add.v4i1(<4 x i1> [[TMP5]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = sext i1 [[TMP9]] to i32
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[TMP1]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]
diff --git a/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll
index 04b4c6fe58db6..6187b4100671c 100644
--- a/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll
@@ -43,9 +43,7 @@ define i8 @reduction_add_trunc(ptr noalias nocapture %A) {
 ; CHECK-NEXT:    [[TMP21:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP21]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[TMP37:%.*]] = trunc <vscale x 8 x i32> [[TMP34]] to <vscale x 8 x i8>
-; CHECK-NEXT:    [[TMP38:%.*]] = trunc <vscale x 8 x i32> [[TMP36]] to <vscale x 8 x i8>
-; CHECK-NEXT:    [[BIN_RDX:%.*]] = add <vscale x 8 x i8> [[TMP38]], [[TMP37]]
+; CHECK-NEXT:    [[BIN_RDX:%.*]] = add <vscale x 8 x i8> [[TMP35]], [[TMP33]]
 ; CHECK-NEXT:    [[TMP39:%.*]] = call i8 @llvm.vector.reduce.add.nxv8i8(<vscale x 8 x i8> [[BIN_RDX]])
 ; CHECK-NEXT:    [[TMP40:%.*]] = zext i8 [[TMP39]] to i32
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 256, [[N_VEC]]



More information about the llvm-commits mailing list