[Mlir-commits] [mlir] [MLIR][RemoveDeadValues] Mark arguments of a public function Live (PR #160242)

xin liu llvmlistbot at llvm.org
Tue Sep 23 09:14:32 PDT 2025


https://github.com/navyxliu updated https://github.com/llvm/llvm-project/pull/160242

>From b83d8dc35fc19afabffb0773940f33c82a8e8505 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at xinliu-pc.lan>
Date: Sun, 21 Sep 2025 23:33:41 -0700
Subject: [PATCH] [MLIR][RemoveDeadValues] Mark arguments of a public function
 Live

This diff also changes traversal order from forward to backward for
region/block/ops. This order guanratees Liveness updates at a callsite
can propagates to the defs of arguments.

```
./bin/llvm-lit -v ../mlir/test/Transforms/remove-dead-values.mlir
```
---
 mlir/include/mlir/IR/Visitors.h              | 14 ++++++
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 52 +++++++++++++++++---
 mlir/test/Transforms/remove-dead-values.mlir | 18 +++++++
 3 files changed, 78 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 893f66ae33deb..5766d262796d6 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -39,6 +39,20 @@ struct ForwardIterator {
   }
 };
 
+/// This iterator enumerates the elements in "backward" order.
+struct BackwardIterator {
+  template <typename T>
+  static auto makeIterable(T &range) {
+    if constexpr (std::is_same<T, Operation>()) {
+      /// Make operations iterable: return the list of regions.
+      return llvm::reverse(range.getRegions());
+    } else {
+      /// Regions and block are already iterable.
+      return llvm::reverse(range);
+    }
+  }
+};
+
 /// A utility class to encode the current walk stage for "generic" walkers.
 /// When walking an operation, we can either choose a Pre/Post order walker
 /// which invokes the callback on an operation before/after all its attached
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 0e84b6dd17f29..4d1cd991af6e5 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -115,9 +115,15 @@ struct RDVFinalCleanupList {
 
 /// Return true iff at least one value in `values` is live, given the liveness
 /// information in `la`.
-static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
+static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet, const DenseSet<Value> &liveSet,
+
                     RunLivenessAnalysis &la) {
   for (Value value : values) {
+    if (liveSet.contains(value)) {
+      LDBG() << "Value " << value << " is marked live by CallOp";
+      return true;
+    }
+
     if (nonLiveSet.contains(value)) {
       LDBG() << "Value " << value << " is already marked non-live (dead)";
       continue;
@@ -257,8 +263,9 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 ///   - Return-like
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
+                            DenseSet<Value> &liveSet,
                             RDVFinalCleanupList &cl) {
-  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
+  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
     LDBG() << "Simple op is not memory effect free or has live results, "
               "preserving it: "
            << OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -376,6 +383,31 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   }
 }
 
+static void processCallOp(CallOpInterface callOp, Operation *module,
+                          RunLivenessAnalysis &la,
+                          DenseSet<Value> &liveSet) {
+  auto callable = callOp.getCallableForCallee();
+
+  if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) {
+    Operation *calleeOp = SymbolTable::lookupSymbolIn(module, symbolRef);
+
+    if (auto funcOp = llvm::dyn_cast_or_null<mlir::FunctionOpInterface>(calleeOp)) {
+      // Ensure the outgoing arguments of PUBLIC functions are live
+      // because processFuncOp can not process them.
+      //
+      // Liveness treats the external function as a blackbox.
+      if (funcOp.isPublic()) {
+        for (Value arg:  callOp.getArgOperands()) {
+          const Liveness *liveness = la.getLiveness(arg);
+          if (liveness && !liveness->isLive) {
+            liveSet.insert(arg);
+          }
+        }
+      }
+    }
+  }
+}
+
 /// Process a region branch operation `regionBranchOp` using the liveness
 /// information in `la`. The processing involves two scenarios:
 ///
@@ -408,6 +440,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
                                   RunLivenessAnalysis &la,
                                   DenseSet<Value> &nonLiveSet,
+                                  DenseSet<Value> &liveSet,
                                   RDVFinalCleanupList &cl) {
   LDBG() << "Processing region branch op: "
          << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
@@ -616,7 +649,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // attributed to something else.
   // Do (1') and (2').
   if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
-      !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
+      !hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) {
     cl.operations.push_back(regionBranchOp.getOperation());
     return;
   }
@@ -834,16 +867,18 @@ void RemoveDeadValues::runOnOperation() {
   // Tracks values eligible for erasure - complements liveness analysis to
   // identify "droppable" values.
   DenseSet<Value> deadVals;
+  // mark outgoing arguments to a public function LIVE.
+  DenseSet<Value> liveVals;
 
   // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
   // end of this pass.
   RDVFinalCleanupList finalCleanupList;
 
-  module->walk([&](Operation *op) {
+  module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
       processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
-      processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
+      processRegionBranchOp(regionBranchOp, la, deadVals, liveVals, finalCleanupList);
     } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
       processBranchOp(branchOp, la, deadVals, finalCleanupList);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
@@ -852,8 +887,13 @@ void RemoveDeadValues::runOnOperation() {
     } else if (isa<CallOpInterface>(op)) {
       // Nothing to do because this op is associated with a function op and gets
       // cleaned when the latter is cleaned.
+      //
+      // The only exception is public callee. By default, Liveness analysis is inter-procedural.
+      // Unused arguments of a public function nonLive and are propagated to the caller.
+      // processCallOp puts them to liveVals.
+      processCallOp(cast<CallOpInterface>(op), module, la, liveVals);
     } else {
-      processSimpleOp(op, la, deadVals, finalCleanupList);
+      processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
     }
   });
 
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index fa2c145bd3701..1580009c74d4d 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -569,6 +569,24 @@ module @return_void_with_unused_argument {
     call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
     return %unused : memref<4xi32>
   }
+
+  // the function is immutable because it is public.
+  func.func public @immutable_fn_return_void_with_unused_argument(%arg0: i32, %unused: i32) -> () {
+    %sum = arith.addi %arg0, %arg0 : i32
+    %c0 = arith.constant 0 : index
+    %buf = memref.alloc() : memref<1xi32>
+    memref.store %sum, %buf[%c0] : memref<1xi32>
+    return
+  }
+  // CHECK-LABEL: func.func @main2
+  // CHECK-SAME: (%[[ARG0_MAIN:.*]]: i32)
+  // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
+  // CHECK: call @immutable_fn_return_void_with_unused_argument(%[[ARG0_MAIN]], %[[UNUSED]]) : (i32, i32) -> ()
+  func.func @main2(%arg0: i32) -> () {
+    %zero = arith.constant 0 : i32
+    call @immutable_fn_return_void_with_unused_argument(%arg0, %zero) : (i32, i32) -> ()
+    return
+  }
 }
 
 // -----



More information about the Mlir-commits mailing list