[Mlir-commits] [mlir] [MLIR][RemoveDeadValues] Mark arguments of a public function Live (PR #162038)

xin liu llvmlistbot at llvm.org
Sun Oct 5 21:29:22 PDT 2025


https://github.com/navyxliu created https://github.com/llvm/llvm-project/pull/162038

This diff is REDO of [#160242](https://github.com/llvm/llvm-project/pull/160242).  

# Problem
Liveness analysis is inter-procedural. If there are some unused arguments in a public function, they propagate to callers. From the perspective of RemoveDeadValues, the signature of a public function is immutable. It can't cope with this situation. One side, it deletes outgoing arguments, on the other side it keeps the function intact.

# Solution
We deploy two methods to fix this bug. 
1) createDummyArgument replaces an arguments with dummies. eg. If type can create a value with zero initializer, such as 'T x{0}'. 

2) as a fallback, we add another DenseSet called 'liveSet'. The initial values are the arguments of public functions. It propagates liveness backward just like Liveness analysis.


# Test plan
```
./bin/llvm-lit -v ../mlir/test/Transforms/remove-dead-values.mlir
before:

./input.mlir:13:5: error: null operand found
    call @immutable_fn_return_void_with_unused_argument(%arg0, %zero) : (i32, i32) -> ()
    ^
./input.mlir:13:5: note: see current operation: "func.call"(%arg0, <<NULL VALUE>>) <{callee = @immutable_fn_return_void_with_unused_argument}> : (i32, <<NULL TYPE>>) -> ()
after: pass
```

 


>From 52cf8cbbcc8bdd1ae3e9ee5e1dd5f62534f563bb Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at xinliu-pc.lan>
Date: Sun, 21 Sep 2025 23:33:41 -0700
Subject: [PATCH 1/4] [MLIR][RemoveDeadValues] Mark arguments of a public
 function Live

This diff also changes traversal order from forward to backward for
region/block/ops. This order guanratees Liveness updates at a callsite
can propagates to the defs of arguments.

```
./bin/llvm-lit -v ../mlir/test/Transforms/remove-dead-values.mlir
```
---
 mlir/include/mlir/IR/Visitors.h              | 14 ++++++
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 52 +++++++++++++++++---
 mlir/test/Transforms/remove-dead-values.mlir | 18 +++++++
 3 files changed, 78 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 893f66ae33deb..5766d262796d6 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -39,6 +39,20 @@ struct ForwardIterator {
   }
 };
 
+/// This iterator enumerates the elements in "backward" order.
+struct BackwardIterator {
+  template <typename T>
+  static auto makeIterable(T &range) {
+    if constexpr (std::is_same<T, Operation>()) {
+      /// Make operations iterable: return the list of regions.
+      return llvm::reverse(range.getRegions());
+    } else {
+      /// Regions and block are already iterable.
+      return llvm::reverse(range);
+    }
+  }
+};
+
 /// A utility class to encode the current walk stage for "generic" walkers.
 /// When walking an operation, we can either choose a Pre/Post order walker
 /// which invokes the callback on an operation before/after all its attached
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index e0c65b0e09774..5dad922ff4b69 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -117,9 +117,15 @@ struct RDVFinalCleanupList {
 
 /// Return true iff at least one value in `values` is live, given the liveness
 /// information in `la`.
-static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
+static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet, const DenseSet<Value> &liveSet,
+
                     RunLivenessAnalysis &la) {
   for (Value value : values) {
+    if (liveSet.contains(value)) {
+      LDBG() << "Value " << value << " is marked live by CallOp";
+      return true;
+    }
+
     if (nonLiveSet.contains(value)) {
       LDBG() << "Value " << value << " is already marked non-live (dead)";
       continue;
@@ -259,8 +265,9 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 ///   - Return-like
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
+                            DenseSet<Value> &liveSet,
                             RDVFinalCleanupList &cl) {
-  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
+  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
     LDBG() << "Simple op is not memory effect free or has live results, "
               "preserving it: "
            << OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -379,6 +386,31 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   }
 }
 
+static void processCallOp(CallOpInterface callOp, Operation *module,
+                          RunLivenessAnalysis &la,
+                          DenseSet<Value> &liveSet) {
+  auto callable = callOp.getCallableForCallee();
+
+  if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) {
+    Operation *calleeOp = SymbolTable::lookupSymbolIn(module, symbolRef);
+
+    if (auto funcOp = llvm::dyn_cast_or_null<mlir::FunctionOpInterface>(calleeOp)) {
+      // Ensure the outgoing arguments of PUBLIC functions are live
+      // because processFuncOp can not process them.
+      //
+      // Liveness treats the external function as a blackbox.
+      if (funcOp.isPublic()) {
+        for (Value arg:  callOp.getArgOperands()) {
+          const Liveness *liveness = la.getLiveness(arg);
+          if (liveness && !liveness->isLive) {
+            liveSet.insert(arg);
+          }
+        }
+      }
+    }
+  }
+}
+
 /// Process a region branch operation `regionBranchOp` using the liveness
 /// information in `la`. The processing involves two scenarios:
 ///
