[Mlir-commits] [mlir] e3c2e6c - [mlir][scf] Fold away `scf.for` iter args cycles (#173436)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 24 05:57:03 PST 2025


Author: Ivan Butygin
Date: 2025-12-24T16:56:59+03:00
New Revision: e3c2e6c56b08a097e0046853eee40fbf7a84b226

URL: https://github.com/llvm/llvm-project/commit/e3c2e6c56b08a097e0046853eee40fbf7a84b226
DIFF: https://github.com/llvm/llvm-project/commit/e3c2e6c56b08a097e0046853eee40fbf7a84b226.diff

LOG: [mlir][scf] Fold away `scf.for` iter args cycles (#173436)

When iter args form cycle through region args/yields with the same init
value, we can replace them all with that init value.

---------

Signed-off-by: Ivan Butygin <ivan.butygin at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 652414f6cbe54..4a6b8aa7b1125 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -990,9 +990,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
 
 namespace {
 // Fold away ForOp iter arguments when:
-// 1) The op yields the iter arguments.
-// 2) The argument's corresponding outer region iterators (inputs) are yielded.
-// 3) The iter arguments have no use and the corresponding (operation) results
+// 1) The argument's corresponding outer region iterators (inputs) are yielded.
+// 2) The iter arguments have no use and the corresponding (operation) results
 // have no use.
 //
 // These arguments must be defined outside of the ForOp region and can just be
@@ -1001,7 +1000,7 @@ namespace {
 // The implementation uses `inlineBlockBefore` to steal the content of the
 // original ForOp and avoid cloning.
 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
-  using OpRewritePattern<scf::ForOp>::OpRewritePattern;
+  using Base::Base;
 
   LogicalResult matchAndRewrite(scf::ForOp forOp,
                                 PatternRewriter &rewriter) const final {
@@ -1030,12 +1029,11 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
                    forOp.getYieldedValues()   // iter yield
                    )) {
       // Forwarded is `true` when:
-      // 1) The region `iter` argument is yielded.
-      // 2) The region `iter` argument the corresponding input is yielded.
-      // 3) The region `iter` argument has no use, and the corresponding op
+      // 1) The region `iter` argument the corresponding input is yielded.
+      // 2) The region `iter` argument has no use, and the corresponding op
       // result has no use.
-      bool forwarded = (arg == yielded) || (init == yielded) ||
-                       (arg.use_empty() && result.use_empty());
+      bool forwarded =
+          (init == yielded) || (arg.use_empty() && result.use_empty());
       if (forwarded) {
         canonicalize = true;
         keepMask.push_back(false);
@@ -1133,7 +1131,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
 /// single-iteration loops with their bodies, and removes empty loops that
 /// iterate at least once and only return values defined outside of the loop.
 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
-  using OpRewritePattern<ForOp>::OpRewritePattern;
+  using Base::Base;
 
   LogicalResult matchAndRewrite(ForOp op,
                                 PatternRewriter &rewriter) const override {
@@ -1204,7 +1202,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
 ///   use_of(%1)
 /// ```
 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
-  using OpRewritePattern<ForOp>::OpRewritePattern;
+  using Base::Base;
 
   LogicalResult matchAndRewrite(ForOp op,
                                 PatternRewriter &rewriter) const override {
@@ -1236,12 +1234,100 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
   }
 };
 
+/// Rewriting pattern that folds away cycles in the yield of a scf.for op.
+///
+/// ```
+/// %res:2 = scf.for ... iter_args(%arg0 = %init, %arg1 = %init) {
+///   ...
+///   use %arg0, %arg1
+///   scf.yield %arg1, %arg0
+/// }
+/// return %res#0, %res#1
+/// ```
+///
+/// folds into:
+///
+/// ```
+/// scf.for ... iter_args() {
+///   ...
+///   use %init, %init
+///   scf.yield
+/// }
+/// return %init, %init
+/// ```
+struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(ForOp op,
+                                PatternRewriter &rewriter) const override {
+    ValueRange yieldedValues = op.getYieldedValues();
+    ValueRange initArgs = op.getInitArgs();
+    ValueRange results = op.getResults();
+    ValueRange regionIterArgs = op.getRegionIterArgs();
+    Block *body = op.getBody();
+
+    unsigned numYieldedValues = op.getNumRegionIterArgs();
+
+    bool changed = false;
+    SmallVector<unsigned> cycle;
+    llvm::SmallBitVector visited(numYieldedValues, false);
+
+    // Go through all possible start points for the cycle.
+    for (auto start : llvm::seq(numYieldedValues)) {
+      if (visited[start])
+        continue;
+
+      cycle.clear();
+      unsigned current = start;
+      bool validCycle = true;
+      Value initValue = initArgs[start];
+      // Go through yield -> block arg -> yield cycles and check if all values
+      // are always equal to the init.
+      while (!visited[current]) {
+        cycle.push_back(current);
+        visited[current] = true;
+
+        // Find whether this yield is from a region iter arg.
+        auto yieldedValue = yieldedValues[current];
+        if (auto arg = dyn_cast<BlockArgument>(yieldedValue);
+            !arg || arg.getOwner() != body) {
+          validCycle = false;
+          break;
+        }
+
+        // Next yield position.
+        current = cast<BlockArgument>(yieldedValue).getArgNumber() -
+                  op.getNumInductionVars();
+
+        // Check if next position has the same init value.
+        if (initArgs[current] != initValue) {
+          validCycle = false;
+          break;
+        }
+      }
+
+      // If we found a valid cycle (yielding own iter arg forms cycle of length
+      // 1), all values in it are always equal to initValue.
+      if (validCycle) {
+        changed = true;
+        for (unsigned idx : cycle) {
+          // This will leave region args and results dead so other
+          // canonicalization patterns can clean them up.
+          rewriter.replaceAllUsesWith(regionIterArgs[idx], initValue);
+          rewriter.replaceAllUsesWith(results[idx], initValue);
+        }
+      }
+    }
+    return success(changed);
+  }
+};
+
 } // namespace
 
 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
-      context);
+  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder,
+              ForOpYieldCyclesFolder>(context);
 }
 
 std::optional<APInt> ForOp::getConstantStep() {

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ac590fc0c47b9..37851710ef010 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1665,11 +1665,11 @@ func.func @func_execute_region_inline_multi_yield() {
 module {
 func.func private @foo()->()
 func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8> {
-  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>  
-  %1 = scf.execute_region -> memref<1x60xui8> no_inline {    
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+  %1 = scf.execute_region -> memref<1x60xui8> no_inline {
     func.call @foo():()->()
     scf.yield %alloc: memref<1x60xui8>
-  }  
+  }
   return %1 : memref<1x60xui8>
 }
 }
@@ -1688,12 +1688,12 @@ func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8>
 module {
 func.func private @foo()->()
 func.func private @execute_region_yeilding_external_and_local_values() -> (memref<1x60xui8>, memref<1x120xui8>) {
-  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>  
-  %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {    
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+  %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
     %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
     func.call @foo():()->()
     scf.yield %alloc, %alloc_1: memref<1x60xui8>,  memref<1x120xui8>
-  }  
+  }
   return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
 }
 }
@@ -1716,18 +1716,18 @@ func.func private @execute_region_yeilding_external_and_local_values() -> (memre
 module {
   func.func private @foo()->()
   func.func private @execute_region_multiple_yields_same_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
-    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>  
-    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>  
+    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
     %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
       %c = "test.cmp"() : () -> i1
       cf.cond_br %c, ^bb2, ^bb3
-    ^bb2:    
+    ^bb2:
       func.call @foo():()->()
       scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
-    ^bb3: 
-      func.call @foo():()->()   
+    ^bb3:
+      func.call @foo():()->()
       scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
-    }  
+    }
     return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
   }
 }
