[Mlir-commits] [mlir] [MLIR][RemoveDeadValues] Privatize public function with NonLive arguments before RDV. (PR #162038)

xin liu llvmlistbot at llvm.org
Wed Oct 29 23:08:26 PDT 2025


https://github.com/navyxliu updated https://github.com/llvm/llvm-project/pull/162038

>From 52cf8cbbcc8bdd1ae3e9ee5e1dd5f62534f563bb Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at xinliu-pc.lan>
Date: Sun, 21 Sep 2025 23:33:41 -0700
Subject: [PATCH 01/11] [MLIR][RemoveDeadValues] Mark arguments of a public
 function Live

This diff also changes traversal order from forward to backward for
region/block/ops. This order guanratees Liveness updates at a callsite
can propagates to the defs of arguments.

```
./bin/llvm-lit -v ../mlir/test/Transforms/remove-dead-values.mlir
```
---
 mlir/include/mlir/IR/Visitors.h              | 14 ++++++
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 52 +++++++++++++++++---
 mlir/test/Transforms/remove-dead-values.mlir | 18 +++++++
 3 files changed, 78 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 893f66ae33deb..5766d262796d6 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -39,6 +39,20 @@ struct ForwardIterator {
   }
 };
 
+/// This iterator enumerates the elements in "backward" order.
+struct BackwardIterator {
+  template <typename T>
+  static auto makeIterable(T &range) {
+    if constexpr (std::is_same<T, Operation>()) {
+      /// Make operations iterable: return the list of regions.
+      return llvm::reverse(range.getRegions());
+    } else {
+      /// Regions and block are already iterable.
+      return llvm::reverse(range);
+    }
+  }
+};
+
 /// A utility class to encode the current walk stage for "generic" walkers.
 /// When walking an operation, we can either choose a Pre/Post order walker
 /// which invokes the callback on an operation before/after all its attached
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index e0c65b0e09774..5dad922ff4b69 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -117,9 +117,15 @@ struct RDVFinalCleanupList {
 
 /// Return true iff at least one value in `values` is live, given the liveness
 /// information in `la`.
-static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
+static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet, const DenseSet<Value> &liveSet,
+
                     RunLivenessAnalysis &la) {
   for (Value value : values) {
+    if (liveSet.contains(value)) {
+      LDBG() << "Value " << value << " is marked live by CallOp";
+      return true;
+    }
+
     if (nonLiveSet.contains(value)) {
       LDBG() << "Value " << value << " is already marked non-live (dead)";
       continue;
@@ -259,8 +265,9 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 ///   - Return-like
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
+                            DenseSet<Value> &liveSet,
                             RDVFinalCleanupList &cl) {
-  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
+  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
     LDBG() << "Simple op is not memory effect free or has live results, "
               "preserving it: "
            << OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -379,6 +386,31 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   }
 }
 
+static void processCallOp(CallOpInterface callOp, Operation *module,
+                          RunLivenessAnalysis &la,
+                          DenseSet<Value> &liveSet) {
+  auto callable = callOp.getCallableForCallee();
+
+  if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) {
+    Operation *calleeOp = SymbolTable::lookupSymbolIn(module, symbolRef);
+
+    if (auto funcOp = llvm::dyn_cast_or_null<mlir::FunctionOpInterface>(calleeOp)) {
+      // Ensure the outgoing arguments of PUBLIC functions are live
+      // because processFuncOp can not process them.
+      //
+      // Liveness treats the external function as a blackbox.
+      if (funcOp.isPublic()) {
+        for (Value arg:  callOp.getArgOperands()) {
+          const Liveness *liveness = la.getLiveness(arg);
+          if (liveness && !liveness->isLive) {
+            liveSet.insert(arg);
+          }
+        }
+      }
+    }
+  }
+}
+
 /// Process a region branch operation `regionBranchOp` using the liveness
 /// information in `la`. The processing involves two scenarios:
 ///
@@ -411,6 +443,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
                                   RunLivenessAnalysis &la,
                                   DenseSet<Value> &nonLiveSet,
+                                  DenseSet<Value> &liveSet,
                                   RDVFinalCleanupList &cl) {
   LDBG() << "Processing region branch op: "
          << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
@@ -619,7 +652,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // attributed to something else.
   // Do (1') and (2').
   if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
-      !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
+      !hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) {
     cl.operations.push_back(regionBranchOp.getOperation());
     return;
   }
@@ -876,16 +909,18 @@ void RemoveDeadValues::runOnOperation() {
   // Tracks values eligible for erasure - complements liveness analysis to
   // identify "droppable" values.
   DenseSet<Value> deadVals;
+  // mark outgoing arguments to a public function LIVE.
+  DenseSet<Value> liveVals;
 
   // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
   // end of this pass.
   RDVFinalCleanupList finalCleanupList;
 
-  module->walk([&](Operation *op) {
+  module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
       processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
-      processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
+      processRegionBranchOp(regionBranchOp, la, deadVals, liveVals, finalCleanupList);
     } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
       processBranchOp(branchOp, la, deadVals, finalCleanupList);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
@@ -894,8 +929,13 @@ void RemoveDeadValues::runOnOperation() {
     } else if (isa<CallOpInterface>(op)) {
       // Nothing to do because this op is associated with a function op and gets
       // cleaned when the latter is cleaned.
+      //
+      // The only exception is public callee. By default, Liveness analysis is inter-procedural.
+      // Unused arguments of a public function nonLive and are propagated to the caller.
+      // processCallOp puts them to liveVals.
+      processCallOp(cast<CallOpInterface>(op), module, la, liveVals);
     } else {
-      processSimpleOp(op, la, deadVals, finalCleanupList);
+      processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
     }
   });
 
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 56449469dc29f..fc857f2989f18 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -569,6 +569,24 @@ module @return_void_with_unused_argument {
     call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
     return %unused : memref<4xi32>
   }
+
+  // the function is immutable because it is public.
+  func.func public @immutable_fn_return_void_with_unused_argument(%arg0: i32, %unused: i32) -> () {
+    %sum = arith.addi %arg0, %arg0 : i32
+    %c0 = arith.constant 0 : index
+    %buf = memref.alloc() : memref<1xi32>
+    memref.store %sum, %buf[%c0] : memref<1xi32>
+    return
+  }
+  // CHECK-LABEL: func.func @main2
+  // CHECK-SAME: (%[[ARG0_MAIN:.*]]: i32)
+  // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
+  // CHECK: call @immutable_fn_return_void_with_unused_argument(%[[ARG0_MAIN]], %[[UNUSED]]) : (i32, i32) -> ()
+  func.func @main2(%arg0: i32) -> () {
+    %zero = arith.constant 0 : i32
+    call @immutable_fn_return_void_with_unused_argument(%arg0, %zero) : (i32, i32) -> ()
+    return
+  }
 }
 
 // -----

>From be516b70225f331e51c4075c5e5f267519ef9d6b Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Tue, 23 Sep 2025 09:27:34 -0700
Subject: [PATCH 02/11] Fix formatter issue.

