[Mlir-commits] [mlir] [SCF] Added canonicalizer for recursively dead uses of iter_args. (PR #191085)
Slava Zakharin
llvmlistbot at llvm.org
Wed Apr 8 17:06:37 PDT 2026
https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/191085
This pattern may appear after Mem2Reg, which conservatively
returns live values of the memory slots from loops.
If those values are not used, we can get rid of the loops'
results and corresponding iter_args.
Co-authored-by: Claude Opus 4.6
>From 798ad11ee126761ba75c383504d023fb46ae1540 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Wed, 8 Apr 2026 16:56:50 -0700
Subject: [PATCH] [SCF] Added canonicalizer for recursively dead uses of
iter_args.
This pattern may appear after Mem2Reg, which conservatively
returns live values of the memory slots from loops.
If those values are not used, we can get rid of the loops'
results and corresponding iter_args.
Co-authored-by: Claude Opus 4.6
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 144 ++++++++++++++++-
mlir/test/Dialect/SCF/canonicalize.mlir | 148 ++++++++++++++++++
.../mem2reg-with-canonicalization.mlir | 62 ++++++++
3 files changed, 353 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Transforms/mem2reg-with-canonicalization.mlir
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9f4f4dc9f58e6..db50b2c1d3314 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1002,11 +1002,153 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
return failure();
}
};
+
+/// Remove iter_arg/result pairs from scf.for when the result is unused and
+/// the corresponding iter_arg block argument is "effectively unused" --
+/// meaning it has no uses, or its only uses are as init operands for nested
+/// scf.for iter_args whose block arguments are also effectively unused.
+///
+/// This handles cases that the generic RemoveDeadRegionBranchOpSuccessorInputs
+/// pattern cannot, specifically when inner loop results are used by outer
+/// loop yields creating cross-loop use chains that appear live but are
+/// semantically dead.
+///
+/// Example:
+/// %r = scf.for %i = %lb to %ub step %s iter_args(%a = %init) -> (f32) {
+/// %inner = scf.for %j = %lb to %ub step %s
+/// iter_args(%b = %a) -> (f32) {
+/// // %b is unused in the body
+/// scf.yield %val : f32
+/// }
+/// scf.yield %inner : f32
+/// }
+/// // %r is unused
+///
+/// After canonicalization:
+/// scf.for %i = %lb to %ub step %s {
+/// scf.for %j = %lb to %ub step %s {
+/// // body without iter_args
+/// }
+/// }
+struct ForOpUnusedIterArgElimination : public OpRewritePattern<ForOp> {
+ using OpRewritePattern<ForOp>::OpRewritePattern;
+
+ /// Check if a block argument is effectively unused. A block argument is
+ /// effectively unused if it has no uses, or all its uses are init operands
+ /// for nested scf.for iter_args where: (a) the inner block arg is also
+ /// effectively unused, and (b) the inner for's result at that position is
+ /// only used as yield operands at positions we are removing from parentFor.
+ static bool isBlockArgEffectivelyUnused(Value blockArg, ForOp parentFor,
+ const BitVector &parentCandidates) {
+ if (blockArg.use_empty())
+ return true;
+
+ for (OpOperand &use : blockArg.getUses()) {
+ auto innerFor = dyn_cast<ForOp>(use.getOwner());
+ if (!innerFor)
+ return false;
+
+ unsigned opNum = use.getOperandNumber();
+ if (opNum < innerFor.getNumControlOperands())
+ return false;
+
+ unsigned innerIdx = opNum - innerFor.getNumControlOperands();
+ Value innerResult = innerFor.getResult(innerIdx);
+
+ // Inner result must only be used at yield positions of parentFor
+ // that are candidates for removal.
+ for (OpOperand &resUse : innerResult.getUses()) {
+ auto yieldOp = dyn_cast<YieldOp>(resUse.getOwner());
+ if (!yieldOp || yieldOp->getParentOp() != parentFor.getOperation())
+ return false;
+ if (!parentCandidates.test(resUse.getOperandNumber()))
+ return false;
+ }
+
+ // Build candidate positions for the inner for: innerIdx is a candidate
+ // because its result will become unused after the parent transformation.
+ BitVector innerCandidates(innerFor.getNumResults(), false);
+ innerCandidates.set(innerIdx);
+ for (unsigned j = 0; j < innerFor.getNumResults(); ++j) {
+ if (innerFor.getResult(j).use_empty())
+ innerCandidates.set(j);
+ }
+
+ Value innerBlockArg = innerFor.getRegionIterArg(innerIdx);
+ if (!isBlockArgEffectivelyUnused(innerBlockArg, innerFor,
+ innerCandidates))
+ return false;
+ }
+ return true;
+ }
+
+ LogicalResult matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const override {
+ unsigned numResults = forOp.getNumResults();
+ if (numResults == 0)
+ return failure();
+
+ // Step 1: Find candidate positions (result unused).
+ BitVector candidates(numResults, false);
+ for (unsigned i = 0; i < numResults; ++i) {
+ if (forOp.getResult(i).use_empty())
+ candidates.set(i);
+ }
+ if (candidates.none())
+ return failure();
+
+ // Step 2: For each candidate, verify the block arg is effectively unused
+ // but not trivially unused (the generic patterns handle the trivial case).
+ BitVector toRemove(numResults, false);
+ for (unsigned i : candidates.set_bits()) {
+ Value blockArg = forOp.getRegionIterArg(i);
+ if (blockArg.use_empty())
+ continue;
+ if (isBlockArgEffectivelyUnused(blockArg, forOp, candidates))
+ toRemove.set(i);
+ }
+ if (toRemove.none())
+ return failure();
+
+ // Step 3: Replace block arg uses with init values. This is safe because
+ // the block arg is effectively unused and the init value dominates the
+ // body.
+ for (unsigned i : toRemove.set_bits()) {
+ Value blockArg = forOp.getRegionIterArg(i);
+ Value initVal = forOp.getInitArgs()[i];
+ rewriter.replaceAllUsesWith(blockArg, initVal);
+ }
+
+ // Step 4: Erase yield operands for removed positions.
+ auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
+ rewriter.modifyOpInPlace(yieldOp,
+ [&]() { yieldOp->eraseOperands(toRemove); });
+
+ // Step 5: Erase block arguments for removed positions.
+ BitVector blockArgsToErase(forOp.getBody()->getNumArguments(), false);
+ for (unsigned i : toRemove.set_bits())
+ blockArgsToErase.set(i + forOp.getNumInductionVars());
+ rewriter.modifyOpInPlace(
+ forOp, [&]() { forOp.getBody()->eraseArguments(blockArgsToErase); });
+
+ // Step 6: Erase init operands for removed positions.
+ BitVector initOperandsToErase(forOp->getNumOperands(), false);
+ for (unsigned i : toRemove.set_bits())
+ initOperandsToErase.set(i + forOp.getNumControlOperands());
+ rewriter.modifyOpInPlace(
+ forOp, [&]() { forOp->eraseOperands(initOperandsToErase); });
+
+ // Step 7: Erase results for removed positions.
+ rewriter.eraseOpResults(forOp, toRemove);
+
+ return success();
+ }
+};
} // namespace
void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ForOpTensorCastFolder>(context);
+ results.add<ForOpTensorCastFolder, ForOpUnusedIterArgElimination>(context);
populateRegionBranchOpInterfaceCanonicalizationPatterns(
results, ForOp::getOperationName());
populateRegionBranchOpInterfaceInliningPattern(
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c324d34942bf8..145a3417eec04 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2360,3 +2360,151 @@ func.func @fold_tensor_cast_into_forall_non_sequential_writes(
// %0#0 contains %arg1 data; %0#1 contains %arg0 data.
return %0#0, %0#1 : tensor<?x32xf32>, tensor<?x32xf32>
}
+
+// -----
+
+// Test: nested for loops with effectively-dead iter_args.
+// The outer result is unused, the outer block arg feeds the inner iter_arg
+// whose block arg is also unused. Both iter_args should be removed.
+
+// CHECK-LABEL: func @nested_for_effectively_dead_iter_args
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+// CHECK: %[[CST:.*]] = arith.constant
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: memref.store %[[CST]], %[[MEM]][]
+// CHECK: }
+// CHECK: }
+// CHECK: return
+func.func @nested_for_effectively_dead_iter_args(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %cst = arith.constant 1.0 : f32
+ %init = arith.constant 0.0 : f32
+ %outer = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+ %inner = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+ memref.store %cst, %mem[] : memref<f32>
+ scf.yield %cst : f32
+ }
+ scf.yield %inner : f32
+ }
+ return
+}
+
+// -----
+
+// Test: three levels of nesting with effectively-dead iter_args.
+
+// CHECK-LABEL: func @triple_nested_effectively_dead_iter_args
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: memref.store
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: return
+func.func @triple_nested_effectively_dead_iter_args(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %cst = arith.constant 1.0 : f32
+ %init = arith.constant 0.0 : f32
+ %outer = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+ %mid = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+ %inner = scf.for %k = %lb to %ub step %step iter_args(%c = %b) -> (f32) {
+ memref.store %cst, %mem[] : memref<f32>
+ scf.yield %cst : f32
+ }
+ scf.yield %inner : f32
+ }
+ scf.yield %mid : f32
+ }
+ return
+}
+
+// -----
+
+// Negative test: block arg is used in the loop body.
+// The iter_arg must NOT be removed even though the result is unused.
+
+// CHECK-LABEL: func @iter_arg_used_in_body
+// CHECK: scf.for {{.*}} iter_args
+// CHECK: memref.store
+// CHECK: scf.yield
+// CHECK: }
+func.func @iter_arg_used_in_body(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %init = arith.constant 0.0 : f32
+ %r = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+ memref.store %a, %mem[] : memref<f32>
+ %next = arith.addf %a, %a : f32
+ scf.yield %next : f32
+ }
+ return
+}
+
+// -----
+
+// Test: two chained for loops where the second loop's result is unused and
+// its iter_arg chain through an inner loop is effectively dead.
+
+// CHECK-LABEL: func @chained_for_effectively_dead
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+// CHECK: %[[CST:.*]] = arith.constant
+// CHECK: scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+// CHECK: scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+// CHECK: memref.store
+// CHECK: }
+// CHECK: }
+// CHECK: scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+// CHECK: scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+// CHECK: memref.store
+// CHECK: }
+// CHECK: }
+// CHECK: return
+func.func @chained_for_effectively_dead(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %cst = arith.constant 1.0 : f32
+ %init = arith.constant 0.0 : f32
+ %first = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+ %inner1 = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+ memref.store %cst, %mem[] : memref<f32>
+ scf.yield %cst : f32
+ }
+ scf.yield %inner1 : f32
+ }
+ %second = scf.for %i = %lb to %ub step %step iter_args(%a = %first) -> (f32) {
+ %inner2 = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+ memref.store %cst, %mem[] : memref<f32>
+ scf.yield %cst : f32
+ }
+ scf.yield %inner2 : f32
+ }
+ return
+}
+
+// -----
+
+// Test: 2-level loop nest with two iter_args, both effectively unused.
+
+// CHECK-LABEL: func @nested_for_two_iter_args
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+// CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: memref.store
+// CHECK: }
+// CHECK: }
+// CHECK: return
+func.func @nested_for_two_iter_args(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %c0 = arith.constant 0.0 : f32
+ %c1 = arith.constant 1.0 : f32
+ %r:2 = scf.for %i = %lb to %ub step %step
+ iter_args(%a = %c0, %b = %c1) -> (f32, f32) {
+ %inner:2 = scf.for %j = %lb to %ub step %step
+ iter_args(%x = %a, %y = %b) -> (f32, f32) {
+ memref.store %c1, %mem[] : memref<f32>
+ scf.yield %c1, %c0 : f32, f32
+ }
+ scf.yield %inner#0, %inner#1 : f32, f32
+ }
+ return
+}
diff --git a/mlir/test/Transforms/mem2reg-with-canonicalization.mlir b/mlir/test/Transforms/mem2reg-with-canonicalization.mlir
new file mode 100644
index 0000000000000..44e315cf620d3
--- /dev/null
+++ b/mlir/test/Transforms/mem2reg-with-canonicalization.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(any(mem2reg))' | FileCheck %s --check-prefix=MEM2REG
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(any(mem2reg,canonicalize))' | FileCheck %s --check-prefix=CANON
+
+// Two loop nests share the same alloca, causing the first loop's result
+// to chain into the second loop's init -- demonstrating the cross-loop
+// use chain that the generic canonicalization patterns cannot handle.
+
+// MEM2REG-LABEL: func.func @redundant_iter_args
+// MEM2REG: %[[POISON:.*]] = ub.poison : f32
+// MEM2REG: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// MEM2REG: %[[R1:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %[[POISON]]) -> (f32) {
+// MEM2REG: %[[R1I:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %{{.*}}) -> (f32) {
+// MEM2REG: memref.store %[[CST]],
+// MEM2REG: scf.yield %[[CST]] : f32
+// MEM2REG: }
+// MEM2REG: scf.yield %[[R1I]] : f32
+// MEM2REG: }
+// MEM2REG: scf.for {{.*}} iter_args(%{{.*}} = %[[R1]]) -> (f32) {
+// MEM2REG: scf.for {{.*}} iter_args(%{{.*}} = %{{.*}}) -> (f32) {
+// MEM2REG: memref.store %[[CST]],
+// MEM2REG: scf.yield %[[CST]] : f32
+// MEM2REG: }
+// MEM2REG: }
+
+// CANON-LABEL: func.func @redundant_iter_args
+// CANON-SAME: (%[[N:.*]]: index, %[[MEM:.*]]: memref<f32>)
+// CANON: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CANON: scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+// CANON: scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+// CANON-NOT: iter_args
+// CANON: memref.store %[[CST]], %[[MEM]][] : memref<f32>
+// CANON: }
+// CANON: }
+// CANON: scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+// CANON: scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+// CANON-NOT: iter_args
+// CANON: memref.store %[[CST]], %[[MEM]][] : memref<f32>
+// CANON: }
+// CANON: }
+// CANON: return
+
+func.func @redundant_iter_args(%n: index, %mem: memref<f32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 1.0 : f32
+ %tmp = memref.alloca() : memref<f32>
+ scf.for %i = %c0 to %n step %c1 {
+ scf.for %j = %c0 to %n step %c1 {
+ memref.store %cst, %tmp[] : memref<f32>
+ %v = memref.load %tmp[] : memref<f32>
+ memref.store %v, %mem[] : memref<f32>
+ }
+ }
+ scf.for %i = %c0 to %n step %c1 {
+ scf.for %j = %c0 to %n step %c1 {
+ memref.store %cst, %tmp[] : memref<f32>
+ %v = memref.load %tmp[] : memref<f32>
+ memref.store %v, %mem[] : memref<f32>
+ }
+ }
+ return
+}
More information about the Mlir-commits
mailing list