[Mlir-commits] [mlir] [MLIR] [Mem2Reg] Fix unused block argument removal logic (PR #188484)
Théo Degioanni
llvmlistbot at llvm.org
Wed Mar 25 06:34:55 PDT 2026
https://github.com/tdegioanni-nvidia created https://github.com/llvm/llvm-project/pull/188484
There was a problem with the way Mem2Reg was removing unused block arguments, as it was incorrectly assuming the reaching definition was still available when connecting the successor operands but it may have been removed as part of the mem2reg process. This approach instead places successor operands eagerly, and removes them along with the block argument if unused (similarly to how it was done before the region support).
This also fixes what I think was a long-standing issue where a block argument only used by operations that will be deleted would not be considered unused.
Fixes #188252
>From 73b4a95bb59ace9d15e0361d472b2c82db191fbf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= <tdegioanni at nvidia.com>
Date: Wed, 25 Mar 2026 13:30:06 +0000
Subject: [PATCH] fix unused block argument removal logic
---
mlir/lib/Transforms/Mem2Reg.cpp | 132 +++++++++++++++++++-------
mlir/test/Dialect/MemRef/mem2reg.mlir | 76 +++++++++++++++
2 files changed, 176 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 5bd0c70f1d33f..77f87dd20d185 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -264,9 +264,9 @@ class MemorySlotPromoter {
/// to a different region, the new region will be processed instead.
void removeBlockingUses(Region *region);
- /// Links merge point block arguments to the terminators targeting the merge
- /// point or remove the argument if it is not used.
- void linkMergePoints();
+ /// Removes operations and merge point block arguments that ended up not being
+ /// necessary.
+ void removeUnusedItems();
/// Lazily-constructed default value representing the content of the slot when
/// no store has been executed. This function may mutate IR.
@@ -656,6 +656,18 @@ void MemorySlotPromoter::promoteInRegion(Region *region, Value reachingDef) {
job.reachingDef = promoteInBlock(block, job.reachingDef);
+ if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
+ for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
+ if (info.mergePoints.contains(blockOperand.get())) {
+ if (!job.reachingDef)
+ job.reachingDef = getOrCreateDefaultValue();
+
+ terminator.getSuccessorOperands(blockOperand.getOperandNumber())
+ .append(job.reachingDef);
+ }
+ }
+ }
+
for (auto *child : job.block->children())
dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
}
@@ -762,40 +774,90 @@ void MemorySlotPromoter::removeBlockingUses(Region *region) {
}
}
-void MemorySlotPromoter::linkMergePoints() {
- // We want to eliminate unused block arguments. In case connecting a block
- // argument to its predecessor would trigger the use of the predecessor's
- // unused block argument, we need to process merge points in an expanding
- // worklist, `mergePointArgsToProcess`.
+void MemorySlotPromoter::removeUnusedItems() {
+ // We want to eliminate unused block arguments. Because block arguments can be
+ // used to populate other block arguments, there might be cycles of arguments
+ // that are only used to populate each-other. We therefore need a small
+ // dataflow analysis to identify which block arguments are truly used.
SmallPtrSet<BlockArgument, 8> mergePointArgsUnused;
- SmallVector<BlockArgument> mergePointArgsToProcess;
+ SmallVector<BlockArgument> usedMergePointArgsToProcess;
+
+ // First, separate the block arguments that are not used or only used for the
+ // purpose of populating a merge point block argument from the others. These
+ // block arguments are potentially unused. Meanwhile, arguments that are
+ // definitely used will be the starting point of the propagation of the
+ // analysis.
+ auto isDefinitelyUsed = [&](BlockArgument arg) {
+ for (auto &use : arg.getUses()) {
+ if (llvm::is_contained(toErase, use.getOwner()))
+ continue;
+
+ // We now want to detect whether the use is to populate a merge point
+ // block argument. If it is not, the argument is definitely used.
+
+ auto branchOp = dyn_cast<BranchOpInterface>(use.getOwner());
+ if (!branchOp)
+ return true;
+
+ auto successorArgument =
+ branchOp.getSuccessorBlockArgument(use.getOperandNumber());
+ if (!successorArgument.has_value())
+ return true;
+
+ if (!info.mergePoints.contains(successorArgument->getOwner()))
+ return true;
+
+ bool isLastBlockArgument =
+ successorArgument->getArgNumber() ==
+ successorArgument->getOwner()->getNumArguments() - 1;
+ if (!isLastBlockArgument)
+ return true;
+ }
+ return false;
+ };
+
for (Block *mergePoint : info.mergePoints) {
BlockArgument arg = mergePoint->getArguments().back();
- if (arg.use_empty())
- mergePointArgsUnused.insert(arg);
+ if (isDefinitelyUsed(arg))
+ usedMergePointArgsToProcess.push_back(arg);
else
- mergePointArgsToProcess.push_back(arg);
+ mergePointArgsUnused.insert(arg);
}
- while (!mergePointArgsToProcess.empty()) {
- BlockArgument arg = mergePointArgsToProcess.pop_back_val();
+ // We now refine mergePointArgsUnused from the information of which block
+ // arguments are definitely used.
+ while (!usedMergePointArgsToProcess.empty()) {
+ BlockArgument arg = usedMergePointArgsToProcess.pop_back_val();
Block *mergePoint = arg.getOwner();
- for (BlockOperand &use : mergePoint->getUses()) {
- Value reachingDef = reachingAtBlockEnd[use.getOwner()->getBlock()];
- if (!reachingDef)
- reachingDef = getOrCreateDefaultValue();
+ assert(arg.getArgNumber() == mergePoint->getNumArguments() - 1 &&
+ "merge point argument must be the last argument of the merge point");
- // If the reaching definition is a block argument of an unused merge
- // point, mark it as used and process it as such later.
- auto reachingDefArgument = dyn_cast<BlockArgument>(reachingDef);
- if (reachingDefArgument &&
- mergePointArgsUnused.erase(reachingDefArgument))
- mergePointArgsToProcess.push_back(reachingDefArgument);
-
- BranchOpInterface user = cast<BranchOpInterface>(use.getOwner());
- user.getSuccessorOperands(use.getOperandNumber()).append(reachingDef);
+ for (BlockOperand &use : mergePoint->getUses()) {
+ // If a value used to populate this used merge point argument is another
+ // merge point block argument that is currently considered unused, it must
+ // now be considered used and processed as such later.
+
+ auto branch = cast<BranchOpInterface>(use.getOwner());
+ SuccessorOperands succOperands =
+ branch.getSuccessorOperands(use.getOperandNumber());
+
+ // The successor operand is either the last one or is not present if the
+ // user block is dead.
+ assert(succOperands.size() == mergePoint->getNumArguments() ||
+ succOperands.size() + 1 == mergePoint->getNumArguments());
+
+ // If the user block is dead, the default value acts as a placeholder
+ // dummy value.
+ if (succOperands.size() + 1 == mergePoint->getNumArguments())
+ succOperands.append(getOrCreateDefaultValue());
+
+ Value populatedValue = succOperands[arg.getArgNumber()];
+ auto populatedValueAsArg = dyn_cast<BlockArgument>(populatedValue);
+ if (populatedValueAsArg &&
+ mergePointArgsUnused.erase(populatedValueAsArg))
+ usedMergePointArgsToProcess.push_back(populatedValueAsArg);
}
builder.setInsertionPointToStart(mergePoint);
@@ -804,8 +866,17 @@ void MemorySlotPromoter::linkMergePoints() {
(*statistics.newBlockArgumentAmount)++;
}
+ for (Operation *toEraseOp : toErase)
+ toEraseOp->erase();
+
for (BlockArgument arg : mergePointArgsUnused) {
Block *mergePoint = arg.getOwner();
+ for (BlockOperand &use : mergePoint->getUses()) {
+ auto branch = cast<BranchOpInterface>(use.getOwner());
+ SuccessorOperands succOperands =
+ branch.getSuccessorOperands(use.getOperandNumber());
+ succOperands.erase(arg.getArgNumber());
+ }
mergePoint->eraseArgument(mergePoint->getNumArguments() - 1);
}
}
@@ -832,11 +903,8 @@ MemorySlotPromoter::promoteSlot() {
op.visitReplacedValues(replacedValues, builder);
}
- // Finally, connect merge points to their predecessor's reaching definitions.
- linkMergePoints();
-
- for (Operation *toEraseOp : toErase)
- toEraseOp->erase();
+ // Finally, remove unused operations and merge point block arguments.
+ removeUnusedItems();
assert(slot.ptr.use_empty() &&
"after promotion, the slot pointer should not be used anymore");
diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir
index 30a268521c69b..20c49267ce7fc 100644
--- a/mlir/test/Dialect/MemRef/mem2reg.mlir
+++ b/mlir/test/Dialect/MemRef/mem2reg.mlir
@@ -171,8 +171,84 @@ func.func @unused_alloca_store_loop() {
// CHECK-NOT: memref.alloca
%cst = arith.constant 1 : i32
%alloca = memref.alloca() : memref<i32>
+ // CHECK: cf.br ^[[BB1:.*]]
cf.br ^bb1
+
+// CHECK: ^[[BB1]]:
^bb1:
+ // CHECK-NOT: memref.store
memref.store %cst, %alloca[] : memref<i32>
+ // CHECK: cf.br ^[[BB1]]
cf.br ^bb1
}
+
+// -----
+
+// CHECK-LABEL: func.func @store_back_to_alloca
+// CHECK-SAME: (%[[COND:.*]]: i1)
+func.func @store_back_to_alloca(%cond: i1) -> i32 {
+ // CHECK-NOT: memref.alloca
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+ %c0 = arith.constant 0 : i32
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+ %c1 = arith.constant 1 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c0, %alloca[] : memref<i32>
+ %loaded = memref.load %alloca[] : memref<i32>
+ // CHECK: cf.cond_br %[[COND]], ^[[STORE_BACK:.*]], ^[[SKIP:.*]]
+ cf.cond_br %cond, ^store_back, ^skip
+
+// CHECK: ^[[STORE_BACK]]:
+^store_back:
+ memref.store %loaded, %alloca[] : memref<i32>
+ // CHECK: cf.br ^[[MERGE:.*]](%[[C0]] : i32)
+ cf.br ^merge
+
+// CHECK: ^[[SKIP]]:
+^skip:
+ memref.store %c1, %alloca[] : memref<i32>
+ // CHECK: cf.br ^[[MERGE]](%[[C1]] : i32)
+ cf.br ^merge
+
+// CHECK: ^[[MERGE]](%[[RESULT:.*]]: i32):
+^merge:
+ %result = memref.load %alloca[] : memref<i32>
+ // CHECK: return %[[RESULT]] : i32
+ return %result : i32
+}
+
+// -----
+
+// Ensure that a merge point used by an erased operation is not considered used.
+
+// CHECK-LABEL: func.func @merge_point_used_by_erased_op
+// CHECK-SAME: (%[[COND:.*]]: i1)
+func.func @merge_point_used_by_erased_op(%cond: i1) -> i32 {
+ // CHECK-NOT: memref.alloca
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+ %c0 = arith.constant 0 : i32
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+ %c1 = arith.constant 1 : i32
+ %alloca = memref.alloca() : memref<i32>
+ // CHECK: cf.cond_br %[[COND]], ^[[MERGE:.*]], ^[[SKIP:.*]]
+ cf.cond_br %cond, ^merge, ^skip
+
+// CHECK: ^[[MERGE]]:
+^merge:
+ memref.store %c0, %alloca[] : memref<i32>
+ // CHECK: cf.br ^[[FINAL:.*]]{{$}}
+ cf.br ^final
+
+// CHECK: ^[[SKIP]]:
+^skip:
+ memref.store %c1, %alloca[] : memref<i32>
+ // CHECK: cf.br ^[[FINAL]]{{$}}
+ cf.br ^final
+
+// CHECK: ^[[FINAL]]:
+^final:
+ %result = memref.load %alloca[] : memref<i32>
+ memref.store %result, %alloca[] : memref<i32>
+ // CHECK: return %[[C0]] : i32
+ return %c0 : i32
+}
More information about the Mlir-commits
mailing list