[Mlir-commits] [mlir] [mlir][scf] Fold away `scf.for` iter args cycles (PR #173436)
Matthias Springer
llvmlistbot at llvm.org
Wed Dec 24 03:02:36 PST 2025
================
@@ -1236,12 +1234,106 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
}
};
+/// Rewriting pattern that folds away cycles in the yield of a scf.for op.
+///
+/// ```
+/// %res:2 = scf.for ... iter_args(%arg0 = %init, %arg1 = %init) {
+/// ...
+/// use %arg0, %arg1
+/// scf.yield %arg1, %arg0
+/// }
+/// return %res#0, %res#1
+/// ```
+///
+/// folds into:
+///
+/// ```
+/// scf.for ... iter_args() {
+/// ...
+/// use %init, %init
+/// scf.yield
+/// }
+/// return %init, %init
+/// ```
+struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(ForOp op,
+ PatternRewriter &rewriter) const override {
+ ValueRange yieldedValues = op.getYieldedValues();
+ ValueRange initArgs = op.getInitArgs();
+ ValueRange results = op.getResults();
+ ValueRange regionIterArgs = op.getRegionIterArgs();
+ Block *body = op.getBody();
+
+ unsigned numYieldedValues = op.getNumRegionIterArgs();
+
+ bool changed = false;
+ SmallVector<unsigned> cycle;
+ llvm::SmallBitVector visited(numYieldedValues, false);
+
+ // Go through all possible start points for the cycle.
+ for (auto start : llvm::seq(numYieldedValues)) {
+ if (visited[start])
+ continue;
+
+ cycle.clear();
+ unsigned current = start;
+ bool validCycle = true;
+ Value initValue = initArgs[start];
+ // Go through yield -> block arg -> yield cycles and check if all values
+ // are always equal to the init.
+ while (!visited[current]) {
+ cycle.push_back(current);
+ visited[current] = true;
+
+ // Find whether this yield is from a region iter arg.
+ auto yieldedValue = yieldedValues[current];
+ if (auto arg = dyn_cast<BlockArgument>(yieldedValue);
+ !arg || arg.getOwner() != body) {
+ validCycle = false;
+ break;
+ }
+
+ // Next yield position.
+ unsigned next = cast<BlockArgument>(yieldedValue).getArgNumber() -
+ op.getNumInductionVars();
+
+ // Check if next position has the same init value.
+ if (initArgs[next] != initValue) {
+ validCycle = false;
+ break;
+ }
+
+ current = next;
+
+ // Completed the cycle.
+ if (current == start)
----------------
matthias-springer wrote:
Is it possible that there's a cycle but we don't find it because we started at the wrong index? And we won't visit it again because `visited` is not reset?
Something along the lines of:
```
for ... iter_args(%a = %0, %b = %1, %c = %1) {
yield %b, %c, %b
}
```
Starting form 0, we find the cycle 1 <-> 2, but `current != start`. Then we don't revisit later because those indices were already marked as visited. Is that possible?
https://github.com/llvm/llvm-project/pull/173436
More information about the Mlir-commits
mailing list