@@ -1746,19 +1746,19 @@ module {
 module {
   func.func private @foo()->()
   func.func private @execute_region_multiple_yields_
diff erent_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
-    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>  
-    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>  
-    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>  
+    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
     %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
       %c = "test.cmp"() : () -> i1
       cf.cond_br %c, ^bb2, ^bb3
-    ^bb2:    
+    ^bb2:
       func.call @foo():()->()
       scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
-    ^bb3: 
-      func.call @foo():()->()   
+    ^bb3:
+      func.call @foo():()->()
       scf.yield %alloc, %alloc_2 : memref<1x60xui8>, memref<1x120xui8>
-    }  
+    }
     return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
   }
 }
@@ -1778,18 +1778,18 @@ module {
 module {
 func.func private @foo()->()
 func.func private @execute_region_multiple_yields_
diff erent_operands() -> (memref<1x60xui8>) {
-  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>  
-  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>   
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
   %1 = scf.execute_region -> (memref<1x60xui8>) no_inline {
     %c = "test.cmp"() : () -> i1
     cf.cond_br %c, ^bb2, ^bb3
-  ^bb2:    
+  ^bb2:
     func.call @foo():()->()
     scf.yield %alloc : memref<1x60xui8>
-  ^bb3:    
+  ^bb3:
     func.call @foo():()->()
     scf.yield %alloc_1 : memref<1x60xui8>
-  }    
+  }
   return %1 : memref<1x60xui8>
 }
 }
