[Mlir-commits] [mlir] Revert "[mlir][scf] Fold away scf.for iter args cycles (#173436)" (PR #173991)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 30 06:20:56 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: Walter Lee (googlewalt)

<details>
<summary>Changes</summary>

It causes issues with Triton usage.

Also revert dependent "[mlir][SCF] index_switch results (#<!-- -->173560)".


---
Full diff: https://github.com/llvm/llvm-project/pull/173991.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+14-150) 
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+24-91) 


``````````diff
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 46d09abd89d69..652414f6cbe54 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -990,8 +990,9 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
 
 namespace {
 // Fold away ForOp iter arguments when:
-// 1) The argument's corresponding outer region iterators (inputs) are yielded.
-// 2) The iter arguments have no use and the corresponding (operation) results
+// 1) The op yields the iter arguments.
+// 2) The argument's corresponding outer region iterators (inputs) are yielded.
+// 3) The iter arguments have no use and the corresponding (operation) results
 // have no use.
 //
 // These arguments must be defined outside of the ForOp region and can just be
@@ -1000,7 +1001,7 @@ namespace {
 // The implementation uses `inlineBlockBefore` to steal the content of the
 // original ForOp and avoid cloning.
 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
-  using Base::Base;
+  using OpRewritePattern<scf::ForOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(scf::ForOp forOp,
                                 PatternRewriter &rewriter) const final {
@@ -1029,11 +1030,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
                    forOp.getYieldedValues()   // iter yield
                    )) {
       // Forwarded is `true` when:
-      // 1) The region `iter` argument the corresponding input is yielded.
-      // 2) The region `iter` argument has no use, and the corresponding op
+      // 1) The region `iter` argument is yielded.
+      // 2) The region `iter` argument the corresponding input is yielded.
+      // 3) The region `iter` argument has no use, and the corresponding op
       // result has no use.
-      bool forwarded =
-          (init == yielded) || (arg.use_empty() && result.use_empty());
+      bool forwarded = (arg == yielded) || (init == yielded) ||
+                       (arg.use_empty() && result.use_empty());
       if (forwarded) {
         canonicalize = true;
         keepMask.push_back(false);
@@ -1131,7 +1133,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
 /// single-iteration loops with their bodies, and removes empty loops that
 /// iterate at least once and only return values defined outside of the loop.
 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
-  using Base::Base;
+  using OpRewritePattern<ForOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ForOp op,
                                 PatternRewriter &rewriter) const override {
@@ -1202,7 +1204,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
 ///   use_of(%1)
 /// ```
 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
-  using Base::Base;
+  using OpRewritePattern<ForOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ForOp op,
                                 PatternRewriter &rewriter) const override {
@@ -1234,100 +1236,12 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
   }
 };
 
-/// Rewriting pattern that folds away cycles in the yield of a scf.for op.
-///
-/// ```
-/// %res:2 = scf.for ... iter_args(%arg0 = %init, %arg1 = %init) {
-///   ...
-///   use %arg0, %arg1
-///   scf.yield %arg1, %arg0
-/// }
-/// return %res#0, %res#1
-/// ```
-///
-/// folds into:
-///
-/// ```
-/// scf.for ... iter_args() {
-///   ...
-///   use %init, %init
-///   scf.yield
-/// }
-/// return %init, %init
-/// ```
-struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
-  using Base::Base;
-
-  LogicalResult matchAndRewrite(ForOp op,
-                                PatternRewriter &rewriter) const override {
-    ValueRange yieldedValues = op.getYieldedValues();
-    ValueRange initArgs = op.getInitArgs();
-    ValueRange results = op.getResults();
-    ValueRange regionIterArgs = op.getRegionIterArgs();
-    Block *body = op.getBody();
-
-    unsigned numYieldedValues = op.getNumRegionIterArgs();
-
-    bool changed = false;
-    SmallVector<unsigned> cycle;
-    llvm::SmallBitVector visited(numYieldedValues, false);
-
-    // Go through all possible start points for the cycle.
-    for (auto start : llvm::seq(numYieldedValues)) {
-      if (visited[start])
-        continue;
-
-      cycle.clear();
-      unsigned current = start;
-      bool validCycle = true;
-      Value initValue = initArgs[start];
-      // Go through yield -> block arg -> yield cycles and check if all values
-      // are always equal to the init.
-      while (!visited[current]) {
-        cycle.push_back(current);
-        visited[current] = true;
-
-        // Find whether this yield is from a region iter arg.
-        auto yieldedValue = yieldedValues[current];
-        if (auto arg = dyn_cast<BlockArgument>(yieldedValue);
-            !arg || arg.getOwner() != body) {
-          validCycle = false;
-          break;
-        }
-
-        // Next yield position.
-        current = cast<BlockArgument>(yieldedValue).getArgNumber() -
-                  op.getNumInductionVars();
-
-        // Check if next position has the same init value.
-        if (initArgs[current] != initValue) {
-          validCycle = false;
-          break;
-        }
-      }
-
-      // If we found a valid cycle (yielding own iter arg forms cycle of length
-      // 1), all values in it are always equal to initValue.
-      if (validCycle) {
-        changed = true;
-        for (unsigned idx : cycle) {
-          // This will leave region args and results dead so other
-          // canonicalization patterns can clean them up.
-          rewriter.replaceAllUsesWith(regionIterArgs[idx], initValue);
-          rewriter.replaceAllUsesWith(results[idx], initValue);
-        }
-      }
-    }
-    return success(changed);
-  }
-};
-
 } // namespace
 
 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder,
