[Mlir-commits] [mlir] Reapply "[MLIR] [Mem2Reg] Fix unused block argument removal logic (#188484)" (#188571) (PR #188599)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 25 13:46:58 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Slava Zakharin (vzakhari)

<details>
<summary>Changes</summary>

This reverts commit d9402d087ab90610d3ff8a78a50eb66d3be4cffd.

This re-applies commit e5adddc5be63b8bb8c36572f68ac64c8042cb282
along with https://github.com/cathyzhyi/llvm-project/commit/62eafb5cd1f2d3df9a3d37bfe03bb21f85615f3c

Co-authored-by: Yi Zhang <cathyzhyi@<!-- -->google.com>


---
Full diff: https://github.com/llvm/llvm-project/pull/188599.diff


3 Files Affected:

- (modified) mlir/lib/Transforms/Mem2Reg.cpp (+117-35) 
- (modified) mlir/test/Dialect/MemRef/mem2reg.mlir (+133) 
- (modified) mlir/test/Transforms/mem2reg.mlir (+40) 


``````````diff
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 0341956e923ce..deaa226f3e809 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
 
@@ -264,9 +265,9 @@ class MemorySlotPromoter {
   /// to a different region, the new region will be processed instead.
   void removeBlockingUses(Region *region);
 
-  /// Links merge point block arguments to the terminators targeting the merge
-  /// point or remove the argument if it is not used.
-  void linkMergePoints();
+  /// Removes operations and merge point block arguments that ended up not being
+  /// necessary.
+  void removeUnusedItems();
 
   /// Lazily-constructed default value representing the content of the slot when
   /// no store has been executed. This function may mutate IR.
@@ -294,7 +295,7 @@ class MemorySlotPromoter {
   /// the promotion.
   llvm::SmallVector<PromotableOpInterface> toVisitReplacedValues;
   /// Operations to be erased at the end of the promotion.
-  llvm::SmallVector<Operation *> toErase;
+  llvm::SmallSetVector<Operation *, 8> toErase;
 
   DominanceInfo &dominance;
   const DataLayout &dataLayout;
@@ -664,6 +665,18 @@ void MemorySlotPromoter::promoteInRegion(Region *region, Value reachingDef) {
 
     job.reachingDef = promoteInBlock(block, job.reachingDef);
 
+    if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
+      for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
+        if (info.mergePoints.contains(blockOperand.get())) {
+          if (!job.reachingDef)
+            job.reachingDef = getOrCreateDefaultValue();
+
+          terminator.getSuccessorOperands(blockOperand.getOperandNumber())
+              .append(job.reachingDef);
+        }
+      }
+    }
+
     for (auto *child : job.block->children())
       dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
   }
@@ -753,7 +766,7 @@ void MemorySlotPromoter::removeBlockingUses(Region *region) {
       if (toPromoteMemOp.removeBlockingUses(slot, blockingUsesMap[toPromote],
                                             builder, reachingDef,
                                             dataLayout) == DeletionKind::Delete)
-        toErase.push_back(toPromote);
+        toErase.insert(toPromote);
       if (toPromoteMemOp.storesTo(slot))
         if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
           replacedValues.push_back({toPromoteMemOp, replacedValue});
@@ -764,46 +777,99 @@ void MemorySlotPromoter::removeBlockingUses(Region *region) {
     builder.setInsertionPointAfter(toPromote);
     if (toPromoteBasic.removeBlockingUses(blockingUsesMap[toPromote],
                                           builder) == DeletionKind::Delete)
-      toErase.push_back(toPromote);
+      toErase.insert(toPromote);
     if (toPromoteBasic.requiresReplacedValues())
       toVisitReplacedValues.push_back(toPromoteBasic);
   }
 }
 
-void MemorySlotPromoter::linkMergePoints() {
-  // We want to eliminate unused block arguments. In case connecting a block
-  // argument to its predecessor would trigger the use of the predecessor's
-  // unused block argument, we need to process merge points in an expanding
-  // worklist, `mergePointArgsToProcess`.
+void MemorySlotPromoter::removeUnusedItems() {
+  // We want to eliminate unused block arguments. Because block arguments can be
+  // used to populate other block arguments, there might be cycles of arguments
+  // that are only used to populate each-other. We therefore need a small
+  // dataflow analysis to identify which block arguments are truly used.
 
   SmallPtrSet<BlockArgument, 8> mergePointArgsUnused;
-  SmallVector<BlockArgument> mergePointArgsToProcess;
+  SmallVector<BlockArgument> usedMergePointArgsToProcess;
+
+  // First, separate the block arguments that are not used or only used for the
+  // purpose of populating a merge point block argument from the others. These
+  // block arguments are potentially unused. Meanwhile, arguments that are
+  // definitely used will be the starting point of the propagation of the
+  // analysis.
+  auto isDefinitelyUsed = [&](BlockArgument arg) {
+    for (auto &use : arg.getUses()) {
+      if (llvm::is_contained(toErase, use.getOwner()))
+        continue;
+
+      // We now want to detect whether the use is to populate a merge point
+      // block argument. If it is not, the argument is definitely used.
+
+      auto branchOp = dyn_cast<BranchOpInterface>(use.getOwner());
+      if (!branchOp)
+        return true;
+
+      std::optional<BlockArgument> successorArgument =
+          branchOp.getSuccessorBlockArgument(use.getOperandNumber());
+      if (!successorArgument)
+        return true;
+
+      if (!info.mergePoints.contains(successorArgument->getOwner()))
+        return true;
+
+      // The last block argument of a merge point is its reaching definition
+      // argument. If the argument being populated is not the last one, it is a
+      // genuine use of the value.
+      bool isLastBlockArgument =
+          successorArgument->getArgNumber() ==
+          successorArgument->getOwner()->getNumArguments() - 1;
+      if (!isLastBlockArgument)
+        return true;
+    }
+    return false;
+  };
+
   for (Block *mergePoint : info.mergePoints) {
     BlockArgument arg = mergePoint->getArguments().back();
-    if (arg.use_empty())
-      mergePointArgsUnused.insert(arg);
+    if (isDefinitelyUsed(arg))
+      usedMergePointArgsToProcess.push_back(arg);
     else
-      mergePointArgsToProcess.push_back(arg);
+      mergePointArgsUnused.insert(arg);
   }
 
-  while (!mergePointArgsToProcess.empty()) {
-    BlockArgument arg = mergePointArgsToProcess.pop_back_val();
+  // We now refine mergePointArgsUnused from the information of which block
+  // arguments are definitely used.
+  while (!usedMergePointArgsToProcess.empty()) {
+    BlockArgument arg = usedMergePointArgsToProcess.pop_back_val();
     Block *mergePoint = arg.getOwner();
 
-    for (BlockOperand &use : mergePoint->getUses()) {
-      Value reachingDef = reachingAtBlockEnd[use.getOwner()->getBlock()];
-      if (!reachingDef)
-        reachingDef = getOrCreateDefaultValue();
+    assert(arg.getArgNumber() == mergePoint->getNumArguments() - 1 &&
+           "merge point argument must be the last argument of the merge point");
 
-      // If the reaching definition is a block argument of an unused merge
-      // point, mark it as used and process it as such later.
-      auto reachingDefArgument = dyn_cast<BlockArgument>(reachingDef);
-      if (reachingDefArgument &&
-          mergePointArgsUnused.erase(reachingDefArgument))
-        mergePointArgsToProcess.push_back(reachingDefArgument);
-
-      BranchOpInterface user = cast<BranchOpInterface>(use.getOwner());
-      user.getSuccessorOperands(use.getOperandNumber()).append(reachingDef);
+    for (BlockOperand &use : mergePoint->getUses()) {
+      // If a value used to populate this used merge point argument is another
+      // merge point block argument that is currently considered unused, it must
+      // now be considered used and processed as such later.
+
+      auto branch = cast<BranchOpInterface>(use.getOwner());
+      SuccessorOperands succOperands =
+          branch.getSuccessorOperands(use.getOperandNumber());
+
+      // The successor operand is either the last one or is not present if the
+      // user block is dead.
+      assert(succOperands.size() == mergePoint->getNumArguments() ||
+             succOperands.size() + 1 == mergePoint->getNumArguments());
+
+      // If the user block is dead, the default value acts as a placeholder
+      // dummy value.
+      if (succOperands.size() + 1 == mergePoint->getNumArguments())
+        succOperands.append(getOrCreateDefaultValue());
+
+      Value populatedValue = succOperands[arg.getArgNumber()];
+      auto populatedValueAsArg = dyn_cast<BlockArgument>(populatedValue);
+      if (populatedValueAsArg &&
+          mergePointArgsUnused.erase(populatedValueAsArg))
+        usedMergePointArgsToProcess.push_back(populatedValueAsArg);
     }
 
     builder.setInsertionPointToStart(mergePoint);
@@ -812,6 +878,25 @@ void MemorySlotPromoter::linkMergePoints() {
       (*statistics.newBlockArgumentAmount)++;
   }
 
+  for (Operation *toEraseOp : toErase)
+    toEraseOp->erase();
+
+  // First, erase all successor operands that feed into unused merge point
+  // block arguments. This must be done before erasing the block arguments
+  // themselves because an unused merge point argument may be used to
+  // populate another unused merge point argument via a branch operation.
+  for (BlockArgument arg : mergePointArgsUnused) {
+    Block *mergePoint = arg.getOwner();
+    for (BlockOperand &use : mergePoint->getUses()) {
+      auto branch = cast<BranchOpInterface>(use.getOwner());
+      SuccessorOperands succOperands =
+          branch.getSuccessorOperands(use.getOperandNumber());
+      succOperands.erase(arg.getArgNumber());
+    }
+  }
+
+  // Now that all successor operands feeding unused args have been removed,
+  // erase the block arguments themselves.
   for (BlockArgument arg : mergePointArgsUnused) {
     Block *mergePoint = arg.getOwner();
     mergePoint->eraseArgument(mergePoint->getNumArguments() - 1);
@@ -840,11 +925,8 @@ MemorySlotPromoter::promoteSlot() {
     op.visitReplacedValues(replacedValues, builder);
   }
 
-  // Finally, connect merge points to their predecessor's reaching definitions.
-  linkMergePoints();
-
-  for (Operation *toEraseOp : toErase)
-    toEraseOp->erase();
+  // Finally, remove unused operations and merge point block arguments.
+  removeUnusedItems();
 
   assert(slot.ptr.use_empty() &&
          "after promotion, the slot pointer should not be used anymore");
diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir
index 30a268521c69b..8f937c4efe75e 100644
--- a/mlir/test/Dialect/MemRef/mem2reg.mlir
+++ b/mlir/test/Dialect/MemRef/mem2reg.mlir
@@ -171,8 +171,141 @@ func.func @unused_alloca_store_loop() {
   // CHECK-NOT: memref.alloca
   %cst = arith.constant 1 : i32
   %alloca = memref.alloca() : memref<i32>
+  // CHECK: cf.br ^[[BB1:.*]]
   cf.br ^bb1
+
+// CHECK: ^[[BB1]]:
 ^bb1:
+  // CHECK-NOT: memref.store
   memref.store %cst, %alloca[] : memref<i32>
+  // CHECK: cf.br ^[[BB1]]
   cf.br ^bb1
 }
+
+// -----
+
+// CHECK-LABEL: func.func @store_back_to_alloca
+// CHECK-SAME: (%[[COND:.*]]: i1)
+func.func @store_back_to_alloca(%cond: i1) -> i32 {
+  // CHECK-NOT: memref.alloca
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+  %c0 = arith.constant 0 : i32
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+  %c1 = arith.constant 1 : i32
+  // CHECK-NOT: memref.alloca
+  %alloca = memref.alloca() : memref<i32>
+  memref.store %c0, %alloca[] : memref<i32>
+  %loaded = memref.load %alloca[] : memref<i32>
+  // CHECK: cf.cond_br %[[COND]], ^[[STORE_BACK:.*]], ^[[SKIP:.*]]
+  cf.cond_br %cond, ^store_back, ^skip
+
+// CHECK: ^[[STORE_BACK]]:
+^store_back:
+  memref.store %loaded, %alloca[] : memref<i32>
+  // CHECK: cf.br ^[[MERGE:.*]](%[[C0]] : i32)
+  cf.br ^merge
+
+// CHECK: ^[[SKIP]]:
+^skip:
+  memref.store %c1, %alloca[] : memref<i32>
+  // CHECK: cf.br ^[[MERGE]](%[[C1]] : i32)
+  cf.br ^merge
+
+// CHECK: ^[[MERGE]](%[[RESULT:.*]]: i32):
+^merge:
+  %result = memref.load %alloca[] : memref<i32>
+  // CHECK: return %[[RESULT]] : i32
+  return %result : i32
+}
+
+// -----
+
+// Ensure that a merge point used by an erased operation is not considered used.
+
+// CHECK-LABEL: func.func @merge_point_used_by_erased_op
+// CHECK-SAME: (%[[COND:.*]]: i1)
+func.func @merge_point_used_by_erased_op(%cond: i1) -> i32 {
+  // CHECK-NOT: memref.alloca
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+  %c0 = arith.constant 0 : i32
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+  %c1 = arith.constant 1 : i32
+  // CHECK-NOT: memref.alloca
+  %alloca = memref.alloca() : memref<i32>
+  // CHECK: cf.cond_br %[[COND]], ^[[PRED1:.*]], ^[[PRED2:.*]]
+  cf.cond_br %cond, ^pred1, ^pred2
+
+// CHECK: ^[[PRED1]]:
+^pred1:
+  memref.store %c0, %alloca[] : memref<i32>
+  // CHECK: cf.br ^[[MERGE:.*]]{{$}}
+  cf.br ^merge
+
+// CHECK: ^[[PRED2]]:
+^pred2:
+  memref.store %c1, %alloca[] : memref<i32>
+  // CHECK: cf.br ^[[MERGE]]{{$}}
+  cf.br ^merge
+
+// CHECK: ^[[MERGE]]:
+^merge:
+  %result = memref.load %alloca[] : memref<i32>
+  memref.store %result, %alloca[] : memref<i32>
+  // CHECK: return %[[C0]] : i32
+  return %c0 : i32
+}
+
+// -----
+
+// Two consecutive merge points: pred1 and pred2 merge at merge1, then merge1
+// and pred3 merge at merge2.
+
+// CHECK-LABEL: func.func @two_consecutive_merge_points
+// CHECK-SAME: (%[[COND1:.*]]: i1, %[[COND2:.*]]: i1)
+func.func @two_consecutive_merge_points(%cond1: i1, %cond2: i1) -> i32 {
+  // CHECK-NOT: memref.alloca
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32
+  // CHECK-NOT: memref.alloca
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %c2 = arith.constant 2 : i32
+  %alloca = memref.alloca() : memref<i32>
+  // CHECK: cf.cond_br %[[COND1]], ^[[PRED1:.*]], ^[[MID:.*]]
+  cf.cond_br %cond1, ^pred1, ^mid
+
+// CHECK: ^[[MID]]:
+^mid:
+  // CHECK: cf.cond_br %[[COND2]], ^[[PRED2:.*]], ^[[PRED3:.*]]
+  cf.cond_br %cond2, ^pred2, ^pred3
+
+// CHECK: ^[[PRED1]]:
+^pred1:
+  memref.store %c0, %alloca[] : memref<i32>
+  // CHECK: cf.br ^[[MERGE1:.*]](%[[C0]] : i32)
+  cf.br ^merge1
+
+// CHECK: ^[[PRED2]]:
+^pred2:
+  memref.store %c1, %alloca[] : memref<i32>
+  // CHECK: cf.br ^[[MERGE1]](%[[C1]] : i32)
+  cf.br ^merge1
+
+// CHECK: ^[[MERGE1]](%[[MARG:.*]]: i32):
+^merge1:
+  // CHECK: cf.br ^[[MERGE2:.*]](%[[MARG]] : i32)
+  cf.br ^merge2
+
+// CHECK: ^[[PRED3]]:
+^pred3:
+  memref.store %c2, %alloca[] : memref<i32>
+  // CHECK: cf.br ^[[MERGE2]](%[[C2]] : i32)
+  cf.br ^merge2
+
+// CHECK: ^[[MERGE2]](%[[RESULT:.*]]: i32):
+^merge2:
+  %result = memref.load %alloca[] : memref<i32>
+  // CHECK: return %[[RESULT]] : i32
+  return %result : i32
+}
diff --git a/mlir/test/Transforms/mem2reg.mlir b/mlir/test/Transforms/mem2reg.mlir
index 70fbddcb25b2a..484d3e46881d8 100644
--- a/mlir/test/Transforms/mem2reg.mlir
+++ b/mlir/test/Transforms/mem2reg.mlir
@@ -113,3 +113,43 @@ func.func @unknown_region_op_load() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// A cycle of merge points where both merge point block arguments are unused.
+// merge1 branches to merge2, and merge2 branches back to merge1, so each
+// merge point's reaching definition arg is used as a successor operand
+// feeding the other. During removeUnusedItems, the successor operand erasure
+// and block argument erasure must be performed in separate phases. Otherwise,
+// regardless of iteration order, erasing either arg first will crash because
+// the other's successor operand still uses it.
+
+// CHECK-LABEL: func.func @cyclic_unused_merge_points
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK-NOT: memref.alloca
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: cf.br ^[[MERGE1:.*]]{{$}}
+// CHECK: ^[[MERGE1]]:
+// CHECK:   cf.cond_br %[[COND]], ^[[MERGE2:.*]], ^[[STORE:.*]]
+// CHECK: ^[[STORE]]:
+// CHECK:   cf.br ^[[MERGE2]]{{$}}
+// CHECK: ^[[MERGE2]]:
+// CHECK:   cf.cond_br %[[COND]], ^[[MERGE1]], ^[[EXIT:.*]]
+// CHECK: ^[[EXIT]]:
+// CHECK:   return %[[C0]] : i32
+func.func @cyclic_unused_merge_points(%cond: i1) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %alloca = memref.alloca() : memref<i32>
+  memref.store %c0, %alloca[] : memref<i32>
+  cf.br ^merge1
+^merge1:
+  cf.cond_br %cond, ^merge2, ^store
+^store:
+  memref.store %c1, %alloca[] : memref<i32>
+  cf.br ^merge2
+^merge2:
+  cf.cond_br %cond, ^merge1, ^exit
+^exit:
+  return %c0 : i32
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/188599


More information about the Mlir-commits mailing list