[Mlir-commits] [mlir] [mlir][SCF] Improve `ForOp::getSuccessorRegions` (PR #177116)
Matthias Springer
llvmlistbot at llvm.org
Wed Jan 21 09:55:43 PST 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/177116
>From 39d0f71ea91c87a6567024b5fa3bc7821fb15249 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 21 Jan 2026 08:27:20 +0000
Subject: [PATCH 1/2] [mlir][SCF] Improve `ForOp::getSuccessorRegions`
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 21 ++++++++++
.../DataFlow/test-dead-code-analysis.mlir | 38 +++++++++++++++++++
.../Dialect/Arith/int-range-narrowing.mlir | 2 +-
3 files changed, 60 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5b6e9304de505..86e66dbaf6171 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -704,6 +704,27 @@ OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
void ForOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<APInt> tripCount = getStaticTripCount();
+ if (tripCount.has_value()) {
+ // The loop has a known static trip count.
+ if (point.isParent()) {
+ if (*tripCount == 0) {
+ // The loop has zero iterations. It branches directly back to the
+ // parent.
+ regions.push_back(RegionSuccessor::parent());
+ } else {
+ // The loop has at least one iteration. It branches into the body.
+ regions.push_back(RegionSuccessor(&getRegion()));
+ }
+ return;
+ } else if (*tripCount == 1) {
+ // The loop has exactly 1 iteration. Therefore, it branches from the
+ // region to the parent. (No further iteration.)
+ regions.push_back(RegionSuccessor::parent());
+ return;
+ }
+ }
+
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
// body.
diff --git a/mlir/test/Analysis/DataFlow/test-dead-code-analysis.mlir b/mlir/test/Analysis/DataFlow/test-dead-code-analysis.mlir
index 7ce5c0f9e3d5a..4d3a61601a85c 100644
--- a/mlir/test/Analysis/DataFlow/test-dead-code-analysis.mlir
+++ b/mlir/test/Analysis/DataFlow/test-dead-code-analysis.mlir
@@ -283,3 +283,41 @@ func.func @test_forall_op_control_flow(%num_threads: index) {
} {tag = "test_forall_op_control_flow"}
return
}
+
+func.func @test_for_op_control_flow() {
+ %c1 = arith.constant 1 : index
+ %c5 = arith.constant 5 : index
+ %c6 = arith.constant 6 : index
+ %c7 = arith.constant 7 : index
+
+ // Test case 1: Zero loop iterations.
+ // CHECK: test_for_op_control_flow_zero:
+ // CHECK: region #0
+ // CHECK: ^bb0 = dead
+ // CHECK: op_preds: (all) predecessors:
+ // CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %c1 {...} {tag = "test_for_op_control_flow_zero"}
+ scf.for %iv = %c5 to %c5 step %c1 {} {tag = "test_for_op_control_flow_zero"}
+
+ // Test case 2: One loop iteration.
+ // CHECK: test_for_op_control_flow_one:
+ // CHECK: region #0
+ // CHECK: ^bb0 = live
+ // CHECK: region_preds: (all) predecessors:
+ // CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {...} {tag = "test_for_op_control_flow_one"}
+ // CHECK: op_preds: (all) predecessors:
+ // CHECK: scf.yield
+ scf.for %iv = %c5 to %c6 step %c1 {} {tag = "test_for_op_control_flow_one"}
+
+ // Test case 3: More than one loop iteration.
+ // CHECK: test_for_op_control_flow_multi:
+ // CHECK: region #0
+ // CHECK: ^bb0 = live
+ // CHECK: region_preds: (all) predecessors:
+ // CHECK: scf.for %arg0 = %{{.*}} to %{{.*}} step %{{.*}} {...} {tag = "test_for_op_control_flow_multi"}
+ // CHECK: scf.yield
+ // CHECK: op_preds: (all) predecessors:
+ // CHECK: scf.yield
+ scf.for %iv = %c5 to %c7 step %c1 {} {tag = "test_for_op_control_flow_multi"}
+
+ return
+}
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 9107bf649b561..e2cd9b50f6736 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -361,7 +361,7 @@ func.func private @use_i64(i64)
// CHECK-LABEL: func.func @loop_with_iter_arg
func.func @loop_with_iter_arg() {
%c0 = arith.constant 0 : index
- %c1 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%cst = arith.constant dense<0.000000e+00> : vector<4xf32>
>From 10fe36f37f9e0c756d07ff70ec73b7f9037152af Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 21 Jan 2026 18:55:33 +0100
Subject: [PATCH 2/2] Update mlir/lib/Dialect/SCF/IR/SCF.cpp
Co-authored-by: Jakub Kuderski <jakub at nod-labs.com>
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 86e66dbaf6171..983f2de63784a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -704,8 +704,7 @@ OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
void ForOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
- std::optional<APInt> tripCount = getStaticTripCount();
- if (tripCount.has_value()) {
+ if (std::optional<APInt> tripCount = getStaticTripCount()) {
// The loop has a known static trip count.
if (point.isParent()) {
if (*tripCount == 0) {
More information about the Mlir-commits
mailing list