[llvm] [VPlan] Fix crash with inloop fmuladd reductions with blend (PR #131154)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 13 08:11:13 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Luke Lau (lukel97)

<details>
<summary>Changes</summary>

When visiting in-loop reduction links, we previously crashed if we had an fmuladd with a blend after it in the chain. This fixes it by lifting the existing blend folding to also handle fmuladd.

This also simplifies the code structure slightly for an upcoming patch I want to post to handle in-loop AnyOf reductions.

I removed the PhiR->isInLoop() check since it's already guarded at the top of the parent Header->Phis() loop.


---
Full diff: https://github.com/llvm/llvm-project/pull/131154.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+13-14) 
- (modified) llvm/test/Transforms/LoopVectorize/reduction-inloop.ll (+90-6) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 02bacde3f60a7..58be4a584fd85 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9713,6 +9713,19 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
     // condition directly.
     VPSingleDefRecipe *PreviousLink = PhiR; // Aka Worklist[0].
     for (VPSingleDefRecipe *CurrentLink : Worklist.getArrayRef().drop_front()) {
+      if (auto *Blend = dyn_cast<VPBlendRecipe>(CurrentLink)) {
+        assert(Blend->getNumIncomingValues() == 2 &&
+               "Blend must have 2 incoming values");
+        if (Blend->getIncomingValue(0) == PhiR)
+          Blend->replaceAllUsesWith(Blend->getIncomingValue(1));
+        else {
+          assert(Blend->getIncomingValue(1) == PhiR &&
+                 "PhiR must be an operand of the blend");
+          Blend->replaceAllUsesWith(Blend->getIncomingValue(0));
+        }
+        continue;
+      }
+
       Instruction *CurrentLinkI = CurrentLink->getUnderlyingInstr();
 
       // Index of the first operand which holds a non-mask vector operand.
@@ -9741,20 +9754,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         LinkVPBB->insert(FMulRecipe, CurrentLink->getIterator());
         VecOp = FMulRecipe;
       } else {
-        auto *Blend = dyn_cast<VPBlendRecipe>(CurrentLink);
-        if (PhiR->isInLoop() && Blend) {
-          assert(Blend->getNumIncomingValues() == 2 &&
-                 "Blend must have 2 incoming values");
-          if (Blend->getIncomingValue(0) == PhiR)
-            Blend->replaceAllUsesWith(Blend->getIncomingValue(1));
-          else {
-            assert(Blend->getIncomingValue(1) == PhiR &&
-                   "PhiR must be an operand of the blend");
-            Blend->replaceAllUsesWith(Blend->getIncomingValue(0));
-          }
-          continue;
-        }
-
         if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) {
           if (isa<VPWidenRecipe>(CurrentLink)) {
             assert(isa<CmpInst>(CurrentLinkI) &&
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index 6315ce541b38b..9a4aa11da7530 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -1106,6 +1106,90 @@ for.end:
   ret float %muladd
 }
 
+define float @reduction_fmuladd_blend(ptr %a, ptr %b, i64 %n, i1 %c) {
+; CHECK-LABEL: @reduction_fmuladd_blend(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[N:%.*]], 4
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i64 [[N]], -4
+; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i1> poison, i1 [[C:%.*]], i64 0
+; CHECK-NEXT:    [[TMP0:%.*]] = xor <4 x i1> [[BROADCAST_SPLATINSERT]], <i1 true, i1 poison, i1 poison, i1 poison>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i1> [[TMP0]], <4 x i1> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi float [ 0.000000e+00, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds float, ptr [[A:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x float>, ptr [[TMP2]], align 4
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds float, ptr [[B:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x float>, ptr [[TMP3]], align 4
+; CHECK-NEXT:    [[TMP4:%.*]] = fmul <4 x float> [[WIDE_LOAD]], [[WIDE_LOAD1]]
+; CHECK-NEXT:    [[TMP5:%.*]] = select <4 x i1> [[TMP1]], <4 x float> [[TMP4]], <4 x float> splat (float -0.000000e+00)
+; CHECK-NEXT:    [[TMP6:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP5]])
+; CHECK-NEXT:    [[TMP7]] = fadd float [[TMP6]], [[VEC_PHI]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
+; CHECK-NEXT:    [[TMP8:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP38:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[N]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]
+; CHECK:       scalar.ph:
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi float [ [[TMP7]], [[MIDDLE_BLOCK]] ], [ 0.000000e+00, [[ENTRY]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.body:
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], [[LATCH:%.*]] ]
+; CHECK-NEXT:    [[SUM:%.*]] = phi float [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ], [ [[SUM_NEXT:%.*]], [[LATCH]] ]
+; CHECK-NEXT:    br i1 [[C]], label [[FOO:%.*]], label [[BAR:%.*]]
+; CHECK:       foo:
+; CHECK-NEXT:    br label [[LATCH]]
+; CHECK:       bar:
+; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr inbounds float, ptr [[B]], i64 [[IV]]
+; CHECK-NEXT:    [[TMP9:%.*]] = load float, ptr [[ARRAYIDX2]], align 4
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds float, ptr [[A]], i64 [[IV]]
+; CHECK-NEXT:    [[TMP10:%.*]] = load float, ptr [[ARRAYIDX]], align 4
+; CHECK-NEXT:    [[MULADD:%.*]] = tail call float @llvm.fmuladd.f32(float [[TMP10]], float [[TMP9]], float [[SUM]])
+; CHECK-NEXT:    br label [[LATCH]]
+; CHECK:       latch:
+; CHECK-NEXT:    [[SUM_NEXT]] = phi float [ [[SUM]], [[FOO]] ], [ [[MULADD]], [[BAR]] ]
+; CHECK-NEXT:    [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1
+; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[IV_NEXT]], [[N]]
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_END]], label [[FOR_BODY]], !llvm.loop [[LOOP39:![0-9]+]]
+; CHECK:       for.end:
+; CHECK-NEXT:    [[SUM_NEXT_LCSSA:%.*]] = phi float [ [[SUM_NEXT]], [[LATCH]] ], [ [[TMP7]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    ret float [[SUM_NEXT_LCSSA]]
+;
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %latch ]
+  %sum = phi float [ 0.000000e+00, %entry ], [ %sum.next, %latch ]
+  %arrayidx = getelementptr inbounds float, ptr %a, i64 %iv
+  %0 = load float, ptr %arrayidx, align 4
+  %arrayidx2 = getelementptr inbounds float, ptr %b, i64 %iv
+  %1 = load float, ptr %arrayidx2, align 4
+
+  br i1 %c, label %foo, label %bar
+
+foo:
+  br label %latch
+
+bar:
+  %muladd = tail call float @llvm.fmuladd.f32(float %0, float %1, float %sum)
+  br label %latch
+
+latch:
+  %sum.next = phi float [ %sum, %foo ], [ %muladd, %bar ]
+  %iv.next = add nuw nsw i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, %n
+  br i1 %exitcond.not, label %for.end, label %for.body
+
+for.end:
+  ret float %sum.next
+}
+
 ; This case was previously failing verification due to the mask for the
 ; reduction being created after the reduction.
 define i32 @predicated_not_dominates_reduction(ptr nocapture noundef readonly %h, i32 noundef %i) {
@@ -1130,7 +1214,7 @@ define i32 @predicated_not_dominates_reduction(ptr nocapture noundef readonly %h
 ; CHECK-NEXT:    [[TMP7]] = add i32 [[TMP6]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP38:![0-9]+]]
+; CHECK-NEXT:    br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP40:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[I]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_END7:%.*]], label [[SCALAR_PH]]
@@ -1157,7 +1241,7 @@ define i32 @predicated_not_dominates_reduction(ptr nocapture noundef readonly %h
 ; CHECK-NEXT:    [[G_1]] = phi i32 [ [[ADD]], [[IF_THEN]] ], [ [[G_016]], [[FOR_BODY2]] ]
 ; CHECK-NEXT:    [[INC6]] = add nuw nsw i32 [[A_117]], 1
 ; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i32 [[INC6]], [[I]]
-; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_END7]], label [[FOR_BODY2]], !llvm.loop [[LOOP39:![0-9]+]]
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_END7]], label [[FOR_BODY2]], !llvm.loop [[LOOP41:![0-9]+]]
 ; CHECK:       for.end7:
 ; CHECK-NEXT:    [[G_1_LCSSA:%.*]] = phi i32 [ [[G_1]], [[FOR_INC5]] ], [ [[TMP7]], [[MIDDLE_BLOCK]] ]
 ; CHECK-NEXT:    ret i32 [[G_1_LCSSA]]
@@ -1219,7 +1303,7 @@ define i32 @predicated_not_dominates_reduction_twoadd(ptr nocapture noundef read
 ; CHECK-NEXT:    [[TMP11]] = add i32 [[TMP10]], [[TMP8]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP40:![0-9]+]]
+; CHECK-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP42:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 [[I]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_END7:%.*]], label [[SCALAR_PH]]
@@ -1247,7 +1331,7 @@ define i32 @predicated_not_dominates_reduction_twoadd(ptr nocapture noundef read
 ; CHECK-NEXT:    [[G_1]] = phi i32 [ [[ADD]], [[IF_THEN]] ], [ [[G_016]], [[FOR_BODY2]] ]
 ; CHECK-NEXT:    [[INC6]] = add nuw nsw i32 [[A_117]], 1
 ; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i32 [[INC6]], [[I]]
-; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_END7]], label [[FOR_BODY2]], !llvm.loop [[LOOP41:![0-9]+]]
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_END7]], label [[FOR_BODY2]], !llvm.loop [[LOOP43:![0-9]+]]
 ; CHECK:       for.end7:
 ; CHECK-NEXT:    [[G_1_LCSSA:%.*]] = phi i32 [ [[G_1]], [[FOR_INC5]] ], [ [[TMP11]], [[MIDDLE_BLOCK]] ]
 ; CHECK-NEXT:    ret i32 [[G_1_LCSSA]]
@@ -1362,7 +1446,7 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
 ; CHECK-NEXT:    [[TMP48]] = add i32 [[VEC_PHI]], [[TMP47]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP49:%.*]] = icmp eq i32 [[INDEX_NEXT]], 1000
-; CHECK-NEXT:    br i1 [[TMP49]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP42:![0-9]+]]
+; CHECK-NEXT:    br i1 [[TMP49]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP44:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    br i1 true, label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
@@ -1377,7 +1461,7 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
 ; CHECK:       if.then:
 ; CHECK-NEXT:    br label [[FOR_INC]]
 ; CHECK:       for.inc:
-; CHECK-NEXT:    br i1 poison, label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP43:![0-9]+]]
+; CHECK-NEXT:    br i1 poison, label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP45:![0-9]+]]
 ;
 entry:
   br label %for.body

``````````

</details>


https://github.com/llvm/llvm-project/pull/131154


More information about the llvm-commits mailing list