[Mlir-commits] [mlir] [mlir][Transforms] `remove-dead-values`: Rely on canonicalizer for region simplification (PR #173505)

Matthias Springer llvmlistbot at llvm.org
Thu Jan 1 10:16:23 PST 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/173505

>From c6ed9f774e860d8cfe42449595bb133bd72b1748 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 1 Jan 2026 16:40:56 +0000
Subject: [PATCH 1/2] [mlir][IR][NFC] Add `RewriterBase::eraseOpResults`
 convenience helper

---
 mlir/include/mlir/IR/PatternMatch.h      |   6 +
 mlir/lib/Dialect/SCF/IR/SCF.cpp          | 196 +++++++----------------
 mlir/lib/IR/PatternMatch.cpp             |  36 +++++
 mlir/lib/Transforms/RemoveDeadValues.cpp |  31 +---
 4 files changed, 105 insertions(+), 164 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 1caab24ac7295..83477c79ff582 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -545,6 +545,12 @@ class RewriterBase : public OpBuilder {
   /// This method erases all operations in a block.
   virtual void eraseBlock(Block *block);
 
+  /// Erase the specified results of the given operation. Results cannot be
+  /// erased directly, so the implementation creates a new replacement
+  /// operation and erases the original operation. The new operation is
+  /// returned.
+  Operation *eraseOpResults(Operation *op, const BitVector &eraseIndices);
+
   /// Inline the operations of block 'source' into block 'dest' before the given
   /// position. The source block will be deleted and must have no uses.
   /// 'argValues' is used to replace the block arguments of 'source'.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0a123112cf68f..d4e341416fd1b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1763,89 +1763,55 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
     //            ForallOp::getCombiningOps(iter_arg).
     //
     //         Based on the check we maintain the following :-
-    //         a. `resultToDelete` - i-th result of scf.forall that'll be
-    //            deleted.
-    //         b. `resultToReplace` - i-th result of the old scf.forall
-    //            whose uses will be replaced by the new scf.forall.
-    //         c. `newOuts` - the shared_outs' operand of the new scf.forall
-    //            corresponding to the i-th result with at least one use.
-    SetVector<OpResult> resultToDelete;
-    SmallVector<Value> resultToReplace;
+    //         a. op results, block arguments, outputs to delete
+    //         b. new outputs (i.e., outputs to retain)
+    SmallVector<Value> resultsToDelete;
+    SmallVector<Value> outsToDelete;
+    SmallVector<BlockArgument> blockArgsToDelete;
     SmallVector<Value> newOuts;
+    BitVector resultIndicesToDelete(forallOp.getNumResults(), false);
+    BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
+                                   false);
     for (OpResult result : forallOp.getResults()) {
       OpOperand *opOperand = forallOp.getTiedOpOperand(result);
       BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
       if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
-        resultToDelete.insert(result);
+        resultsToDelete.push_back(result);
+        outsToDelete.push_back(opOperand->get());
+        blockArgsToDelete.push_back(blockArg);
+        resultIndicesToDelete[result.getResultNumber()] = true;
+        blockIndicesToDelete[blockArg.getArgNumber()] = true;
       } else {
-        resultToReplace.push_back(result);
         newOuts.push_back(opOperand->get());
       }
     }
 
     // Return early if all results of scf.forall have at least one use and being
     // modified within the loop.
-    if (resultToDelete.empty())
+    if (resultsToDelete.empty())
       return failure();
 
-    // Step 2: For the the i-th result, do the following :-
-    //         a. Fetch the corresponding BlockArgument.
-    //         b. Look for store ops (currently tensor.parallel_insert_slice)
-    //            with the BlockArgument as its destination operand.
-    //         c. Remove the operations fetched in b.
-    for (OpResult result : resultToDelete) {
-      OpOperand *opOperand = forallOp.getTiedOpOperand(result);
-      BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+    // Step 2: Erase combining ops and replace uses of deleted results and
+    //         block arguments with the corresponding outputs.
+    for (auto blockArg : blockArgsToDelete) {
       SmallVector<Operation *> combiningOps =
           forallOp.getCombiningOps(blockArg);
       for (Operation *combiningOp : combiningOps)
         rewriter.eraseOp(combiningOp);
     }
-
-    // Step 3. Create a new scf.forall op with the new shared_outs' operands
-    //         fetched earlier
-    auto newForallOp = scf::ForallOp::create(
-        rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
-        forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
-        forallOp.getMapping(),
-        /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
-
-    // Step 4. Merge the block of the old scf.forall into the newly created
-    //         scf.forall using the new set of arguments.
-    Block *loopBody = forallOp.getBody();
-    Block *newLoopBody = newForallOp.getBody();
-    ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
-    // Form initial new bbArg list with just the control operands of the new
-    // scf.forall op.
-    SmallVector<Value> newBlockArgs =
-        llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
-                            [](BlockArgument b) -> Value { return b; });
-    Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
-    unsigned index = 0;
-    // Take the new corresponding bbArg if the old bbArg was used as a
-    // destination in the in_parallel op. For all other bbArgs, use the
-    // corresponding init_arg from the old scf.forall op.
-    for (OpResult result : forallOp.getResults()) {
-      if (resultToDelete.count(result)) {
-        newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
-      } else {
-        newBlockArgs.push_back(newSharedOutsArgs[index++]);
-      }
+    for (auto [blockArg, out, result] :
+         llvm::zip_equal(blockArgsToDelete, outsToDelete, resultsToDelete)) {
+      rewriter.replaceAllUsesWith(blockArg, out);
+      rewriter.replaceAllUsesWith(result, out);
     }
-    rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
-
-    // Step 5. Replace the uses of result of old scf.forall with that of the new
-    //         scf.forall.
-    for (auto &&[oldResult, newResult] :
-         llvm::zip(resultToReplace, newForallOp->getResults()))
-      rewriter.replaceAllUsesWith(oldResult, newResult);
-
-    // Step 6. Replace the uses of those values that either has no use or are
-    //         not being modified within the loop with the corresponding
-    //         OpOperand.
-    for (OpResult oldResult : resultToDelete)
-      rewriter.replaceAllUsesWith(oldResult,
-                                  forallOp.getTiedOpOperand(oldResult)->get());
+    forallOp.getBody()->eraseArguments(blockIndicesToDelete);
+
+    // Step 3. Create a new scf.forall op with only the shared_outs/results
+    //         that should be retained.
+    auto newForallOp = cast<scf::ForallOp>(
+        rewriter.eraseOpResults(forallOp, resultIndicesToDelete));
+    newForallOp.getOutputsMutable().assign(newOuts);
+
     return success();
   }
 };
