[Mlir-commits] [mlir] [MLIR] Add logging/tracing to DataFlow analysis and RemoveDeadValues (NFC) (PR #144695)

Mehdi Amini llvmlistbot at llvm.org
Sun Jun 22 02:39:29 PDT 2025


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/144695

>From b8c1d4aa67275289b3b6226e6f802664ac03dea5 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Wed, 18 Jun 2025 06:04:42 -0700
Subject: [PATCH] [MLIR] Add logging/tracing to DataFlow analysis and
 RemoveDeadValues (NFC)

Debugging issue with this pass is quite difficult at the moment, this
should help.
---
 .../Analysis/DataFlow/DeadCodeAnalysis.cpp    | 81 +++++++++++++++++--
 .../Analysis/DataFlow/LivenessAnalysis.cpp    | 51 +++++++++++-
 mlir/lib/Transforms/RemoveDeadValues.cpp      | 56 +++++++++++--
 3 files changed, 176 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index e805e21d878bf..1abdfcbf3496f 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -22,9 +22,14 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
 #include <cassert>
 #include <optional>
 
+#define DEBUG_TYPE "dead-code-analysis"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 using namespace mlir;
 using namespace mlir::dataflow;
 
@@ -122,6 +127,7 @@ DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
 }
 
 LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
+  LDBG("Initializing DeadCodeAnalysis for top-level op: " << top->getName());
   // Mark the top-level blocks as executable.
   for (Region &region : top->getRegions()) {
     if (region.empty())
@@ -129,6 +135,7 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
     auto *state =
         getOrCreate<Executable>(getProgramPointBefore(&region.front()));
     propagateIfChanged(state, state->setToLive());
+    LDBG("Marked entry block live for region in op: " << top->getName());
   }
 
   // Mark as overdefined the predecessors of symbol callables with potentially
@@ -139,13 +146,18 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
 }
 
 void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
+  LDBG("[init] Entering initializeSymbolCallables for top-level op: "
+       << top->getName());
   analysisScope = top;
   auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
+    LDBG("[init] Processing symbol table op: " << symTable->getName());
     Region &symbolTableRegion = symTable->getRegion(0);
     Block *symbolTableBlock = &symbolTableRegion.front();
 
     bool foundSymbolCallable = false;
     for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
+      LDBG("[init] Found CallableOpInterface: "
+           << callable.getOperation()->getName());
       Region *callableRegion = callable.getCallableRegion();
       if (!callableRegion)
         continue;
@@ -159,6 +171,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
         auto *state =
             getOrCreate<PredecessorState>(getProgramPointAfter(callable));
         propagateIfChanged(state, state->setHasUnknownPredecessors());
+        LDBG("[init] Marked callable as having unknown predecessors: "
+             << callable.getOperation()->getName());
       }
       foundSymbolCallable = true;
     }
@@ -173,10 +187,15 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
     if (!uses) {
       // If we couldn't gather the symbol uses, conservatively assume that
       // we can't track information for any nested symbols.
+      LDBG("[init] Could not gather symbol uses, conservatively marking "
+           "all nested callables as having unknown predecessors");
       return top->walk([&](CallableOpInterface callable) {
         auto *state =
             getOrCreate<PredecessorState>(getProgramPointAfter(callable));
         propagateIfChanged(state, state->setHasUnknownPredecessors());
+        LDBG("[init] Marked nested callable as "
+             "having unknown predecessors: "
+             << callable.getOperation()->getName());
       });
     }
 
@@ -190,10 +209,15 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
         continue;
       auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(symbol));
       propagateIfChanged(state, state->setHasUnknownPredecessors());
+      LDBG("[init] Found non-call use for symbol, "
+           "marked as having unknown predecessors: "
+           << symbol->getName());
     }
   };
   SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
                                 walkFn);
+  LDBG("[init] Finished initializeSymbolCallables for top-level op: "
+       << top->getName());
 }
 
 /// Returns true if the operation is a returning terminator in region