@@ -2171,3 +2171,39 @@ func.func @scf_for_all_step_size_0()  {
   }
   return
 }
+
+// -----
+
+func.func private @side_effect()
+
+// CHECK-LABEL: func @iter_args_cycles
+//  CHECK-SAME:   (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[A:.*]]: i32, %[[B:.*]]: i64, %[[C:.*]]: f32)
+//       CHECK:   scf.for %[[IV:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+//       CHECK:   func.call @side_effect()
+//   CHECK-NOT:   yield
+//       CHECK:   return %[[A]], %[[B]], %[[A]], %[[B]], %[[B]], %[[C]] : i32, i64, i32, i64, i64, f32
+func.func @iter_args_cycles(%lb : index, %ub : index, %step : index, %a : i32, %b : i64, %c : f32) -> (i32, i64, i32, i64, i64, f32) {
+  %res:6 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %b, %2 = %a, %3 = %b, %4 = %b, %5 = %c) -> (i32, i64, i32, i64, i64, f32) {
+    func.call @side_effect() : () -> ()
+    scf.yield %2, %4, %0, %1, %3, %5 : i32, i64, i32, i64, i64, f32
+  }
+  return %res#0, %res#1, %res#2, %res#3, %res#4, %res#5 : i32, i64, i32, i64, i64, f32
+}
+
+// -----
+
+func.func private @side_effect(i32)
+
+// CHECK-LABEL: func @iter_args_cycles_non_cycle_start
+//  CHECK-SAME:   (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[A:.*]]: i32, %[[B:.*]]: i32)
+//       CHECK:   %[[RES:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER_ARG:.*]] = %[[A]]) -> (i32) {
+//       CHECK:   func.call @side_effect(%[[ITER_ARG]])
+//       CHECK:   yield %[[B]] : i32
+//       CHECK:   return %[[RES]], %[[B]], %[[B]] : i32, i32, i32
+func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : index, %a : i32, %b : i32) -> (i32, i32, i32) {
+  %res:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %b, %2 = %b) -> (i32, i32, i32) {
+    func.call @side_effect(%0) : (i32) -> ()
+    scf.yield %1, %2, %1 : i32, i32, i32
+  }
+  return %res#0, %res#1, %res#2 : i32, i32, i32
+}


        


More information about the Mlir-commits mailing list