@@ -2413,53 +2379,27 @@ namespace {
 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
   using OpRewritePattern<IfOp>::OpRewritePattern;
 
-  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
-                    PatternRewriter &rewriter) const {
-    // Move all operations to the destination block.
-    rewriter.mergeBlocks(source, dest);
-    // Replace the yield op by one that returns only the used values.
-    auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
-    SmallVector<Value, 4> usedOperands;
-    llvm::transform(usedResults, std::back_inserter(usedOperands),
-                    [&](OpResult result) {
-                      return yieldOp.getOperand(result.getResultNumber());
-                    });
-    rewriter.modifyOpInPlace(yieldOp,
-                             [&]() { yieldOp->setOperands(usedOperands); });
-  }
-
   LogicalResult matchAndRewrite(IfOp op,
                                 PatternRewriter &rewriter) const override {
-    // Compute the list of used results.
-    SmallVector<OpResult, 4> usedResults;
-    llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
-                  [](OpResult result) { return !result.use_empty(); });
-
-    // Replace the operation if only a subset of its results have uses.
-    if (usedResults.size() == op.getNumResults())
-      return failure();
-
-    // Compute the result types of the replacement operation.
-    SmallVector<Type, 4> newTypes;
-    llvm::transform(usedResults, std::back_inserter(newTypes),
-                    [](OpResult result) { return result.getType(); });
+    // Compute the list of unused results.
+    BitVector toErase(op.getNumResults(), false);
+    for (auto [idx, result] : llvm::enumerate(op.getResults()))
+      if (result.use_empty())
+        toErase[idx] = true;
+    if (toErase.none())
+      return rewriter.notifyMatchFailure(op, "no results to erase");
+
+    // Erase results.
+    auto newOp = cast<scf::IfOp>(rewriter.eraseOpResults(op, toErase));
+
+    // Erase operands.
+    rewriter.modifyOpInPlace(newOp.thenYield(), [&]() {
+      newOp.thenYield()->eraseOperands(toErase);
+    });
+    rewriter.modifyOpInPlace(newOp.elseYield(), [&]() {
+      newOp.elseYield()->eraseOperands(toErase);
+    });
 
-    // Create a replacement operation with empty then and else regions.
-    auto newOp =
-        IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
-    rewriter.createBlock(&newOp.getThenRegion());
-    rewriter.createBlock(&newOp.getElseRegion());
-
-    // Move the bodies and replace the terminators (note there is a then and
-    // an else region since the operation returns results).
-    transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
-    transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
-
-    // Replace the operation by the new one.
-    SmallVector<Value, 4> repResults(op.getNumResults());
-    for (const auto &en : llvm::enumerate(usedResults))
-      repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
-    rewriter.replaceOp(op, repResults);
     return success();
   }
 };
@@ -4720,43 +4660,27 @@ struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
                                 PatternRewriter &rewriter) const override {
     // Find dead results.
     BitVector deadResults(op.getNumResults(), false);
-    SmallVector<Type> newResultTypes;
-    for (auto [idx, result] : llvm::enumerate(op.getResults())) {
-      if (!result.use_empty()) {
-        newResultTypes.push_back(result.getType());
-      } else {
+    for (auto [idx, result] : llvm::enumerate(op.getResults()))
+      if (result.use_empty())
         deadResults[idx] = true;
-      }
-    }
     if (!deadResults.any())
       return rewriter.notifyMatchFailure(op, "no dead results to fold");
 
-    // Create new op without dead results and inline case regions.
-    auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
-                                       op.getArg(), op.getCases(),
-                                       op.getCaseRegions().size());
-    auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
-      rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
-      // Remove respective operands from yield op.
-      Operation *terminator = newRegion.front().getTerminator();
+    // Erase dead results.
+    auto newOp =
+        cast<scf::IndexSwitchOp>(rewriter.eraseOpResults(op, deadResults));
+
+    // Erase operands from yield ops.
+    auto updateCaseRegion = [&](Region &region) {
+      Operation *terminator = region.front().getTerminator();
       assert(isa<YieldOp>(terminator) && "expected yield op");
       rewriter.modifyOpInPlace(
           terminator, [&]() { terminator->eraseOperands(deadResults); });
     };
-    for (auto [oldRegion, newRegion] :
-         llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
-      inlineCaseRegion(oldRegion, newRegion);
-    inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
-
-    // Replace op with new op.
-    SmallVector<Value> newResults(op.getNumResults(), Value());
-    unsigned nextNewResult = 0;
-    for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
-      if (deadResults[idx])
-        continue;
-      newResults[idx] = newOp.getResult(nextNewResult++);
-    }
-    rewriter.replaceOp(op, newResults);
+    updateCaseRegion(newOp.getDefaultRegion());
+    for (Region &caseRegion : newOp.getCaseRegions())
+      updateCaseRegion(caseRegion);
+
     return success();
   }
 };
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 226e4e518d3e0..913063c87e1fa 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -244,6 +244,42 @@ void RewriterBase::eraseBlock(Block *block) {
   block->erase();
 }
 
+Operation *RewriterBase::eraseOpResults(Operation *op,
+                                        const BitVector &eraseIndices) {
+  assert(op->getNumResults() == eraseIndices.size() &&
+         "number of op results and bitvector size must match");
+
+  // Gather new result types.
+  SmallVector<Type> newResultTypes;
+  newResultTypes.reserve(op->getNumResults() - eraseIndices.count());
+  for (OpResult result : op->getResults())
+    if (!eraseIndices[result.getResultNumber()])
+      newResultTypes.push_back(result.getType());
+
+  // Create a new operation and inline all regions.
+  InsertionGuard g(*this);
+  setInsertionPoint(op);
+  OperationState state(op->getLoc(), op->getName().getStringRef(),
+                       op->getOperands(), newResultTypes, op->getAttrs());
+  for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
+    state.addRegion();
+  Operation *newOp = create(state);
+  for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
+    // Move all blocks of `region` into `newRegion`.
+    Region &newRegion = newOp->getRegion(index);
+    inlineRegionBefore(region, newRegion, newRegion.begin());
+  }
+
+  // Replace the original operation with the new operation.
+  SmallVector<Value> replacements(op->getNumResults(), Value());
+  unsigned nextResultIdx = 0;
+  for (unsigned i = 0, e = op->getNumResults(); i < e; ++i)
+    if (!eraseIndices[i])
+      replacements[i] = newOp->getResult(nextResultIdx++);
+  replaceOp(op, replacements);
+  return newOp;
+}
+
 void RewriterBase::finalizeOpModification(Operation *op) {
   // Notify the listener that the operation was modified.
   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 45266bc7b34ea..07911c6111043 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -205,35 +205,10 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
   assert(op->getNumResults() == toErase.size() &&
          "expected the number of results in `op` and the size of `toErase` to "
          "be the same");
-
-  std::vector<Type> newResultTypes;
-  for (OpResult result : op->getResults())
-    if (!toErase[result.getResultNumber()])
-      newResultTypes.push_back(result.getType());
+  for (auto idx : toErase.set_bits())
+    op->getResult(idx).dropAllUses();
   IRRewriter rewriter(op);
-  rewriter.setInsertionPointAfter(op);
-  OperationState state(op->getLoc(), op->getName().getStringRef(),
-                       op->getOperands(), newResultTypes, op->getAttrs());
-  for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
-    state.addRegion();
-  Operation *newOp = rewriter.create(state);
-  for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
-    // Move all blocks of `region` into `newRegion`.
-    Region &newRegion = newOp->getRegion(index);
-    rewriter.inlineRegionBefore(region, newRegion, newRegion.begin());
-  }
-
-  unsigned indexOfNextNewCallOpResultToReplace = 0;
-  for (auto [index, result] : llvm::enumerate(op->getResults())) {
-    assert(result && "expected result to be non-null");
-    if (toErase[index]) {
-      result.dropAllUses();
-    } else {
-      result.replaceAllUsesWith(
-          newOp->getResult(indexOfNextNewCallOpResultToReplace++));
-    }
-  }
-  op->erase();
+  rewriter.eraseOpResults(op, toErase);
 }
 
 /// Convert a list of `Operand`s to a list of `OpOperand`s.