@@ -205,9 +229,12 @@ static bool isRegionOrCallableReturn(Operation *op) {
 }
 
 LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
+  LDBG("[init] Entering initializeRecursively for op: " << op->getName()
+                                                        << " at " << op);
   // Initialize the analysis by visiting every op with control-flow semantics.
   if (op->getNumRegions() || op->getNumSuccessors() ||
       isRegionOrCallableReturn(op) || isa<CallOpInterface>(op)) {
+    LDBG("[init] Visiting op with control-flow semantics: " << *op);
     // When the liveness of the parent block changes, make sure to re-invoke the
     // analysis on the op.
     if (op->getBlock())
@@ -218,14 +245,22 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
       return failure();
   }
   // Recurse on nested operations.
-  for (Region &region : op->getRegions())
-    for (Operation &op : region.getOps())
-      if (failed(initializeRecursively(&op)))
+  for (Region &region : op->getRegions()) {
+    LDBG("[init] Recursing into region of op: " << op->getName());
+    for (Operation &nestedOp : region.getOps()) {
+      LDBG("[init] Recursing into nested op: " << nestedOp.getName() << " at "
+                                               << &nestedOp);
+      if (failed(initializeRecursively(&nestedOp)))
         return failure();
+    }
+  }
+  LDBG("[init] Finished initializeRecursively for op: " << op->getName()
+                                                        << " at " << op);
   return success();
 }
 
 void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
+  LDBG("Marking edge live from block " << from << " to block " << to);
   auto *state = getOrCreate<Executable>(getProgramPointBefore(to));
   propagateIfChanged(state, state->setToLive());
   auto *edgeState =
@@ -234,37 +269,48 @@ void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
 }
 
 void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
+  LDBG("Marking entry blocks live for op: " << op->getName());
   for (Region &region : op->getRegions()) {
     if (region.empty())
       continue;
     auto *state =
         getOrCreate<Executable>(getProgramPointBefore(&region.front()));
     propagateIfChanged(state, state->setToLive());
+    LDBG("Marked entry block live for region in op: " << op->getName());
   }
 }
 
 LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
+  LDBG("Visiting program point: " << point << " " << *point);
   if (point->isBlockStart())
     return success();
   Operation *op = point->getPrevOp();
+  LDBG("Visiting operation: " << *op);
 
   // If the parent block is not executable, there is nothing to do.
   if (op->getBlock() != nullptr &&
-      !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
+      !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
+           ->isLive()) {
+    LDBG("Parent block not live, skipping op: " << *op);
     return success();
+  }
 
   // We have a live call op. Add this as a live predecessor of the callee.
-  if (auto call = dyn_cast<CallOpInterface>(op))
+  if (auto call = dyn_cast<CallOpInterface>(op)) {
+    LDBG("Visiting call operation: " << *op);
     visitCallOperation(call);
+  }
 
   // Visit the regions.
   if (op->getNumRegions()) {
     // Check if we can reason about the region control-flow.
     if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+      LDBG("Visiting region branch operation: " << *op);
       visitRegionBranchOperation(branch);
 
       // Check if this is a callable operation.
     } else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
+      LDBG("Visiting callable operation: " << *op);
       const auto *callsites = getOrCreateFor<PredecessorState>(
           getProgramPointAfter(op), getProgramPointAfter(callable));
 
@@ -276,16 +322,19 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
 
       // Otherwise, conservatively mark all entry blocks as executable.
     } else {
+      LDBG("Marking all entry blocks live for op: " << *op);
       markEntryBlocksLive(op);
     }
   }
 
   if (isRegionOrCallableReturn(op)) {
     if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
+      LDBG("Visiting region terminator: " << *op);
       // Visit the exiting terminator of a region.
       visitRegionTerminator(op, branch);
     } else if (auto callable =
                    dyn_cast<CallableOpInterface>(op->getParentOp())) {
+      LDBG("Visiting callable terminator: " << *op);
       // Visit the exiting terminator of a callable.
       visitCallableTerminator(op, callable);
     }
