[Mlir-commits] [mlir] [mlir][scf] Fold away `scf.for` iter args cycles (PR #173436)
Ivan Butygin
llvmlistbot at llvm.org
Wed Dec 24 04:08:57 PST 2025
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/173436
>From 9be865a06df966cb7be2c7ec76318e28d83e64d0 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 24 Dec 2025 01:41:50 +0100
Subject: [PATCH 1/5] [mlir][scf] Fold away `scf.for` iter args cycles
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 99 +++++++++++++++++++++++--
mlir/test/Dialect/SCF/canonicalize.mlir | 66 +++++++++++------
2 files changed, 136 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 652414f6cbe54..99802cfe4b662 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1001,7 +1001,7 @@ namespace {
// The implementation uses `inlineBlockBefore` to steal the content of the
// original ForOp and avoid cloning.
struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
- using OpRewritePattern<scf::ForOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const final {
@@ -1133,7 +1133,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
/// single-iteration loops with their bodies, and removes empty loops that
/// iterate at least once and only return values defined outside of the loop.
struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
- using OpRewritePattern<ForOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
@@ -1204,7 +1204,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
/// use_of(%1)
/// ```
struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
- using OpRewritePattern<ForOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
@@ -1236,12 +1236,101 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
}
};
+/// Rewriting pattern that folds away cycles in the yield of a scf.for op.
+///
+/// ```
+/// %res:2 = scf.for ... iter_args(%arg0 = %init, %arg1 = %init) {
+/// ...
+/// use %arg0, %arg1
+/// scf.yield %arg1, %arg0
+/// }
+/// return %res#0, %res#1
+/// ```
+///
+/// folds into:
+///
+/// ```
+/// scf.for ... iter_args() {
+/// ...
+/// use %init, %init
+/// scf.yield
+/// }
+/// return %init, %init
+/// ```
+struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(ForOp op,
+ PatternRewriter &rewriter) const override {
+ ValueRange yieldedValues = op.getYieldedValues();
+ ValueRange initArgs = op.getInitArgs();
+ ValueRange results = op.getResults();
+ ValueRange regionIterArgs = op.getRegionIterArgs();
+ Block *body = op.getBody();
+
+ unsigned numYieldedValues = op.getNumRegionIterArgs();
+
+ bool changed = false;
+ SmallVector<unsigned> cycle;
+ llvm::SmallBitVector visited(numYieldedValues, false);
+ for (auto start : llvm::seq(numYieldedValues)) {
+ if (visited[start])
+ continue;
+
+ cycle.clear();
+ unsigned current = start;
+ bool validCycle = true;
+ Value initValue = initArgs[start];
+ while (!visited[current]) {
+ cycle.push_back(current);
+ visited[current] = true;
+
+ // Find whether this yield is from a region iter arg.
+ auto yieldedValue = yieldedValues[current];
+ if (auto arg = dyn_cast<BlockArgument>(yieldedValue);
+ !arg || arg.getOwner() != body) {
+ validCycle = false;
+ break;
+ }
+
+ unsigned next = cast<BlockArgument>(yieldedValue).getArgNumber() -
+ op.getNumInductionVars();
+
+ // Check if next position has the same init value.
+ if (initArgs[next] != initValue) {
+ validCycle = false;
+ break;
+ }
+
+ current = next;
+
+ // Completed the cycle.
+ if (current == start)
+ break;
+ }
+
+ // If we found a valid cycle of length > 1, all values in it
+ // are always equal to initValue.
+ if (validCycle && cycle.size() > 1) {
+ changed = true;
+ for (unsigned idx : cycle) {
+ // This will leave region args and results dead so other
+ // canonicalization patterns can clean them up.
+ rewriter.replaceAllUsesWith(regionIterArgs[idx], initValue);
+ rewriter.replaceAllUsesWith(results[idx], initValue);
+ }
+ }
+ }
+ return success(changed);
+ }
+};
+
} // namespace
void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
- context);
+ results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder,
+ ForOpYieldCyclesFolder>(context);
}
std::optional<APInt> ForOp::getConstantStep() {
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ac590fc0c47b9..e69bbff0254e1 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1665,11 +1665,11 @@ func.func @func_execute_region_inline_multi_yield() {
module {
func.func private @foo()->()
func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8> {
- %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
- %1 = scf.execute_region -> memref<1x60xui8> no_inline {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %1 = scf.execute_region -> memref<1x60xui8> no_inline {
func.call @foo():()->()
scf.yield %alloc: memref<1x60xui8>
- }
+ }
return %1 : memref<1x60xui8>
}
}
@@ -1688,12 +1688,12 @@ func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8>
module {
func.func private @foo()->()
func.func private @execute_region_yeilding_external_and_local_values() -> (memref<1x60xui8>, memref<1x120xui8>) {
- %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
- %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
func.call @foo():()->()
scf.yield %alloc, %alloc_1: memref<1x60xui8>, memref<1x120xui8>
- }
+ }
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
}
}
@@ -1716,18 +1716,18 @@ func.func private @execute_region_yeilding_external_and_local_values() -> (memre
module {
func.func private @foo()->()
func.func private @execute_region_multiple_yields_same_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
- %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
- %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
%c = "test.cmp"() : () -> i1
cf.cond_br %c, ^bb2, ^bb3
- ^bb2:
+ ^bb2:
func.call @foo():()->()
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
- ^bb3:
- func.call @foo():()->()
+ ^bb3:
+ func.call @foo():()->()
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
- }
+ }
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
}
}
@@ -1746,19 +1746,19 @@ module {
module {
func.func private @foo()->()
func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
- %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
- %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
- %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
%c = "test.cmp"() : () -> i1
cf.cond_br %c, ^bb2, ^bb3
- ^bb2:
+ ^bb2:
func.call @foo():()->()
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
- ^bb3:
- func.call @foo():()->()
+ ^bb3:
+ func.call @foo():()->()
scf.yield %alloc, %alloc_2 : memref<1x60xui8>, memref<1x120xui8>
- }
+ }
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
}
}
@@ -1778,18 +1778,18 @@ module {
module {
func.func private @foo()->()
func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>) {
- %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
- %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%1 = scf.execute_region -> (memref<1x60xui8>) no_inline {
%c = "test.cmp"() : () -> i1
cf.cond_br %c, ^bb2, ^bb3
- ^bb2:
+ ^bb2:
func.call @foo():()->()
scf.yield %alloc : memref<1x60xui8>
- ^bb3:
+ ^bb3:
func.call @foo():()->()
scf.yield %alloc_1 : memref<1x60xui8>
- }
+ }
return %1 : memref<1x60xui8>
}
}
@@ -2171,3 +2171,21 @@ func.func @scf_for_all_step_size_0() {
}
return
}
+
+// -----
+
+func.func private @side_effect()
+
+// CHECK-LABEL: func @iter_args_cycles
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[A:.*]]: i32, %[[B:.*]]: i64, %[[C:.*]]: f32)
+// CHECK: scf.for %[[IV:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: func.call @side_effect() : () -> ()
+// CHECK-NOT: yield
+// CHECK: return %[[A]], %[[B]], %[[A]], %[[B]], %[[B]], %[[C]] : i32, i64, i32, i64, i64, f32
+func.func @iter_args_cycles(%lb : index, %ub : index, %step : index, %a : i32, %b : i64, %c : f32) -> (i32, i64, i32, i64, i64, f32) {
+ %res:6 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %b, %2 = %a, %3 = %b, %4 = %b, %5 = %c) -> (i32, i64, i32, i64, i64, f32) {
+ func.call @side_effect() : () -> ()
+ scf.yield %2, %4, %0, %1, %3, %5 : i32, i64, i32, i64, i64, f32
+ }
+ return %res#0, %res#1, %res#2, %res#3, %res#4, %res#5 : i32, i64, i32, i64, i64, f32
+}
>From d652c62461592ebd869436215c94d4a38704e0fd Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 24 Dec 2025 10:28:57 +0100
Subject: [PATCH 2/5] handle cycles of size 1
Signed-off-by: Ivan Butygin <ivan.butygin at gmail.com>
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 20 +++++++++-----------
1 file changed, 9 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 99802cfe4b662..7978681da07c5 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -990,9 +990,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
namespace {
// Fold away ForOp iter arguments when:
-// 1) The op yields the iter arguments.
-// 2) The argument's corresponding outer region iterators (inputs) are yielded.
-// 3) The iter arguments have no use and the corresponding (operation) results
+// 1) The argument's corresponding outer region iterators (inputs) are yielded.
+// 2) The iter arguments have no use and the corresponding (operation) results
// have no use.
//
// These arguments must be defined outside of the ForOp region and can just be
@@ -1030,12 +1029,11 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
forOp.getYieldedValues() // iter yield
)) {
// Forwarded is `true` when:
- // 1) The region `iter` argument is yielded.
- // 2) The region `iter` argument the corresponding input is yielded.
- // 3) The region `iter` argument has no use, and the corresponding op
+ // 1) The region `iter` argument the corresponding input is yielded.
+ // 2) The region `iter` argument has no use, and the corresponding op
// result has no use.
- bool forwarded = (arg == yielded) || (init == yielded) ||
- (arg.use_empty() && result.use_empty());
+ bool forwarded =
+ (init == yielded) || (arg.use_empty() && result.use_empty());
if (forwarded) {
canonicalize = true;
keepMask.push_back(false);
@@ -1309,9 +1307,9 @@ struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
break;
}
- // If we found a valid cycle of length > 1, all values in it
- // are always equal to initValue.
- if (validCycle && cycle.size() > 1) {
+ // If we found a valid cycle (yielding own iter arg is also a cycle), all
+ // values in it are always equal to initValue.
+ if (validCycle) {
changed = true;
for (unsigned idx : cycle) {
// This will leave region args and results dead so other
>From 385c6056898c315d60fb91138e16baf8bb4f7a0a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 24 Dec 2025 10:40:32 +0100
Subject: [PATCH 3/5] code comments
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 7978681da07c5..a53ec60573d50 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1271,6 +1271,8 @@ struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
bool changed = false;
SmallVector<unsigned> cycle;
llvm::SmallBitVector visited(numYieldedValues, false);
+
+ // Go through all possible start points for the cycle.
for (auto start : llvm::seq(numYieldedValues)) {
if (visited[start])
continue;
@@ -1279,6 +1281,8 @@ struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
unsigned current = start;
bool validCycle = true;
Value initValue = initArgs[start];
+ // Go through yield -> block arg -> yield cycles and check if all values
+ // are always equal to the init.
while (!visited[current]) {
cycle.push_back(current);
visited[current] = true;
@@ -1291,6 +1295,7 @@ struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
break;
}
+ // Next yield position.
unsigned next = cast<BlockArgument>(yieldedValue).getArgNumber() -
op.getNumInductionVars();
@@ -1307,8 +1312,8 @@ struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
break;
}
- // If we found a valid cycle (yielding own iter arg is also a cycle), all
- // values in it are always equal to initValue.
+ // If we found a valid cycle (yielding own iter arg forms cycle of length
+ // 1), all values in it are always equal to initValue.
if (validCycle) {
changed = true;
for (unsigned idx : cycle) {
>From 0a47e239d2893976d627251ee804976c45303031 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 24 Dec 2025 12:27:04 +0100
Subject: [PATCH 4/5] simplify cycle handling
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 ----
mlir/test/Dialect/SCF/canonicalize.mlir | 20 +++++++++++++++++++-
2 files changed, 19 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index a53ec60573d50..5174cb1e85d4b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1306,10 +1306,6 @@ struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
}
current = next;
-
- // Completed the cycle.
- if (current == start)
- break;
}
// If we found a valid cycle (yielding own iter arg forms cycle of length
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index e69bbff0254e1..37851710ef010 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2179,7 +2179,7 @@ func.func private @side_effect()
// CHECK-LABEL: func @iter_args_cycles
// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[A:.*]]: i32, %[[B:.*]]: i64, %[[C:.*]]: f32)
// CHECK: scf.for %[[IV:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
-// CHECK: func.call @side_effect() : () -> ()
+// CHECK: func.call @side_effect()
// CHECK-NOT: yield
// CHECK: return %[[A]], %[[B]], %[[A]], %[[B]], %[[B]], %[[C]] : i32, i64, i32, i64, i64, f32
func.func @iter_args_cycles(%lb : index, %ub : index, %step : index, %a : i32, %b : i64, %c : f32) -> (i32, i64, i32, i64, i64, f32) {
@@ -2189,3 +2189,21 @@ func.func @iter_args_cycles(%lb : index, %ub : index, %step : index, %a : i32, %
}
return %res#0, %res#1, %res#2, %res#3, %res#4, %res#5 : i32, i64, i32, i64, i64, f32
}
+
+// -----
+
+func.func private @side_effect(i32)
+
+// CHECK-LABEL: func @iter_args_cycles_non_cycle_start
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[A:.*]]: i32, %[[B:.*]]: i32)
+// CHECK: %[[RES:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER_ARG:.*]] = %[[A]]) -> (i32) {
+// CHECK: func.call @side_effect(%[[ITER_ARG]])
+// CHECK: yield %[[B]] : i32
+// CHECK: return %[[RES]], %[[B]], %[[B]] : i32, i32, i32
+func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : index, %a : i32, %b : i32) -> (i32, i32, i32) {
+ %res:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %b, %2 = %b) -> (i32, i32, i32) {
+ func.call @side_effect(%0) : (i32) -> ()
+ scf.yield %1, %2, %1 : i32, i32, i32
+ }
+ return %res#0, %res#1, %res#2 : i32, i32, i32
+}
>From 9a5ec075c185e2702aaecf2ed4232cd3a2be99ae Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 24 Dec 2025 13:02:10 +0100
Subject: [PATCH 5/5] remove next
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5174cb1e85d4b..4a6b8aa7b1125 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1296,16 +1296,14 @@ struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
}
// Next yield position.
- unsigned next = cast<BlockArgument>(yieldedValue).getArgNumber() -
- op.getNumInductionVars();
+ current = cast<BlockArgument>(yieldedValue).getArgNumber() -
+ op.getNumInductionVars();
// Check if next position has the same init value.
- if (initArgs[next] != initValue) {
+ if (initArgs[current] != initValue) {
validCycle = false;
break;
}
-
- current = next;
}
// If we found a valid cycle (yielding own iter arg forms cycle of length
More information about the Mlir-commits
mailing list