[Mlir-commits] [mlir] [mlir][scf] Fold away `scf.for` iter args cycles (PR #173436)
Matthias Springer
llvmlistbot at llvm.org
Wed Dec 24 01:26:09 PST 2025
================
@@ -1236,12 +1236,101 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
}
};
+/// Rewriting pattern that folds away cycles in the yield of a scf.for op.
+///
+/// ```
+/// %res:2 = scf.for ... iter_args(%arg0 = %init, %arg1 = %init) {
+/// ...
+/// use %arg0, %arg1
+/// scf.yield %arg1, %arg0
+/// }
+/// return %res#0, %res#1
+/// ```
+///
+/// folds into:
+///
+/// ```
+/// scf.for ... iter_args() {
+/// ...
+/// use %init, %init
+/// scf.yield
+/// }
+/// return %init, %init
+/// ```
+struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(ForOp op,
+ PatternRewriter &rewriter) const override {
+ ValueRange yieldedValues = op.getYieldedValues();
+ ValueRange initArgs = op.getInitArgs();
+ ValueRange results = op.getResults();
+ ValueRange regionIterArgs = op.getRegionIterArgs();
+ Block *body = op.getBody();
+
+ unsigned numYieldedValues = op.getNumRegionIterArgs();
+
+ bool changed = false;
+ SmallVector<unsigned> cycle;
+ llvm::SmallBitVector visited(numYieldedValues, false);
+ for (auto start : llvm::seq(numYieldedValues)) {
+ if (visited[start])
+ continue;
+
+ cycle.clear();
+ unsigned current = start;
+ bool validCycle = true;
+ Value initValue = initArgs[start];
+ while (!visited[current]) {
----------------
matthias-springer wrote:
Can you sprinkle a few more comments to explain what this is doing?
https://github.com/llvm/llvm-project/pull/173436
More information about the Mlir-commits
mailing list