@@ -294,10 +343,12 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
   if (op->getNumSuccessors()) {
     // Check if we can reason about the control-flow.
     if (auto branch = dyn_cast<BranchOpInterface>(op)) {
+      LDBG("Visiting branch operation: " << *op);
       visitBranchOperation(branch);
 
       // Otherwise, conservatively mark all successors as exectuable.
     } else {
+      LDBG("Marking all successors live for op: " << *op);
       for (Block *successor : op->getSuccessors())
         markEdgeLive(op->getBlock(), successor);
     }
@@ -307,6 +358,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
 }
 
 void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
+  LDBG("visitCallOperation: " << call.getOperation()->getName());
   Operation *callableOp = call.resolveCallableInTable(&symbolTable);
 
   // A call to a externally-defined callable has unknown predecessors.
@@ -329,11 +381,15 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
     auto *callsites =
         getOrCreate<PredecessorState>(getProgramPointAfter(callableOp));
     propagateIfChanged(callsites, callsites->join(call));
+    LDBG("Added callsite as predecessor for callable: "
+         << callableOp->getName());
   } else {
     // Mark this call op's predecessors as overdefined.
     auto *predecessors =
         getOrCreate<PredecessorState>(getProgramPointAfter(call));
     propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
+    LDBG("Marked call op's predecessors as unknown for: "
+         << call.getOperation()->getName());
   }
 }
 
@@ -365,6 +421,7 @@ DeadCodeAnalysis::getOperandValues(Operation *op) {
 }
 
 void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
+  LDBG("visitBranchOperation: " << branch.getOperation()->getName());
   // Try to deduce a single successor for the branch.
   std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
   if (!operands)
@@ -372,15 +429,18 @@ void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
 
   if (Block *successor = branch.getSuccessorForOperands(*operands)) {
     markEdgeLive(branch->getBlock(), successor);
+    LDBG("Branch has single successor: " << successor);
   } else {
     // Otherwise, mark all successors as executable and outgoing edges.
     for (Block *successor : branch->getSuccessors())
       markEdgeLive(branch->getBlock(), successor);
+    LDBG("Branch has multiple/all successors live");
   }
 }
 
 void DeadCodeAnalysis::visitRegionBranchOperation(
     RegionBranchOpInterface branch) {
+  LDBG("visitRegionBranchOperation: " << branch.getOperation()->getName());
   // Try to deduce which regions are executable.
   std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
   if (!operands)
@@ -397,16 +457,19 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
     // Mark the entry block as executable.
     auto *state = getOrCreate<Executable>(point);
     propagateIfChanged(state, state->setToLive());
+    LDBG("Marked region successor live: " << point);
     // Add the parent op as a predecessor.
     auto *predecessors = getOrCreate<PredecessorState>(point);
     propagateIfChanged(
         predecessors,
         predecessors->join(branch, successor.getSuccessorInputs()));
+    LDBG("Added region branch as predecessor for successor: " << point);
   }
 }
 
 void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
                                              RegionBranchOpInterface branch) {
+  LDBG("visitRegionTerminator: " << *op);
   std::optional<SmallVector<Attribute>> operands = getOperandValues(op);
   if (!operands)
     return;
@@ -425,6 +488,7 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
       auto *state =
           getOrCreate<Executable>(getProgramPointBefore(&region->front()));
       propagateIfChanged(state, state->setToLive());
+      LDBG("Marked region entry block live for region: " << region);
       predecessors = getOrCreate<PredecessorState>(
           getProgramPointBefore(&region->front()));
     } else {
@@ -434,11 +498,14 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
     }
     propagateIfChanged(predecessors,
                        predecessors->join(op, successor.getSuccessorInputs()));