---
 mlir/lib/Transforms/RemoveDeadValues.cpp | 26 +++++++++++++-----------
 1 file changed, 14 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 5dad922ff4b69..50055ff7788a4 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -117,7 +117,8 @@ struct RDVFinalCleanupList {
 
 /// Return true iff at least one value in `values` is live, given the liveness
 /// information in `la`.
-static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet, const DenseSet<Value> &liveSet,
+static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
+                    const DenseSet<Value> &liveSet,
 
                     RunLivenessAnalysis &la) {
   for (Value value : values) {
@@ -265,9 +266,9 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 ///   - Return-like
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
-                            DenseSet<Value> &liveSet,
-                            RDVFinalCleanupList &cl) {
-  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
+                            DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+  if (!isMemoryEffectFree(op) ||
+      hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
     LDBG() << "Simple op is not memory effect free or has live results, "
               "preserving it: "
            << OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -387,20 +388,20 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 }
 
 static void processCallOp(CallOpInterface callOp, Operation *module,
-                          RunLivenessAnalysis &la,
-                          DenseSet<Value> &liveSet) {
+                          RunLivenessAnalysis &la, DenseSet<Value> &liveSet) {
   auto callable = callOp.getCallableForCallee();
 
   if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) {
     Operation *calleeOp = SymbolTable::lookupSymbolIn(module, symbolRef);
 
-    if (auto funcOp = llvm::dyn_cast_or_null<mlir::FunctionOpInterface>(calleeOp)) {
+    if (auto funcOp =
+            llvm::dyn_cast_or_null<mlir::FunctionOpInterface>(calleeOp)) {
       // Ensure the outgoing arguments of PUBLIC functions are live
       // because processFuncOp can not process them.
       //
       // Liveness treats the external function as a blackbox.
       if (funcOp.isPublic()) {
-        for (Value arg:  callOp.getArgOperands()) {
+        for (Value arg : callOp.getArgOperands()) {
           const Liveness *liveness = la.getLiveness(arg);
           if (liveness && !liveness->isLive) {
             liveSet.insert(arg);
@@ -920,7 +921,8 @@ void RemoveDeadValues::runOnOperation() {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
       processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
-      processRegionBranchOp(regionBranchOp, la, deadVals, liveVals, finalCleanupList);
+      processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
+                            finalCleanupList);
     } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
       processBranchOp(branchOp, la, deadVals, finalCleanupList);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
@@ -930,9 +932,9 @@ void RemoveDeadValues::runOnOperation() {
       // Nothing to do because this op is associated with a function op and gets
       // cleaned when the latter is cleaned.
       //
-      // The only exception is public callee. By default, Liveness analysis is inter-procedural.
-      // Unused arguments of a public function nonLive and are propagated to the caller.
-      // processCallOp puts them to liveVals.
+      // The only exception is public callee. By default, Liveness analysis is
+      // inter-procedural. Unused arguments of a public function nonLive and are
+      // propagated to the caller. processCallOp puts them to liveVals.
       processCallOp(cast<CallOpInterface>(op), module, la, liveVals);
     } else {
       processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);

>From 6c62398cbf139d8384f712500311f11377bfd0a4 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Thu, 2 Oct 2025 22:41:04 -0700
Subject: [PATCH 03/11] update processCallOp.

---
 .../mlir/Analysis/DataFlow/LivenessAnalysis.h |   3 +
 mlir/lib/Transforms/RemoveDeadValues.cpp      | 107 +++++++++++-------
 2 files changed, 71 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index cf1fd6e2d48ca..be7e027b95f64 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -102,6 +102,9 @@ struct RunLivenessAnalysis {
 
   const Liveness *getLiveness(Value val);
 
+  /// Return the configuration of the solver used for this analysis.
+  const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
+
 private:
   /// Stores the result of the liveness analysis that was run.
   DataFlowSolver solver;
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 50055ff7788a4..621ec7b3827d3 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -33,6 +33,7 @@
 
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dialect.h"
@@ -119,7 +120,6 @@ struct RDVFinalCleanupList {
 /// information in `la`.
 static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
                     const DenseSet<Value> &liveSet,
-
                     RunLivenessAnalysis &la) {
   for (Value value : values) {
     if (liveSet.contains(value)) {
@@ -151,7 +151,7 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
 /// i-th value in `values` is live, given the liveness information in `la`.
 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
-                           RunLivenessAnalysis &la) {
+                           const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
   BitVector lives(values.size(), true);
 
   for (auto [index, value] : llvm::enumerate(values)) {
@@ -161,7 +161,9 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
              << " is already marked non-live (dead) at index " << index;
       continue;
     }
-
+    if (liveSet.contains(value)) {
+      continue;
+    }
     const Liveness *liveness = la.getLiveness(value);
     // It is important to note that when `liveness` is null, we can't tell if
     // `value` is live or not. So, the safe option is to consider it live. Also,
@@ -267,6 +269,16 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+  for (Value val: op->getResults()) {
+    if (liveSet.contains(val)) {
+      LDBG() << "Simple op is used by a public function, "
+                "preserving it: "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions());
+      liveSet.insert_range(op->getOperands());
+      return;
+    }
+  }
+
   if (!isMemoryEffectFree(op) ||
       hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
     LDBG() << "Simple op is not memory effect free or has live results, "
@@ -295,7 +307,7 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
 ///       removal.
 ///   (6) Marking all its results as non-live values.
 static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
-                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet,
                           RDVFinalCleanupList &cl) {
   LDBG() << "Processing function op: "
          << OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
@@ -307,7 +319,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 
   // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
   SmallVector<Value> arguments(funcOp.getArguments());
-  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
+  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
   nonLiveArgs = nonLiveArgs.flip();
 
   // Do (1).
@@ -360,7 +372,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
-    BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
+    BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, liveSet, la);
     nonLiveRets &= liveCallRets.flip();
   }
 
@@ -386,29 +398,52 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
     collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
   }
 }
+// create a cheaper value with the same type of oldVal in front of CallOp.
+static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
+  OpBuilder builder(callOp.getOperation());
+  Type type = oldVal.getType();
+
+  // Create zero constant for any supported type
+  if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
+      return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
+  }
+  return {};
+}
 
 static void processCallOp(CallOpInterface callOp, Operation *module,
-                          RunLivenessAnalysis &la, DenseSet<Value> &liveSet) {
-  auto callable = callOp.getCallableForCallee();
-
-  if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) {
-    Operation *calleeOp = SymbolTable::lookupSymbolIn(module, symbolRef);
-
-    if (auto funcOp =
-            llvm::dyn_cast_or_null<mlir::FunctionOpInterface>(calleeOp)) {
-      // Ensure the outgoing arguments of PUBLIC functions are live
-      // because processFuncOp can not process them.
-      //
-      // Liveness treats the external function as a blackbox.
-      if (funcOp.isPublic()) {
-        for (Value arg : callOp.getArgOperands()) {
-          const Liveness *liveness = la.getLiveness(arg);
-          if (liveness && !liveness->isLive) {
-            liveSet.insert(arg);
-          }
-        }
+                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet) {
+  if (!la.getSolverConfig().isInterprocedural())
+    return;
+
+  Operation *callableOp = callOp.resolveCallable();
+  auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
+  if (!funcOp || !funcOp.isPublic()) {
+    return;
+  }
+  LDBG() << "processCallOp" << funcOp.getName();
+  // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
+  SmallVector<Value> arguments(funcOp.getArguments());
+  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
+  nonLiveArgs = nonLiveArgs.flip();
+
+  if (nonLiveArgs.count() > 0) {
+    LDBG() << funcOp.getName() << " contains NonLive arguments";
+    // The number of operands in the call op may not match the number of
+    // arguments in the func op.
+    SmallVector<OpOperand *> callOpOperands =
+        operandsToOpOperands(callOp.getArgOperands());
+
+    for (int index : nonLiveArgs.set_bits()) {
+      OpOperand *operand = callOpOperands[index];
+      Value oldVal = operand->get();
+      if (Value dummy = createDummyArgument(callOp, oldVal)) {
+        callOp->setOperand(operand->getOperandNumber(), dummy);
+        nonLiveSet.insert(oldVal);
+      } else {
+        liveSet.insert(oldVal);
       }
     }
+    LDBG() << "after changed " << callOp;
   }
 }
 
@@ -450,7 +485,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
          << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
   // Mark live results of `regionBranchOp` in `liveResults`.
   auto markLiveResults = [&](BitVector &liveResults) {
-    liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
+    liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
   };
 
   // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
@@ -459,7 +494,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
       if (region.empty())
         continue;
       SmallVector<Value> arguments(region.front().getArguments());
-      BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
+      BitVector regionLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
       liveArgs[&region] = regionLiveArgs;
     }
   };
@@ -731,7 +766,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
 ///     as well as their corresponding arguments.
 
 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
-                            DenseSet<Value> &nonLiveSet,
+                            DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet,
                             RDVFinalCleanupList &cl) {
   LDBG() << "Processing branch op: " << *branchOp;
   unsigned numSuccessors = branchOp->getNumSuccessors();
@@ -750,7 +785,7 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
 
     // Do (2)
     BitVector successorNonLive =
-        markLives(operandValues, nonLiveSet, la).flip();
+        markLives(operandValues, nonLiveSet, liveSet, la).flip();
     collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
                          successorNonLive);
 
@@ -910,7 +945,7 @@ void RemoveDeadValues::runOnOperation() {
   // Tracks values eligible for erasure - complements liveness analysis to
   // identify "droppable" values.
   DenseSet<Value> deadVals;
-  // mark outgoing arguments to a public function LIVE.
+  // mark outgoing arguments to a public function LIVE. We also propagate liveness backward.
   DenseSet<Value> liveVals;
 
   // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
@@ -919,23 +954,17 @@ void RemoveDeadValues::runOnOperation() {
 
   module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
-      processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
+      processFuncOp(funcOp, module, la, deadVals, liveVals, finalCleanupList);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
       processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
                             finalCleanupList);
     } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
-      processBranchOp(branchOp, la, deadVals, finalCleanupList);
+      processBranchOp(branchOp, la, deadVals, liveVals, finalCleanupList);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
       // Nothing to do here because this is a terminator op and it should be
       // honored with respect to its parent
     } else if (isa<CallOpInterface>(op)) {
-      // Nothing to do because this op is associated with a function op and gets
-      // cleaned when the latter is cleaned.
-      //
-      // The only exception is public callee. By default, Liveness analysis is
-      // inter-procedural. Unused arguments of a public function nonLive and are
-      // propagated to the caller. processCallOp puts them to liveVals.
-      processCallOp(cast<CallOpInterface>(op), module, la, liveVals);
+      processCallOp(cast<CallOpInterface>(op), module, la, deadVals, liveVals);
     } else {
       processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
     }

>From 3f4d69f8910007656105a8798b0a5934310df489 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Sun, 5 Oct 2025 20:45:24 -0700
Subject: [PATCH 04/11] Update the lit test and formatter.

---
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 37 +++++++++++---------
 mlir/test/Transforms/remove-dead-values.mlir | 22 ++++++------
 2 files changed, 32 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 621ec7b3827d3..ed5c6a8d2ead0 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -119,8 +119,7 @@ struct RDVFinalCleanupList {
 /// Return true iff at least one value in `values` is live, given the liveness
 /// information in `la`.
 static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
-                    const DenseSet<Value> &liveSet,
-                    RunLivenessAnalysis &la) {
+                    const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
   for (Value value : values) {
     if (liveSet.contains(value)) {
       LDBG() << "Value " << value << " is marked live by CallOp";
@@ -151,7 +150,8 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
 /// i-th value in `values` is live, given the liveness information in `la`.
 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
-                           const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
+                           const DenseSet<Value> &liveSet,
+                           RunLivenessAnalysis &la) {
   BitVector lives(values.size(), true);
 
   for (auto [index, value] : llvm::enumerate(values)) {
@@ -269,7 +269,7 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
-  for (Value val: op->getResults()) {
+  for (Value val : op->getResults()) {
     if (liveSet.contains(val)) {
       LDBG() << "Simple op is used by a public function, "
                 "preserving it: "
@@ -307,8 +307,8 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
 ///       removal.
 ///   (6) Marking all its results as non-live values.
 static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
-                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet,
-                          RDVFinalCleanupList &cl) {
+                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+                          DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
   LDBG() << "Processing function op: "
          << OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
   if (funcOp.isPublic() || funcOp.isExternal()) {
@@ -372,7 +372,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
-    BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, liveSet, la);
+    BitVector liveCallRets =
+        markLives(callOp->getResults(), nonLiveSet, liveSet, la);
     nonLiveRets &= liveCallRets.flip();
   }
 
@@ -398,20 +399,22 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
     collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
   }
 }
-// create a cheaper value with the same type of oldVal in front of CallOp.
+
+// Create a cheaper value with the same type of oldVal in front of CallOp.
 static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
   OpBuilder builder(callOp.getOperation());
   Type type = oldVal.getType();
 
   // Create zero constant for any supported type
   if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
-      return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
+    return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
   }
   return {};
 }
 
 static void processCallOp(CallOpInterface callOp, Operation *module,
-                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet) {
+                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+                          DenseSet<Value> &liveSet) {
   if (!la.getSolverConfig().isInterprocedural())
     return;
 
@@ -420,7 +423,8 @@ static void processCallOp(CallOpInterface callOp, Operation *module,
   if (!funcOp || !funcOp.isPublic()) {
     return;
   }
-  LDBG() << "processCallOp" << funcOp.getName();
+
+  LDBG() << "processCallOp to a public function: " << funcOp.getName();
   // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
   SmallVector<Value> arguments(funcOp.getArguments());
   BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
@@ -443,7 +447,6 @@ static void processCallOp(CallOpInterface callOp, Operation *module,
         liveSet.insert(oldVal);
       }
     }
-    LDBG() << "after changed " << callOp;
   }
 }
 
@@ -485,7 +488,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
          << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
   // Mark live results of `regionBranchOp` in `liveResults`.
   auto markLiveResults = [&](BitVector &liveResults) {
-    liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
+    liveResults =
+        markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
   };
 
   // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
@@ -766,8 +770,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
 ///     as well as their corresponding arguments.
 
 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
-                            DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet,
-                            RDVFinalCleanupList &cl) {
+                            DenseSet<Value> &nonLiveSet,
+                            DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
   LDBG() << "Processing branch op: " << *branchOp;
   unsigned numSuccessors = branchOp->getNumSuccessors();
 
@@ -945,7 +949,8 @@ void RemoveDeadValues::runOnOperation() {
   // Tracks values eligible for erasure - complements liveness analysis to
   // identify "droppable" values.
   DenseSet<Value> deadVals;
-  // mark outgoing arguments to a public function LIVE. We also propagate liveness backward.
+  // mark outgoing arguments to a public function LIVE. We also propagate
+  // liveness backward.
   DenseSet<Value> liveVals;
 
   // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index fc857f2989f18..ebb76a8835ceb 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -570,21 +570,21 @@ module @return_void_with_unused_argument {
     return %unused : memref<4xi32>
   }
 
-  // the function is immutable because it is public.
-  func.func public @immutable_fn_return_void_with_unused_argument(%arg0: i32, %unused: i32) -> () {
-    %sum = arith.addi %arg0, %arg0 : i32
-    %c0 = arith.constant 0 : index
-    %buf = memref.alloc() : memref<1xi32>
-    memref.store %sum, %buf[%c0] : memref<1xi32>
+  // the function signature is immutable because it is public.
+  func.func public @immutable_fn_with_unused_argument(%arg0: i32, %arg1: memref<4xf32>) -> () {
     return
   }
+
   // CHECK-LABEL: func.func @main2
-  // CHECK-SAME: (%[[ARG0_MAIN:.*]]: i32)
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
   // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
-  // CHECK: call @immutable_fn_return_void_with_unused_argument(%[[ARG0_MAIN]], %[[UNUSED]]) : (i32, i32) -> ()
-  func.func @main2(%arg0: i32) -> () {
-    %zero = arith.constant 0 : i32
-    call @immutable_fn_return_void_with_unused_argument(%arg0, %zero) : (i32, i32) -> ()
+  // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
+  func.func @main2() -> () {
+    %one = arith.constant 1 : i32
+    %scalar = arith.addi %one, %one: i32
+    %mem = memref.alloc() : memref<4xf32>
+
+    call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
     return
   }
 }

>From 4904fb9ab8162477b13e7b0827696d11a65c9389 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Tue, 7 Oct 2025 20:28:14 -0700
Subject: [PATCH 05/11] Add LIT test contains SCF

This diff also tries to propagate liveness recursively.
---
 mlir/include/mlir/IR/Visitors.h              |  2 +-
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 95 ++++++++++++++++----
 mlir/test/Transforms/remove-dead-values.mlir | 21 ++++-
 3 files changed, 99 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 5766d262796d6..632db0a7b5cd4 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -45,7 +45,7 @@ struct BackwardIterator {
   static auto makeIterable(T &range) {
     if constexpr (std::is_same<T, Operation>()) {
       /// Make operations iterable: return the list of regions.
-      return llvm::reverse(range.getRegions());
+      return range.getRegions();
     } else {
       /// Regions and block are already iterable.
       return llvm::reverse(range);
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index ed5c6a8d2ead0..ba2c56e5d9661 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -269,16 +269,6 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
-  for (Value val : op->getResults()) {
-    if (liveSet.contains(val)) {
-      LDBG() << "Simple op is used by a public function, "
-                "preserving it: "
-             << OpWithFlags(op, OpPrintingFlags().skipRegions());
-      liveSet.insert_range(op->getOperands());
-      return;
-    }
-  }
-
   if (!isMemoryEffectFree(op) ||
       hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
     LDBG() << "Simple op is not memory effect free or has live results, "
@@ -412,6 +402,82 @@ static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
   return {};
 }
 
+// When you mark a call operand as live, also mark its definition chain, recursively.
+// We handle RegionBranchOpInterface here. I think we should handle BranchOpInterface as well.
+void propagateBackward(Value val, DenseSet<Value> &liveSet) {
+  if (liveSet.contains(val)) return;
+  liveSet.insert(val);
+
+  if (auto defOp = val.getDefiningOp()) {
+    // Mark operands of live results as live
+    for (Value operand : defOp->getOperands()) {
+      propagateBackward(operand, liveSet);
+    }
+
+    // Handle RegionBranchOpInterface specially
+    if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(defOp)) {
+      // If this is a result of a RegionBranchOpInterface, we need to trace back
+      // through the control flow to find the sources that contribute to this result
+
+      OpResult result = cast<OpResult>(val);
+      unsigned resultIndex = result.getResultNumber();
+
+      // Find all possible sources that can contribute to this result
+      // by examining all regions and their terminators
+      for (Region &region : regionBranchOp->getRegions()) {
+        if (region.empty()) continue;
+
+        // Get the successors from this region
+        SmallVector<RegionSuccessor> successors;
+        regionBranchOp.getSuccessorRegions(RegionBranchPoint(&region), successors);
+
+        // Check if any successor can produce this result
+        for (const RegionSuccessor &successor : successors) {
+          if (successor.isParent()) {
+            // This region can return to the parent operation
+            ValueRange successorInputs = successor.getSuccessorInputs();
+            if (resultIndex < successorInputs.size()) {
+              // Find the terminator that contributes to this result
+              Operation *terminator = region.back().getTerminator();
+              if (auto regionBranchTerm =
+                  dyn_cast<RegionBranchTerminatorOpInterface>(terminator)) {
+                OperandRange terminatorOperands =
+                    regionBranchTerm.getSuccessorOperands(RegionBranchPoint::parent());
+                if (resultIndex < terminatorOperands.size()) {
+                  // This terminator operand contributes to our result
+                  propagateBackward(terminatorOperands[resultIndex], liveSet);
+                }
+              }
+            }
+          }
+        }
+
+        // Also mark region arguments as live if they might contribute to this result
+        // Find which operand of the parent operation corresponds to region arguments
+        Block &entryBlock = region.front();
+        for (BlockArgument arg : entryBlock.getArguments()) {
+          // Get entry successor operands - these are the operands that flow
+          // from the parent operation to this region
+          SmallVector<RegionSuccessor> entrySuccessors;
+          regionBranchOp.getSuccessorRegions(RegionBranchPoint::parent(), entrySuccessors);
+
+          for (const RegionSuccessor &entrySuccessor : entrySuccessors) {
+            if (entrySuccessor.getSuccessor() == &region) {
+              // Get the operands that are forwarded to this region
+              OperandRange entryOperands =
+                  regionBranchOp.getEntrySuccessorOperands(RegionBranchPoint::parent());
+              unsigned argIndex = arg.getArgNumber();
+              if (argIndex < entryOperands.size()) {
+                propagateBackward(entryOperands[argIndex], liveSet);
+              }
+              break;
+            }
+          }
+        }
+      }
+    }
+  }
+}
 static void processCallOp(CallOpInterface callOp, Operation *module,
                           RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
                           DenseSet<Value> &liveSet) {
@@ -439,13 +505,8 @@ static void processCallOp(CallOpInterface callOp, Operation *module,
 
     for (int index : nonLiveArgs.set_bits()) {
       OpOperand *operand = callOpOperands[index];
-      Value oldVal = operand->get();
-      if (Value dummy = createDummyArgument(callOp, oldVal)) {
-        callOp->setOperand(operand->getOperandNumber(), dummy);
-        nonLiveSet.insert(oldVal);
-      } else {
-        liveSet.insert(oldVal);
-      }
+      LDBG() << "mark operand " << index << " live " << operand->get();
+      propagateBackward(operand->get(), liveSet);
     }
   }
 }
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index ebb76a8835ceb..a60efa45fe943 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -576,8 +576,9 @@ module @return_void_with_unused_argument {
   }
 
   // CHECK-LABEL: func.func @main2
+  // CHECK: %[[ONE:.*]] = arith.constant 1 : i32
+  // CHECK: %[[UNUSED:.*]] = arith.addi %[[ONE]], %[[ONE]] : i32
   // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
-  // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
   // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
   func.func @main2() -> () {
     %one = arith.constant 1 : i32
@@ -587,6 +588,24 @@ module @return_void_with_unused_argument {
     call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
     return
   }
+
+  // CHECK-LABEL: func.func @main3
+  // CHECK: %[[UNUSED:.*]] = scf.if %arg0 -> (i32)
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
+  // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
+  func.func @main3(%arg0: i1) {
+    %0 = scf.if %arg0 -> (i32) {
+      %c1_i32 = arith.constant 1 : i32
+      scf.yield %c1_i32 : i32
+    } else {
+      %c0_i32 = arith.constant 0 : i32
+      scf.yield %c0_i32 : i32
+    }
+    %mem = memref.alloc() : memref<4xf32>
+
+    call @immutable_fn_with_unused_argument(%0, %mem) : (i32, memref<4xf32>) -> ()
+    return
+  }
 }
 
 // -----

>From 8f92fee035fd1fdbda438481975770ecff2a1651 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Thu, 9 Oct 2025 20:28:52 -0700
Subject: [PATCH 06/11] [MLIR][RemoveDeadValues] Privatize public function
 beforehand.

---
 .../mlir/Analysis/DataFlow/LivenessAnalysis.h |   7 +-
 mlir/include/mlir/IR/Visitors.h               |  14 -
 .../Analysis/DataFlow/LivenessAnalysis.cpp    |   1 +
 mlir/lib/Transforms/RemoveDeadValues.cpp      | 330 +++++++++---------
 mlir/test/Transforms/remove-dead-values.mlir  |  51 +--
 5 files changed, 211 insertions(+), 192 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index be7e027b95f64..7bf3ada5847e6 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -24,6 +24,7 @@
 #define MLIR_ANALYSIS_DATAFLOW_LIVENESSANALYSIS_H
 
 #include <mlir/Analysis/DataFlow/SparseAnalysis.h>
+#include <mlir/Pass/AnalysisManager.h>
 #include <optional>
 
 namespace mlir::dataflow {
@@ -101,13 +102,17 @@ struct RunLivenessAnalysis {
   RunLivenessAnalysis(Operation *op);
 
   const Liveness *getLiveness(Value val);
-
+  // This only mark Liveness results are stale.
+  void invalidate() { valid = false; }
   /// Return the configuration of the solver used for this analysis.
   const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
+  /// The function is called by analysis_impl::isInvalidated.
+  bool isInvalidated(AnalysisManager::PreservedAnalyses&) const { return !valid; }
 
 private:
   /// Stores the result of the liveness analysis that was run.
   DataFlowSolver solver;
+  bool valid {true};
 };
 
 } // end namespace mlir::dataflow
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 632db0a7b5cd4..893f66ae33deb 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -39,20 +39,6 @@ struct ForwardIterator {
   }
 };
 
-/// This iterator enumerates the elements in "backward" order.
-struct BackwardIterator {
-  template <typename T>
-  static auto makeIterable(T &range) {
-    if constexpr (std::is_same<T, Operation>()) {
-      /// Make operations iterable: return the list of regions.
-      return range.getRegions();
-    } else {
-      /// Regions and block are already iterable.
-      return llvm::reverse(range);
-    }
-  }
-};
-
 /// A utility class to encode the current walk stage for "generic" walkers.
 /// When walking an operation, we can either choose a Pre/Post order walker
 /// which invokes the callback on an operation before/after all its attached
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index d705d8d4c7819..943c60bda9de6 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -325,5 +325,6 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
 }
 
 const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
+  assert(valid && "getLiveness called after invalidate");
   return solver.lookupState<Liveness>(val);
 }
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index ba2c56e5d9661..05d49d8beec39 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -33,7 +33,6 @@
 
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dialect.h"
@@ -47,6 +46,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/AnalysisManager.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/FoldUtils.h"
@@ -119,13 +119,8 @@ struct RDVFinalCleanupList {
 /// Return true iff at least one value in `values` is live, given the liveness
 /// information in `la`.
 static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
-                    const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
+                    RunLivenessAnalysis &la) {
   for (Value value : values) {
-    if (liveSet.contains(value)) {
-      LDBG() << "Value " << value << " is marked live by CallOp";
-      return true;
-    }
-
     if (nonLiveSet.contains(value)) {
       LDBG() << "Value " << value << " is already marked non-live (dead)";
       continue;
@@ -150,7 +145,6 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
 /// i-th value in `values` is live, given the liveness information in `la`.
 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
-                           const DenseSet<Value> &liveSet,
                            RunLivenessAnalysis &la) {
   BitVector lives(values.size(), true);
 
@@ -161,9 +155,7 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
              << " is already marked non-live (dead) at index " << index;
       continue;
     }
-    if (liveSet.contains(value)) {
-      continue;
-    }
+
     const Liveness *liveness = la.getLiveness(value);
     // It is important to note that when `liveness` is null, we can't tell if
     // `value` is live or not. So, the safe option is to consider it live. Also,
@@ -268,9 +260,8 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 ///   - Return-like
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
-                            DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
-  if (!isMemoryEffectFree(op) ||
-      hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
+                            RDVFinalCleanupList &cl) {
+  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
     LDBG() << "Simple op is not memory effect free or has live results, "
               "preserving it: "
            << OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -298,7 +289,7 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
 ///   (6) Marking all its results as non-live values.
 static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
                           RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
-                          DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+                          RDVFinalCleanupList &cl) {
   LDBG() << "Processing function op: "
          << OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
   if (funcOp.isPublic() || funcOp.isExternal()) {
@@ -309,7 +300,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 
   // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
   SmallVector<Value> arguments(funcOp.getArguments());
-  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
+  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
   nonLiveArgs = nonLiveArgs.flip();
 
   // Do (1).
@@ -362,8 +353,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
-    BitVector liveCallRets =
-        markLives(callOp->getResults(), nonLiveSet, liveSet, la);
+    BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
     nonLiveRets &= liveCallRets.flip();
   }
 
@@ -390,127 +380,6 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   }
 }
 
-// Create a cheaper value with the same type of oldVal in front of CallOp.
-static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
-  OpBuilder builder(callOp.getOperation());
-  Type type = oldVal.getType();
-
-  // Create zero constant for any supported type
-  if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
-    return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
-  }
-  return {};
-}
-
-// When you mark a call operand as live, also mark its definition chain, recursively.
-// We handle RegionBranchOpInterface here. I think we should handle BranchOpInterface as well.
-void propagateBackward(Value val, DenseSet<Value> &liveSet) {
-  if (liveSet.contains(val)) return;
-  liveSet.insert(val);
-
-  if (auto defOp = val.getDefiningOp()) {
-    // Mark operands of live results as live
-    for (Value operand : defOp->getOperands()) {
-      propagateBackward(operand, liveSet);
-    }
-
-    // Handle RegionBranchOpInterface specially
-    if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(defOp)) {
-      // If this is a result of a RegionBranchOpInterface, we need to trace back
-      // through the control flow to find the sources that contribute to this result
-
-      OpResult result = cast<OpResult>(val);
-      unsigned resultIndex = result.getResultNumber();
-
-      // Find all possible sources that can contribute to this result
-      // by examining all regions and their terminators
-      for (Region &region : regionBranchOp->getRegions()) {
-        if (region.empty()) continue;
-
-        // Get the successors from this region
-        SmallVector<RegionSuccessor> successors;
-        regionBranchOp.getSuccessorRegions(RegionBranchPoint(&region), successors);
-
-        // Check if any successor can produce this result
-        for (const RegionSuccessor &successor : successors) {
-          if (successor.isParent()) {
-            // This region can return to the parent operation
-            ValueRange successorInputs = successor.getSuccessorInputs();
-            if (resultIndex < successorInputs.size()) {
-              // Find the terminator that contributes to this result
-              Operation *terminator = region.back().getTerminator();
-              if (auto regionBranchTerm =
-                  dyn_cast<RegionBranchTerminatorOpInterface>(terminator)) {
-                OperandRange terminatorOperands =
-                    regionBranchTerm.getSuccessorOperands(RegionBranchPoint::parent());
-                if (resultIndex < terminatorOperands.size()) {
-                  // This terminator operand contributes to our result
-                  propagateBackward(terminatorOperands[resultIndex], liveSet);
-                }
-              }
-            }
-          }
-        }
-
-        // Also mark region arguments as live if they might contribute to this result
-        // Find which operand of the parent operation corresponds to region arguments
-        Block &entryBlock = region.front();
-        for (BlockArgument arg : entryBlock.getArguments()) {
-          // Get entry successor operands - these are the operands that flow
-          // from the parent operation to this region
-          SmallVector<RegionSuccessor> entrySuccessors;
-          regionBranchOp.getSuccessorRegions(RegionBranchPoint::parent(), entrySuccessors);
-
-          for (const RegionSuccessor &entrySuccessor : entrySuccessors) {
-            if (entrySuccessor.getSuccessor() == &region) {
-              // Get the operands that are forwarded to this region
-              OperandRange entryOperands =
-                  regionBranchOp.getEntrySuccessorOperands(RegionBranchPoint::parent());
-              unsigned argIndex = arg.getArgNumber();
-              if (argIndex < entryOperands.size()) {
-                propagateBackward(entryOperands[argIndex], liveSet);
-              }
-              break;
-            }
-          }
-        }
-      }
-    }
-  }
-}
-static void processCallOp(CallOpInterface callOp, Operation *module,
-                          RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
-                          DenseSet<Value> &liveSet) {
-  if (!la.getSolverConfig().isInterprocedural())
-    return;
-
-  Operation *callableOp = callOp.resolveCallable();
-  auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
-  if (!funcOp || !funcOp.isPublic()) {
-    return;
-  }
-
-  LDBG() << "processCallOp to a public function: " << funcOp.getName();
-  // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
-  SmallVector<Value> arguments(funcOp.getArguments());
-  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
-  nonLiveArgs = nonLiveArgs.flip();
-
-  if (nonLiveArgs.count() > 0) {
-    LDBG() << funcOp.getName() << " contains NonLive arguments";
-    // The number of operands in the call op may not match the number of
-    // arguments in the func op.
-    SmallVector<OpOperand *> callOpOperands =
-        operandsToOpOperands(callOp.getArgOperands());
-
-    for (int index : nonLiveArgs.set_bits()) {
-      OpOperand *operand = callOpOperands[index];
-      LDBG() << "mark operand " << index << " live " << operand->get();
-      propagateBackward(operand->get(), liveSet);
-    }
-  }
-}
-
 /// Process a region branch operation `regionBranchOp` using the liveness
 /// information in `la`. The processing involves two scenarios:
 ///
@@ -543,14 +412,12 @@ static void processCallOp(CallOpInterface callOp, Operation *module,
 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
                                   RunLivenessAnalysis &la,
                                   DenseSet<Value> &nonLiveSet,
-                                  DenseSet<Value> &liveSet,
                                   RDVFinalCleanupList &cl) {
   LDBG() << "Processing region branch op: "
          << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
   // Mark live results of `regionBranchOp` in `liveResults`.
   auto markLiveResults = [&](BitVector &liveResults) {
-    liveResults =
-        markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
+    liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
   };
 
   // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
@@ -559,7 +426,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
       if (region.empty())
         continue;
       SmallVector<Value> arguments(region.front().getArguments());
-      BitVector regionLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
+      BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
       liveArgs[&region] = regionLiveArgs;
     }
   };
@@ -753,7 +620,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // attributed to something else.
   // Do (1') and (2').
   if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
-      !hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) {
+      !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
     cl.operations.push_back(regionBranchOp.getOperation());
     return;
   }
@@ -832,7 +699,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
 
 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
-                            DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+                            RDVFinalCleanupList &cl) {
   LDBG() << "Processing branch op: " << *branchOp;
   unsigned numSuccessors = branchOp->getNumSuccessors();
 
@@ -850,7 +717,7 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
 
     // Do (2)
     BitVector successorNonLive =
-        markLives(operandValues, nonLiveSet, liveSet, la).flip();
+        markLives(operandValues, nonLiveSet, la).flip();
     collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
                          successorNonLive);
 