@@ -411,6 +443,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
                                   RunLivenessAnalysis &la,
                                   DenseSet<Value> &nonLiveSet,
+                                  DenseSet<Value> &liveSet,
                                   RDVFinalCleanupList &cl) {
   LDBG() << "Processing region branch op: "
          << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
@@ -619,7 +652,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // attributed to something else.
   // Do (1') and (2').
   if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
-      !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
+      !hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) {
     cl.operations.push_back(regionBranchOp.getOperation());
     return;
   }
@@ -876,16 +909,18 @@ void RemoveDeadValues::runOnOperation() {
   // Tracks values eligible for erasure - complements liveness analysis to
   // identify "droppable" values.
   DenseSet<Value> deadVals;
+  // mark outgoing arguments to a public function LIVE.
+  DenseSet<Value> liveVals;
 
   // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
   // end of this pass.
   RDVFinalCleanupList finalCleanupList;
 
-  module->walk([&](Operation *op) {
+  module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
       processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
-      processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
+      processRegionBranchOp(regionBranchOp, la, deadVals, liveVals, finalCleanupList);
     } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
       processBranchOp(branchOp, la, deadVals, finalCleanupList);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
@@ -894,8 +929,13 @@ void RemoveDeadValues::runOnOperation() {
     } else if (isa<CallOpInterface>(op)) {
       // Nothing to do because this op is associated with a function op and gets
       // cleaned when the latter is cleaned.
+      //
+      // The only exception is public callee. By default, Liveness analysis is inter-procedural.
+      // Unused arguments of a public function nonLive and are propagated to the caller.
+      // processCallOp puts them to liveVals.
+      processCallOp(cast<CallOpInterface>(op), module, la, liveVals);
     } else {
-      processSimpleOp(op, la, deadVals, finalCleanupList);
+      processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
     }
   });
 
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 56449469dc29f..fc857f2989f18 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -569,6 +569,24 @@ module @return_void_with_unused_argument {
     call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
     return %unused : memref<4xi32>
   }
+
+  // the function is immutable because it is public.
+  func.func public @immutable_fn_return_void_with_unused_argument(%arg0: i32, %unused: i32) -> () {
+    %sum = arith.addi %arg0, %arg0 : i32
+    %c0 = arith.constant 0 : index
+    %buf = memref.alloc() : memref<1xi32>
+    memref.store %sum, %buf[%c0] : memref<1xi32>
+    return
+  }
+  // CHECK-LABEL: func.func @main2
+  // CHECK-SAME: (%[[ARG0_MAIN:.*]]: i32)
+  // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
+  // CHECK: call @immutable_fn_return_void_with_unused_argument(%[[ARG0_MAIN]], %[[UNUSED]]) : (i32, i32) -> ()
+  func.func @main2(%arg0: i32) -> () {
+    %zero = arith.constant 0 : i32
+    call @immutable_fn_return_void_with_unused_argument(%arg0, %zero) : (i32, i32) -> ()
+    return
+  }
 }
 
 // -----

>From be516b70225f331e51c4075c5e5f267519ef9d6b Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Tue, 23 Sep 2025 09:27:34 -0700
Subject: [PATCH 2/4] Fix formatter issue.