+    LDBG("Added region terminator as predecessor for successor: "
+         << (successor.getSuccessor() ? "region entry" : "parent op"));
   }
 }
 
 void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
                                                CallableOpInterface callable) {
+  LDBG("visitCallableTerminator: " << *op);
   // Add as predecessors to all callsites this return op.
   auto *callsites = getOrCreateFor<PredecessorState>(
       getProgramPointAfter(op), getProgramPointAfter(callable));
@@ -449,11 +516,15 @@ void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
         getOrCreate<PredecessorState>(getProgramPointAfter(predecessor));
     if (canResolve) {
       propagateIfChanged(predecessors, predecessors->join(op));
+      LDBG("Added callable terminator as predecessor for callsite: "
+           << predecessor->getName());
     } else {
       // If the terminator is not a return-like, then conservatively assume we
       // can't resolve the predecessor.
       propagateIfChanged(predecessors,
                          predecessors->setHasUnknownPredecessors());
+      LDBG("Could not resolve callable terminator for callsite: "
+           << predecessor->getName());
     }
   }
 }
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index d61cdb143e7dd..c6c50820f25f8 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -10,6 +10,7 @@
 #include <cassert>
 #include <mlir/Analysis/DataFlow/LivenessAnalysis.h>
 
+#include "llvm/Support/Debug.h"
 #include <mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h>
 #include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
 #include <mlir/Analysis/DataFlow/SparseAnalysis.h>
@@ -20,6 +21,10 @@
 #include <mlir/Interfaces/SideEffectInterfaces.h>
 #include <mlir/Support/LLVM.h>
 
+#define DEBUG_TYPE "liveness-analysis"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 using namespace mlir;
 using namespace mlir::dataflow;
 