@@ -1003,36 +870,185 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
 };
 } // namespace
 
+/// If the target of CallOP is a public function and at least one argument is NonLive,
+/// privatize the function.
+/// Our strategy here is separation interface and implementation. eg.
+///
+/// public void foo(int unused){...}
+/// =>
+/// public void foo(int unused) {          // old function, interface
+///   return __foo_impl(unused);
+/// }
+///
+/// private void __foo_impl(int unused) { // the new private function, or implementation.
+///  ...                                  // the function body of the original function.
+/// }
+///
+/// Returns true if any IR changes were made, false otherwise.
+static bool processCallOp(CallOpInterface callOp, Operation *module,
+                          RunLivenessAnalysis &la) {
+  Operation *callableOp = callOp.resolveCallable();
+  auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
+  if (!funcOp || !funcOp.isPublic()) {
+    return false;
+  }
+
+  LDBG() << "Processing callOp " << callOp
+         << " target is a public function: " << funcOp.getOperation()->getName();
+
+  // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
+  SmallVector<Value> arguments(callOp.getArgOperands());
+  BitVector nonLiveArgs = markLives(arguments, DenseSet<Value>(), la);
+  nonLiveArgs = nonLiveArgs.flip();
+
+  if (nonLiveArgs.count() > 0) {
+    auto moduleOp = cast<ModuleOp>(module);
+    OpBuilder rewriter(moduleOp.getContext());
+
+    // Clone function and create private version
+    FunctionOpInterface clonedFunc = cast<FunctionOpInterface>(funcOp.clone());
+
+    // Set visibility = 'private' and a new name for the cloned function
+    SymbolTable::setSymbolVisibility(clonedFunc,
+                                  SymbolTable::Visibility::Private);
+    std::string newName = "__" + funcOp.getName().str() + "_privatized";
+    clonedFunc.setName(newName);
+
+    // Insert the cloned function into the module
+    rewriter.setInsertionPointAfter(funcOp);
+    rewriter.insert(clonedFunc);
+
+    // Replace ALL callsites of the original function to call the cloned function directly
+    LogicalResult result = SymbolTable::replaceAllSymbolUses(
+        funcOp,
+        clonedFunc.getNameAttr(),
+        moduleOp
+    );
+
+    if (result.failed()) {
+      LDBG() << "Failed to replace all symbol uses for " << funcOp.getName();
+      return false;
+    }
+
+    LDBG() << "Redirected all callsites from " << funcOp.getName()
+           << " to " << newName;
+
+    // Transform the original funcOp into a wrapper that calls the cloned function
+    Region &funcBody = funcOp.getFunctionBody();
+
+    // Clean the original function body
+    funcBody.dropAllReferences();
+    funcBody.getBlocks().clear();
+
+    // Create a new entry block for the wrapper function
+    Block *wrapperBlock = rewriter.createBlock(&funcBody);
+
+    // Add block arguments that match the function signature
+    for (Type argType : funcOp.getArgumentTypes()) {
+      wrapperBlock->addArgument(argType, funcOp.getLoc());
+    }
+
+    // Set insertion point to the new block
+    rewriter.setInsertionPointToStart(wrapperBlock);
+
+    // Clone the original call operation and update its callee
+    Operation *clonedCallOp = callOp->clone();
+
+    // Update the callee symbol reference to point to the new private function
+    if (auto callableOp = dyn_cast<CallOpInterface>(clonedCallOp)) {
+      auto symbolRef = SymbolRefAttr::get(funcOp.getContext(), clonedFunc.getName());
+      callableOp.setCalleeFromCallable(symbolRef);
+    }
+
+    // Set the call arguments to use the wrapper block's arguments
+    clonedCallOp->setOperands(wrapperBlock->getArguments());
+
+    // Insert the cloned call operation
+    rewriter.insert(clonedCallOp);
+
+    // Create return operation of the same type as the original function's return
+    SmallVector<Operation*> returnOps;
+    for (Block &block : clonedFunc.getFunctionBody()) {
+      if (block.getNumSuccessors() > 0)
+        continue;
+
+      Operation *terminator = block.getTerminator();
+      if (terminator && terminator->hasTrait<OpTrait::ReturnLike>()) {
+        returnOps.push_back(terminator);
+        break; // Use first return as template
+      }
+    }
+
+    if (!returnOps.empty()) {
+      Operation *templateReturnOp = returnOps[0];
+      Operation *newReturnOp = templateReturnOp->clone();
+      newReturnOp->setOperands(clonedCallOp->getResults());
+      newReturnOp->setLoc(funcOp.getLoc());
+      rewriter.insert(newReturnOp);
+    }
+
+    LDBG() << "Created wrapper function " << funcOp.getName()
+           << " that calls " << newName;
+
+    return true; // Changes were made
+  }
+
+  return false; // No changes made
+}
+
 void RemoveDeadValues::runOnOperation() {
-  auto &la = getAnalysis<RunLivenessAnalysis>();
+  AnalysisManager am = getAnalysisManager();
+  RunLivenessAnalysis *la = &am.getAnalysis<RunLivenessAnalysis>();
   Operation *module = getOperation();
 
+  // Only privatize public funciton if liveness analysis is inter-procedural.
+  if (la->getSolverConfig().isInterprocedural()) {
+    bool changed = false;
+    module->walk([&](CallOpInterface callOp) {
+      if (processCallOp(callOp, module, *la)) {
+        changed = true;
+      }
+    });
+
+    if (changed) {
+      LDBG() << "IR has changed, invalidate RunLivenessAnalysis only";
+      auto & pa = getPassState().preservedAnalyses;
+      bool preserved = pa.isPreserved<RunLivenessAnalysis>();
+      la->invalidate();
+      am.invalidate(pa);
+
+      la = &am.getAnalysis<RunLivenessAnalysis>();
+      // if RunLivenessAnalysis was previously preserved, preserved the updated results.
+      if (preserved) {
+        pa.preserve<RunLivenessAnalysis>();
+      }
+    }
+  }
+
+
   // Tracks values eligible for erasure - complements liveness analysis to
   // identify "droppable" values.
   DenseSet<Value> deadVals;
-  // mark outgoing arguments to a public function LIVE. We also propagate
-  // liveness backward.
-  DenseSet<Value> liveVals;
 
   // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
   // end of this pass.
   RDVFinalCleanupList finalCleanupList;
 
-  module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
+  module->walk([&](Operation *op) {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
-      processFuncOp(funcOp, module, la, deadVals, liveVals, finalCleanupList);
+      processFuncOp(funcOp, module, *la, deadVals, finalCleanupList);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
-      processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
-                            finalCleanupList);
+      processRegionBranchOp(regionBranchOp, *la, deadVals, finalCleanupList);
     } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
-      processBranchOp(branchOp, la, deadVals, liveVals, finalCleanupList);
+      processBranchOp(branchOp, *la, deadVals, finalCleanupList);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
       // Nothing to do here because this is a terminator op and it should be
       // honored with respect to its parent
     } else if (isa<CallOpInterface>(op)) {
-      processCallOp(cast<CallOpInterface>(op), module, la, deadVals, liveVals);
+      // Nothing to do because this op is associated with a function op and gets
+      // cleaned when the latter is cleaned.
     } else {
-      processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
+      processSimpleOp(op, *la, deadVals, finalCleanupList);
     }
   });
 
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index a60efa45fe943..bcac4638604fa 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -569,31 +569,25 @@ module @return_void_with_unused_argument {
     call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
     return %unused : memref<4xi32>
   }
