[Mlir-commits] [mlir] [MLIR] Fix RemoveDeadValues handling of unreachable code (PR #153973)

Mehdi Amini llvmlistbot at llvm.org
Sun Aug 17 03:59:51 PDT 2025


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/153973

>From 0f0602e02ce153cfe36c6187f58790675659cf13 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sun, 17 Aug 2025 03:53:38 -0700
Subject: [PATCH] [MLIR] Fix RemoveDeadValues handling of unreachable code

The code was "conservatively" considering that the absence of liveness
information meant that we weren't sure if a value was dead. However
the dataflow framework will skip visiting these values when its already
knows that a block is dynamically unreachable.
This leds to a crash in the provided test case since the producer operation
in a reachable block is actually deleted (its liveness is correctly set to
"dead") but the user of the operation was "conservatively" preserved and left
with nullptr operands.

Fixes #153906
---
 .../Analysis/DataFlow/LivenessAnalysis.cpp    | 29 +++++++++++-
 mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 46 +++++++++++++++++--
 mlir/lib/Transforms/RemoveDeadValues.cpp      | 46 ++++++++++++++++---
 .../DataFlow/test-liveness-analysis.mlir      | 20 ++++++++
 mlir/test/Transforms/remove-dead-values.mlir  | 21 +++++++++
 .../DataFlow/TestLivenessAnalysis.cpp         |  1 -
 6 files changed, 150 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 509f5202be08d..761098778a371 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -294,7 +294,34 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
   solver.load<LivenessAnalysis>(symbolTable);
   LDBG() << "Initializing and running solver";
   (void)solver.initializeAndRun(op);
-  LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName();
+  LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName()
+         << " check on dead code now:";
+  // The framework doesn't visit operations in dead blocks, so we need to
+  // explicitly mark them as dead.
+  op->walk([&](Operation *op) {
+    if (op->getNumResults() == 0)
+      return;
+    for (auto result : llvm::enumerate(op->getResults())) {
+      if (getLiveness(result.value()))
+        continue;
+      LDBG() << "Result: " << result.index() << " of "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions())
+             << " has no liveness info, mark dead";
+      solver.getOrCreateState<Liveness>(result.value());
+    }
+    for (auto &region : op->getRegions()) {
+      for (auto &block : region) {
+        for (auto blockArg : llvm::enumerate(block.getArguments())) {
+          if (getLiveness(blockArg.value()))
+            continue;
+          LDBG() << "Block argument: " << blockArg.index() << " of "
+                 << OpWithFlags(op, OpPrintingFlags().skipRegions())
+                 << " has no liveness info, mark dead";
+          solver.getOrCreateState<Liveness>(blockArg.value());
+        }
+      }
+    }
+  });
 }
 
 const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index e625f626d12fd..13a3e1480c836 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -19,12 +19,15 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
 #include <cassert>
 #include <optional>
 
 using namespace mlir;
 using namespace mlir::dataflow;
 
+#define DEBUG_TYPE "dataflow"
+
 //===----------------------------------------------------------------------===//
 // AbstractSparseLattice
 //===----------------------------------------------------------------------===//
@@ -64,22 +67,36 @@ AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
 
 LogicalResult
 AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
+  LDBG() << "Initializing recursively for operation: " << op->getName();
+
   // Initialize the analysis by visiting every owner of an SSA value (all
   // operations and blocks).
-  if (failed(visitOperation(op)))
+  if (failed(visitOperation(op))) {
+    LDBG() << "Failed to visit operation: " << op->getName();
     return failure();
+  }
 
   for (Region &region : op->getRegions()) {
+    LDBG() << "Processing region with " << region.getBlocks().size()
+           << " blocks";
     for (Block &block : region) {
+      LDBG() << "Processing block with " << block.getNumArguments()
+             << " arguments";
       getOrCreate<Executable>(getProgramPointBefore(&block))
           ->blockContentSubscribe(this);
       visitBlock(&block);
-      for (Operation &op : block)
-        if (failed(initializeRecursively(&op)))
+      for (Operation &op : block) {
+        LDBG() << "Recursively initializing nested operation: " << op.getName();
+        if (failed(initializeRecursively(&op))) {
+          LDBG() << "Failed to initialize nested operation: " << op.getName();
           return failure();
+        }
+      }
     }
   }
 
+  LDBG() << "Successfully completed recursive initialization for operation: "
+         << op->getName();
   return success();
 }
 
@@ -409,11 +426,20 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
 
 LogicalResult
 AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