@@ -77,28 +82,46 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) {
 LogicalResult
 LivenessAnalysis::visitOperation(Operation *op, ArrayRef<Liveness *> operands,
                                  ArrayRef<const Liveness *> results) {
+  LLVM_DEBUG(DBGS() << "[visitOperation] Enter: ";
+             op->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
+             llvm::dbgs() << "\n");
   // This marks values of type (1.a) and (4) liveness as "live".
   if (!isMemoryEffectFree(op) || op->hasTrait<OpTrait::ReturnLike>()) {
-    for (auto *operand : operands)
+    LDBG("[visitOperation] Operation has memory effects or is "
+         "return-like, marking operands live");
+    for (auto *operand : operands) {
+      LDBG(" [visitOperation] Marking operand live: "
+           << operand << " (" << operand->isLive << ")");
       propagateIfChanged(operand, operand->markLive());
+    }
   }
 
   // This marks values of type (3) liveness as "live".
   bool foundLiveResult = false;
   for (const Liveness *r : results) {
     if (r->isLive && !foundLiveResult) {
+      LDBG("[visitOperation] Found live result, "
+           "meeting all operands with result: "
+           << r);
       // It is assumed that each operand is used to compute each result of an
       // op. Thus, if at least one result is live, each operand is live.
-      for (Liveness *operand : operands)
+      for (Liveness *operand : operands) {
+        LDBG(" [visitOperation] Meeting operand: " << operand
+                                                   << " with result: " << r);
         meet(operand, *r);
+      }
       foundLiveResult = true;
     }
+    LDBG("[visitOperation] Adding dependency for result: " << r << " after op: "
+                                                           << *op);
     addDependency(const_cast<Liveness *>(r), getProgramPointAfter(op));
   }
   return success();
 }
 
 void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
+  LDBG("Visiting branch operand: " << operand.get()
+                                   << " in op: " << *operand.getOwner());
   // We know (at the moment) and assume (for the future) that `operand` is a
   // non-forwarded branch operand of a `RegionBranchOpInterface`,
   // `BranchOpInterface`, `RegionBranchTerminatorOpInterface` or return-like op.
@@ -130,6 +153,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
       for (Value result : op->getResults()) {
         if (getLatticeElement(result)->isLive) {
           mayLive = true;
+          LDBG("[visitBranchOperand] Non-forwarded branch "
+               "operand may be live due to live result: "
+               << result);
           break;
         }
       }
@@ -149,6 +175,8 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
     // Therefore, we conservatively consider the non-forwarded operand of the
     // branch operation may live.
     mayLive = true;
+    LDBG("[visitBranchOperand] Non-forwarded branch operand may "
+         "be live due to branch op interface");
   } else {
     Operation *parentOp = op->getParentOp();
     assert(isa<RegionBranchOpInterface>(parentOp) &&
@@ -164,6 +192,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
       for (Value result : parentOp->getResults()) {
         if (getLatticeElement(result)->isLive) {
           mayLive = true;
+          LDBG("[visitBranchOperand] Non-forwarded branch "
+               "operand may be live due to parent live result: "
+               << result);
           break;
         }
       }
@@ -184,6 +215,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
     for (Operation &nestedOp : *block) {
       if (!isMemoryEffectFree(&nestedOp)) {
         mayLive = true;
+        LDBG("Non-forwarded branch operand may be "
+             "live due to memory effect in block: "
+             << block);
         break;
       }
     }
@@ -191,6 +225,7 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
 
   if (mayLive) {
     Liveness *operandLiveness = getLatticeElement(operand.get());
+    LDBG("Marking branch operand live: " << operand.get());
     propagateIfChanged(operandLiveness, operandLiveness->markLive());
   }
 
@@ -202,6 +237,7 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
   SmallVector<const Liveness *, 4> resultsLiveness;
   for (const Value result : op->getResults())
     resultsLiveness.push_back(getLatticeElement(result));
+  LDBG("Visiting operation for non-forwarded branch operand: " << *op);
   (void)visitOperation(op, operandLiveness, resultsLiveness);
 
   // We also visit the parent op with the parent's results and this operand if
@@ -214,10 +250,14 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
   SmallVector<const Liveness *, 4> parentResultsLiveness;
   for (const Value parentResult : parentOp->getResults())
     parentResultsLiveness.push_back(getLatticeElement(parentResult));
+  LDBG("Visiting parent operation for non-forwarded branch operand: "
+       << *parentOp);
   (void)visitOperation(parentOp, operandLiveness, parentResultsLiveness);
 }
 
 void LivenessAnalysis::visitCallOperand(OpOperand &operand) {
+  LDBG("Visiting call operand: " << operand.get()
+                                 << " in op: " << *operand.getOwner());
   // We know (at the moment) and assume (for the future) that `operand` is a
   // non-forwarded call operand of an op implementing `CallOpInterface`.
   assert(isa<CallOpInterface>(operand.getOwner()) &&
@@ -230,14 +270,18 @@ void LivenessAnalysis::visitCallOperand(OpOperand &operand) {
   // This marks values of type (1.c) liveness as "live". A non-forwarded
   // call operand is live.
   Liveness *operandLiveness = getLatticeElement(operand.get());
+  LDBG("Marking call operand live: " << operand.get());
   propagateIfChanged(operandLiveness, operandLiveness->markLive());
 }
 
 void LivenessAnalysis::setToExitState(Liveness *lattice) {
+  LDBG("setToExitState for lattice: " << lattice);
   if (lattice->isLive) {
+    LDBG("Lattice already live, nothing to do");
     return;
   }
   // This marks values of type (2) liveness as "live".
+  LDBG("Marking lattice live due to exit state");
   (void)lattice->markLive();
   propagateIfChanged(lattice, ChangeResult::Change);
 }
@@ -247,12 +291,15 @@ void LivenessAnalysis::setToExitState(Liveness *lattice) {
 //===----------------------------------------------------------------------===//
 
 RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
+  LDBG("Constructing RunLivenessAnalysis for op: " << op->getName());
   SymbolTableCollection symbolTable;
 
   solver.load<DeadCodeAnalysis>();
   solver.load<SparseConstantPropagation>();
   solver.load<LivenessAnalysis>(symbolTable);
+  LDBG("Initializing and running solver");
   (void)solver.initializeAndRun(op);
+  LDBG("Dumping liveness state for op");
 }
 
 const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 08dfea8eb2648..ad21ce8f18700 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -52,12 +52,17 @@
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
 #include <cassert>
 #include <cstddef>
 #include <memory>
 #include <optional>
 #include <vector>
 
+#define DEBUG_TYPE "remove-dead-values"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 namespace mlir {
 #define GEN_PASS_DEF_REMOVEDEADVALUES
 #include "mlir/Transforms/Passes.h.inc"
@@ -115,12 +120,23 @@ struct RDVFinalCleanupList {
 static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
                     RunLivenessAnalysis &la) {
   for (Value value : values) {
-    if (nonLiveSet.contains(value))
+    if (nonLiveSet.contains(value)) {
+      LDBG("Value " << value << " is already marked non-live (dead)");
       continue;
+    }
 
     const Liveness *liveness = la.getLiveness(value);
-    if (!liveness || liveness->isLive)
+    if (!liveness) {
+      LDBG("Value " << value
+                    << " has no liveness info, conservatively considered live");
       return true;
+    }
+    if (liveness->isLive) {
+      LDBG("Value " << value << " is live according to liveness analysis");
+      return true;
+    } else {
+      LDBG("Value " << value << " is dead according to liveness analysis");
+    }
   }
   return false;
 }
@@ -134,6 +150,8 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
   for (auto [index, value] : llvm::enumerate(values)) {
     if (nonLiveSet.contains(value)) {
       lives.reset(index);
+      LDBG("Value " << value << " is already marked non-live (dead) at index "
+                    << index);
       continue;
     }
 
@@ -144,8 +162,19 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
     // of the results of an op and we know that these new values are live
     // (because they weren't erased) and also their liveness is null because
     // liveness analysis ran before their creation.
-    if (liveness && !liveness->isLive)
+    if (!liveness) {
+      LDBG("Value " << value << " at index " << index
+                    << " has no liveness info, conservatively considered live");
+      continue;
+    }
+    if (!liveness->isLive) {
       lives.reset(index);
+      LDBG("Value " << value << " at index " << index
+                    << " is dead according to liveness analysis");
+    } else {
+      LDBG("Value " << value << " at index " << index
+                    << " is live according to liveness analysis");
+    }
   }
 
   return lives;
@@ -160,6 +189,8 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
     if (!nonLive[index])
       continue;
     nonLiveSet.insert(result);
+    LDBG("Marking value " << result << " as non-live (dead) at index "
+                          << index);
   }
 }
 
@@ -229,9 +260,16 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             RDVFinalCleanupList &cl) {
-  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la))
+  LDBG("Processing simple op: " << *op);
+  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
+    LDBG("Simple op is not memory effect free or has live results, skipping: "
+         << *op);
     return;