+}
 
+// check that public functions with non-live arguments correctly.
+module @public_function_with_nonlive_arguments {
   // the function signature is immutable because it is public.
-  func.func public @immutable_fn_with_unused_argument(%arg0: i32, %arg1: memref<4xf32>) -> () {
+  func.func public @public_fn_with_unused_argument(%unused: i32) -> () {
     return
   }
-
-  // CHECK-LABEL: func.func @main2
-  // CHECK: %[[ONE:.*]] = arith.constant 1 : i32
-  // CHECK: %[[UNUSED:.*]] = arith.addi %[[ONE]], %[[ONE]] : i32
-  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
-  // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
-  func.func @main2() -> () {
-    %one = arith.constant 1 : i32
-    %scalar = arith.addi %one, %one: i32
-    %mem = memref.alloc() : memref<4xf32>
-
-    call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
+  // CHECK-LABEL: func.func @main
+  // CHECK: call @__public_fn_with_unused_argument_privatized() : () -> ()
+  func.func @main() -> () {
+    %zero = arith.constant 0 : i32
+    call @public_fn_with_unused_argument(%zero) : (i32) -> ()
     return
   }
 
-  // CHECK-LABEL: func.func @main3
-  // CHECK: %[[UNUSED:.*]] = scf.if %arg0 -> (i32)
-  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
-  // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
-  func.func @main3(%arg0: i1) {
+  // CHECK-LABEL: func.func @main2
+  // CHECK: call @__public_fn_with_unused_argument_privatized() : () -> ()
+  func.func @main2(%arg0: i1) {
     %0 = scf.if %arg0 -> (i32) {
       %c1_i32 = arith.constant 1 : i32
       scf.yield %c1_i32 : i32
@@ -601,9 +595,26 @@ module @return_void_with_unused_argument {
       %c0_i32 = arith.constant 0 : i32
       scf.yield %c0_i32 : i32
     }
-    %mem = memref.alloc() : memref<4xf32>
 
-    call @immutable_fn_with_unused_argument(%0, %mem) : (i32, memref<4xf32>) -> ()
+    call @public_fn_with_unused_argument(%0) : (i32) -> ()
+    return
+  }
+
+  func.func public @fn_return_multiple(%arg0: i32) -> (i32, i32, i32) {
+    %one = arith.constant 1 : i32
+    %two = arith.constant 2 : i32
+    %three = arith.constant 4 : i32
+
+    return %one, %two, %three: i32, i32, i32
+  }
+
+  // CHECK-LABEL: func.func @main3
+  // CHECK: call @__fn_return_multiple_privatized() : () -> (i32, i32, i32)
+  func.func @main3(%arg: i32) -> () {
+    %one = arith.constant 1 : i32
+    %scalar = arith.addi %arg, %one: i32
+
+    call @fn_return_multiple(%scalar) : (i32) -> (i32, i32, i32)
     return
   }
 }

>From 443dff206e45b8aecb1f78e86ab9a80d1a815a2d Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Fri, 10 Oct 2025 00:14:25 -0700
Subject: [PATCH 07/11] Update format.

---
 .../mlir/Analysis/DataFlow/LivenessAnalysis.h |  8 ++-
 mlir/lib/Transforms/RemoveDeadValues.cpp      | 70 ++++++++-----------
 2 files changed, 36 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index 7bf3ada5847e6..20162e6586600 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -102,17 +102,19 @@ struct RunLivenessAnalysis {
   RunLivenessAnalysis(Operation *op);
 
   const Liveness *getLiveness(Value val);
-  // This only mark Liveness results are stale.
+  // This only remarks that Liveness results are stale.
   void invalidate() { valid = false; }
   /// Return the configuration of the solver used for this analysis.
   const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
   /// The function is called by analysis_impl::isInvalidated.
-  bool isInvalidated(AnalysisManager::PreservedAnalyses&) const { return !valid; }
+  bool isInvalidated(AnalysisManager::PreservedAnalyses &) const {
+    return !valid;
+  }
 
 private:
   /// Stores the result of the liveness analysis that was run.
   DataFlowSolver solver;
-  bool valid {true};
+  bool valid{true};
 };
 
 } // end namespace mlir::dataflow
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 05d49d8beec39..0b0d2c1485c1d 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -870,18 +870,20 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
 };
 } // namespace
 
-/// If the target of CallOP is a public function and at least one argument is NonLive,
-/// privatize the function.
-/// Our strategy here is separation interface and implementation. eg.
+/// If the target of CallOp is a public function and at least one argument is
+/// NonLive, privatize the function. Our strategy here is separation interface
+/// and implementation. eg.
 ///
 /// public void foo(int unused){...}
 /// =>
 /// public void foo(int unused) {          // old function, interface
-///   return __foo_impl(unused);
+///   return __foo_privatized(unused);
 /// }
 ///
-/// private void __foo_impl(int unused) { // the new private function, or implementation.
-///  ...                                  // the function body of the original function.
+/// private void __foo_privatized(int unused) { // the new private function, or
+/// implementation.
+///  ...                                        // the function body of the
+///  original function.
 /// }
 ///
 /// Returns true if any IR changes were made, false otherwise.
@@ -893,8 +895,8 @@ static bool processCallOp(CallOpInterface callOp, Operation *module,
     return false;
   }
 
-  LDBG() << "Processing callOp " << callOp
-         << " target is a public function: " << funcOp.getOperation()->getName();
+  LDBG() << "Processing callOp " << callOp << " target is a public function: "
+         << funcOp.getOperation()->getName();
 
   // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
   SmallVector<Value> arguments(callOp.getArgOperands());
@@ -910,7 +912,7 @@ static bool processCallOp(CallOpInterface callOp, Operation *module,
 
     // Set visibility = 'private' and a new name for the cloned function
     SymbolTable::setSymbolVisibility(clonedFunc,
-                                  SymbolTable::Visibility::Private);
+                                     SymbolTable::Visibility::Private);
     std::string newName = "__" + funcOp.getName().str() + "_privatized";
     clonedFunc.setName(newName);
 
@@ -918,22 +920,21 @@ static bool processCallOp(CallOpInterface callOp, Operation *module,
     rewriter.setInsertionPointAfter(funcOp);
     rewriter.insert(clonedFunc);
 
-    // Replace ALL callsites of the original function to call the cloned function directly
+    // Replace ALL callsites of the original function to call the cloned
+    // function directly
     LogicalResult result = SymbolTable::replaceAllSymbolUses(
-        funcOp,
-        clonedFunc.getNameAttr(),
-        moduleOp
-    );
+        funcOp, clonedFunc.getNameAttr(), moduleOp);
 
     if (result.failed()) {
       LDBG() << "Failed to replace all symbol uses for " << funcOp.getName();
       return false;
     }
 
-    LDBG() << "Redirected all callsites from " << funcOp.getName()
-           << " to " << newName;
+    LDBG() << "Redirected all callsites from " << funcOp.getName() << " to "
+           << newName;
 
-    // Transform the original funcOp into a wrapper that calls the cloned function
+    // Transform the original funcOp into a wrapper that calls the cloned
+    // function
     Region &funcBody = funcOp.getFunctionBody();
 
     // Clean the original function body
@@ -952,44 +953,35 @@ static bool processCallOp(CallOpInterface callOp, Operation *module,
     rewriter.setInsertionPointToStart(wrapperBlock);
 
     // Clone the original call operation and update its callee
-    Operation *clonedCallOp = callOp->clone();
-
+    auto clonedCallOp = cast<CallOpInterface>(callOp->clone());
     // Update the callee symbol reference to point to the new private function
-    if (auto callableOp = dyn_cast<CallOpInterface>(clonedCallOp)) {
-      auto symbolRef = SymbolRefAttr::get(funcOp.getContext(), clonedFunc.getName());
-      callableOp.setCalleeFromCallable(symbolRef);
-    }
-
+    auto symbolRef =
+        SymbolRefAttr::get(funcOp.getContext(), clonedFunc.getName());
+    clonedCallOp.setCalleeFromCallable(symbolRef);
     // Set the call arguments to use the wrapper block's arguments
     clonedCallOp->setOperands(wrapperBlock->getArguments());
-
-    // Insert the cloned call operation
     rewriter.insert(clonedCallOp);
 
-    // Create return operation of the same type as the original function's return
-    SmallVector<Operation*> returnOps;
+    // Create return operation of the same type as the original function's
+    // return
+    Operation *returnOp = nullptr;
     for (Block &block : clonedFunc.getFunctionBody()) {
       if (block.getNumSuccessors() > 0)
         continue;
 
       Operation *terminator = block.getTerminator();
       if (terminator && terminator->hasTrait<OpTrait::ReturnLike>()) {
-        returnOps.push_back(terminator);
+        returnOp = terminator;
         break; // Use first return as template
       }
     }
 
-    if (!returnOps.empty()) {
-      Operation *templateReturnOp = returnOps[0];
-      Operation *newReturnOp = templateReturnOp->clone();
+    if (returnOp) {
+      Operation *newReturnOp = returnOp->clone();
       newReturnOp->setOperands(clonedCallOp->getResults());
       newReturnOp->setLoc(funcOp.getLoc());
       rewriter.insert(newReturnOp);
     }
-
-    LDBG() << "Created wrapper function " << funcOp.getName()
-           << " that calls " << newName;
-
     return true; // Changes were made
   }
 
@@ -1012,20 +1004,20 @@ void RemoveDeadValues::runOnOperation() {
 
     if (changed) {
       LDBG() << "IR has changed, invalidate RunLivenessAnalysis only";
-      auto & pa = getPassState().preservedAnalyses;
+      auto &pa = getPassState().preservedAnalyses;
       bool preserved = pa.isPreserved<RunLivenessAnalysis>();
       la->invalidate();
       am.invalidate(pa);
 
       la = &am.getAnalysis<RunLivenessAnalysis>();
-      // if RunLivenessAnalysis was previously preserved, preserved the updated results.
+      // If RunLivenessAnalysis was previously preserved, preserved the updated
+      // results.
       if (preserved) {
         pa.preserve<RunLivenessAnalysis>();
       }
     }
   }
 
-
   // Tracks values eligible for erasure - complements liveness analysis to
   // identify "droppable" values.
   DenseSet<Value> deadVals;

>From 12322aa286e7151bcc9f4cf3187cb617ce35e6a9 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Fri, 10 Oct 2025 10:12:04 -0700
Subject: [PATCH 08/11] Update for formatter.  No trivial braces.

---
 mlir/lib/Transforms/RemoveDeadValues.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 0b0d2c1485c1d..af5795019d9d8 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -891,9 +891,8 @@ static bool processCallOp(CallOpInterface callOp, Operation *module,
                           RunLivenessAnalysis &la) {
   Operation *callableOp = callOp.resolveCallable();
   auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
-  if (!funcOp || !funcOp.isPublic()) {
+  if (!funcOp || !funcOp.isPublic())
     return false;
-  }
 
   LDBG() << "Processing callOp " << callOp << " target is a public function: "
          << funcOp.getOperation()->getName();
@@ -985,7 +984,7 @@ static bool processCallOp(CallOpInterface callOp, Operation *module,
     return true; // Changes were made
   }
 
-  return false; // No changes made
+  return false;
 }
 
 void RemoveDeadValues::runOnOperation() {

>From f6a9b624d34995fcd71c36a56701254f3d84cf1f Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Fri, 10 Oct 2025 10:28:21 -0700
Subject: [PATCH 09/11] fix a typo

---
 mlir/lib/Transforms/RemoveDeadValues.cpp | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index af5795019d9d8..0f34ebf263e27 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -887,7 +887,7 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
 /// }
 ///
 /// Returns true if any IR changes were made, false otherwise.
-static bool processCallOp(CallOpInterface callOp, Operation *module,
+static bool processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
                           RunLivenessAnalysis &la) {
   Operation *callableOp = callOp.resolveCallable();
   auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
@@ -903,7 +903,6 @@ static bool processCallOp(CallOpInterface callOp, Operation *module,
   nonLiveArgs = nonLiveArgs.flip();
 
   if (nonLiveArgs.count() > 0) {
-    auto moduleOp = cast<ModuleOp>(module);
     OpBuilder rewriter(moduleOp.getContext());
 
     // Clone function and create private version
@@ -992,11 +991,11 @@ void RemoveDeadValues::runOnOperation() {
   RunLivenessAnalysis *la = &am.getAnalysis<RunLivenessAnalysis>();
   Operation *module = getOperation();
 
-  // Only privatize public funciton if liveness analysis is inter-procedural.
+  // Only privatize public functions if liveness analysis is inter-procedural.
   if (la->getSolverConfig().isInterprocedural()) {
     bool changed = false;
     module->walk([&](CallOpInterface callOp) {
-      if (processCallOp(callOp, module, *la)) {
+      if (processCallOp(callOp, cast<ModuleOp>(module), *la)) {
         changed = true;
       }
     });
@@ -1007,7 +1006,6 @@ void RemoveDeadValues::runOnOperation() {
       bool preserved = pa.isPreserved<RunLivenessAnalysis>();
       la->invalidate();
       am.invalidate(pa);
-
       la = &am.getAnalysis<RunLivenessAnalysis>();
       // If RunLivenessAnalysis was previously preserved, preserved the updated
       // results.

>From 44ccd989b470cad11c2628c8f9c1a1dccd0e7cc4 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Tue, 28 Oct 2025 08:38:53 -0700
Subject: [PATCH 10/11] Update per reviewer's feedbacks.

also return LogicalResult and exits early when it encounters an error.
---
 mlir/lib/Transforms/RemoveDeadValues.cpp | 39 ++++++++++++++----------
 1 file changed, 23 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 0f34ebf263e27..86e02709b20f8 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -54,6 +54,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/DebugLog.h"
+#include "llvm/Support/LogicalResult.h"
 #include <cassert>
 #include <cstddef>
 #include <memory>
@@ -887,12 +888,12 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
 /// }
 ///
 /// Returns true if any IR changes were made, false otherwise.
-static bool processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
-                          RunLivenessAnalysis &la) {
+static LogicalResult processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
+                                   RunLivenessAnalysis &la, bool &changed) {
   Operation *callableOp = callOp.resolveCallable();
   auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
   if (!funcOp || !funcOp.isPublic())
-    return false;
+    return LogicalResult::success();
 
   LDBG() << "Processing callOp " << callOp << " target is a public function: "
          << funcOp.getOperation()->getName();
@@ -924,8 +925,10 @@ static bool processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
         funcOp, clonedFunc.getNameAttr(), moduleOp);
 
     if (result.failed()) {
-      LDBG() << "Failed to replace all symbol uses for " << funcOp.getName();
-      return false;
+      callOp.emitError(
+          "Failed to replace all symbol uses when privatizing function ")
+          << funcOp.getName();
+      return result;
     }
 
     LDBG() << "Redirected all callsites from " << funcOp.getName() << " to "
@@ -980,10 +983,10 @@ static bool processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
       newReturnOp->setLoc(funcOp.getLoc());
       rewriter.insert(newReturnOp);
     }
-    return true; // Changes were made
+    changed = true; // Changes were made
   }
 
-  return false;
+  return LogicalResult::success();
 }
 
 void RemoveDeadValues::runOnOperation() {
@@ -991,14 +994,19 @@ void RemoveDeadValues::runOnOperation() {
   RunLivenessAnalysis *la = &am.getAnalysis<RunLivenessAnalysis>();
   Operation *module = getOperation();
 
-  // Only privatize public functions if liveness analysis is inter-procedural.
-  if (la->getSolverConfig().isInterprocedural()) {
+  // In a module, only privatize public functions if liveness analysis is
+  // inter-procedural.
+  if (la->getSolverConfig().isInterprocedural() && isa<ModuleOp>(module)) {
     bool changed = false;
-    module->walk([&](CallOpInterface callOp) {
-      if (processCallOp(callOp, cast<ModuleOp>(module), *la)) {
-        changed = true;
-      }
-    });
+    WalkResult walkResult =
+        module->walk([&](CallOpInterface callOp) -> WalkResult {
+          return processCallOp(callOp, cast<ModuleOp>(module), *la, changed);
+        });
+
+    if (walkResult.wasInterrupted()) {
+      signalPassFailure();
+      return;
+    }
 
     if (changed) {
       LDBG() << "IR has changed, invalidate RunLivenessAnalysis only";
@@ -1009,9 +1017,8 @@ void RemoveDeadValues::runOnOperation() {
       la = &am.getAnalysis<RunLivenessAnalysis>();
       // If RunLivenessAnalysis was previously preserved, preserved the updated
       // results.
-      if (preserved) {
+      if (preserved)
         pa.preserve<RunLivenessAnalysis>();
-      }
     }
   }
 

>From 81e387c812ccdb6c91ae61bc52ae0ce3f705bd6e Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Wed, 29 Oct 2025 23:06:24 -0700
Subject: [PATCH 11/11] Use SymbolTable and also don't clone function body.

one optimization I made is that I don't clone function body. The function body
migrates to the new private function. After that, I create wrapper body for
the original function.
---
 mlir/lib/Transforms/RemoveDeadValues.cpp | 42 +++++++++++++-----------
 1 file changed, 22 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 86e02709b20f8..142e464702b4b 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -887,11 +887,16 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
 ///  original function.
 /// }
 ///
-/// Returns true if any IR changes were made, false otherwise.
+/// changed = true if any IR changes were made.
+///
+/// Cloning has to be Interface-based because downstream projects may use their
+/// own func/call/return ops.
 static LogicalResult processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
-                                   RunLivenessAnalysis &la, bool &changed) {
-  Operation *callableOp = callOp.resolveCallable();
-  auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
+                                   RunLivenessAnalysis &la,
+                                   SymbolTableCollection *symbolTable,
+                                   bool &changed) {
+  Operation *callableOp = callOp.resolveCallableInTable(symbolTable);
+  auto funcOp = dyn_cast_or_null<FunctionOpInterface>(callableOp);
   if (!funcOp || !funcOp.isPublic())
     return LogicalResult::success();
 
@@ -907,7 +912,8 @@ static LogicalResult processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
     OpBuilder rewriter(moduleOp.getContext());
 
     // Clone function and create private version
-    FunctionOpInterface clonedFunc = cast<FunctionOpInterface>(funcOp.clone());
+    FunctionOpInterface clonedFunc =
+        cast<FunctionOpInterface>(funcOp->cloneWithoutRegions());
 
     // Set visibility = 'private' and a new name for the cloned function
     SymbolTable::setSymbolVisibility(clonedFunc,
@@ -930,20 +936,15 @@ static LogicalResult processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
           << funcOp.getName();
       return result;
     }
-
     LDBG() << "Redirected all callsites from " << funcOp.getName() << " to "
            << newName;
 
-    // Transform the original funcOp into a wrapper that calls the cloned
-    // function
-    Region &funcBody = funcOp.getFunctionBody();
+    Region &clonedFuncBody = clonedFunc.getFunctionBody();
+    // Move the body from funcOp to clonedFunc
+    clonedFuncBody.takeBody(funcOp.getFunctionBody());
 
-    // Clean the original function body
-    funcBody.dropAllReferences();
-    funcBody.getBlocks().clear();
-
-    // Create a new entry block for the wrapper function
-    Block *wrapperBlock = rewriter.createBlock(&funcBody);
+    // Create a new entry block for the wrapper function in funcOp
+    Block *wrapperBlock = rewriter.createBlock(&funcOp.getFunctionBody());
 
     // Add block arguments that match the function signature
     for (Type argType : funcOp.getArgumentTypes()) {
@@ -964,9 +965,9 @@ static LogicalResult processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
     rewriter.insert(clonedCallOp);
 
     // Create return operation of the same type as the original function's
-    // return
+    // returnOp.
     Operation *returnOp = nullptr;
-    for (Block &block : clonedFunc.getFunctionBody()) {
+    for (Block &block : clonedFuncBody) {
       if (block.getNumSuccessors() > 0)
         continue;
 
@@ -980,7 +981,7 @@ static LogicalResult processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
     if (returnOp) {
       Operation *newReturnOp = returnOp->clone();
       newReturnOp->setOperands(clonedCallOp->getResults());
-      newReturnOp->setLoc(funcOp.getLoc());
+      newReturnOp->setLoc(returnOp->getLoc());
       rewriter.insert(newReturnOp);
     }
     changed = true; // Changes were made
@@ -998,11 +999,12 @@ void RemoveDeadValues::runOnOperation() {
   // inter-procedural.
   if (la->getSolverConfig().isInterprocedural() && isa<ModuleOp>(module)) {
     bool changed = false;
+    SymbolTableCollection symbolTable;
     WalkResult walkResult =
         module->walk([&](CallOpInterface callOp) -> WalkResult {
-          return processCallOp(callOp, cast<ModuleOp>(module), *la, changed);
+          return processCallOp(callOp, cast<ModuleOp>(module), *la,
+                               &symbolTable, changed);
         });
-
     if (walkResult.wasInterrupted()) {
       signalPassFailure();
       return;



More information about the Mlir-commits mailing list