-              ForOpYieldCyclesFolder>(context);
+  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
+      context);
 }
 
 std::optional<APInt> ForOp::getConstantStep() {
@@ -4797,59 +4711,9 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
   }
 };
 
-/// Canonicalization patterns that folds away dead results of
-/// "scf.index_switch" ops.
-struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
-  using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IndexSwitchOp op,
-                                PatternRewriter &rewriter) const override {
-    // Find dead results.
-    BitVector deadResults(op.getNumResults(), false);
-    SmallVector<Type> newResultTypes;
-    for (auto [idx, result] : llvm::enumerate(op.getResults())) {
-      if (!result.use_empty()) {
-        newResultTypes.push_back(result.getType());
-      } else {
-        deadResults[idx] = true;
-      }
-    }
-    if (!deadResults.any())
-      return rewriter.notifyMatchFailure(op, "no dead results to fold");
-
-    // Create new op without dead results and inline case regions.
-    auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
-                                       op.getArg(), op.getCases(),
-                                       op.getCaseRegions().size());
-    auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
-      rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
-      // Remove respective operands from yield op.
-      Operation *terminator = newRegion.front().getTerminator();
-      assert(isa<YieldOp>(terminator) && "expected yield op");
-      rewriter.modifyOpInPlace(
-          terminator, [&]() { terminator->eraseOperands(deadResults); });
-    };
-    for (auto [oldRegion, newRegion] :
-         llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
-      inlineCaseRegion(oldRegion, newRegion);
-    inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
-
-    // Replace op with new op.
-    SmallVector<Value> newResults(op.getNumResults(), Value());
-    unsigned nextNewResult = 0;
-    for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
-      if (deadResults[idx])
-        continue;
-      newResults[idx] = newOp.getResult(nextNewResult++);
-    }
-    rewriter.replaceOp(op, newResults);
-    return success();
-  }
-};
-
 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
+  results.add<FoldConstantCase>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 984ea10f7e540..ac590fc0c47b9 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1665,11 +1665,11 @@ func.func @func_execute_region_inline_multi_yield() {
 module {
 func.func private @foo()->()
 func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8> {
-  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
-  %1 = scf.execute_region -> memref<1x60xui8> no_inline {
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>  
+  %1 = scf.execute_region -> memref<1x60xui8> no_inline {    
     func.call @foo():()->()
     scf.yield %alloc: memref<1x60xui8>
-  }
+  }  
   return %1 : memref<1x60xui8>
 }
 }
@@ -1688,12 +1688,12 @@ func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8>
 module {
 func.func private @foo()->()
 func.func private @execute_region_yeilding_external_and_local_values() -> (memref<1x60xui8>, memref<1x120xui8>) {
-  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
-  %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>  
+  %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {    
     %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
     func.call @foo():()->()
     scf.yield %alloc, %alloc_1: memref<1x60xui8>,  memref<1x120xui8>
-  }
+  }  
   return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
 }
 }
@@ -1716,18 +1716,18 @@ func.func private @execute_region_yeilding_external_and_local_values() -> (memre
 module {
   func.func private @foo()->()
   func.func private @execute_region_multiple_yields_same_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
-    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
-    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>  
+    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>  
     %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
       %c = "test.cmp"() : () -> i1
       cf.cond_br %c, ^bb2, ^bb3
-    ^bb2:
+    ^bb2:    
       func.call @foo():()->()
       scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
-    ^bb3:
-      func.call @foo():()->()
+    ^bb3: 
+      func.call @foo():()->()   
       scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
-    }
+    }  
     return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
   }
 }