>From 777c7299cc16ff8fa69d1588612da31ebf4bd176 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 24 Dec 2025 13:26:00 +0000
Subject: [PATCH 2/2] tmp commit

simple test working

draft: do not erase IR, just replace uses
---
 mlir/include/mlir/Transforms/Passes.h        |   1 +
 mlir/include/mlir/Transforms/Passes.td       |  10 +
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 511 +++++++------------
 mlir/test/Transforms/remove-dead-values.mlir | 155 ++++--
 4 files changed, 300 insertions(+), 377 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 724da009e70f1..9983944d374c5 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -42,6 +42,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_MEM2REG
 #define GEN_PASS_DECL_PRINTIRPASS
 #define GEN_PASS_DECL_PRINTOPSTATS
+#define GEN_PASS_DECL_REMOVEDEADVALUES
 #define GEN_PASS_DECL_SCCP
 #define GEN_PASS_DECL_SROA
 #define GEN_PASS_DECL_STRIPDEBUGINFO
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 55addfdb693e4..fc2d60d198cd6 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -246,7 +246,17 @@ def RemoveDeadValues : Pass<"remove-dead-values"> {
     do = square_and_double_of_y(5)
     print(do)
     ```
+
+    Note: If `canonicalize` is set to "false", this pass does not remove any
+    block arguments / op results from ops that implement the
+    RegionBranchOpInterface. Instead, it just sets dead operands to
+    "ub.poison".
   }];
+
+  let options = [
+    Option<"canonicalize", "canonicalize", "bool", /*default=*/"true",
+           "Canonicalize region branch ops">,
+  ];
   let constructor = "mlir::createRemoveDeadValuesPass()";
   let dependentDialects = ["ub::UBDialect"];
 }
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 07911c6111043..94b25f78786f9 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -94,8 +94,11 @@ struct ResultsToCleanup {
 struct OperandsToCleanup {
   Operation *op;
   BitVector nonLive;
-  Operation *callee =
-      nullptr; // Optional: For CallOpInterface ops, stores the callee function
+  // Optional: For CallOpInterface ops, stores the callee function.
+  Operation *callee = nullptr;
+  // Determines whether the operand should be replaced with a ub.poison result
+  // or erased entirely.
+  bool replaceWithPoison = false;
 };
 
 struct BlockArgsToCleanup {
@@ -199,27 +202,6 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
   }
 }
 
-/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
-/// is 1.
-static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
-  assert(op->getNumResults() == toErase.size() &&
-         "expected the number of results in `op` and the size of `toErase` to "
-         "be the same");
-  for (auto idx : toErase.set_bits())
-    op->getResult(idx).dropAllUses();
-  IRRewriter rewriter(op);
-  rewriter.eraseOpResults(op, toErase);
-}
-
-/// Convert a list of `Operand`s to a list of `OpOperand`s.
-static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
-  OpOperand *values = operands.getBase();
-  SmallVector<OpOperand *> opOperands;
-  for (unsigned i = 0, e = operands.size(); i < e; i++)
-    opOperands.push_back(&values[i]);
-  return opOperands;
-}
-
 /// Process a simple operation `op` using the liveness analysis `la`.
 /// If the operation has no memory effects and none of its results are live:
 ///   1. Add the operation to a list for future removal, and
@@ -379,30 +361,20 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 ///
 /// Scenario 1: If the operation has no memory effects and none of its results
 /// are live:
-///   (1') Enqueue all its uses for deletion.
-///   (2') Enqueue the branch itself for deletion.
+///   1.1. Enqueue all its uses for deletion.
+///   1.2. Enqueue the branch itself for deletion.
 ///
 /// Scenario 2: Otherwise:
-///   (1) Collect its unnecessary operands (operands forwarded to unnecessary
-///       results or arguments).
-///   (2) Process each of its regions.
-///   (3) Collect the uses of its unnecessary results (results forwarded from
-///       unnecessary operands
-///       or terminator operands).
-///   (4) Add these results to the deletion list.
-///
-/// Processing a region includes:
-///   (a) Collecting the uses of its unnecessary arguments (arguments forwarded
-///       from unnecessary operands
-///       or terminator operands).
-///   (b) Collecting these unnecessary arguments.
-///   (c) Collecting its unnecessary terminator operands (terminator operands
-///       forwarded to unnecessary results
-///       or arguments).
+///   2.1. Collect block arguments and op results that we would like to keep,
+///        based on their liveness.
+///   2.2. Find all operands that are forwarded to only dead region successor
+///        inputs. I.e., forwarded to block arguments / op results that we do
+///        not want to keep.
+///   2.3. Enqueue all such operands for replacement with ub.poison.
 ///
-/// Value Flow Note: In this operation, values flow as follows:
-/// - From operands and terminator operands (successor operands)
-/// - To arguments and results (successor inputs).
+/// Note: In scenario 2, block arguments and op results are not removed.
+/// However, the IR is simplified such that canonicalization patterns can
+/// remove them later.
 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
                                   RunLivenessAnalysis &la,
                                   DenseSet<Value> &nonLiveSet,
@@ -416,282 +388,76 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
   // It could never be live because of this op but its liveness could have been
   // attributed to something else.
-  // Do (1') and (2').
   if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
       !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
     cl.operations.push_back(regionBranchOp.getOperation());
     return;
   }
 
-  // Mark live results of `regionBranchOp` in `liveResults`.
-  auto markLiveResults = [&](BitVector &liveResults) {
-    liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
-  };
-
-  // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
-  auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
-    for (Region &region : regionBranchOp->getRegions()) {
-      if (region.empty())
-        continue;
-      SmallVector<Value> arguments(region.front().getArguments());
-      BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
-      liveArgs[&region] = regionLiveArgs;
-    }
-  };
-
-  // Return the successors of `region` if the latter is not null. Else return
-  // the successors of `regionBranchOp`.
-  auto getSuccessors = [&](RegionBranchPoint point) {
-    SmallVector<RegionSuccessor> successors;
-    regionBranchOp.getSuccessorRegions(point, successors);
-    return successors;
-  };
-
-  // Return the operands of `terminator` that are forwarded to `successor` if
-  // the former is not null. Else return the operands of `regionBranchOp`
-  // forwarded to `successor`.
-  auto getForwardedOpOperands = [&](RegionBranchPoint src,
-                                    const RegionSuccessor &successor) {
-    SmallVector<OpOperand *> opOperands = operandsToOpOperands(
-        regionBranchOp.getSuccessorOperands(src, successor));
-    return opOperands;
-  };
-
-  // Mark the non-forwarded operands of `regionBranchOp` in
-  // `nonForwardedOperands`.
-  auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
-    nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
-    for (const RegionSuccessor &successor :
-         getSuccessors(RegionBranchPoint::parent())) {
-      for (OpOperand *opOperand :
-           getForwardedOpOperands(RegionBranchPoint::parent(), successor))
-        nonForwardedOperands.reset(opOperand->getOperandNumber());
-    }
-  };
-
-  // Mark the non-forwarded terminator operands of the various regions of
-  // `regionBranchOp` in `nonForwardedRets`.
-  auto markNonForwardedReturnValues =
-      [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
-        for (Region &region : regionBranchOp->getRegions()) {
-          if (region.empty())
-            continue;
-          // TODO: this isn't correct in face of multiple terminators.
-          auto terminator = cast<RegionBranchTerminatorOpInterface>(
-              region.front().getTerminator());
-          nonForwardedRets[terminator] =
-              BitVector(terminator->getNumOperands(), true);
-          for (const RegionSuccessor &successor : getSuccessors(terminator)) {
-            for (OpOperand *opOperand : getForwardedOpOperands(
-                     RegionBranchPoint(terminator), successor))
-              nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
-          }
-        }
-      };
-
-  // Update `valuesToKeep` (which is expected to correspond to operands or
-  // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
-  // `region`. When `valuesToKeep` correspond to operands, `region` is null.
-  // Else, `region` is the parent region of the terminator.
-  auto updateOperandsOrTerminatorOperandsToKeep =
-      [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
-          DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
-        Operation *terminator =
-            region ? region->front().getTerminator() : nullptr;
-        RegionBranchPoint point =
-            terminator
-                ? RegionBranchPoint(
-                      cast<RegionBranchTerminatorOpInterface>(terminator))
-                : RegionBranchPoint::parent();
-
-        for (const RegionSuccessor &successor : getSuccessors(point)) {
-          Region *successorRegion = successor.getSuccessor();
-          for (auto [opOperand, input] :
-               llvm::zip(getForwardedOpOperands(point, successor),
-                         successor.getSuccessorInputs())) {
-            size_t operandNum = opOperand->getOperandNumber();
-            bool updateBasedOn =
-                successorRegion
-                    ? argsToKeep[successorRegion]
-                                [cast<BlockArgument>(input).getArgNumber()]
-                    : resultsToKeep[cast<OpResult>(input).getResultNumber()];
-            valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
-          }
-        }
-      };
-
-  // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
-  // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
-  // value is modified, else, false.
-  auto recomputeResultsAndArgsToKeep =
-      [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
-          BitVector &operandsToKeep,
-          DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
-          bool &resultsOrArgsToKeepChanged) {
-        resultsOrArgsToKeepChanged = false;
-
-        // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
-        for (const RegionSuccessor &successor :
-             getSuccessors(RegionBranchPoint::parent())) {
-          Region *successorRegion = successor.getSuccessor();
-          for (auto [opOperand, input] :
-               llvm::zip(getForwardedOpOperands(RegionBranchPoint::parent(),
-                                                successor),
-                         successor.getSuccessorInputs())) {
-            bool recomputeBasedOn =
-                operandsToKeep[opOperand->getOperandNumber()];
-            bool toRecompute =
-                successorRegion
-                    ? argsToKeep[successorRegion]
-                                [cast<BlockArgument>(input).getArgNumber()]
-                    : resultsToKeep[cast<OpResult>(input).getResultNumber()];
-            if (!toRecompute && recomputeBasedOn)
-              resultsOrArgsToKeepChanged = true;
-            if (successorRegion) {
-              argsToKeep[successorRegion][cast<BlockArgument>(input)
-                                              .getArgNumber()] =
-                  argsToKeep[successorRegion]
-                            [cast<BlockArgument>(input).getArgNumber()] |
-                  recomputeBasedOn;
-            } else {
-              resultsToKeep[cast<OpResult>(input).getResultNumber()] =
-                  resultsToKeep[cast<OpResult>(input).getResultNumber()] |
-                  recomputeBasedOn;
-            }
-          }
-        }
-
-        // Recompute `resultsToKeep` and `argsToKeep` based on
-        // `terminatorOperandsToKeep`.
-        for (Region &region : regionBranchOp->getRegions()) {
-          if (region.empty())
-            continue;
-          auto terminator = cast<RegionBranchTerminatorOpInterface>(
-              region.front().getTerminator());
-          for (const RegionSuccessor &successor : getSuccessors(terminator)) {
-            Region *successorRegion = successor.getSuccessor();
-            for (auto [opOperand, input] :
-                 llvm::zip(getForwardedOpOperands(RegionBranchPoint(terminator),
-                                                  successor),
-                           successor.getSuccessorInputs())) {
-              bool recomputeBasedOn =
-                  terminatorOperandsToKeep[region.back().getTerminator()]
-                                          [opOperand->getOperandNumber()];
-              bool toRecompute =
-                  successorRegion
-                      ? argsToKeep[successorRegion]
-                                  [cast<BlockArgument>(input).getArgNumber()]
-                      : resultsToKeep[cast<OpResult>(input).getResultNumber()];
-              if (!toRecompute && recomputeBasedOn)
-                resultsOrArgsToKeepChanged = true;
-              if (successorRegion) {
-                argsToKeep[successorRegion][cast<BlockArgument>(input)
-                                                .getArgNumber()] =
-                    argsToKeep[successorRegion]
-                              [cast<BlockArgument>(input).getArgNumber()] |
-                    recomputeBasedOn;
-              } else {
-                resultsToKeep[cast<OpResult>(input).getResultNumber()] =
-                    resultsToKeep[cast<OpResult>(input).getResultNumber()] |
-                    recomputeBasedOn;
-              }
-            }
-          }
-        }
-      };
-
-  // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
-  // `operandsToKeep`, and `terminatorOperandsToKeep`.
-  auto markValuesToKeep =
-      [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
-          BitVector &operandsToKeep,
-          DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
-        bool resultsOrArgsToKeepChanged = true;
-        // We keep updating and recomputing the values until we reach a point
-        // where they stop changing.
-        while (resultsOrArgsToKeepChanged) {
-          // Update the operands that need to be kept.
-          updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
-                                                   resultsToKeep, argsToKeep);
-
-          // Update the terminator operands that need to be kept.
-          for (Region &region : regionBranchOp->getRegions()) {
-            if (region.empty())
-              continue;
-            updateOperandsOrTerminatorOperandsToKeep(
-                terminatorOperandsToKeep[region.back().getTerminator()],
-                resultsToKeep, argsToKeep, &region);
-          }
-
-          // Recompute the results and arguments that need to be kept.
-          recomputeResultsAndArgsToKeep(
-              resultsToKeep, argsToKeep, operandsToKeep,
-              terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
-        }
-      };
-
-  // Scenario 2.
-  // At this point, we know that every non-forwarded operand of `regionBranchOp`
-  // is live.
-
-  // Stores the results of `regionBranchOp` that we want to keep.
-  BitVector resultsToKeep;
-  // Stores the mapping from regions of `regionBranchOp` to their arguments that
-  // we want to keep.
-  DenseMap<Region *, BitVector> argsToKeep;
-  // Stores the operands of `regionBranchOp` that we want to keep.
-  BitVector operandsToKeep;
-  // Stores the mapping from region terminators in `regionBranchOp` to their
-  // operands that we want to keep.
-  DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
-
-  // Initializing the above variables...
-
-  // The live results of `regionBranchOp` definitely need to be kept.
-  markLiveResults(resultsToKeep);
-  // Similarly, the live arguments of the regions in `regionBranchOp` definitely
-  // need to be kept.
-  markLiveArgs(argsToKeep);
-  // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
-  // A live forwarded operand can be removed but no non-forwarded operand can be
-  // removed since it "controls" the flow of data in this control flow op.
-  markNonForwardedOperands(operandsToKeep);
-  // Similarly, the non-forwarded terminator operands of the regions in
-  // `regionBranchOp` definitely need to be kept.
-  markNonForwardedReturnValues(terminatorOperandsToKeep);
-
-  // Mark the values (results, arguments, operands, and terminator operands)
-  // that we want to keep.
-  markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
-                   terminatorOperandsToKeep);
-
-  // Do (1).
-  cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
-
-  // Do (2.a) and (2.b).
+  // Compute values that are alive.
+  DenseSet<Value> valuesToKeep;
+  for (Value result : regionBranchOp->getResults()) {
+    if (hasLive(result, nonLiveSet, la))
+      valuesToKeep.insert(result);
+  }
   for (Region &region : regionBranchOp->getRegions()) {
     if (region.empty())
       continue;
-    BitVector argsToRemove = argsToKeep[&region].flip();
-    cl.blocks.push_back({&region.front(), argsToRemove});
-    collectNonLiveValues(nonLiveSet, region.front().getArguments(),
-                         argsToRemove);
+    for (Value arg : region.front().getArguments()) {
+      if (hasLive(arg, nonLiveSet, la))
+        valuesToKeep.insert(arg);
+    }
   }
 
-  // Do (2.c).
-  for (Region &region : regionBranchOp->getRegions()) {
-    if (region.empty())
+  // Mapping from operands to forwarded successor inputs. An operand can be
+  // forwarded to multiple successors.
+  //
+  // Example:
+  //
+  // %0 = scf.while : () -> i32 {
+  //   scf.condition(...) %forwarded_value : i32
+  // } do {
+  // ^bb0(%arg0: i32):
+  //   scf.yield
+  // }
+  // // No uses of %0.
+  //
+  // In the above example, %forwarded_value is forwarded to %arg0 and %0. Both
+  // %arg0 and %0 are dead, so %forwarded_value can be replaced with a
+  // ub.poison result.
+  //
+  // operandToSuccessorInputs[%forwarded_value] = {%arg0, %0}
+  //
+  RegionBranchSuccessorMapping operandToSuccessorInputs;
+  regionBranchOp.getSuccessorOperandInputMapping(operandToSuccessorInputs);
+
+  DenseMap<Operation *, BitVector> deadOperandsPerOp;
+  for (auto [opOperand, successorInputs] : operandToSuccessorInputs) {
+    // If one of the successor inputs is live, the respective operand must be
+    // kept.
+    bool anyAlive = llvm::any_of(successorInputs, [&](Value input) {
+      return valuesToKeep.contains(input);
+    });
+    if (anyAlive)
       continue;
-    Operation *terminator = region.front().getTerminator();
-    cl.operands.push_back(
-        {terminator, terminatorOperandsToKeep[terminator].flip()});
+
+    // All successor inputs are dead: ub.poison can be passed as operand.
+    // Create an entry in `deadOperandsPerOp` (initialized to "false", i.e.,
+    // no "dead" op operands) if it's the first time that we are seeing an op
+    // operand for this op. Otherwise, just take the existing bit vector from
+    // the map.
+    BitVector &deadOperands =
+        deadOperandsPerOp
+            .try_emplace(opOperand->getOwner(),
+                         opOperand->getOwner()->getNumOperands(), false)
+            .first->second;
+    deadOperands.set(opOperand->getOperandNumber());
   }
 
-  // Do (3) and (4).
-  BitVector resultsToRemove = resultsToKeep.flip();
-  collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
-                       resultsToRemove);
-  cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
+  for (auto [op, deadOperands] : deadOperandsPerOp) {
+    cl.operands.push_back(
+        {op, deadOperands, nullptr, /*replaceWithPoison=*/true});
+  }
 }
 
 /// Steps to process a `BranchOpInterface` operation:
@@ -751,11 +517,44 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
   }
 }
 
+/// Create ub.poison ops for the given values. If a value has no uses, return
+/// an "empty" value.
+static SmallVector<Value> createPoisonedValues(OpBuilder &b,
+                                               ValueRange values) {
+  return llvm::map_to_vector(values, [&](Value value) {
+    if (value.use_empty())
+      return Value();
+    return ub::PoisonOp::create(b, value.getLoc(), value.getType()).getResult();
+  });
+}
+
+namespace {
+/// A listener that keeps track of ub.poison ops.
+struct TrackingListener : public RewriterBase::Listener {
+  void notifyOperationErased(Operation *op) override {
+    if (auto poisonOp = dyn_cast<ub::PoisonOp>(op))
+      poisonOps.erase(poisonOp);
+  }
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override {
+    if (auto poisonOp = dyn_cast<ub::PoisonOp>(op))
+      poisonOps.insert(poisonOp);
+  }
+  DenseSet<ub::PoisonOp> poisonOps;
+};
+} // namespace
+
 /// Removes dead values collected in RDVFinalCleanupList.
 /// To be run once when all dead values have been collected.
-static void cleanUpDeadVals(RDVFinalCleanupList &list) {
+static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
   LDBG() << "Starting cleanup of dead values...";
 
+  // New ub.poison ops may be inserted during cleanup. Some of these ops may no
+  // longer be needed after the cleanup. A tracking listener keeps track of all
+  // new ub.poison ops, so that they can be removed again after the cleanup.
+  TrackingListener listener;
+  IRRewriter rewriter(ctx, &listener);
+
   // 1. Blocks, We must remove the block arguments and successor operands before
   // deleting the operation, as they may reside in the region operation.
   LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
@@ -773,10 +572,12 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
     });
     // Note: Iterate from the end to make sure that that indices of not yet
     // processes arguments do not change.
+    rewriter.setInsertionPointToStart(b.b);
     for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
       if (!b.nonLiveArgs[i])
         continue;
-      b.b->getArgument(i).dropAllUses();
+      b.b->getArgument(i).replaceAllUsesWith(
+          createPoisonedValues(rewriter, b.b->getArgument(i)).front());
       b.b->eraseArgument(i);
     }
   }
@@ -805,22 +606,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
     }
   }
 
-  // 3. Operations
-  LDBG() << "Cleaning up " << list.operations.size() << " operations";
-  for (Operation *op : list.operations) {
-    LDBG() << "Erasing operation: "
-           << OpWithFlags(op,
-                          OpPrintingFlags().skipRegions().printGenericOpForm());
-    if (op->hasTrait<OpTrait::IsTerminator>()) {
-      // When erasing a terminator, insert an unreachable op in its place.
-      OpBuilder b(op);
-      ub::UnreachableOp::create(b, op->getLoc());
-    }
-    op->dropAllUses();
-    op->erase();
-  }
-
-  // 4. Functions
+  // 3. Functions
   LDBG() << "Cleaning up " << list.functions.size() << " functions";
   // Record which function arguments were erased so we can shrink call-site
   // argument segments for CallOpInterface operations (e.g. ops using
@@ -837,12 +623,18 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
       llvm::interleaveComma(f.nonLiveRets.set_bits(), os);
       os << "]";
     });
-    // Drop all uses of the dead arguments.
-    for (auto deadIdx : f.nonLiveArgs.set_bits())
-      f.funcOp.getArgument(deadIdx).dropAllUses();
     // Some functions may not allow erasing arguments or results. These calls
     // return failure in such cases without modifying the function, so it's okay
     // to proceed.
+    bool hasBody = !f.funcOp.getFunctionBody().empty();
+    if (hasBody) {
+      rewriter.setInsertionPointToStart(&f.funcOp.getFunctionBody().front());
+      for (auto deadIdx : f.nonLiveArgs.set_bits()) {
+        f.funcOp.getArgument(deadIdx).replaceAllUsesWith(
+            createPoisonedValues(rewriter, f.funcOp.getArgument(deadIdx))
+                .front());
+      }
+    }
     if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
       // Record only if we actually erased something.
       if (f.nonLiveArgs.any())
@@ -851,7 +643,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
     (void)f.funcOp.eraseResults(f.nonLiveRets);
   }
 
-  // 5. Operands
+  // 4. Operands
   LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
   for (OperandsToCleanup &o : list.operands) {
     // Handle call-specific cleanup only when we have a cached callee reference.
@@ -896,11 +688,20 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
            << OpWithFlags(o.op,
                           OpPrintingFlags().skipRegions().printGenericOpForm());
       });
-      o.op->eraseOperands(o.nonLive);
+      if (o.replaceWithPoison) {
+        rewriter.setInsertionPoint(o.op);
+        for (auto deadIdx : o.nonLive.set_bits()) {
+          o.op->setOperand(
+              deadIdx, createPoisonedValues(rewriter, o.op->getOperand(deadIdx))
+                           .front());
+        }
+      } else {
+        o.op->eraseOperands(o.nonLive);
+      }
     }
   }
 
-  // 6. Results
+  // 5. Results
   LDBG() << "Cleaning up " << list.results.size() << " result lists";
   for (auto &r : list.results) {
     LDBG_OS([&](raw_ostream &os) {
@@ -910,8 +711,34 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
          << OpWithFlags(r.op,
                         OpPrintingFlags().skipRegions().printGenericOpForm());
     });
-    dropUsesAndEraseResults(r.op, r.nonLive);
+    rewriter.setInsertionPoint(r.op);
+    for (auto deadIdx : r.nonLive.set_bits()) {
+      r.op->getResult(deadIdx).replaceAllUsesWith(
+          createPoisonedValues(rewriter, r.op->getResult(deadIdx)).front());
+    }
+    rewriter.eraseOpResults(r.op, r.nonLive);
   }
+
+  // 6. Operations
+  LDBG() << "Cleaning up " << list.operations.size() << " operations";
+  for (Operation *op : list.operations) {
+    LDBG() << "Erasing operation: "
+           << OpWithFlags(op,
+                          OpPrintingFlags().skipRegions().printGenericOpForm());
+    rewriter.setInsertionPoint(op);
+    if (op->hasTrait<OpTrait::IsTerminator>()) {
+      // When erasing a terminator, insert an unreachable op in its place.
+      ub::UnreachableOp::create(rewriter, op->getLoc());
+    }
+    rewriter.replaceOp(op, createPoisonedValues(rewriter, op->getResults()));
+  }
+
+  // 7. Remove all dead poison ops.
+  for (ub::PoisonOp poisonOp : listener.poisonOps) {
+    if (poisonOp.use_empty())
+      poisonOp.erase();
+  }
+
   LDBG() << "Finished cleanup of dead values";
 }
 
@@ -950,7 +777,27 @@ void RemoveDeadValues::runOnOperation() {
     }
   });
 
-  cleanUpDeadVals(finalCleanupList);
+  MLIRContext *context = module->getContext();
+  cleanUpDeadVals(context, finalCleanupList);
+
+  if (!canonicalize)
+    return;
+
+  // Canonicalize all region branch ops.
+  SmallVector<Operation *> opsToCanonicalize;
+  module->walk([&](RegionBranchOpInterface regionBranchOp) {
+    opsToCanonicalize.push_back(regionBranchOp.getOperation());
+  });
+  RewritePatternSet owningPatterns(context);
+  for (auto *dialect : context->getLoadedDialects())
+    dialect->getCanonicalizationPatterns(owningPatterns);
+  for (RegisteredOperationName op : context->getRegisteredOperations())
+    op.getCanonicalizationPatterns(owningPatterns, context);
+  if (failed(applyOpPatternsGreedily(opsToCanonicalize,
+                                     std::move(owningPatterns)))) {
+    module->emitError("greedy pattern rewrite failed to converge");
+    signalPassFailure();
+  }
 }
 
 std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 71306676d48e9..b9a883dbd524e 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -remove-dead-values -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -remove-dead-values="canonicalize=0" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -remove-dead-values="canonicalize=1" -split-input-file | FileCheck %s --check-prefix=CHECK-CANONICALIZE
 
 // The IR is updated regardless of memref.global private constant
 //
@@ -55,19 +56,20 @@ func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0:
 
 // Checking that iter_args are properly handled
 //
+// CHECK-CANONICALIZE-LABEL: func @cleanable_loop_iter_args_value
 func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
   %non_live = arith.constant 0 : index
-  // CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
+  // CHECK-CANONICALIZE: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
   %result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) {
-    // CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
+    // CHECK-CANONICALIZE: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
     %new_live = arith.addi %live_arg, %i : index
-    // CHECK: scf.yield [[SUM:%.+]]
+    // CHECK-CANONICALIZE: scf.yield [[SUM:%.+]]
     scf.yield %new_live, %non_live_arg : index, index
   }
-  // CHECK: return [[RESULT]] : index
+  // CHECK-CANONICALIZE: return [[RESULT]] : index
   return %result : index
 }
 
@@ -79,7 +81,8 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
 #map = affine_map<(d0, d1, d2) -> (0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 module {
-  func.func @main() {
+  // CHECK-LABEL: @dead_linalg_generic
+  func.func @dead_linalg_generic() {
     %cst_3 = arith.constant dense<54> : tensor<1x25x13xi32>
     %cst_7 = arith.constant dense<11> : tensor<1x25x13xi32>
     // CHECK-NOT: arith.constant
@@ -229,18 +232,34 @@ func.func @main() -> (i32, i32) {
 //  anywhere else. Thus, %arg7 is also not kept in the `scf.yield` op.
 //
 // Note that this cleanup cannot be done by the `canonicalize` pass.
-//
-// CHECK:       func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
-// CHECK-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
-// CHECK-NEXT:      %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
-// CHECK-NEXT:      scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
+
+// CHECK-LABEL: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(
+// CHECK-SAME:      %[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
+// CHECK-NEXT:    %[[p0:.*]] = ub.poison : i32
+// CHECK-NEXT:    %[[while:.*]]:3 = scf.while (%{{.*}} = %[[p0]], %[[arg4:.*]] = %[[arg2]]) : (i32, i32) -> (i32, i32, i32) {
+// CHECK-NEXT:      %[[add1:.*]] = arith.addi %[[arg4]], %[[arg4]] : i32
+// CHECK-NEXT:      %[[p1:.*]] = ub.poison : i32
+// CHECK-NEXT:      scf.condition(%[[arg0]]) %[[add1]], %[[arg4]], %[[p1]] : i32, i32, i32
 // CHECK-NEXT:    } do {
-// CHECK-NEXT:    ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
-// CHECK-NEXT:      %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
-// CHECK-NEXT:      scf.yield %[[live_1]] : i32
+// CHECK-NEXT:    ^bb0(%{{.*}}: i32, %[[arg6:.*]]: i32, %{{.*}}: i32):
+// CHECK-NEXT:      %[[add2:.*]] = arith.addi %[[arg6]], %[[arg6]] : i32
+// CHECK-NEXT:      %[[p2:.*]] = ub.poison : i32
+// CHECK-NEXT:      scf.yield %[[p2]], %[[add2]] : i32, i32
 // CHECK-NEXT:    }
-// CHECK-NEXT:    return %[[live_and_non_live]]#0
+// CHECK-NEXT:    return %[[while]]#0 : i32
 // CHECK-NEXT:  }
+
+// CHECK-CANONICALIZE:       func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
+// CHECK-CANONICALIZE-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
+// CHECK-CANONICALIZE-NEXT:      %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
+// CHECK-CANONICALIZE-NEXT:      scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
+// CHECK-CANONICALIZE-NEXT:    } do {
+// CHECK-CANONICALIZE-NEXT:    ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CHECK-CANONICALIZE-NEXT:      %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
+// CHECK-CANONICALIZE-NEXT:      scf.yield %[[live_1]] : i32
+// CHECK-CANONICALIZE-NEXT:    }
+// CHECK-CANONICALIZE-NEXT:    return %[[live_and_non_live]]#0
+// CHECK-CANONICALIZE-NEXT:  }
 func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%arg0: i1, %arg1: i32, %arg2: i32) -> (i32) {
   %live, %non_live, %non_live_0 = scf.while (%arg3 = %arg1, %arg4 = %arg2) : (i32, i32) -> (i32, i32, i32) {
     %live_0 = arith.addi %arg4, %arg4 : i32
@@ -284,21 +303,21 @@ func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o
 //
 // Note that this cleanup cannot be done by the `canonicalize` pass.
 //
-// CHECK:       func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 {
-// CHECK-NEXT:    %[[c0:.*]] = arith.constant 0
-// CHECK-NEXT:    %[[c1:.*]] = arith.constant 1
-// CHECK-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
-// CHECK-NEXT:      func.call @identity() : () -> ()
-// CHECK-NEXT:      scf.condition(%[[arg2]]) %[[arg4]], %[[arg3]] : i32, i32
-// CHECK-NEXT:    } do {
-// CHECK-NEXT:    ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
-// CHECK-NEXT:      scf.yield %[[arg5]], %[[arg6]] : i32, i32
-// CHECK-NEXT:    }
-// CHECK-NEXT:    return %[[live_and_non_live]]#0 : i32
-// CHECK-NEXT:  }
-// CHECK:       func.func private @identity() {
-// CHECK-NEXT:    return
-// CHECK-NEXT:  }
+// CHECK-CANONICALIZE:       func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 {
+// CHECK-CANONICALIZE-NEXT:    %[[c0:.*]] = arith.constant 0
+// CHECK-CANONICALIZE-NEXT:    %[[c1:.*]] = arith.constant 1
+// CHECK-CANONICALIZE-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
+// CHECK-CANONICALIZE-NEXT:      func.call @identity() : () -> ()
+// CHECK-CANONICALIZE-NEXT:      scf.condition(%[[arg2]]) %[[arg3]], %[[arg4]] : i32, i32
+// CHECK-CANONICALIZE-NEXT:    } do {
+// CHECK-CANONICALIZE-NEXT:    ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CHECK-CANONICALIZE-NEXT:      scf.yield %[[arg6]], %[[arg5]] : i32, i32
+// CHECK-CANONICALIZE-NEXT:    }
+// CHECK-CANONICALIZE-NEXT:    return %[[live_and_non_live]]#1 : i32
+// CHECK-CANONICALIZE-NEXT:  }
+// CHECK-CANONICALIZE:       func.func private @identity() {
+// CHECK-CANONICALIZE-NEXT:    return
+// CHECK-CANONICALIZE-NEXT:  }
 func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%arg2: i1) -> (i32) {
   %c0 = arith.constant 0 : i32
   %c1 = arith.constant 1 : i32
@@ -325,17 +344,17 @@ func.func private @identity(%arg1 : i32) -> (i32) {
 //
 // Note that this cleanup cannot be done by the `canonicalize` pass.
 //
-// CHECK:       func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
-// CHECK-NEXT:    scf.index_switch %[[arg0]]
-// CHECK-NEXT:    case 1 {
-// CHECK-NEXT:      %[[c10:.*]] = arith.constant 10
-// CHECK-NEXT:      memref.store %[[c10]], %[[arg1]][]
-// CHECK-NEXT:      scf.yield
-// CHECK-NEXT:    }
-// CHECK-NEXT:    default {
-// CHECK-NEXT:    }
-// CHECK-NEXT:    return
-// CHECK-NEXT:  }
+// CHECK-CANONICALIZE:       func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
+// CHECK-CANONICALIZE-NEXT:    scf.index_switch %[[arg0]]
+// CHECK-CANONICALIZE-NEXT:    case 1 {
+// CHECK-CANONICALIZE-NEXT:      %[[c10:.*]] = arith.constant 10
+// CHECK-CANONICALIZE-NEXT:      memref.store %[[c10]], %[[arg1]][]
+// CHECK-CANONICALIZE:           scf.yield
+// CHECK-CANONICALIZE-NEXT:    }
+// CHECK-CANONICALIZE-NEXT:    default {
+// CHECK-CANONICALIZE:         }
+// CHECK-CANONICALIZE-NEXT:    return
+// CHECK-CANONICALIZE-NEXT:  }
 func.func @clean_region_branch_op_remove_result(%arg0 : index, %arg1 : memref<i32>) {
   %non_live = scf.index_switch %arg0 -> i32
   case 1 {
@@ -539,10 +558,10 @@ module {
   }
 }
 
-// CHECK-LABEL: func @test_zero_operands
-// CHECK: memref.alloca_scope
-// CHECK: memref.store
-// CHECK-NOT: memref.alloca_scope.return
+// CHECK-CANONICALIZE-LABEL: func @test_zero_operands
+// CHECK-CANONICALIZE-NEXT:    %[[c0:.*]] = arith.constant 0
+// CHECK-CANONICALIZE-NEXT:    memref.store %[[c0]]
+// CHECK-CANONICALIZE-NOT:     memref.alloca_scope.return
 
 // -----
 
@@ -714,3 +733,49 @@ func.func private @remove_dead_branch_op(%c: i1, %arg0: i64, %arg1: i64) -> (i64
 ^bb2:
   return %arg1 : i64
 }
+
+// -----
+
+// CHECK-LABEL: func @scf_while_dead_iter_args()
+// CHECK:         %[[c5:.*]] = arith.constant 5 : i32
+// CHECK:         %[[while:.*]]:2 = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> (i32, i32) {
+// CHECK:           vector.print %[[arg0]]
+// CHECK:           %[[cmpi:.*]] = arith.cmpi
+// CHECK:           %[[p0:.*]] = ub.poison : i32
+// CHECK:           scf.condition(%[[cmpi]]) %[[arg0]], %[[p0]]
+// CHECK:         } do {
+// CHECK:         ^bb0(%[[arg1:.*]]: i32, %[[arg2:.*]]: i32):
+// CHECK:           %[[p1:.*]] = ub.poison : i32
+// CHECK:           scf.yield %[[p1]]
+// CHECK:         }
+// CHECK:         return %[[while]]#0
+
+// CHECK-CANONICALIZE-LABEL: func @scf_while_dead_iter_args()
+// CHECK-CANONICALIZE:         %[[c5:.*]] = arith.constant 5 : i32
+// CHECK-CANONICALIZE:         %[[while:.*]] = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> i32 {
+// CHECK-CANONICALIZE:           vector.print %[[arg0]]
+// CHECK-CANONICALIZE:           %[[cmpi:.*]] = arith.cmpi
+// CHECK-CANONICALIZE:           scf.condition(%[[cmpi]]) %[[arg0]]
+// CHECK-CANONICALIZE:         } do {
+// CHECK-CANONICALIZE:         ^bb0(%[[arg1:.*]]: i32):
+// CHECK-CANONICALIZE:           %[[p0:.*]] = ub.poison : i32
+// CHECK-CANONICALIZE:           scf.yield %[[p0]]
+// CHECK-CANONICALIZE:         }
+// CHECK-CANONICALIZE:         return %[[while]]
+func.func @scf_while_dead_iter_args() -> i32 {
+  %c5 = arith.constant 5 : i32
+  %result:2 = scf.while (%arg0 = %c5) : (i32) -> (i32, i32) {
+    vector.print %arg0 : i32
+    // Note: This condition is always "false". (And the liveness analysis
+    // can figure that out.)
+    %cmp2 = arith.cmpi slt, %arg0, %c5 : i32
+    scf.condition(%cmp2) %arg0, %arg0 : i32, i32
+  } do {
+  ^bb0(%arg1: i32, %arg2: i32):
+    %x = scf.execute_region -> i32 {
+      scf.yield %arg2 : i32
+    }
+    scf.yield %x : i32
+  }
+  return %result#0 : i32
+}



More information about the Mlir-commits mailing list