+  LDBG() << "Visiting operation: " << op->getName() << " with "
+         << op->getNumOperands() << " operands and " << op->getNumResults()
+         << " results";
+
   // If we're in a dead block, bail out.
   if (op->getBlock() != nullptr &&
-      !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
+      !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
+           ->isLive()) {
+    LDBG() << "Operation is in dead block, bailing out";
     return success();
+  }
 
+  LDBG() << "Creating lattice elements for " << op->getNumOperands()
+         << " operands and " << op->getNumResults() << " results";
   SmallVector<AbstractSparseLattice *> operandLattices =
       getLatticeElements(op->getOperands());
   SmallVector<const AbstractSparseLattice *> resultLattices =
@@ -422,11 +448,15 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   // Block arguments of region branch operations flow back into the operands
   // of the parent op
   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+    LDBG() << "Processing RegionBranchOpInterface operation";
     visitRegionSuccessors(branch, operandLattices);
     return success();
   }
 
   if (auto branch = dyn_cast<BranchOpInterface>(op)) {
+    LDBG() << "Processing BranchOpInterface operation with "
+           << op->getNumSuccessors() << " successors";
+
     // Block arguments of successor blocks flow back into our operands.
 
     // We remember all operands not forwarded to any block in a BitVector.
@@ -463,6 +493,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   // For function calls, connect the arguments of the entry blocks to the
   // operands of the call op that are forwarded to these arguments.
   if (auto call = dyn_cast<CallOpInterface>(op)) {
+    LDBG() << "Processing CallOpInterface operation";
     Operation *callableOp = call.resolveCallableInTable(&symbolTable);
     if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
       // Not all operands of a call op forward to arguments. Such operands are
@@ -513,6 +544,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   // of this op itself and the operands of the terminators of the regions of
   // this op.
   if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+    LDBG() << "Processing RegionBranchTerminatorOpInterface operation";
     if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
       visitRegionSuccessorsFromTerminator(terminator, branch);
       return success();
@@ -520,12 +552,16 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   }
 
   if (op->hasTrait<OpTrait::ReturnLike>()) {
+    LDBG() << "Processing ReturnLike operation";
     // Going backwards, the operands of the return are derived from the
     // results of all CallOps calling this CallableOp.
-    if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
+    if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
+      LDBG() << "Callable parent found, visiting callable operation";
       return visitCallableOperation(op, callable, operandLattices);
+    }
   }
 
+  LDBG() << "Using default visitOperationImpl for operation: " << op->getName();
   return visitOperationImpl(op, operandLattices, resultLattices);
 }
 
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 4ccb83f3ad298..02dad69e49614 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -258,18 +258,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             RDVFinalCleanupList &cl) {
-  LDBG() << "Processing simple op: " << *op;
   if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
-    LDBG()
-        << "Simple op is not memory effect free or has live results, skipping: "
-        << *op;
+    LDBG() << "Simple op is not memory effect free or has live results, "
+              "preserving it: "
+           << OpWithFlags(op, OpPrintingFlags().skipRegions());
     return;
   }
 
   LDBG()
       << "Simple op has all dead results and is memory effect free, scheduling "
          "for removal: "
-      << *op;
+      << OpWithFlags(op, OpPrintingFlags().skipRegions());
   cl.operations.push_back(op);
   collectNonLiveValues(nonLiveSet, op->getResults(),
                        BitVector(op->getNumResults(), true));
@@ -728,19 +727,31 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
 /// Removes dead values collected in RDVFinalCleanupList.
 /// To be run once when all dead values have been collected.
 static void cleanUpDeadVals(RDVFinalCleanupList &list) {
+  LDBG() << "Starting cleanup of dead values...";
+
   // 1. Operations
+  LDBG() << "Cleaning up " << list.operations.size() << " operations";
   for (auto &op : list.operations) {
+    LDBG() << "Erasing operation: "
+           << OpWithFlags(op, OpPrintingFlags().skipRegions());
     op->dropAllUses();
     op->erase();
   }
 
   // 2. Values
+  LDBG() << "Cleaning up " << list.values.size() << " values";
   for (auto &v : list.values) {
+    LDBG() << "Dropping all uses of value: " << v;
     v.dropAllUses();
   }
 
   // 3. Functions
+  LDBG() << "Cleaning up " << list.functions.size() << " functions";
   for (auto &f : list.functions) {
+    LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
+    LDBG() << "  Erasing " << f.nonLiveArgs.count() << " non-live arguments";
+    LDBG() << "  Erasing " << f.nonLiveRets.count()
+           << " non-live return values";
     // Some functions may not allow erasing arguments or results. These calls
     // return failure in such cases without modifying the function, so it's okay
     // to proceed.
@@ -749,44 +760,67 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
   }
 
   // 4. Operands
+  LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
   for (OperationToCleanup &o : list.operands) {
-    if (o.op->getNumOperands() > 0)
+    if (o.op->getNumOperands() > 0) {
+      LDBG() << "Erasing " << o.nonLive.count()
+             << " non-live operands from operation: "
+             << OpWithFlags(o.op, OpPrintingFlags().skipRegions());
       o.op->eraseOperands(o.nonLive);
+    }
   }
 
   // 5. Results
+  LDBG() << "Cleaning up " << list.results.size() << " result lists";
   for (auto &r : list.results) {
+    LDBG() << "Erasing " << r.nonLive.count()
+           << " non-live results from operation: "
+           << OpWithFlags(r.op, OpPrintingFlags().skipRegions());
     dropUsesAndEraseResults(r.op, r.nonLive);
   }
 
   // 6. Blocks
+  LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
   for (auto &b : list.blocks) {
     // blocks that are accessed via multiple codepaths processed once
     if (b.b->getNumArguments() != b.nonLiveArgs.size())
       continue;
+    LDBG() << "Erasing " << b.nonLiveArgs.count()
+           << " non-live arguments from block: " << b.b;
     // it iterates backwards because erase invalidates all successor indexes
     for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
       if (!b.nonLiveArgs[i])
         continue;
+      LDBG() << "  Erasing block argument " << i << ": " << b.b->getArgument(i);
       b.b->getArgument(i).dropAllUses();
       b.b->eraseArgument(i);
     }
   }
 
   // 7. Successor Operands
