[Mlir-commits] [mlir] [mlir][SCF] Improve `ForOp::getSuccessorRegions` (PR #177116)

Matthias Springer llvmlistbot at llvm.org
Wed Jan 21 09:55:43 PST 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/177116

>From 39d0f71ea91c87a6567024b5fa3bc7821fb15249 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 21 Jan 2026 08:27:20 +0000
Subject: [PATCH 1/2] [mlir][SCF] Improve `ForOp::getSuccessorRegions`

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 21 ++++++++++
 .../DataFlow/test-dead-code-analysis.mlir     | 38 +++++++++++++++++++
 .../Dialect/Arith/int-range-narrowing.mlir    |  2 +-
 3 files changed, 60 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5b6e9304de505..86e66dbaf6171 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -704,6 +704,27 @@ OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
 
 void ForOp::getSuccessorRegions(RegionBranchPoint point,
                                 SmallVectorImpl<RegionSuccessor> &regions) {
+  std::optional<APInt> tripCount = getStaticTripCount();
+  if (tripCount.has_value()) {
+    // The loop has a known static trip count.
+    if (point.isParent()) {
+      if (*tripCount == 0) {
+        // The loop has zero iterations. It branches directly back to the
+        // parent.
+        regions.push_back(RegionSuccessor::parent());
+      } else {
+        // The loop has at least one iteration. It branches into the body.
+        regions.push_back(RegionSuccessor(&getRegion()));
+      }
+      return;
+    } else if (*tripCount == 1) {
+      // The loop has exactly 1 iteration. Therefore, it branches from the
+      // region to the parent. (No further iteration.)
+      regions.push_back(RegionSuccessor::parent());
+      return;
+    }
+  }
+
   // Both the operation itself and the region may be branching into the body or
   // back into the operation itself. It is possible for loop not to enter the
   // body.
diff --git a/mlir/test/Analysis/DataFlow/test-dead-code-analysis.mlir b/mlir/test/Analysis/DataFlow/test-dead-code-analysis.mlir
index 7ce5c0f9e3d5a..4d3a61601a85c 100644
--- a/mlir/test/Analysis/DataFlow/test-dead-code-analysis.mlir
+++ b/mlir/test/Analysis/DataFlow/test-dead-code-analysis.mlir
@@ -283,3 +283,41 @@ func.func @test_forall_op_control_flow(%num_threads: index) {
   } {tag = "test_forall_op_control_flow"}
   return
 }
+
+func.func @test_for_op_control_flow() {
+  %c1 = arith.constant 1 : index
+  %c5 = arith.constant 5 : index
+  %c6 = arith.constant 6 : index
+  %c7 = arith.constant 7 : index
+
+  // Test case 1: Zero loop iterations.
+  // CHECK: test_for_op_control_flow_zero:
+  // CHECK:  region #0
+  // CHECK:   ^bb0 = dead
+  // CHECK: op_preds: (all) predecessors:
+  // CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %c1 {...} {tag = "test_for_op_control_flow_zero"}
+  scf.for %iv = %c5 to %c5 step %c1 {} {tag = "test_for_op_control_flow_zero"}
+
+  // Test case 2: One loop iteration.
+  // CHECK: test_for_op_control_flow_one:
+  // CHECK:  region #0
+  // CHECK:   ^bb0 = live
+  // CHECK: region_preds: (all) predecessors:
+  // CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {...} {tag = "test_for_op_control_flow_one"}
+  // CHECK: op_preds: (all) predecessors:
+  // CHECK:   scf.yield
+  scf.for %iv = %c5 to %c6 step %c1 {} {tag = "test_for_op_control_flow_one"}
+
+  // Test case 3: More than one loop iteration.
+  // CHECK: test_for_op_control_flow_multi:
+  // CHECK:  region #0
+  // CHECK:   ^bb0 = live
+  // CHECK: region_preds: (all) predecessors:
+  // CHECK:   scf.for %arg0 = %{{.*}} to %{{.*}} step %{{.*}} {...} {tag = "test_for_op_control_flow_multi"}
+  // CHECK:   scf.yield
+  // CHECK: op_preds: (all) predecessors:
+  // CHECK:   scf.yield
+  scf.for %iv = %c5 to %c7 step %c1 {} {tag = "test_for_op_control_flow_multi"}
+
+  return
+}
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 9107bf649b561..e2cd9b50f6736 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -361,7 +361,7 @@ func.func private @use_i64(i64)
 // CHECK-LABEL: func.func @loop_with_iter_arg
 func.func @loop_with_iter_arg() {
   %c0 = arith.constant 0 : index
-  %c1 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
   %c16 = arith.constant 16 : index
 
   %cst = arith.constant dense<0.000000e+00> : vector<4xf32>

>From 10fe36f37f9e0c756d07ff70ec73b7f9037152af Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 21 Jan 2026 18:55:33 +0100
Subject: [PATCH 2/2] Update mlir/lib/Dialect/SCF/IR/SCF.cpp

Co-authored-by: Jakub Kuderski <jakub at nod-labs.com>
---
 mlir/lib/Dialect/SCF/IR/SCF.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 86e66dbaf6171..983f2de63784a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -704,8 +704,7 @@ OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
 
 void ForOp::getSuccessorRegions(RegionBranchPoint point,
                                 SmallVectorImpl<RegionSuccessor> &regions) {
-  std::optional<APInt> tripCount = getStaticTripCount();
-  if (tripCount.has_value()) {
+  if (std::optional<APInt> tripCount = getStaticTripCount()) {
     // The loop has a known static trip count.
     if (point.isParent()) {
       if (*tripCount == 0) {



More information about the Mlir-commits mailing list