@@ -1746,19 +1746,19 @@ module {
 module {
   func.func private @foo()->()
   func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
-    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
-    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
-    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>  
+    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>  
+    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>  
     %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
       %c = "test.cmp"() : () -> i1
       cf.cond_br %c, ^bb2, ^bb3
-    ^bb2:
+    ^bb2:    
       func.call @foo():()->()
       scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
-    ^bb3:
-      func.call @foo():()->()
+    ^bb3: 
+      func.call @foo():()->()   
       scf.yield %alloc, %alloc_2 : memref<1x60xui8>, memref<1x120xui8>
-    }
+    }  
     return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
   }
 }
@@ -1778,18 +1778,18 @@ module {
 module {
 func.func private @foo()->()
 func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>) {
-  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
-  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>  
+  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>   
   %1 = scf.execute_region -> (memref<1x60xui8>) no_inline {
     %c = "test.cmp"() : () -> i1
     cf.cond_br %c, ^bb2, ^bb3
-  ^bb2:
+  ^bb2:    
     func.call @foo():()->()
     scf.yield %alloc : memref<1x60xui8>
-  ^bb3:
+  ^bb3:    
     func.call @foo():()->()
     scf.yield %alloc_1 : memref<1x60xui8>
-  }
+  }    
   return %1 : memref<1x60xui8>
 }
 }
@@ -2171,70 +2171,3 @@ func.func @scf_for_all_step_size_0()  {
   }
   return
 }
-
-// -----
-
-func.func private @side_effect()
-
-// CHECK-LABEL: func @iter_args_cycles
-//  CHECK-SAME:   (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[A:.*]]: i32, %[[B:.*]]: i64, %[[C:.*]]: f32)
-//       CHECK:   scf.for %[[IV:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
-//       CHECK:   func.call @side_effect()
-//   CHECK-NOT:   yield
-//       CHECK:   return %[[A]], %[[B]], %[[A]], %[[B]], %[[B]], %[[C]] : i32, i64, i32, i64, i64, f32
-func.func @iter_args_cycles(%lb : index, %ub : index, %step : index, %a : i32, %b : i64, %c : f32) -> (i32, i64, i32, i64, i64, f32) {
-  %res:6 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %b, %2 = %a, %3 = %b, %4 = %b, %5 = %c) -> (i32, i64, i32, i64, i64, f32) {
-    func.call @side_effect() : () -> ()
-    scf.yield %2, %4, %0, %1, %3, %5 : i32, i64, i32, i64, i64, f32
-  }
-  return %res#0, %res#1, %res#2, %res#3, %res#4, %res#5 : i32, i64, i32, i64, i64, f32
-}
-
-// -----
-
-func.func private @side_effect(i32)
-
-// CHECK-LABEL: func @iter_args_cycles_non_cycle_start
-//  CHECK-SAME:   (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[A:.*]]: i32, %[[B:.*]]: i32)
-//       CHECK:   %[[RES:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER_ARG:.*]] = %[[A]]) -> (i32) {
-//       CHECK:   func.call @side_effect(%[[ITER_ARG]])
-//       CHECK:   yield %[[B]] : i32
-//       CHECK:   return %[[RES]], %[[B]], %[[B]] : i32, i32, i32
-func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : index, %a : i32, %b : i32) -> (i32, i32, i32) {
-  %res:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %b, %2 = %b) -> (i32, i32, i32) {
-    func.call @side_effect(%0) : (i32) -> ()
-    scf.yield %1, %2, %1 : i32, i32, i32
-  }
-  return %res#0, %res#1, %res#2 : i32, i32, i32
-}
-
-// -----
-
-// CHECK-LABEL: func @dead_index_switch_result(
-//  CHECK-SAME:     %[[arg0:.*]]: index
-//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10
-//   CHECK-DAG:   %[[c11:.*]] = arith.constant 11
-//       CHECK:   %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
-//       CHECK:   case 1 {
-//       CHECK:     memref.store %[[c10]]
-//       CHECK:     scf.yield %[[arg0]] : index
-//       CHECK:   } 
-//       CHECK:   default {
-//       CHECK:     memref.store %[[c11]]
-//       CHECK:     scf.yield %[[arg0]] : index
-//       CHECK:   }
-//       CHECK:   return %[[switch]]
-func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
-  %non_live, %live = scf.index_switch %arg0 -> i32, index
-  case 1 {
-    %c10 = arith.constant 10 : i32
-    memref.store %c10, %arg1[] : memref<i32>
-    scf.yield %c10, %arg0 : i32, index
-  }
-  default {
-    %c11 = arith.constant 11 : i32
-    memref.store %c11, %arg1[] : memref<i32>
-    scf.yield %c11, %arg0 : i32, index
-  }
-  return %live : index
-}

``````````

</details>


https://github.com/llvm/llvm-project/pull/173991


More information about the Mlir-commits mailing list