+  }
 
+  LDBG("Simple op has all dead results and is memory effect free, scheduling "
+       "for removal: "
+       << *op);
   cl.operations.push_back(op);
   collectNonLiveValues(nonLiveSet, op->getResults(),
                        BitVector(op->getNumResults(), true));
@@ -250,8 +288,12 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
 static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
                           RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
                           RDVFinalCleanupList &cl) {
-  if (funcOp.isPublic() || funcOp.isExternal())
+  LDBG("Processing function op: " << funcOp.getOperation()->getName());
+  if (funcOp.isPublic() || funcOp.isExternal()) {
+    LDBG("Function is public or external, skipping: "
+         << funcOp.getOperation()->getName());
     return;
+  }
 
   // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
   SmallVector<Value> arguments(funcOp.getArguments());
@@ -369,6 +411,9 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
                                   RunLivenessAnalysis &la,
                                   DenseSet<Value> &nonLiveSet,
                                   RDVFinalCleanupList &cl) {
+  LLVM_DEBUG(DBGS() << "Processing region branch op: "; regionBranchOp->print(
+      llvm::dbgs(), OpPrintingFlags().skipRegions());
+             llvm::dbgs() << "\n");
   // Mark live results of `regionBranchOp` in `liveResults`.
   auto markLiveResults = [&](BitVector &liveResults) {
     liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
@@ -654,6 +699,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             RDVFinalCleanupList &cl) {
+  LDBG("Processing branch op: " << *branchOp);
   unsigned numSuccessors = branchOp->getNumSuccessors();
 
   for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {



More information about the Mlir-commits mailing list