---
 mlir/lib/Transforms/RemoveDeadValues.cpp | 26 +++++++++++++-----------
 1 file changed, 14 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 5dad922ff4b69..50055ff7788a4 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -117,7 +117,8 @@ struct RDVFinalCleanupList {
 
 /// Return true iff at least one value in `values` is live, given the liveness
 /// information in `la`.
-static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet, const DenseSet<Value> &liveSet,
+static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
+                    const DenseSet<Value> &liveSet,
 
                     RunLivenessAnalysis &la) {
   for (Value value : values) {
@@ -265,9 +266,9 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 ///   - Return-like
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
-                            DenseSet<Value> &liveSet,
-                            RDVFinalCleanupList &cl) {
-  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
+                            DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+  if (!isMemoryEffectFree(op) ||
+      hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
     LDBG() << "Simple op is not memory effect free or has live results, "
               "preserving it: "
            << OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -387,20 +388,20 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 }
 
 static void processCallOp(CallOpInterface callOp, Operation *module,
-                          RunLivenessAnalysis &la,
-                          DenseSet<Value> &liveSet) {
+                          RunLivenessAnalysis &la, DenseSet<Value> &liveSet) {
   auto callable = callOp.getCallableForCallee();
 
   if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) {
     Operation *calleeOp = SymbolTable::lookupSymbolIn(module, symbolRef);
 
-    if (auto funcOp = llvm::dyn_cast_or_null<mlir::FunctionOpInterface>(calleeOp)) {
+    if (auto funcOp =
+            llvm::dyn_cast_or_null<mlir::FunctionOpInterface>(calleeOp)) {
       // Ensure the outgoing arguments of PUBLIC functions are live
       // because processFuncOp can not process them.
       //
       // Liveness treats the external function as a blackbox.
       if (funcOp.isPublic()) {
-        for (Value arg:  callOp.getArgOperands()) {
+        for (Value arg : callOp.getArgOperands()) {
           const Liveness *liveness = la.getLiveness(arg);
           if (liveness && !liveness->isLive) {
             liveSet.insert(arg);
@@ -920,7 +921,8 @@ void RemoveDeadValues::runOnOperation() {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
       processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
-      processRegionBranchOp(regionBranchOp, la, deadVals, liveVals, finalCleanupList);
+      processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
+                            finalCleanupList);
     } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
       processBranchOp(branchOp, la, deadVals, finalCleanupList);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
@@ -930,9 +932,9 @@ void RemoveDeadValues::runOnOperation() {
       // Nothing to do because this op is associated with a function op and gets
       // cleaned when the latter is cleaned.
       //
-      // The only exception is public callee. By default, Liveness analysis is inter-procedural.
-      // Unused arguments of a public function nonLive and are propagated to the caller.
-      // processCallOp puts them to liveVals.
+      // The only exception is public callee. By default, Liveness analysis is
+      // inter-procedural. Unused arguments of a public function nonLive and are
+      // propagated to the caller. processCallOp puts them to liveVals.
       processCallOp(cast<CallOpInterface>(op), module, la, liveVals);
     } else {
       processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);

>From 6c62398cbf139d8384f712500311f11377bfd0a4 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Thu, 2 Oct 2025 22:41:04 -0700
Subject: [PATCH 3/4] update processCallOp.

---
 .../mlir/Analysis/DataFlow/LivenessAnalysis.h |   3 +
 mlir/lib/Transforms/RemoveDeadValues.cpp      | 107 +++++++++++-------
 2 files changed, 71 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index cf1fd6e2d48ca..be7e027b95f64 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -102,6 +102,9 @@ struct RunLivenessAnalysis {
 
   const Liveness *getLiveness(Value val);
 
+  /// Return the configuration of the solver used for this analysis.
+  const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
+
 private:
   /// Stores the result of the liveness analysis that was run.
   DataFlowSolver solver;
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 50055ff7788a4..621ec7b3827d3 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -33,6 +33,7 @@
 
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dialect.h"
@@ -119,7 +120,6 @@ struct RDVFinalCleanupList {
 /// information in `la`.
 static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
                     const DenseSet<Value> &liveSet,
-
                     RunLivenessAnalysis &la) {
   for (Value value : values) {
     if (liveSet.contains(value)) {
@@ -151,7 +151,7 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
 /// i-th value in `values` is live, given the liveness information in `la`.
 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
-                           RunLivenessAnalysis &la) {
+                           const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
   BitVector lives(values.size(), true);
 
   for (auto [index, value] : llvm::enumerate(values)) {
@@ -161,7 +161,9 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
              << " is already marked non-live (dead) at index " << index;
       continue;
     }
-
+    if (liveSet.contains(value)) {
+      continue;
+    }
     const Liveness *liveness = la.getLiveness(value);
     // It is important to note that when `liveness` is null, we can't tell if
     // `value` is live or not. So, the safe option is to consider it live. Also,
@@ -267,6 +269,16 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+  for (Value val: op->getResults()) {
+    if (liveSet.contains(val)) {
+      LDBG() << "Simple op is used by a public function, "
+                "preserving it: "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions());
+      liveSet.insert_range(op->getOperands());
+      return;
+    }
+  }
+
   if (!isMemoryEffectFree(op) ||
       hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
     LDBG() << "Simple op is not memory effect free or has live results, "
@@ -295,7 +307,7 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
 ///       removal.
 ///   (6) Marking all its results as non-live values.
 static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
-                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet,
                           RDVFinalCleanupList &cl) {
   LDBG() << "Processing function op: "
          << OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
@@ -307,7 +319,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 
   // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
   SmallVector<Value> arguments(funcOp.getArguments());
-  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
+  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
   nonLiveArgs = nonLiveArgs.flip();
 
   // Do (1).
@@ -360,7 +372,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
-    BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
+    BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, liveSet, la);
     nonLiveRets &= liveCallRets.flip();
   }
 
@@ -386,29 +398,52 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
     collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
   }
 }
+// create a cheaper value with the same type of oldVal in front of CallOp.
+static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
+  OpBuilder builder(callOp.getOperation());
+  Type type = oldVal.getType();
+
+  // Create zero constant for any supported type
+  if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
+      return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
+  }
+  return {};
+}
 
 static void processCallOp(CallOpInterface callOp, Operation *module,
-                          RunLivenessAnalysis &la, DenseSet<Value> &liveSet) {
-  auto callable = callOp.getCallableForCallee();
-
-  if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) {
-    Operation *calleeOp = SymbolTable::lookupSymbolIn(module, symbolRef);
-
-    if (auto funcOp =
-            llvm::dyn_cast_or_null<mlir::FunctionOpInterface>(calleeOp)) {
-      // Ensure the outgoing arguments of PUBLIC functions are live
-      // because processFuncOp can not process them.
-      //
-      // Liveness treats the external function as a blackbox.
-      if (funcOp.isPublic()) {
-        for (Value arg : callOp.getArgOperands()) {
-          const Liveness *liveness = la.getLiveness(arg);
-          if (liveness && !liveness->isLive) {
-            liveSet.insert(arg);
-          }
-        }
+                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet) {
+  if (!la.getSolverConfig().isInterprocedural())
+    return;
+
+  Operation *callableOp = callOp.resolveCallable();
+  auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
+  if (!funcOp || !funcOp.isPublic()) {
+    return;
+  }
+  LDBG() << "processCallOp" << funcOp.getName();
+  // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
+  SmallVector<Value> arguments(funcOp.getArguments());
+  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
+  nonLiveArgs = nonLiveArgs.flip();
+
+  if (nonLiveArgs.count() > 0) {
+    LDBG() << funcOp.getName() << " contains NonLive arguments";
+    // The number of operands in the call op may not match the number of
+    // arguments in the func op.
+    SmallVector<OpOperand *> callOpOperands =
+        operandsToOpOperands(callOp.getArgOperands());
+
+    for (int index : nonLiveArgs.set_bits()) {
+      OpOperand *operand = callOpOperands[index];
+      Value oldVal = operand->get();
+      if (Value dummy = createDummyArgument(callOp, oldVal)) {
+        callOp->setOperand(operand->getOperandNumber(), dummy);
+        nonLiveSet.insert(oldVal);
+      } else {
+        liveSet.insert(oldVal);
       }
     }
+    LDBG() << "after changed " << callOp;
   }
 }
 
@@ -450,7 +485,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
          << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
   // Mark live results of `regionBranchOp` in `liveResults`.
   auto markLiveResults = [&](BitVector &liveResults) {
-    liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
+    liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
   };
 
   // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
@@ -459,7 +494,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
       if (region.empty())
         continue;
       SmallVector<Value> arguments(region.front().getArguments());
-      BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
+      BitVector regionLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
       liveArgs[&region] = regionLiveArgs;
     }
   };
@@ -731,7 +766,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
 ///     as well as their corresponding arguments.
 
 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
-                            DenseSet<Value> &nonLiveSet,
+                            DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet,
                             RDVFinalCleanupList &cl) {
   LDBG() << "Processing branch op: " << *branchOp;
   unsigned numSuccessors = branchOp->getNumSuccessors();
@@ -750,7 +785,7 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
 
     // Do (2)
     BitVector successorNonLive =
-        markLives(operandValues, nonLiveSet, la).flip();
+        markLives(operandValues, nonLiveSet, liveSet, la).flip();
     collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
                          successorNonLive);
 
@@ -910,7 +945,7 @@ void RemoveDeadValues::runOnOperation() {
   // Tracks values eligible for erasure - complements liveness analysis to
   // identify "droppable" values.
   DenseSet<Value> deadVals;
-  // mark outgoing arguments to a public function LIVE.
+  // mark outgoing arguments to a public function LIVE. We also propagate liveness backward.
   DenseSet<Value> liveVals;
 
   // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
@@ -919,23 +954,17 @@ void RemoveDeadValues::runOnOperation() {
 
   module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
-      processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
+      processFuncOp(funcOp, module, la, deadVals, liveVals, finalCleanupList);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
       processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
                             finalCleanupList);
     } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
-      processBranchOp(branchOp, la, deadVals, finalCleanupList);
+      processBranchOp(branchOp, la, deadVals, liveVals, finalCleanupList);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
       // Nothing to do here because this is a terminator op and it should be
       // honored with respect to its parent
     } else if (isa<CallOpInterface>(op)) {
-      // Nothing to do because this op is associated with a function op and gets
-      // cleaned when the latter is cleaned.
-      //
-      // The only exception is public callee. By default, Liveness analysis is
-      // inter-procedural. Unused arguments of a public function nonLive and are
-      // propagated to the caller. processCallOp puts them to liveVals.
-      processCallOp(cast<CallOpInterface>(op), module, la, liveVals);
+      processCallOp(cast<CallOpInterface>(op), module, la, deadVals, liveVals);
     } else {
       processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
     }

>From 3f4d69f8910007656105a8798b0a5934310df489 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Sun, 5 Oct 2025 20:45:24 -0700
Subject: [PATCH 4/4] Update the lit test and formatter.

---
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 37 +++++++++++---------
 mlir/test/Transforms/remove-dead-values.mlir | 22 ++++++------
 2 files changed, 32 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 621ec7b3827d3..ed5c6a8d2ead0 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -119,8 +119,7 @@ struct RDVFinalCleanupList {
 /// Return true iff at least one value in `values` is live, given the liveness
 /// information in `la`.
 static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
-                    const DenseSet<Value> &liveSet,
-                    RunLivenessAnalysis &la) {
+                    const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
   for (Value value : values) {
     if (liveSet.contains(value)) {
       LDBG() << "Value " << value << " is marked live by CallOp";
@@ -151,7 +150,8 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
 /// i-th value in `values` is live, given the liveness information in `la`.
 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
-                           const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
+                           const DenseSet<Value> &liveSet,
+                           RunLivenessAnalysis &la) {
   BitVector lives(values.size(), true);
 
   for (auto [index, value] : llvm::enumerate(values)) {
@@ -269,7 +269,7 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
-  for (Value val: op->getResults()) {
+  for (Value val : op->getResults()) {
     if (liveSet.contains(val)) {
       LDBG() << "Simple op is used by a public function, "
                 "preserving it: "
@@ -307,8 +307,8 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
 ///       removal.
 ///   (6) Marking all its results as non-live values.
 static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
-                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet,
-                          RDVFinalCleanupList &cl) {
+                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+                          DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
   LDBG() << "Processing function op: "
          << OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
   if (funcOp.isPublic() || funcOp.isExternal()) {
@@ -372,7 +372,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
-    BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, liveSet, la);
+    BitVector liveCallRets =
+        markLives(callOp->getResults(), nonLiveSet, liveSet, la);
     nonLiveRets &= liveCallRets.flip();
   }
 
@@ -398,20 +399,22 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
     collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
   }
 }
-// create a cheaper value with the same type of oldVal in front of CallOp.
+
+// Create a cheaper value with the same type of oldVal in front of CallOp.
 static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
   OpBuilder builder(callOp.getOperation());
   Type type = oldVal.getType();
 
   // Create zero constant for any supported type
   if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
-      return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
+    return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
   }
   return {};
 }
 
 static void processCallOp(CallOpInterface callOp, Operation *module,
-                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet) {
+                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+                          DenseSet<Value> &liveSet) {
   if (!la.getSolverConfig().isInterprocedural())
     return;
 
@@ -420,7 +423,8 @@ static void processCallOp(CallOpInterface callOp, Operation *module,
   if (!funcOp || !funcOp.isPublic()) {
     return;
   }
-  LDBG() << "processCallOp" << funcOp.getName();
+
+  LDBG() << "processCallOp to a public function: " << funcOp.getName();
   // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
   SmallVector<Value> arguments(funcOp.getArguments());
   BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
@@ -443,7 +447,6 @@ static void processCallOp(CallOpInterface callOp, Operation *module,
         liveSet.insert(oldVal);
       }
     }
-    LDBG() << "after changed " << callOp;
   }
 }
 
@@ -485,7 +488,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
          << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
   // Mark live results of `regionBranchOp` in `liveResults`.
   auto markLiveResults = [&](BitVector &liveResults) {
-    liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
+    liveResults =
+        markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
   };
 
   // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
@@ -766,8 +770,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
 ///     as well as their corresponding arguments.
 
 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
-                            DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet,
-                            RDVFinalCleanupList &cl) {
+                            DenseSet<Value> &nonLiveSet,
+                            DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
   LDBG() << "Processing branch op: " << *branchOp;
   unsigned numSuccessors = branchOp->getNumSuccessors();
 
@@ -945,7 +949,8 @@ void RemoveDeadValues::runOnOperation() {
   // Tracks values eligible for erasure - complements liveness analysis to
   // identify "droppable" values.
   DenseSet<Value> deadVals;
-  // mark outgoing arguments to a public function LIVE. We also propagate liveness backward.
+  // mark outgoing arguments to a public function LIVE. We also propagate
+  // liveness backward.
   DenseSet<Value> liveVals;
 
   // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index fc857f2989f18..ebb76a8835ceb 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -570,21 +570,21 @@ module @return_void_with_unused_argument {
     return %unused : memref<4xi32>
   }
 
-  // the function is immutable because it is public.
-  func.func public @immutable_fn_return_void_with_unused_argument(%arg0: i32, %unused: i32) -> () {
-    %sum = arith.addi %arg0, %arg0 : i32
-    %c0 = arith.constant 0 : index
-    %buf = memref.alloc() : memref<1xi32>
-    memref.store %sum, %buf[%c0] : memref<1xi32>
+  // the function signature is immutable because it is public.
+  func.func public @immutable_fn_with_unused_argument(%arg0: i32, %arg1: memref<4xf32>) -> () {
     return
   }
+
   // CHECK-LABEL: func.func @main2
-  // CHECK-SAME: (%[[ARG0_MAIN:.*]]: i32)
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
   // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
-  // CHECK: call @immutable_fn_return_void_with_unused_argument(%[[ARG0_MAIN]], %[[UNUSED]]) : (i32, i32) -> ()
-  func.func @main2(%arg0: i32) -> () {
-    %zero = arith.constant 0 : i32
-    call @immutable_fn_return_void_with_unused_argument(%arg0, %zero) : (i32, i32) -> ()
+  // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
+  func.func @main2() -> () {
+    %one = arith.constant 1 : i32
+    %scalar = arith.addi %one, %one: i32
+    %mem = memref.alloc() : memref<4xf32>
+
+    call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
     return
   }
 }



More information about the Mlir-commits mailing list