[llvm] [LoopFusion] Fix sink instructions (PR #147501)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 8 03:41:26 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Madhur Amilkanthwar (madhur13490)

<details>
<summary>Changes</summary>

If we have instructions in second loop's preheader which can be sunk, we should also be adjusting
PHI nodes to receive values from the new loop's latch block.

Fixes #<!-- -->128600

---
Full diff: https://github.com/llvm/llvm-project/pull/147501.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/LoopFuse.cpp (+29-5) 
- (added) llvm/test/Transforms/LoopFusion/sunk-phi-nodes.ll (+86) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/llvm/lib/Transforms/Scalar/LoopFuse.cpp
index d6bd92d520e28..6e1556a4d90b4 100644
--- a/llvm/lib/Transforms/Scalar/LoopFuse.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFuse.cpp
@@ -988,8 +988,8 @@ struct LoopFuser {
 
             // If it is not safe to hoist/sink all instructions in the
             // pre-header, we cannot fuse these loops.
-            if (!collectMovablePreheaderInsts(*FC0, *FC1, SafeToHoist,
-                                              SafeToSink)) {
+            if (!collectAndFixMovablePreheaderInsts(*FC0, *FC1, SafeToHoist,
+                                                    SafeToSink)) {
               LLVM_DEBUG(dbgs() << "Could not hoist/sink all instructions in "
                                    "Fusion Candidate Pre-header.\n"
                                 << "Not Fusing.\n");
@@ -1033,8 +1033,8 @@ struct LoopFuser {
                                                FuseCounter);
 
           FusionCandidate FusedCand(
-              performFusion((Peel ? FC0Copy : *FC0), *FC1), DT, &PDT, ORE,
-              FC0Copy.PP);
+              performFusion((Peel ? FC0Copy : *FC0), *FC1, SafeToSink), DT,
+              &PDT, ORE, FC0Copy.PP);
           FusedCand.verify();
           assert(FusedCand.isEligibleForFusion(SE) &&
                  "Fused candidate should be eligible for fusion!");
@@ -1176,9 +1176,31 @@ struct LoopFuser {
     return true;
   }
 
+  void fixPHINodes(SmallVector<Instruction *, 4> &SafeToSink,
+                   const FusionCandidate &FC0,
+                   const FusionCandidate &FC1) const {
+    // Iterate over SafeToSink instructions and update PHI nodes
+    // to take values from the latch block of FC0 if they are taking
+    // from the latch block of FC1.
+    for (Instruction *Inst : SafeToSink) {
+      LLVM_DEBUG(dbgs() << "UPDATING: Instruction: " << *Inst << "\n");
+      // Continue if the instruction is not a PHI node.
+      if (!isa<PHINode>(Inst))
+        continue;
+      PHINode *Phi = dyn_cast<PHINode>(Inst);
+      LLVM_DEBUG(dbgs() << "UPDATING: PHI node: " << *Phi << "\n");
+      for (unsigned I = 0; I < Phi->getNumIncomingValues(); I++) {
+        if (Phi->getIncomingBlock(I) != FC0.Latch)
+          continue;
+        assert(FC1.Latch && "FC1 latch is not set");
+        Phi->setIncomingBlock(I, FC1.Latch);
+      }
+    }
+  }
+
   /// Collect instructions in the \p FC1 Preheader that can be hoisted
   /// to the \p FC0 Preheader or sunk into the \p FC1 Body
-  bool collectMovablePreheaderInsts(
+  bool collectAndFixMovablePreheaderInsts(
       const FusionCandidate &FC0, const FusionCandidate &FC1,
       SmallVector<Instruction *, 4> &SafeToHoist,
       SmallVector<Instruction *, 4> &SafeToSink) const {
@@ -1226,6 +1248,8 @@ struct LoopFuser {
     }
     LLVM_DEBUG(
         dbgs() << "All preheader instructions could be sunk or hoisted!\n");
+
+    fixPHINodes(SafeToSink, FC0, FC1);
     return true;
   }
 
diff --git a/llvm/test/Transforms/LoopFusion/sunk-phi-nodes.ll b/llvm/test/Transforms/LoopFusion/sunk-phi-nodes.ll
new file mode 100644
index 0000000000000..3c72df8ae19fb
--- /dev/null
+++ b/llvm/test/Transforms/LoopFusion/sunk-phi-nodes.ll
@@ -0,0 +1,86 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=mem2reg,loop-rotate,loop-fusion < %s 2>&1 | FileCheck %s
+define i32 @main() {
+; CHECK-LABEL: define i32 @main() {
+; CHECK-NEXT:  [[ENTRY:.*]]:
+; CHECK-NEXT:    br label %[[FOR_BODY:.*]]
+; CHECK:       [[FOR_BODY]]:
+; CHECK-NEXT:    [[SUM1_02:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[ADD:%.*]], %[[FOR_INC6:.*]] ]
+; CHECK-NEXT:    [[I_01:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[INC:%.*]], %[[FOR_INC6]] ]
+; CHECK-NEXT:    [[I1_04:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[INC7:%.*]], %[[FOR_INC6]] ]
+; CHECK-NEXT:    [[SUM2_03:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[ADD5:%.*]], %[[FOR_INC6]] ]
+; CHECK-NEXT:    [[ADD]] = add nsw i32 [[SUM1_02]], [[I_01]]
+; CHECK-NEXT:    br label %[[FOR_INC:.*]]
+; CHECK:       [[FOR_INC]]:
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[I1_04]], [[I1_04]]
+; CHECK-NEXT:    [[ADD5]] = add nsw i32 [[SUM2_03]], [[MUL]]
+; CHECK-NEXT:    br label %[[FOR_INC6]]
+; CHECK:       [[FOR_INC6]]:
+; CHECK-NEXT:    [[INC]] = add nsw i32 [[I_01]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[INC]], 10
+; CHECK-NEXT:    [[INC7]] = add nsw i32 [[I1_04]], 1
+; CHECK-NEXT:    [[CMP3:%.*]] = icmp slt i32 [[INC7]], 10
+; CHECK-NEXT:    br i1 [[CMP3]], label %[[FOR_BODY]], label %[[FOR_END8:.*]]
+; CHECK:       [[FOR_END8]]:
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %retval = alloca i32, align 4
+  %sum1 = alloca i32, align 4
+  %sum2 = alloca i32, align 4
+  %i = alloca i32, align 4
+  %i1 = alloca i32, align 4
+  store i32 0, ptr %retval, align 4
+  store i32 0, ptr %sum1, align 4
+  store i32 0, ptr %sum2, align 4
+  store i32 0, ptr %i, align 4
+  br label %for.cond
+
+for.cond:
+  %0 = load i32, ptr %i, align 4
+  %cmp = icmp slt i32 %0, 10
+  br i1 %cmp, label %for.body, label %for.end
+
+for.body:
+  %1 = load i32, ptr %i, align 4
+  %2 = load i32, ptr %sum1, align 4
+  %add = add nsw i32 %2, %1
+  store i32 %add, ptr %sum1, align 4
+  br label %for.inc
+
+for.inc:
+  %3 = load i32, ptr %i, align 4
+  %inc = add nsw i32 %3, 1
+  store i32 %inc, ptr %i, align 4
+  br label %for.cond
+
+for.end:
+  store i32 0, ptr %i1, align 4
+  br label %for.cond2
+
+for.cond2:
+  %4 = load i32, ptr %i1, align 4
+  %cmp3 = icmp slt i32 %4, 10
+  br i1 %cmp3, label %for.body4, label %for.end8
+
+for.body4:
+  %5 = load i32, ptr %i1, align 4
+  %6 = load i32, ptr %i1, align 4
+  %mul = mul nsw i32 %5, %6
+  %7 = load i32, ptr %sum2, align 4
+  %add5 = add nsw i32 %7, %mul
+  store i32 %add5, ptr %sum2, align 4
+  br label %for.inc6
+
+for.inc6:
+  %8 = load i32, ptr %i1, align 4
+  %inc7 = add nsw i32 %8, 1
+  store i32 %inc7, ptr %i1, align 4
+  br label %for.cond2
+
+for.end8:
+  %9 = load i32, ptr %sum1, align 4
+  %10 = load i32, ptr %sum2, align 4
+  ret i32 0
+}
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/147501


More information about the llvm-commits mailing list