[Mlir-commits] [mlir] [mlir][scf] Fold away `scf.for` iter args cycles (PR #173436)

Matthias Springer llvmlistbot at llvm.org
Wed Dec 24 03:02:36 PST 2025


================
@@ -1236,12 +1234,106 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
   }
 };
 
+/// Rewriting pattern that folds away cycles in the yield of a scf.for op.
+///
+/// ```
+/// %res:2 = scf.for ... iter_args(%arg0 = %init, %arg1 = %init) {
+///   ...
+///   use %arg0, %arg1
+///   scf.yield %arg1, %arg0
+/// }
+/// return %res#0, %res#1
+/// ```
+///
+/// folds into:
+///
+/// ```
+/// scf.for ... iter_args() {
+///   ...
+///   use %init, %init
+///   scf.yield
+/// }
+/// return %init, %init
+/// ```
+struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(ForOp op,
+                                PatternRewriter &rewriter) const override {
+    ValueRange yieldedValues = op.getYieldedValues();
+    ValueRange initArgs = op.getInitArgs();
+    ValueRange results = op.getResults();
+    ValueRange regionIterArgs = op.getRegionIterArgs();
+    Block *body = op.getBody();
+
+    unsigned numYieldedValues = op.getNumRegionIterArgs();
+
+    bool changed = false;
+    SmallVector<unsigned> cycle;
+    llvm::SmallBitVector visited(numYieldedValues, false);
+
+    // Go through all possible start points for the cycle.
+    for (auto start : llvm::seq(numYieldedValues)) {
+      if (visited[start])
+        continue;
+
+      cycle.clear();
+      unsigned current = start;
+      bool validCycle = true;
+      Value initValue = initArgs[start];
+      // Go through yield -> block arg -> yield cycles and check if all values
+      // are always equal to the init.
+      while (!visited[current]) {
+        cycle.push_back(current);
+        visited[current] = true;
+
+        // Find whether this yield is from a region iter arg.
+        auto yieldedValue = yieldedValues[current];
+        if (auto arg = dyn_cast<BlockArgument>(yieldedValue);
+            !arg || arg.getOwner() != body) {
+          validCycle = false;
+          break;
+        }
+
+        // Next yield position.
+        unsigned next = cast<BlockArgument>(yieldedValue).getArgNumber() -
+                        op.getNumInductionVars();
+
+        // Check if next position has the same init value.
+        if (initArgs[next] != initValue) {
+          validCycle = false;
+          break;
+        }
+
+        current = next;
+
+        // Completed the cycle.
+        if (current == start)
----------------
matthias-springer wrote:

Is it possible that there's a cycle but we don't find it because we started at the wrong index? And we won't visit it again because `visited` is not reset?

Something along the lines of:
```
for ... iter_args(%a = %0, %b = %1, %c = %1) {
  yield %b, %c, %b
}
```

Starting form 0, we find the cycle 1 <-> 2, but `current != start`. Then we don't revisit later because those indices were already marked as visited. Is that possible?


https://github.com/llvm/llvm-project/pull/173436


More information about the Mlir-commits mailing list