[llvm] [VPlan] Truncate/Extend ComputeReductionResult at construction (NFC). (PR #141860)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Sun Jun 1 14:16:03 PDT 2025
https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/141860
>From 6eb0f459eb1d1fe7328d49b6f479fffb7b813d1b Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sun, 1 Jun 2025 22:07:35 +0100
Subject: [PATCH] [VPlan] Truncate/Extend ComputeReductionResult at
construction (NFC).
Instead of looking up the narrower reduction type via getRecurrenceType
we can generate the needed extend directly at constructiond re-use the
truncated value from the loop.
---
.../Transforms/Vectorize/LoopVectorize.cpp | 54 +++++++++++--------
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 23 ++------
.../LoopVectorize/X86/cost-model.ll | 3 +-
.../epilog-vectorization-reductions.ll | 6 +--
.../LoopVectorize/reduction-small-size.ll | 9 ++--
.../scalable-reduction-inloop.ll | 4 +-
6 files changed, 43 insertions(+), 56 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e9ace195684b3..e803edf18a926 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7536,6 +7536,13 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
// created a bc.merge.rdx Phi after the main vector body. Ensure that we carry
// over the incoming values correctly.
using namespace VPlanPatternMatch;
+ if (EpiRedResult->getNumUsers() == 1 &&
+ isa<VPInstructionWithType>(*EpiRedResult->user_begin())) {
+ EpiRedResult = cast<VPInstructionWithType>(*EpiRedResult->user_begin());
+ assert((EpiRedResult->getOpcode() == Instruction::SExt ||
+ EpiRedResult->getOpcode() == Instruction::ZExt) &&
+ "can only have SExt/ZExt users");
+ }
assert(count_if(EpiRedResult->users(), IsaPred<VPPhi>) == 1 &&
"ResumePhi must have a single user");
auto *EpiResumePhiVPI =
@@ -9468,28 +9475,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
PhiR->setOperand(1, NewExitingVPV);
}
- // If the vector reduction can be performed in a smaller type, we truncate
- // then extend the loop exit value to enable InstCombine to evaluate the
- // entire expression in the smaller type.
- if (MinVF.isVector() && PhiTy != RdxDesc.getRecurrenceType() &&
- !RecurrenceDescriptor::isAnyOfRecurrenceKind(
- RdxDesc.getRecurrenceKind())) {
- assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!");
- Type *RdxTy = RdxDesc.getRecurrenceType();
- auto *Trunc =
- new VPWidenCastRecipe(Instruction::Trunc, NewExitingVPV, RdxTy);
- auto *Extnd =
- RdxDesc.isSigned()
- ? new VPWidenCastRecipe(Instruction::SExt, Trunc, PhiTy)
- : new VPWidenCastRecipe(Instruction::ZExt, Trunc, PhiTy);
-
- Trunc->insertAfter(NewExitingVPV->getDefiningRecipe());
- Extnd->insertAfter(Trunc);
- if (PhiR->getOperand(1) == NewExitingVPV)
- PhiR->setOperand(1, Extnd->getVPSingleValue());
- NewExitingVPV = Extnd;
- }
-
// We want code in the middle block to appear to execute on the location of
// the scalar loop's latch terminator because: (a) it is all compiler
// generated, (b) these instructions are always executed after evaluating
@@ -9521,6 +9506,31 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
Builder.createNaryOp(VPInstruction::ComputeReductionResult,
{PhiR, NewExitingVPV}, Flags, ExitDL);
}
+ // If the vector reduction can be performed in a smaller type, we truncate
+ // then extend the loop exit value to enable InstCombine to evaluate the
+ // entire expression in the smaller type.
+ if (MinVF.isVector() && PhiTy != RdxDesc.getRecurrenceType() &&
+ !RecurrenceDescriptor::isAnyOfRecurrenceKind(
+ RdxDesc.getRecurrenceKind())) {
+ assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!");
+ Type *RdxTy = RdxDesc.getRecurrenceType();
+ auto *Trunc =
+ new VPWidenCastRecipe(Instruction::Trunc, NewExitingVPV, RdxTy);
+ Instruction::CastOps ExtendOpc =
+ RdxDesc.isSigned() ? Instruction::SExt : Instruction::ZExt;
+ auto *Extnd = new VPWidenCastRecipe(ExtendOpc, Trunc, PhiTy);
+ Trunc->insertAfter(NewExitingVPV->getDefiningRecipe());
+ Extnd->insertAfter(Trunc);
+ if (PhiR->getOperand(1) == NewExitingVPV)
+ PhiR->setOperand(1, Extnd->getVPSingleValue());
+
+ // Update ComputeReductionResult with the truncated exiting value and
+ // extend its result.
+ FinalReductionResult->setOperand(1, Trunc);
+ FinalReductionResult =
+ Builder.createScalarCast(ExtendOpc, FinalReductionResult, PhiTy, {});
+ }
+
// Update all users outside the vector region.
OrigExitingVPV->replaceUsesWithIf(
FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index a4831ea7c11f7..672c1f2d9524c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -633,7 +633,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
// and will be removed by breaking up the recipe further.
auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
- auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
// Get its reduction variable descriptor.
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
@@ -641,7 +640,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
assert(!RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
"should be handled by ComputeFindLastIVResult");
- Type *PhiTy = OrigPhi->getType();
// The recipe's operands are the reduction phi, followed by one operand for
// each part of the reduction.
unsigned UF = getNumOperands() - 1;
@@ -653,15 +651,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
if (hasFastMathFlags())
Builder.setFastMathFlags(getFastMathFlags());
- // If the vector reduction can be performed in a smaller type, we truncate
- // then extend the loop exit value to enable InstCombine to evaluate the
- // entire expression in the smaller type.
- // TODO: Handle this in truncateToMinBW.
- if (State.VF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) {
- Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), State.VF);
- for (unsigned Part = 0; Part < UF; ++Part)
- RdxParts[Part] = Builder.CreateTrunc(RdxParts[Part], RdxVecTy);
- }
// Reduce all of the unrolled parts into a single vector.
Value *ReducedPartRdx = RdxParts[0];
if (PhiR->isOrdered()) {
@@ -687,19 +676,14 @@ Value *VPInstruction::generate(VPTransformState &State) {
// TODO: Support in-order reductions based on the recurrence descriptor.
// All ops in the reduction inherit fast-math-flags from the recurrence
// descriptor.
- if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
+ if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) {
+ auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
ReducedPartRdx =
createAnyOfReduction(Builder, ReducedPartRdx,
RdxDesc.getRecurrenceStartValue(), OrigPhi);
- else
+ } else
ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
- // If the reduction can be performed in a smaller type, we need to extend
- // the reduction to the wider type before we branch to the original loop.
- if (PhiTy != RdxDesc.getRecurrenceType())
- ReducedPartRdx = RdxDesc.isSigned()
- ? Builder.CreateSExt(ReducedPartRdx, PhiTy)
- : Builder.CreateZExt(ReducedPartRdx, PhiTy);
}
return ReducedPartRdx;
@@ -1040,6 +1024,7 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
void VPInstructionWithType::execute(VPTransformState &State) {
State.setDebugLocFrom(getDebugLoc());
switch (getOpcode()) {
+ case Instruction::SExt:
case Instruction::ZExt:
case Instruction::Trunc: {
Value *Op = State.get(getOperand(0), VPLane(0));
diff --git a/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll b/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll
index 7c42c3d9cd52e..a6ac9c2886a92 100644
--- a/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/cost-model.ll
@@ -1167,8 +1167,7 @@ define i32 @narrowed_reduction(ptr %a, i1 %cmp) #0 {
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 16
; CHECK-NEXT: br i1 true, label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP28:![0-9]+]]
; CHECK: middle.block:
-; CHECK-NEXT: [[TMP10:%.*]] = trunc <16 x i32> [[TMP7]] to <16 x i1>
-; CHECK-NEXT: [[TMP20:%.*]] = call i1 @llvm.vector.reduce.or.v16i1(<16 x i1> [[TMP10]])
+; CHECK-NEXT: [[TMP20:%.*]] = call i1 @llvm.vector.reduce.or.v16i1(<16 x i1> [[TMP5]])
; CHECK-NEXT: [[TMP21:%.*]] = zext i1 [[TMP20]] to i32
; CHECK-NEXT: br i1 true, label [[EXIT:%.*]], label [[VEC_EPILOG_PH]]
; CHECK: scalar.ph:
diff --git a/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll b/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll
index 0a2bb8d5682f2..c101d6a19aa2e 100644
--- a/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/epilog-vectorization-reductions.ll
@@ -208,8 +208,7 @@ define i16 @reduction_or_trunc(ptr noalias nocapture %ptr) {
; CHECK-NEXT: [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], 256
; CHECK-NEXT: br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
; CHECK: middle.block:
-; CHECK-NEXT: [[TMP9:%.*]] = trunc <4 x i32> [[TMP7]] to <4 x i16>
-; CHECK-NEXT: [[TMP10:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP9]])
+; CHECK-NEXT: [[TMP10:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP6]])
; CHECK-NEXT: [[TMP11:%.*]] = zext i16 [[TMP10]] to i32
; CHECK-NEXT: br i1 true, label [[FOR_END:%.*]], label [[VEC_EPILOG_ITER_CHECK:%.*]]
; CHECK: vec.epilog.iter.check:
@@ -234,8 +233,7 @@ define i16 @reduction_or_trunc(ptr noalias nocapture %ptr) {
; CHECK-NEXT: [[TMP21:%.*]] = icmp eq i32 [[INDEX_NEXT4]], 256
; CHECK-NEXT: br i1 [[TMP21]], label [[VEC_EPILOG_MIDDLE_BLOCK:%.*]], label [[VEC_EPILOG_VECTOR_BODY]], !llvm.loop [[LOOP9:![0-9]+]]
; CHECK: vec.epilog.middle.block:
-; CHECK-NEXT: [[TMP22:%.*]] = trunc <4 x i32> [[TMP20]] to <4 x i16>
-; CHECK-NEXT: [[TMP23:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP22]])
+; CHECK-NEXT: [[TMP23:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP19]])
; CHECK-NEXT: [[TMP24:%.*]] = zext i16 [[TMP23]] to i32
; CHECK-NEXT: br i1 true, label [[FOR_END]], label [[VEC_EPILOG_SCALAR_PH]]
; CHECK: vec.epilog.scalar.ph:
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll b/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll
index 796c1d116aa19..13cc1b657d231 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-small-size.ll
@@ -25,8 +25,7 @@ define i8 @PR34687(i1 %c, i32 %x, i32 %n) {
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT: br i1 [[TMP5]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK: middle.block:
-; CHECK-NEXT: [[TMP6:%.*]] = trunc <4 x i32> [[TMP4]] to <4 x i8>
-; CHECK-NEXT: [[TMP7:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP6]])
+; CHECK-NEXT: [[TMP7:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP3]])
; CHECK-NEXT: [[TMP8:%.*]] = zext i8 [[TMP7]] to i32
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]
@@ -104,8 +103,7 @@ define i8 @PR34687_no_undef(i1 %c, i32 %x, i32 %n) {
; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT: br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK: middle.block:
-; CHECK-NEXT: [[TMP8:%.*]] = trunc <4 x i32> [[TMP6]] to <4 x i8>
-; CHECK-NEXT: [[TMP9:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP8]])
+; CHECK-NEXT: [[TMP9:%.*]] = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> [[TMP5]])
; CHECK-NEXT: [[TMP10:%.*]] = zext i8 [[TMP9]] to i32
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]
@@ -183,8 +181,7 @@ define i32 @PR35734(i32 %x, i32 %y) {
; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT: br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
; CHECK: middle.block:
-; CHECK-NEXT: [[TMP8:%.*]] = trunc <4 x i32> [[TMP6]] to <4 x i1>
-; CHECK-NEXT: [[TMP9:%.*]] = call i1 @llvm.vector.reduce.add.v4i1(<4 x i1> [[TMP8]])
+; CHECK-NEXT: [[TMP9:%.*]] = call i1 @llvm.vector.reduce.add.v4i1(<4 x i1> [[TMP5]])
; CHECK-NEXT: [[TMP10:%.*]] = sext i1 [[TMP9]] to i32
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[TMP1]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]
diff --git a/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll
index 223acfa2e3a25..6e251cb7b0ad3 100644
--- a/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/scalable-reduction-inloop.ll
@@ -28,9 +28,7 @@ define i8 @reduction_add_trunc(ptr noalias nocapture %A) {
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], [[TMP31]]
; CHECK-NEXT: [[TMP32:%.*]] = icmp eq i32 [[INDEX_NEXT]], {{%.*}}
; CHECK: middle.block:
-; CHECK-NEXT: [[TMP37:%.*]] = trunc <vscale x 8 x i32> [[TMP34]] to <vscale x 8 x i8>
-; CHECK-NEXT: [[TMP38:%.*]] = trunc <vscale x 8 x i32> [[TMP36]] to <vscale x 8 x i8>
-; CHECK-NEXT: [[BIN_RDX:%.*]] = add <vscale x 8 x i8> [[TMP38]], [[TMP37]]
+; CHECK-NEXT: [[BIN_RDX:%.*]] = add <vscale x 8 x i8> [[TMP35]], [[TMP33]]
; CHECK-NEXT: [[TMP39:%.*]] = call i8 @llvm.vector.reduce.add.nxv8i8(<vscale x 8 x i8> [[BIN_RDX]])
; CHECK-NEXT: [[TMP40:%.*]] = zext i8 [[TMP39]] to i32
; CHECK-NEXT: %cmp.n = icmp eq i32 256, %n.vec
More information about the llvm-commits
mailing list