[Mlir-commits] [mlir] [MLIR] [Mem2Reg] Fix unused block argument removal logic (PR #188484)
Théo Degioanni
llvmlistbot at llvm.org
Wed Mar 25 07:05:50 PDT 2026
https://github.com/tdegioanni-nvidia updated https://github.com/llvm/llvm-project/pull/188484
>From 73b4a95bb59ace9d15e0361d472b2c82db191fbf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= <tdegioanni at nvidia.com>
Date: Wed, 25 Mar 2026 13:30:06 +0000
Subject: [PATCH 1/5] fix unused block argument removal logic
---
mlir/lib/Transforms/Mem2Reg.cpp | 132 +++++++++++++++++++-------
mlir/test/Dialect/MemRef/mem2reg.mlir | 76 +++++++++++++++
2 files changed, 176 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 5bd0c70f1d33f..77f87dd20d185 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -264,9 +264,9 @@ class MemorySlotPromoter {
/// to a different region, the new region will be processed instead.
void removeBlockingUses(Region *region);
- /// Links merge point block arguments to the terminators targeting the merge
- /// point or remove the argument if it is not used.
- void linkMergePoints();
+ /// Removes operations and merge point block arguments that ended up not being
+ /// necessary.
+ void removeUnusedItems();
/// Lazily-constructed default value representing the content of the slot when
/// no store has been executed. This function may mutate IR.
@@ -656,6 +656,18 @@ void MemorySlotPromoter::promoteInRegion(Region *region, Value reachingDef) {
job.reachingDef = promoteInBlock(block, job.reachingDef);
+ if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
+ for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
+ if (info.mergePoints.contains(blockOperand.get())) {
+ if (!job.reachingDef)
+ job.reachingDef = getOrCreateDefaultValue();
+
+ terminator.getSuccessorOperands(blockOperand.getOperandNumber())
+ .append(job.reachingDef);
+ }
+ }
+ }
+
for (auto *child : job.block->children())
dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
}
@@ -762,40 +774,90 @@ void MemorySlotPromoter::removeBlockingUses(Region *region) {
}
}
-void MemorySlotPromoter::linkMergePoints() {
- // We want to eliminate unused block arguments. In case connecting a block
- // argument to its predecessor would trigger the use of the predecessor's
- // unused block argument, we need to process merge points in an expanding
- // worklist, `mergePointArgsToProcess`.
+void MemorySlotPromoter::removeUnusedItems() {
+ // We want to eliminate unused block arguments. Because block arguments can be
+ // used to populate other block arguments, there might be cycles of arguments
+ // that are only used to populate each-other. We therefore need a small
+ // dataflow analysis to identify which block arguments are truly used.
SmallPtrSet<BlockArgument, 8> mergePointArgsUnused;
- SmallVector<BlockArgument> mergePointArgsToProcess;
+ SmallVector<BlockArgument> usedMergePointArgsToProcess;
+
+ // First, separate the block arguments that are not used or only used for the
+ // purpose of populating a merge point block argument from the others. These
+ // block arguments are potentially unused. Meanwhile, arguments that are
+ // definitely used will be the starting point of the propagation of the
+ // analysis.
+ auto isDefinitelyUsed = [&](BlockArgument arg) {
+ for (auto &use : arg.getUses()) {
+ if (llvm::is_contained(toErase, use.getOwner()))
+ continue;
+
+ // We now want to detect whether the use is to populate a merge point
+ // block argument. If it is not, the argument is definitely used.
+
+ auto branchOp = dyn_cast<BranchOpInterface>(use.getOwner());
+ if (!branchOp)
+ return true;
+
+ auto successorArgument =
+ branchOp.getSuccessorBlockArgument(use.getOperandNumber());
+ if (!successorArgument.has_value())
+ return true;
+
+ if (!info.mergePoints.contains(successorArgument->getOwner()))
+ return true;
+
+ bool isLastBlockArgument =
+ successorArgument->getArgNumber() ==
+ successorArgument->getOwner()->getNumArguments() - 1;
+ if (!isLastBlockArgument)
+ return true;
+ }
+ return false;
+ };
+
for (Block *mergePoint : info.mergePoints) {
BlockArgument arg = mergePoint->getArguments().back();
- if (arg.use_empty())
- mergePointArgsUnused.insert(arg);
+ if (isDefinitelyUsed(arg))
+ usedMergePointArgsToProcess.push_back(arg);
else
- mergePointArgsToProcess.push_back(arg);
+ mergePointArgsUnused.insert(arg);
}
- while (!mergePointArgsToProcess.empty()) {
- BlockArgument arg = mergePointArgsToProcess.pop_back_val();
+ // We now refine mergePointArgsUnused from the information of which block
+ // arguments are definitely used.
+ while (!usedMergePointArgsToProcess.empty()) {
+ BlockArgument arg = usedMergePointArgsToProcess.pop_back_val();
Block *mergePoint = arg.getOwner();
- for (BlockOperand &use : mergePoint->getUses()) {
- Value reachingDef = reachingAtBlockEnd[use.getOwner()->getBlock()];
- if (!reachingDef)
- reachingDef = getOrCreateDefaultValue();
+ assert(arg.getArgNumber() == mergePoint->getNumArguments() - 1 &&
+ "merge point argument must be the last argument of the merge point");
- // If the reaching definition is a block argument of an unused merge
- // point, mark it as used and process it as such later.
- auto reachingDefArgument = dyn_cast<BlockArgument>(reachingDef);
- if (reachingDefArgument &&
- mergePointArgsUnused.erase(reachingDefArgument))
- mergePointArgsToProcess.push_back(reachingDefArgument);
-
- BranchOpInterface user = cast<BranchOpInterface>(use.getOwner());
- user.getSuccessorOperands(use.getOperandNumber()).append(reachingDef);
+ for (BlockOperand &use : mergePoint->getUses()) {
+ // If a value used to populate this used merge point argument is another
+ // merge point block argument that is currently considered unused, it must
+ // now be considered used and processed as such later.
+
+ auto branch = cast<BranchOpInterface>(use.getOwner());
+ SuccessorOperands succOperands =
+ branch.getSuccessorOperands(use.getOperandNumber());
+
+ // The successor operand is either the last one or is not present if the
+ // user block is dead.
+ assert(succOperands.size() == mergePoint->getNumArguments() ||
+ succOperands.size() + 1 == mergePoint->getNumArguments());
+
+ // If the user block is dead, the default value acts as a placeholder
+ // dummy value.
+ if (succOperands.size() + 1 == mergePoint->getNumArguments())
+ succOperands.append(getOrCreateDefaultValue());
+
+ Value populatedValue = succOperands[arg.getArgNumber()];
+ auto populatedValueAsArg = dyn_cast<BlockArgument>(populatedValue);
+ if (populatedValueAsArg &&
+ mergePointArgsUnused.erase(populatedValueAsArg))
+ usedMergePointArgsToProcess.push_back(populatedValueAsArg);
}
builder.setInsertionPointToStart(mergePoint);
@@ -804,8 +866,17 @@ void MemorySlotPromoter::linkMergePoints() {
(*statistics.newBlockArgumentAmount)++;
}
+ for (Operation *toEraseOp : toErase)
+ toEraseOp->erase();
+
for (BlockArgument arg : mergePointArgsUnused) {
Block *mergePoint = arg.getOwner();
+ for (BlockOperand &use : mergePoint->getUses()) {
+ auto branch = cast<BranchOpInterface>(use.getOwner());
+ SuccessorOperands succOperands =
+ branch.getSuccessorOperands(use.getOperandNumber());
+ succOperands.erase(arg.getArgNumber());
+ }
mergePoint->eraseArgument(mergePoint->getNumArguments() - 1);
}
}
@@ -832,11 +903,8 @@ MemorySlotPromoter::promoteSlot() {
op.visitReplacedValues(replacedValues, builder);
}
- // Finally, connect merge points to their predecessor's reaching definitions.
- linkMergePoints();
-
- for (Operation *toEraseOp : toErase)
- toEraseOp->erase();
+ // Finally, remove unused operations and merge point block arguments.
+ removeUnusedItems();
assert(slot.ptr.use_empty() &&
"after promotion, the slot pointer should not be used anymore");
diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir
index 30a268521c69b..20c49267ce7fc 100644
--- a/mlir/test/Dialect/MemRef/mem2reg.mlir
+++ b/mlir/test/Dialect/MemRef/mem2reg.mlir
@@ -171,8 +171,84 @@ func.func @unused_alloca_store_loop() {
// CHECK-NOT: memref.alloca
%cst = arith.constant 1 : i32
%alloca = memref.alloca() : memref<i32>
+ // CHECK: cf.br ^[[BB1:.*]]
cf.br ^bb1
+
+// CHECK: ^[[BB1]]:
^bb1:
+ // CHECK-NOT: memref.store
memref.store %cst, %alloca[] : memref<i32>
+ // CHECK: cf.br ^[[BB1]]
cf.br ^bb1
}
+
+// -----
+
+// CHECK-LABEL: func.func @store_back_to_alloca
+// CHECK-SAME: (%[[COND:.*]]: i1)
+func.func @store_back_to_alloca(%cond: i1) -> i32 {
+ // CHECK-NOT: memref.alloca
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+ %c0 = arith.constant 0 : i32
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+ %c1 = arith.constant 1 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c0, %alloca[] : memref<i32>
+ %loaded = memref.load %alloca[] : memref<i32>
+ // CHECK: cf.cond_br %[[COND]], ^[[STORE_BACK:.*]], ^[[SKIP:.*]]
+ cf.cond_br %cond, ^store_back, ^skip
+
+// CHECK: ^[[STORE_BACK]]:
+^store_back:
+ memref.store %loaded, %alloca[] : memref<i32>
+ // CHECK: cf.br ^[[MERGE:.*]](%[[C0]] : i32)
+ cf.br ^merge
+
+// CHECK: ^[[SKIP]]:
+^skip:
+ memref.store %c1, %alloca[] : memref<i32>
+ // CHECK: cf.br ^[[MERGE]](%[[C1]] : i32)
+ cf.br ^merge
+
+// CHECK: ^[[MERGE]](%[[RESULT:.*]]: i32):
+^merge:
+ %result = memref.load %alloca[] : memref<i32>
+ // CHECK: return %[[RESULT]] : i32
+ return %result : i32
+}
+
+// -----
+
+// Ensure that a merge point used by an erased operation is not considered used.
+
+// CHECK-LABEL: func.func @merge_point_used_by_erased_op
+// CHECK-SAME: (%[[COND:.*]]: i1)
+func.func @merge_point_used_by_erased_op(%cond: i1) -> i32 {
+ // CHECK-NOT: memref.alloca
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+ %c0 = arith.constant 0 : i32
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+ %c1 = arith.constant 1 : i32
+ %alloca = memref.alloca() : memref<i32>
+ // CHECK: cf.cond_br %[[COND]], ^[[MERGE:.*]], ^[[SKIP:.*]]
+ cf.cond_br %cond, ^merge, ^skip
+
+// CHECK: ^[[MERGE]]:
+^merge:
+ memref.store %c0, %alloca[] : memref<i32>
+ // CHECK: cf.br ^[[FINAL:.*]]{{$}}
+ cf.br ^final
+
+// CHECK: ^[[SKIP]]:
+^skip:
+ memref.store %c1, %alloca[] : memref<i32>
+ // CHECK: cf.br ^[[FINAL]]{{$}}
+ cf.br ^final
+
+// CHECK: ^[[FINAL]]:
+^final:
+ %result = memref.load %alloca[] : memref<i32>
+ memref.store %result, %alloca[] : memref<i32>
+ // CHECK: return %[[C0]] : i32
+ return %c0 : i32
+}
>From 3e8bdf3aea748067abe6f169aa4d3d51ca199dd2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= <tdegioanni at nvidia.com>
Date: Wed, 25 Mar 2026 13:43:03 +0000
Subject: [PATCH 2/5] make toErase a set vector for fast lookup
---
mlir/lib/Transforms/Mem2Reg.cpp | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 77f87dd20d185..58fc1b845ac73 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -19,6 +19,7 @@
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/GenericIteratedDominanceFrontier.h"
@@ -294,7 +295,7 @@ class MemorySlotPromoter {
/// the promotion.
llvm::SmallVector<PromotableOpInterface> toVisitReplacedValues;
/// Operations to be erased at the end of the promotion.
- llvm::SmallVector<Operation *> toErase;
+ llvm::SmallSetVector<Operation *, 8> toErase;
DominanceInfo &dominance;
const DataLayout &dataLayout;
@@ -757,7 +758,7 @@ void MemorySlotPromoter::removeBlockingUses(Region *region) {
if (toPromoteMemOp.removeBlockingUses(slot, blockingUsesMap[toPromote],
builder, reachingDef,
dataLayout) == DeletionKind::Delete)
- toErase.push_back(toPromote);
+ toErase.insert(toPromote);
if (toPromoteMemOp.storesTo(slot))
if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
replacedValues.push_back({toPromoteMemOp, replacedValue});
@@ -768,7 +769,7 @@ void MemorySlotPromoter::removeBlockingUses(Region *region) {
builder.setInsertionPointAfter(toPromote);
if (toPromoteBasic.removeBlockingUses(blockingUsesMap[toPromote],
builder) == DeletionKind::Delete)
- toErase.push_back(toPromote);
+ toErase.insert(toPromote);
if (toPromoteBasic.requiresReplacedValues())
toVisitReplacedValues.push_back(toPromoteBasic);
}
>From b958ce9e54d4677f2fa19fa5ad57acd5bd744bda Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= <tdegioanni at nvidia.com>
Date: Wed, 25 Mar 2026 13:50:20 +0000
Subject: [PATCH 3/5] fix silly test names
---
mlir/test/Dialect/MemRef/mem2reg.mlir | 20 ++++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)
diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir
index 20c49267ce7fc..ab0f32bf5f2ca 100644
--- a/mlir/test/Dialect/MemRef/mem2reg.mlir
+++ b/mlir/test/Dialect/MemRef/mem2reg.mlir
@@ -230,23 +230,23 @@ func.func @merge_point_used_by_erased_op(%cond: i1) -> i32 {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
%c1 = arith.constant 1 : i32
%alloca = memref.alloca() : memref<i32>
- // CHECK: cf.cond_br %[[COND]], ^[[MERGE:.*]], ^[[SKIP:.*]]
- cf.cond_br %cond, ^merge, ^skip
+ // CHECK: cf.cond_br %[[COND]], ^[[PRED1:.*]], ^[[PRED2:.*]]
+ cf.cond_br %cond, ^pred1, ^pred2
-// CHECK: ^[[MERGE]]:
-^merge:
+// CHECK: ^[[PRED1]]:
+^pred1:
memref.store %c0, %alloca[] : memref<i32>
// CHECK: cf.br ^[[FINAL:.*]]{{$}}
- cf.br ^final
+ cf.br ^merge
-// CHECK: ^[[SKIP]]:
-^skip:
+// CHECK: ^[[PRED2]]:
+^pred2:
memref.store %c1, %alloca[] : memref<i32>
// CHECK: cf.br ^[[FINAL]]{{$}}
- cf.br ^final
+ cf.br ^merge
-// CHECK: ^[[FINAL]]:
-^final:
+// CHECK: ^[[MERGE]]:
+^merge:
%result = memref.load %alloca[] : memref<i32>
memref.store %result, %alloca[] : memref<i32>
// CHECK: return %[[C0]] : i32
>From 37255a998a99faebed9ca3a5bfc54790d2b8ecc5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= <tdegioanni at nvidia.com>
Date: Wed, 25 Mar 2026 14:02:57 +0000
Subject: [PATCH 4/5] add consecutive merge point test for unused propagation
---
mlir/test/Dialect/MemRef/mem2reg.mlir | 57 +++++++++++++++++++++++++++
1 file changed, 57 insertions(+)
diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir
index ab0f32bf5f2ca..59fb3a3b4a78a 100644
--- a/mlir/test/Dialect/MemRef/mem2reg.mlir
+++ b/mlir/test/Dialect/MemRef/mem2reg.mlir
@@ -192,6 +192,7 @@ func.func @store_back_to_alloca(%cond: i1) -> i32 {
%c0 = arith.constant 0 : i32
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
%c1 = arith.constant 1 : i32
+ // CHECK-NOT: memref.alloca
%alloca = memref.alloca() : memref<i32>
memref.store %c0, %alloca[] : memref<i32>
%loaded = memref.load %alloca[] : memref<i32>
@@ -229,6 +230,7 @@ func.func @merge_point_used_by_erased_op(%cond: i1) -> i32 {
%c0 = arith.constant 0 : i32
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
%c1 = arith.constant 1 : i32
+ // CHECK-NOT: memref.alloca
%alloca = memref.alloca() : memref<i32>
// CHECK: cf.cond_br %[[COND]], ^[[PRED1:.*]], ^[[PRED2:.*]]
cf.cond_br %cond, ^pred1, ^pred2
@@ -252,3 +254,58 @@ func.func @merge_point_used_by_erased_op(%cond: i1) -> i32 {
// CHECK: return %[[C0]] : i32
return %c0 : i32
}
+
+// -----
+
+// Two consecutive merge points: pred1 and pred2 merge at merge1, then merge1
+// and pred3 merge at merge2.
+
+// CHECK-LABEL: func.func @two_consecutive_merge_points
+// CHECK-SAME: (%[[COND1:.*]]: i1, %[[COND2:.*]]: i1)
+func.func @two_consecutive_merge_points(%cond1: i1, %cond2: i1) -> i32 {
+ // CHECK-NOT: memref.alloca
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32
+ // CHECK-NOT: memref.alloca
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %c2 = arith.constant 2 : i32
+ %alloca = memref.alloca() : memref<i32>
+ // CHECK: cf.cond_br %[[COND1]], ^[[PRED1:.*]], ^[[MID:.*]]
+ cf.cond_br %cond1, ^pred1, ^mid
+
+// CHECK: ^[[MID]]:
+^mid:
+ // CHECK: cf.cond_br %[[COND2]], ^[[PRED2:.*]], ^[[PRED3:.*]]
+ cf.cond_br %cond2, ^pred2, ^pred3
+
+// CHECK: ^[[PRED1]]:
+^pred1:
+ memref.store %c0, %alloca[] : memref<i32>
+ // CHECK: cf.br ^[[MERGE1:.*]](%[[C0]] : i32)
+ cf.br ^merge1
+
+// CHECK: ^[[PRED2]]:
+^pred2:
+ memref.store %c1, %alloca[] : memref<i32>
+ // CHECK: cf.br ^[[MERGE1]](%[[C1]] : i32)
+ cf.br ^merge1
+
+// CHECK: ^[[MERGE1]](%[[MARG:.*]]: i32):
+^merge1:
+ // CHECK: cf.br ^[[MERGE2:.*]](%[[MARG]] : i32)
+ cf.br ^merge2
+
+// CHECK: ^[[PRED3]]:
+^pred3:
+ memref.store %c2, %alloca[] : memref<i32>
+ // CHECK: cf.br ^[[MERGE2]](%[[C2]] : i32)
+ cf.br ^merge2
+
+// CHECK: ^[[MERGE2]](%[[RESULT:.*]]: i32):
+^merge2:
+ %result = memref.load %alloca[] : memref<i32>
+ // CHECK: return %[[RESULT]] : i32
+ return %result : i32
+}
>From 359b9bf59f8ec887d2b002c79c9fe8c0326f0110 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= <tdegioanni at nvidia.com>
Date: Wed, 25 Mar 2026 14:04:46 +0000
Subject: [PATCH 5/5] fix typo
---
mlir/test/Dialect/MemRef/mem2reg.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir
index 59fb3a3b4a78a..8f937c4efe75e 100644
--- a/mlir/test/Dialect/MemRef/mem2reg.mlir
+++ b/mlir/test/Dialect/MemRef/mem2reg.mlir
@@ -238,13 +238,13 @@ func.func @merge_point_used_by_erased_op(%cond: i1) -> i32 {
// CHECK: ^[[PRED1]]:
^pred1:
memref.store %c0, %alloca[] : memref<i32>
- // CHECK: cf.br ^[[FINAL:.*]]{{$}}
+ // CHECK: cf.br ^[[MERGE:.*]]{{$}}
cf.br ^merge
// CHECK: ^[[PRED2]]:
^pred2:
memref.store %c1, %alloca[] : memref<i32>
- // CHECK: cf.br ^[[FINAL]]{{$}}
+ // CHECK: cf.br ^[[MERGE]]{{$}}
cf.br ^merge
// CHECK: ^[[MERGE]]:
More information about the Mlir-commits
mailing list