[Mlir-commits] [mlir] 82c1f94 - [mlir][Transforms] `remove-dead-values`: Rely on canonicalizer for region simplification (#173505)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 7 05:51:44 PST 2026


Author: Matthias Springer
Date: 2026-01-07T14:51:40+01:00
New Revision: 82c1f9435d72e57e5bb949f8cdb8a0bd71ca86f3

URL: https://github.com/llvm/llvm-project/commit/82c1f9435d72e57e5bb949f8cdb8a0bd71ca86f3
DIFF: https://github.com/llvm/llvm-project/commit/82c1f9435d72e57e5bb949f8cdb8a0bd71ca86f3.diff

LOG: [mlir][Transforms] `remove-dead-values`: Rely on canonicalizer for region simplification (#173505)

This commit simplifies the `remove-dead-values` pass and fixes a bug in
the handling of `RegionBranchOpInterface` ops. The pass used to produce
invalid IR ("null value found") for the newly added test case.

`remove-dead-values` is a pass for additional IR simplification that
cannot be performed by the canonicalizer pass. Based on a liveness
analysis, it erases dead values / IR. (The liveness analysis is a
dataflow analysis that has more information about the IR than a
canonicalization pattern, which can see only "local" information.)

Region-based ops are difficult. The liveness analysis may determine that
an SSA value is dead. However, that does not mean that the value can
actually be removed. Doing so may violate an region data flow (as
modeled by the `RegionBranchOpInterface`). As an example, consider the
case where a region branch terminator may dispatch to one of two region
successor with the same forwarded values. A successor input (block
argument) can be erased only if it is dead on both successors.

Before this commit, there used to be complex logic to determine when it
is safe to erase an SSA value. That logic was broken. The new
implementation does not remove any block arguments or op results of
region-based ops. Instead, operands of region-based ops and region
branch terminators are replaced with `ub.poison` if all of their
successor values are dead. This simplifies the IR good enough for the
canonicalizer to perform the remaining region simplification (i.e.,
dropping block arguments etc.).

RFC:
https://discourse.llvm.org/t/rfc-delegate-simplification-of-region-based-ops-from-remove-dead-values-to-canonicalizer/89194

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/Passes.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Transforms/RemoveDeadValues.cpp
    mlir/test/Transforms/remove-dead-values.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 724da009e70f1..9983944d374c5 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -42,6 +42,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_MEM2REG
 #define GEN_PASS_DECL_PRINTIRPASS
 #define GEN_PASS_DECL_PRINTOPSTATS
+#define GEN_PASS_DECL_REMOVEDEADVALUES
 #define GEN_PASS_DECL_SCCP
 #define GEN_PASS_DECL_SROA
 #define GEN_PASS_DECL_STRIPDEBUGINFO

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 55addfdb693e4..fc2d60d198cd6 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -246,7 +246,17 @@ def RemoveDeadValues : Pass<"remove-dead-values"> {
     do = square_and_double_of_y(5)
     print(do)
     ```
+
+    Note: If `canonicalize` is set to "false", this pass does not remove any
+    block arguments / op results from ops that implement the
+    RegionBranchOpInterface. Instead, it just sets dead operands to
+    "ub.poison".
   }];
+
+  let options = [
+    Option<"canonicalize", "canonicalize", "bool", /*default=*/"true",
+           "Canonicalize region branch ops">,
+  ];
   let constructor = "mlir::createRemoveDeadValuesPass()";
   let dependentDialects = ["ub::UBDialect"];
 }

diff  --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index fc2c2acf8afd3..44b1bcf8e4300 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -94,8 +94,11 @@ struct ResultsToCleanup {
 struct OperandsToCleanup {
   Operation *op;
   BitVector nonLive;
-  Operation *callee =
-      nullptr; // Optional: For CallOpInterface ops, stores the callee function
+  // Optional: For CallOpInterface ops, stores the callee function.
+  Operation *callee = nullptr;
+  // Determines whether the operand should be replaced with a ub.poison result
+  // or erased entirely.
+  bool replaceWithPoison = false;
 };
 
 struct BlockArgsToCleanup {
@@ -201,25 +204,16 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
 
 /// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
 /// is 1.
-static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
+static void dropUsesAndEraseResults(RewriterBase &rewriter, Operation *op,
+                                    BitVector toErase) {
   assert(op->getNumResults() == toErase.size() &&
          "expected the number of results in `op` and the size of `toErase` to "
          "be the same");
   for (auto idx : toErase.set_bits())
     op->getResult(idx).dropAllUses();
-  IRRewriter rewriter(op);
   rewriter.eraseOpResults(op, toErase);
 }
 
-/// Convert a list of `Operand`s to a list of `OpOperand`s.
-static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
-  OpOperand *values = operands.getBase();
-  SmallVector<OpOperand *> opOperands;
-  for (unsigned i = 0, e = operands.size(); i < e; i++)
-    opOperands.push_back(&values[i]);
-  return opOperands;
-}
-
 /// Process a simple operation `op` using the liveness analysis `la`.
 /// If the operation has no memory effects and none of its results are live:
 ///   1. Add the operation to a list for future removal, and
@@ -379,30 +373,20 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 ///
 /// Scenario 1: If the operation has no memory effects and none of its results
 /// are live:
-///   (1') Enqueue all its uses for deletion.
-///   (2') Enqueue the branch itself for deletion.
+///   1.1. Enqueue all its uses for deletion.
+///   1.2. Enqueue the branch itself for deletion.
 ///
 /// Scenario 2: Otherwise:
-///   (1) Collect its unnecessary operands (operands forwarded to unnecessary
-///       results or arguments).
-///   (2) Process each of its regions.
-///   (3) Collect the uses of its unnecessary results (results forwarded from
-///       unnecessary operands
-///       or terminator operands).
-///   (4) Add these results to the deletion list.
-///
-/// Processing a region includes:
-///   (a) Collecting the uses of its unnecessary arguments (arguments forwarded
-///       from unnecessary operands
-///       or terminator operands).
-///   (b) Collecting these unnecessary arguments.
-///   (c) Collecting its unnecessary terminator operands (terminator operands
-///       forwarded to unnecessary results
-///       or arguments).
+///   2.1. Find all operands that are forwarded to only dead region successor
+///        inputs. I.e., forwarded to block arguments / op results that we do
+///        not want to keep.
+///   2.2. Also find operands who's values are dead (i.e., are scheduled for
+///        erasure) due to other operations.
+///   2.3. Enqueue all such operands for replacement with ub.poison.
 ///
-/// Value Flow Note: In this operation, values flow as follows:
-/// - From operands and terminator operands (successor operands)
-/// - To arguments and results (successor inputs).
+/// Note: In scenario 2, block arguments and op results are not removed.
+/// However, the IR is simplified such that canonicalization patterns can
+/// remove them later.
 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
                                   RunLivenessAnalysis &la,
                                   DenseSet<Value> &nonLiveSet,
@@ -416,282 +400,67 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
   // It could never be live because of this op but its liveness could have been
   // attributed to something else.
-  // Do (1') and (2').
   if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
       !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
     cl.operations.push_back(regionBranchOp.getOperation());
     return;
   }
 
-  // Mark live results of `regionBranchOp` in `liveResults`.
-  auto markLiveResults = [&](BitVector &liveResults) {
-    liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
-  };
-
-  // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
-  auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
-    for (Region &region : regionBranchOp->getRegions()) {
-      if (region.empty())
-        continue;
-      SmallVector<Value> arguments(region.front().getArguments());
-      BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
-      liveArgs[&region] = regionLiveArgs;
-    }
-  };
-
-  // Return the successors of `region` if the latter is not null. Else return
-  // the successors of `regionBranchOp`.
-  auto getSuccessors = [&](RegionBranchPoint point) {
-    SmallVector<RegionSuccessor> successors;
-    regionBranchOp.getSuccessorRegions(point, successors);
-    return successors;
-  };
-
-  // Return the operands of `terminator` that are forwarded to `successor` if
-  // the former is not null. Else return the operands of `regionBranchOp`
-  // forwarded to `successor`.
-  auto getForwardedOpOperands = [&](RegionBranchPoint src,
-                                    const RegionSuccessor &successor) {
-    SmallVector<OpOperand *> opOperands = operandsToOpOperands(
-        regionBranchOp.getSuccessorOperands(src, successor));
-    return opOperands;
-  };
-
-  // Mark the non-forwarded operands of `regionBranchOp` in
-  // `nonForwardedOperands`.
-  auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
-    nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
-    for (const RegionSuccessor &successor :
-         getSuccessors(RegionBranchPoint::parent())) {
-      for (OpOperand *opOperand :
-           getForwardedOpOperands(RegionBranchPoint::parent(), successor))
-        nonForwardedOperands.reset(opOperand->getOperandNumber());
+  // Mapping from operands to forwarded successor inputs. An operand can be
+  // forwarded to multiple successors.
+  //
+  // Example:
+  //
+  // %0 = scf.while : () -> i32 {
+  //   scf.condition(...) %forwarded_value : i32
+  // } do {
+  // ^bb0(%arg0: i32):
+  //   scf.yield
+  // }
+  // // No uses of %0.
+  //
+  // In the above example, %forwarded_value is forwarded to %arg0 and %0. Both
+  // %arg0 and %0 are dead, so %forwarded_value can be replaced with a
+  // ub.poison result.
+  //
+  // operandToSuccessorInputs[%forwarded_value] = {%arg0, %0}
+  //
+  RegionBranchSuccessorMapping operandToSuccessorInputs;
+  regionBranchOp.getSuccessorOperandInputMapping(operandToSuccessorInputs);
+
+  DenseMap<Operation *, BitVector> deadOperandsPerOp;
+  for (auto [opOperand, successorInputs] : operandToSuccessorInputs) {
+    // Helper function to mark the operand as dead, to be replaced with a
+    // ub.poison result.
+    auto markOperandDead = [&opOperand = opOperand, &deadOperandsPerOp]() {
+      // Create an entry in `deadOperandsPerOp` (initialized to "false", i.e.,
+      // no "dead" op operands) if it's the first time that we are seeing an op
+      // operand for this op. Otherwise, just take the existing bit vector from
+      // the map.
+      BitVector &deadOperands =
+          deadOperandsPerOp
+              .try_emplace(opOperand->getOwner(),
+                           opOperand->getOwner()->getNumOperands(), false)
+              .first->second;
+      deadOperands.set(opOperand->getOperandNumber());
+    };
+
+    // The operand value is scheduled for removal. Mark it as dead.
+    if (!hasLive(opOperand->get(), nonLiveSet, la)) {
+      markOperandDead();
+      continue;
     }
-  };
-
-  // Mark the non-forwarded terminator operands of the various regions of
-  // `regionBranchOp` in `nonForwardedRets`.
-  auto markNonForwardedReturnValues =
-      [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
-        for (Region &region : regionBranchOp->getRegions()) {
-          if (region.empty())
-            continue;
-          // TODO: this isn't correct in face of multiple terminators.
-          auto terminator = cast<RegionBranchTerminatorOpInterface>(
-              region.front().getTerminator());
-          nonForwardedRets[terminator] =
-              BitVector(terminator->getNumOperands(), true);
-          for (const RegionSuccessor &successor : getSuccessors(terminator)) {
-            for (OpOperand *opOperand : getForwardedOpOperands(
-                     RegionBranchPoint(terminator), successor))
-              nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
-          }
-        }
-      };
-
-  // Update `valuesToKeep` (which is expected to correspond to operands or
-  // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
-  // `region`. When `valuesToKeep` correspond to operands, `region` is null.
-  // Else, `region` is the parent region of the terminator.
-  auto updateOperandsOrTerminatorOperandsToKeep =
-      [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
-          DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
-        Operation *terminator =
-            region ? region->front().getTerminator() : nullptr;
-        RegionBranchPoint point =
-            terminator
-                ? RegionBranchPoint(
-                      cast<RegionBranchTerminatorOpInterface>(terminator))
-                : RegionBranchPoint::parent();
-
-        for (const RegionSuccessor &successor : getSuccessors(point)) {
-          Region *successorRegion = successor.getSuccessor();
-          for (auto [opOperand, input] :
-               llvm::zip(getForwardedOpOperands(point, successor),
-                         successor.getSuccessorInputs())) {
-            size_t operandNum = opOperand->getOperandNumber();
-            bool updateBasedOn =
-                successorRegion
-                    ? argsToKeep[successorRegion]
-                                [cast<BlockArgument>(input).getArgNumber()]
-                    : resultsToKeep[cast<OpResult>(input).getResultNumber()];
-            valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
-          }
-        }
-      };
-
-  // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
-  // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
-  // value is modified, else, false.
-  auto recomputeResultsAndArgsToKeep =
-      [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
-          BitVector &operandsToKeep,
-          DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
-          bool &resultsOrArgsToKeepChanged) {
-        resultsOrArgsToKeepChanged = false;
-
-        // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
-        for (const RegionSuccessor &successor :
-             getSuccessors(RegionBranchPoint::parent())) {
-          Region *successorRegion = successor.getSuccessor();
-          for (auto [opOperand, input] :
-               llvm::zip(getForwardedOpOperands(RegionBranchPoint::parent(),
-                                                successor),
-                         successor.getSuccessorInputs())) {
-            bool recomputeBasedOn =
-                operandsToKeep[opOperand->getOperandNumber()];
-            bool toRecompute =
-                successorRegion
-                    ? argsToKeep[successorRegion]
-                                [cast<BlockArgument>(input).getArgNumber()]
-                    : resultsToKeep[cast<OpResult>(input).getResultNumber()];
-            if (!toRecompute && recomputeBasedOn)
-              resultsOrArgsToKeepChanged = true;
-            if (successorRegion) {
-              argsToKeep[successorRegion][cast<BlockArgument>(input)
-                                              .getArgNumber()] =
-                  argsToKeep[successorRegion]
-                            [cast<BlockArgument>(input).getArgNumber()] |
-                  recomputeBasedOn;
-            } else {
-              resultsToKeep[cast<OpResult>(input).getResultNumber()] =
-                  resultsToKeep[cast<OpResult>(input).getResultNumber()] |
-                  recomputeBasedOn;
-            }
-          }
-        }
-
-        // Recompute `resultsToKeep` and `argsToKeep` based on
-        // `terminatorOperandsToKeep`.
-        for (Region &region : regionBranchOp->getRegions()) {
-          if (region.empty())
-            continue;
-          auto terminator = cast<RegionBranchTerminatorOpInterface>(
-              region.front().getTerminator());
-          for (const RegionSuccessor &successor : getSuccessors(terminator)) {
-            Region *successorRegion = successor.getSuccessor();
-            for (auto [opOperand, input] :
-                 llvm::zip(getForwardedOpOperands(RegionBranchPoint(terminator),
-                                                  successor),
-                           successor.getSuccessorInputs())) {
-              bool recomputeBasedOn =
-                  terminatorOperandsToKeep[region.back().getTerminator()]
-                                          [opOperand->getOperandNumber()];
-              bool toRecompute =
-                  successorRegion
-                      ? argsToKeep[successorRegion]
-                                  [cast<BlockArgument>(input).getArgNumber()]
-                      : resultsToKeep[cast<OpResult>(input).getResultNumber()];
-              if (!toRecompute && recomputeBasedOn)
-                resultsOrArgsToKeepChanged = true;
-              if (successorRegion) {
-                argsToKeep[successorRegion][cast<BlockArgument>(input)
-                                                .getArgNumber()] =
-                    argsToKeep[successorRegion]
-                              [cast<BlockArgument>(input).getArgNumber()] |
-                    recomputeBasedOn;
-              } else {
-                resultsToKeep[cast<OpResult>(input).getResultNumber()] =
-                    resultsToKeep[cast<OpResult>(input).getResultNumber()] |
-                    recomputeBasedOn;
-              }
-            }
-          }
-        }
-      };
-
-  // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
-  // `operandsToKeep`, and `terminatorOperandsToKeep`.
-  auto markValuesToKeep =
-      [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
-          BitVector &operandsToKeep,
-          DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
-        bool resultsOrArgsToKeepChanged = true;
-        // We keep updating and recomputing the values until we reach a point
-        // where they stop changing.
-        while (resultsOrArgsToKeepChanged) {
-          // Update the operands that need to be kept.
-          updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
-                                                   resultsToKeep, argsToKeep);
-
-          // Update the terminator operands that need to be kept.
-          for (Region &region : regionBranchOp->getRegions()) {
-            if (region.empty())
-              continue;
-            updateOperandsOrTerminatorOperandsToKeep(
-                terminatorOperandsToKeep[region.back().getTerminator()],
-                resultsToKeep, argsToKeep, &region);
-          }
 
-          // Recompute the results and arguments that need to be kept.
-          recomputeResultsAndArgsToKeep(
-              resultsToKeep, argsToKeep, operandsToKeep,
-              terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
-        }
-      };
-
-  // Scenario 2.
-  // At this point, we know that every non-forwarded operand of `regionBranchOp`
-  // is live.
-
-  // Stores the results of `regionBranchOp` that we want to keep.
-  BitVector resultsToKeep;
-  // Stores the mapping from regions of `regionBranchOp` to their arguments that
-  // we want to keep.
-  DenseMap<Region *, BitVector> argsToKeep;
-  // Stores the operands of `regionBranchOp` that we want to keep.
-  BitVector operandsToKeep;
-  // Stores the mapping from region terminators in `regionBranchOp` to their
-  // operands that we want to keep.
-  DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
-
-  // Initializing the above variables...
-
-  // The live results of `regionBranchOp` definitely need to be kept.
-  markLiveResults(resultsToKeep);
-  // Similarly, the live arguments of the regions in `regionBranchOp` definitely
-  // need to be kept.
-  markLiveArgs(argsToKeep);
-  // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
-  // A live forwarded operand can be removed but no non-forwarded operand can be
-  // removed since it "controls" the flow of data in this control flow op.
-  markNonForwardedOperands(operandsToKeep);
-  // Similarly, the non-forwarded terminator operands of the regions in
-  // `regionBranchOp` definitely need to be kept.
-  markNonForwardedReturnValues(terminatorOperandsToKeep);
-
-  // Mark the values (results, arguments, operands, and terminator operands)
-  // that we want to keep.
-  markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
-                   terminatorOperandsToKeep);
-
-  // Do (1).
-  cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
-
-  // Do (2.a) and (2.b).
-  for (Region &region : regionBranchOp->getRegions()) {
-    if (region.empty())
-      continue;
-    BitVector argsToRemove = argsToKeep[&region].flip();
-    cl.blocks.push_back({&region.front(), argsToRemove});
-    collectNonLiveValues(nonLiveSet, region.front().getArguments(),
-                         argsToRemove);
+    // If one of the successor inputs is live, the respective operand must be
+    // kept. Otherwise, ub.poison can be passed as operand.
+    if (!hasLive(successorInputs, nonLiveSet, la))
+      markOperandDead();
   }
 
-  // Do (2.c).
-  for (Region &region : regionBranchOp->getRegions()) {
-    if (region.empty())
-      continue;
-    Operation *terminator = region.front().getTerminator();
+  for (auto [op, deadOperands] : deadOperandsPerOp) {
     cl.operands.push_back(
-        {terminator, terminatorOperandsToKeep[terminator].flip()});
+        {op, deadOperands, nullptr, /*replaceWithPoison=*/true});
   }
-
-  // Do (3) and (4).
-  BitVector resultsToRemove = resultsToKeep.flip();
-  collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
-                       resultsToRemove);
-  cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
 }
 
 /// Steps to process a `BranchOpInterface` operation:
@@ -751,11 +520,44 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
   }
 }
 
+/// Create ub.poison ops for the given values. If a value has no uses, return
+/// an "empty" value.
+static SmallVector<Value> createPoisonedValues(OpBuilder &b,
+                                               ValueRange values) {
+  return llvm::map_to_vector(values, [&](Value value) {
+    if (value.use_empty())
+      return Value();
+    return ub::PoisonOp::create(b, value.getLoc(), value.getType()).getResult();
+  });
+}
+
+namespace {
+/// A listener that keeps track of ub.poison ops.
+struct TrackingListener : public RewriterBase::Listener {
+  void notifyOperationErased(Operation *op) override {
+    if (auto poisonOp = dyn_cast<ub::PoisonOp>(op))
+      poisonOps.erase(poisonOp);
+  }
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override {
+    if (auto poisonOp = dyn_cast<ub::PoisonOp>(op))
+      poisonOps.insert(poisonOp);
+  }
+  DenseSet<ub::PoisonOp> poisonOps;
+};
+} // namespace
+
 /// Removes dead values collected in RDVFinalCleanupList.
 /// To be run once when all dead values have been collected.
-static void cleanUpDeadVals(RDVFinalCleanupList &list) {
+static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
   LDBG() << "Starting cleanup of dead values...";
 
+  // New ub.poison ops may be inserted during cleanup. Some of these ops may no
+  // longer be needed after the cleanup. A tracking listener keeps track of all
+  // new ub.poison ops, so that they can be removed again after the cleanup.
+  TrackingListener listener;
+  IRRewriter rewriter(ctx, &listener);
+
   // 1. Blocks, We must remove the block arguments and successor operands before
   // deleting the operation, as they may reside in the region operation.
   LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
@@ -881,7 +683,16 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
            << OpWithFlags(o.op,
                           OpPrintingFlags().skipRegions().printGenericOpForm());
       });
-      o.op->eraseOperands(o.nonLive);
+      if (o.replaceWithPoison) {
+        rewriter.setInsertionPoint(o.op);
+        for (auto deadIdx : o.nonLive.set_bits()) {
+          o.op->setOperand(
+              deadIdx, createPoisonedValues(rewriter, o.op->getOperand(deadIdx))
+                           .front());
+        }
+      } else {
+        o.op->eraseOperands(o.nonLive);
+      }
     }
   }
 
@@ -895,7 +706,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
          << OpWithFlags(r.op,
                         OpPrintingFlags().skipRegions().printGenericOpForm());
     });
-    dropUsesAndEraseResults(r.op, r.nonLive);
+    dropUsesAndEraseResults(rewriter, r.op, r.nonLive);
   }
 
   // 6. Operations
@@ -904,13 +715,19 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
     LDBG() << "Erasing operation: "
            << OpWithFlags(op,
                           OpPrintingFlags().skipRegions().printGenericOpForm());
+    rewriter.setInsertionPoint(op);
     if (op->hasTrait<OpTrait::IsTerminator>()) {
       // When erasing a terminator, insert an unreachable op in its place.
-      OpBuilder b(op);
-      ub::UnreachableOp::create(b, op->getLoc());
+      ub::UnreachableOp::create(rewriter, op->getLoc());
     }
     op->dropAllUses();
-    op->erase();
+    rewriter.eraseOp(op);
+  }
+
+  // 7. Remove all dead poison ops.
+  for (ub::PoisonOp poisonOp : listener.poisonOps) {
+    if (poisonOp.use_empty())
+      poisonOp.erase();
   }
 
   LDBG() << "Finished cleanup of dead values";
@@ -951,7 +768,29 @@ void RemoveDeadValues::runOnOperation() {
     }
   });
 
-  cleanUpDeadVals(finalCleanupList);
+  MLIRContext *context = module->getContext();
+  cleanUpDeadVals(context, finalCleanupList);
+
+  if (!canonicalize)
+    return;
+
+  // Canonicalize all region branch ops.
+  SmallVector<Operation *> opsToCanonicalize;
+  module->walk([&](RegionBranchOpInterface regionBranchOp) {
+    opsToCanonicalize.push_back(regionBranchOp.getOperation());
+  });
+  // TODO: Apply only region branch op canonicalization patterns or find a
+  // better API to collect all canonicalization patterns.
+  RewritePatternSet owningPatterns(context);
+  for (auto *dialect : context->getLoadedDialects())
+    dialect->getCanonicalizationPatterns(owningPatterns);
+  for (RegisteredOperationName op : context->getRegisteredOperations())
+    op.getCanonicalizationPatterns(owningPatterns, context);
+  if (failed(applyOpPatternsGreedily(opsToCanonicalize,
+                                     std::move(owningPatterns)))) {
+    module->emitError("greedy pattern rewrite failed to converge");
+    signalPassFailure();
+  }
 }
 
 std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {

diff  --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index bd730915c6dcd..2584573c8b4dc 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -remove-dead-values -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -remove-dead-values="canonicalize=0" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -remove-dead-values="canonicalize=1" -split-input-file | FileCheck %s --check-prefix=CHECK-CANONICALIZE
 
 // The IR is updated regardless of memref.global private constant
 //
@@ -55,19 +56,20 @@ func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0:
 
 // Checking that iter_args are properly handled
 //
+// CHECK-CANONICALIZE-LABEL: func @cleanable_loop_iter_args_value
 func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
   %non_live = arith.constant 0 : index
-  // CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
+  // CHECK-CANONICALIZE: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
   %result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) {
-    // CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
+    // CHECK-CANONICALIZE: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
     %new_live = arith.addi %live_arg, %i : index
-    // CHECK: scf.yield [[SUM:%.+]]
+    // CHECK-CANONICALIZE: scf.yield [[SUM:%.+]]
     scf.yield %new_live, %non_live_arg : index, index
   }
-  // CHECK: return [[RESULT]] : index
+  // CHECK-CANONICALIZE: return [[RESULT]] : index
   return %result : index
 }
 
@@ -79,7 +81,8 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
 #map = affine_map<(d0, d1, d2) -> (0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 module {
-  func.func @main() {
+  // CHECK-LABEL: @dead_linalg_generic
+  func.func @dead_linalg_generic() {
     %cst_3 = arith.constant dense<54> : tensor<1x25x13xi32>
     %cst_7 = arith.constant dense<11> : tensor<1x25x13xi32>
     // CHECK-NOT: arith.constant
@@ -229,18 +232,34 @@ func.func @main() -> (i32, i32) {
 //  anywhere else. Thus, %arg7 is also not kept in the `scf.yield` op.
 //
 // Note that this cleanup cannot be done by the `canonicalize` pass.
-//
-// CHECK:       func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
-// CHECK-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
-// CHECK-NEXT:      %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
-// CHECK-NEXT:      scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
+
+// CHECK-LABEL: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(
+// CHECK-SAME:      %[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
+// CHECK-NEXT:    %[[p0:.*]] = ub.poison : i32
+// CHECK-NEXT:    %[[while:.*]]:3 = scf.while (%{{.*}} = %[[p0]], %[[arg4:.*]] = %[[arg2]]) : (i32, i32) -> (i32, i32, i32) {
+// CHECK-NEXT:      %[[add1:.*]] = arith.addi %[[arg4]], %[[arg4]] : i32
+// CHECK-NEXT:      %[[p1:.*]] = ub.poison : i32
+// CHECK-NEXT:      scf.condition(%[[arg0]]) %[[add1]], %[[arg4]], %[[p1]] : i32, i32, i32
 // CHECK-NEXT:    } do {
-// CHECK-NEXT:    ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
-// CHECK-NEXT:      %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
-// CHECK-NEXT:      scf.yield %[[live_1]] : i32
+// CHECK-NEXT:    ^bb0(%{{.*}}: i32, %[[arg6:.*]]: i32, %{{.*}}: i32):
+// CHECK-NEXT:      %[[add2:.*]] = arith.addi %[[arg6]], %[[arg6]] : i32
+// CHECK-NEXT:      %[[p2:.*]] = ub.poison : i32
+// CHECK-NEXT:      scf.yield %[[p2]], %[[add2]] : i32, i32
 // CHECK-NEXT:    }
-// CHECK-NEXT:    return %[[live_and_non_live]]#0
+// CHECK-NEXT:    return %[[while]]#0 : i32
 // CHECK-NEXT:  }
+
+// CHECK-CANONICALIZE:       func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
+// CHECK-CANONICALIZE-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
+// CHECK-CANONICALIZE-NEXT:      %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
+// CHECK-CANONICALIZE-NEXT:      scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
+// CHECK-CANONICALIZE-NEXT:    } do {
+// CHECK-CANONICALIZE-NEXT:    ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CHECK-CANONICALIZE-NEXT:      %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
+// CHECK-CANONICALIZE-NEXT:      scf.yield %[[live_1]] : i32
+// CHECK-CANONICALIZE-NEXT:    }
+// CHECK-CANONICALIZE-NEXT:    return %[[live_and_non_live]]#0
+// CHECK-CANONICALIZE-NEXT:  }
 func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%arg0: i1, %arg1: i32, %arg2: i32) -> (i32) {
   %live, %non_live, %non_live_0 = scf.while (%arg3 = %arg1, %arg4 = %arg2) : (i32, i32) -> (i32, i32, i32) {
     %live_0 = arith.addi %arg4, %arg4 : i32
@@ -284,21 +303,21 @@ func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o
 //
 // Note that this cleanup cannot be done by the `canonicalize` pass.
 //
-// CHECK:       func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 {
-// CHECK-NEXT:    %[[c0:.*]] = arith.constant 0
-// CHECK-NEXT:    %[[c1:.*]] = arith.constant 1
-// CHECK-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
-// CHECK-NEXT:      func.call @identity() : () -> ()
-// CHECK-NEXT:      scf.condition(%[[arg2]]) %[[arg4]], %[[arg3]] : i32, i32
-// CHECK-NEXT:    } do {
-// CHECK-NEXT:    ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
-// CHECK-NEXT:      scf.yield %[[arg5]], %[[arg6]] : i32, i32
-// CHECK-NEXT:    }
-// CHECK-NEXT:    return %[[live_and_non_live]]#0 : i32
-// CHECK-NEXT:  }
-// CHECK:       func.func private @identity() {
-// CHECK-NEXT:    return
-// CHECK-NEXT:  }
+// CHECK-CANONICALIZE:       func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 {
+// CHECK-CANONICALIZE-NEXT:    %[[c0:.*]] = arith.constant 0
+// CHECK-CANONICALIZE-NEXT:    %[[c1:.*]] = arith.constant 1
+// CHECK-CANONICALIZE-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
+// CHECK-CANONICALIZE-NEXT:      func.call @identity() : () -> ()
+// CHECK-CANONICALIZE-NEXT:      scf.condition(%[[arg2]]) %[[arg3]], %[[arg4]] : i32, i32
+// CHECK-CANONICALIZE-NEXT:    } do {
+// CHECK-CANONICALIZE-NEXT:    ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CHECK-CANONICALIZE-NEXT:      scf.yield %[[arg6]], %[[arg5]] : i32, i32
+// CHECK-CANONICALIZE-NEXT:    }
+// CHECK-CANONICALIZE-NEXT:    return %[[live_and_non_live]]#1 : i32
+// CHECK-CANONICALIZE-NEXT:  }
+// CHECK-CANONICALIZE:       func.func private @identity() {
+// CHECK-CANONICALIZE-NEXT:    return
+// CHECK-CANONICALIZE-NEXT:  }
 func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%arg2: i1) -> (i32) {
   %c0 = arith.constant 0 : i32
   %c1 = arith.constant 1 : i32
@@ -325,17 +344,17 @@ func.func private @identity(%arg1 : i32) -> (i32) {
 //
 // Note that this cleanup cannot be done by the `canonicalize` pass.
 //
-// CHECK:       func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
-// CHECK-NEXT:    scf.index_switch %[[arg0]]
-// CHECK-NEXT:    case 1 {
-// CHECK-NEXT:      %[[c10:.*]] = arith.constant 10
-// CHECK-NEXT:      memref.store %[[c10]], %[[arg1]][]
-// CHECK-NEXT:      scf.yield
-// CHECK-NEXT:    }
-// CHECK-NEXT:    default {
-// CHECK-NEXT:    }
-// CHECK-NEXT:    return
-// CHECK-NEXT:  }
+// CHECK-CANONICALIZE:       func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
+// CHECK-CANONICALIZE-NEXT:    scf.index_switch %[[arg0]]
+// CHECK-CANONICALIZE-NEXT:    case 1 {
+// CHECK-CANONICALIZE-NEXT:      %[[c10:.*]] = arith.constant 10
+// CHECK-CANONICALIZE-NEXT:      memref.store %[[c10]], %[[arg1]][]
+// CHECK-CANONICALIZE:           scf.yield
+// CHECK-CANONICALIZE-NEXT:    }
+// CHECK-CANONICALIZE-NEXT:    default {
+// CHECK-CANONICALIZE:         }
+// CHECK-CANONICALIZE-NEXT:    return
+// CHECK-CANONICALIZE-NEXT:  }
 func.func @clean_region_branch_op_remove_result(%arg0 : index, %arg1 : memref<i32>) {
   %non_live = scf.index_switch %arg0 -> i32
   case 1 {
@@ -539,10 +558,10 @@ module {
   }
 }
 
-// CHECK-LABEL: func @test_zero_operands
-// CHECK: memref.alloca_scope
-// CHECK: memref.store
-// CHECK-NOT: memref.alloca_scope.return
+// CHECK-CANONICALIZE-LABEL: func @test_zero_operands
+// CHECK-CANONICALIZE-NEXT:    %[[c0:.*]] = arith.constant 0
+// CHECK-CANONICALIZE-NEXT:    memref.store %[[c0]]
+// CHECK-CANONICALIZE-NOT:     memref.alloca_scope.return
 
 // -----
 
@@ -731,3 +750,49 @@ func.func @affine_loop_no_use_iv_has_side_effect_op() {
 // CHECK: }
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @scf_while_dead_iter_args()
+// CHECK:         %[[c5:.*]] = arith.constant 5 : i32
+// CHECK:         %[[while:.*]]:2 = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> (i32, i32) {
+// CHECK:           vector.print %[[arg0]]
+// CHECK:           %[[cmpi:.*]] = arith.cmpi
+// CHECK:           %[[p0:.*]] = ub.poison : i32
+// CHECK:           scf.condition(%[[cmpi]]) %[[arg0]], %[[p0]]
+// CHECK:         } do {
+// CHECK:         ^bb0(%[[arg1:.*]]: i32, %[[arg2:.*]]: i32):
+// CHECK:           %[[p1:.*]] = ub.poison : i32
+// CHECK:           scf.yield %[[p1]]
+// CHECK:         }
+// CHECK:         return %[[while]]#0
+
+// CHECK-CANONICALIZE-LABEL: func @scf_while_dead_iter_args()
+// CHECK-CANONICALIZE:         %[[c5:.*]] = arith.constant 5 : i32
+// CHECK-CANONICALIZE:         %[[while:.*]] = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> i32 {
+// CHECK-CANONICALIZE:           vector.print %[[arg0]]
+// CHECK-CANONICALIZE:           %[[cmpi:.*]] = arith.cmpi
+// CHECK-CANONICALIZE:           scf.condition(%[[cmpi]]) %[[arg0]]
+// CHECK-CANONICALIZE:         } do {
+// CHECK-CANONICALIZE:         ^bb0(%[[arg1:.*]]: i32):
+// CHECK-CANONICALIZE:           %[[p0:.*]] = ub.poison : i32
+// CHECK-CANONICALIZE:           scf.yield %[[p0]]
+// CHECK-CANONICALIZE:         }
+// CHECK-CANONICALIZE:         return %[[while]]
+func.func @scf_while_dead_iter_args() -> i32 {
+  %c5 = arith.constant 5 : i32
+  %result:2 = scf.while (%arg0 = %c5) : (i32) -> (i32, i32) {
+    vector.print %arg0 : i32
+    // Note: This condition is always "false". (And the liveness analysis
+    // can figure that out.)
+    %cmp2 = arith.cmpi slt, %arg0, %c5 : i32
+    scf.condition(%cmp2) %arg0, %arg0 : i32, i32
+  } do {
+  ^bb0(%arg1: i32, %arg2: i32):
+    %x = scf.execute_region -> i32 {
+      scf.yield %arg2 : i32
+    }
+    scf.yield %x : i32
+  }
+  return %result#0 : i32
+}


        


More information about the Mlir-commits mailing list