+  LDBG() << "Cleaning up " << list.successorOperands.size()
+         << " successor operand lists";
   for (auto &op : list.successorOperands) {
     SuccessorOperands successorOperands =
         op.branch.getSuccessorOperands(op.successorIndex);
     // blocks that are accessed via multiple codepaths processed once
     if (successorOperands.size() != op.nonLiveOperands.size())
       continue;
+    LDBG() << "Erasing " << op.nonLiveOperands.count()
+           << " non-live successor operands from successor "
+           << op.successorIndex << " of branch: "
+           << OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
     // it iterates backwards because erase invalidates all successor indexes
     for (int i = successorOperands.size() - 1; i >= 0; --i) {
       if (!op.nonLiveOperands[i])
         continue;
+      LDBG() << "  Erasing successor operand " << i << ": "
+             << successorOperands[i];
       successorOperands.erase(i);
     }
   }
+
+  LDBG() << "Finished cleanup of dead values";
 }
 
 struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
index a89a0f4084e99..4fece5520c9d3 100644
--- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
+++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
@@ -283,3 +283,23 @@ func.func @test_10_negative() -> (i32) {
   %0:2 = func.call @private_1() : () -> (i32, i32)
   return %0#0 : i32
 }
+
+// -----
+
+// Test that we correctly set a liveness value for operations
+// in dead block. These won't be visited by the dataflow framework
+// so the analysis need to explicit manage them.
+// CHECK-LABEL: test_tag: dead_block_cmpi:
+// CHECK-NEXT: operand #0: not live
+// CHECK-NEXT: operand #1: not live
+// CHECK-NEXT: result #0: not live
+func.func @dead_block() {
+  %false = arith.constant false
+  %zero = arith.constant 0 : i64
+  cf.cond_br %false, ^bb1, ^bb4
+  ^bb1:
+    %3 = arith.cmpi eq, %zero, %zero  {tag = "dead_block_cmpi"} : i64
+    cf.br ^bb1
+  ^bb4:
+    return
+}
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 9ded6a30d9c95..0f8d757086e87 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -571,3 +571,24 @@ module @return_void_with_unused_argument {
   }
 }
 
+// -----
+
+// CHECK-LABEL: module @dynamically_unreachable
+module @dynamically_unreachable {
+  func.func @dynamically_unreachable() {
+    // This value is used by an operation in a dynamically unreachable block.
+    %zero = arith.constant 0 : i64
+
+    // Dataflow analysis knows from the constant condition that
+    // ^bb1 is unreachable
+    %false = arith.constant false
+    cf.cond_br %false, ^bb1, ^bb4
+  ^bb1:
+    // This unreachable operation should be removed.
+    // CHECK-NOT: arith.cmpi
+    %3 = arith.cmpi eq, %zero, %zero : i64
+    cf.br ^bb1
+  ^bb4:
+    return
+  }
+}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
index 43005e22584c2..8e2f03b644e49 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
@@ -33,7 +33,6 @@ struct TestLivenessAnalysisPass
 
   void runOnOperation() override {
     auto &livenessAnalysis = getAnalysis<RunLivenessAnalysis>();
-
     Operation *op = getOperation();
 
     raw_ostream &os = llvm::outs();



More information about the Mlir-commits mailing list