[Mlir-commits] [mlir] [MLIR][RemoveDeadValues] Mark arguments of a public function Live (PR #162038)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Oct 5 21:30:00 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: xin liu (navyxliu)
<details>
<summary>Changes</summary>
This diff is REDO of [#<!-- -->160242](https://github.com/llvm/llvm-project/pull/160242).
# Problem
Liveness analysis is inter-procedural. If there are some unused arguments in a public function, they propagate to callers. From the perspective of RemoveDeadValues, the signature of a public function is immutable. It can't cope with this situation. One side, it deletes outgoing arguments, on the other side it keeps the function intact.
# Solution
We deploy two methods to fix this bug.
1) createDummyArgument replaces an arguments with dummies. eg. If type can create a value with zero initializer, such as 'T x{0}'.
2) as a fallback, we add another DenseSet called 'liveSet'. The initial values are the arguments of public functions. It propagates liveness backward just like Liveness analysis.
# Test plan
```
./bin/llvm-lit -v ../mlir/test/Transforms/remove-dead-values.mlir
before:
./input.mlir:13:5: error: null operand found
call @<!-- -->immutable_fn_return_void_with_unused_argument(%arg0, %zero) : (i32, i32) -> ()
^
./input.mlir:13:5: note: see current operation: "func.call"(%arg0, <<NULL VALUE>>) <{callee = @<!-- -->immutable_fn_return_void_with_unused_argument}> : (i32, <<NULL TYPE>>) -> ()
after: pass
```
---
Full diff: https://github.com/llvm/llvm-project/pull/162038.diff
4 Files Affected:
- (modified) mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h (+3)
- (modified) mlir/include/mlir/IR/Visitors.h (+14)
- (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+95-19)
- (modified) mlir/test/Transforms/remove-dead-values.mlir (+18)
``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index cf1fd6e2d48ca..be7e027b95f64 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -102,6 +102,9 @@ struct RunLivenessAnalysis {
const Liveness *getLiveness(Value val);
+ /// Return the configuration of the solver used for this analysis.
+ const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
+
private:
/// Stores the result of the liveness analysis that was run.
DataFlowSolver solver;
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 893f66ae33deb..5766d262796d6 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -39,6 +39,20 @@ struct ForwardIterator {
}
};
+/// This iterator enumerates the elements in "backward" order.
+struct BackwardIterator {
+ template <typename T>
+ static auto makeIterable(T &range) {
+ if constexpr (std::is_same<T, Operation>()) {
+ /// Make operations iterable: return the list of regions.
+ return llvm::reverse(range.getRegions());
+ } else {
+ /// Regions and block are already iterable.
+ return llvm::reverse(range);
+ }
+ }
+};
+
/// A utility class to encode the current walk stage for "generic" walkers.
/// When walking an operation, we can either choose a Pre/Post order walker
/// which invokes the callback on an operation before/after all its attached
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index e0c65b0e09774..ed5c6a8d2ead0 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -33,6 +33,7 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
@@ -118,8 +119,13 @@ struct RDVFinalCleanupList {
/// Return true iff at least one value in `values` is live, given the liveness
/// information in `la`.
static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
- RunLivenessAnalysis &la) {
+ const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
for (Value value : values) {
+ if (liveSet.contains(value)) {
+ LDBG() << "Value " << value << " is marked live by CallOp";
+ return true;
+ }
+
if (nonLiveSet.contains(value)) {
LDBG() << "Value " << value << " is already marked non-live (dead)";
continue;
@@ -144,6 +150,7 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
/// i-th value in `values` is live, given the liveness information in `la`.
static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
+ const DenseSet<Value> &liveSet,
RunLivenessAnalysis &la) {
BitVector lives(values.size(), true);
@@ -154,7 +161,9 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
<< " is already marked non-live (dead) at index " << index;
continue;
}
-
+ if (liveSet.contains(value)) {
+ continue;
+ }
const Liveness *liveness = la.getLiveness(value);
// It is important to note that when `liveness` is null, we can't tell if
// `value` is live or not. So, the safe option is to consider it live. Also,
@@ -259,8 +268,19 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
/// - Return-like
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
- RDVFinalCleanupList &cl) {
- if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
+ DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+ for (Value val : op->getResults()) {
+ if (liveSet.contains(val)) {
+ LDBG() << "Simple op is used by a public function, "
+ "preserving it: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ liveSet.insert_range(op->getOperands());
+ return;
+ }
+ }
+
+ if (!isMemoryEffectFree(op) ||
+ hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
LDBG() << "Simple op is not memory effect free or has live results, "
"preserving it: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -288,7 +308,7 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
/// (6) Marking all its results as non-live values.
static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
- RDVFinalCleanupList &cl) {
+ DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
LDBG() << "Processing function op: "
<< OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
if (funcOp.isPublic() || funcOp.isExternal()) {
@@ -299,7 +319,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(funcOp.getArguments());
- BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
+ BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
nonLiveArgs = nonLiveArgs.flip();
// Do (1).
@@ -352,7 +372,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
- BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
+ BitVector liveCallRets =
+ markLives(callOp->getResults(), nonLiveSet, liveSet, la);
nonLiveRets &= liveCallRets.flip();
}
@@ -379,6 +400,56 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
}
}
+// Create a cheaper value with the same type of oldVal in front of CallOp.
+static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
+ OpBuilder builder(callOp.getOperation());
+ Type type = oldVal.getType();
+
+ // Create zero constant for any supported type
+ if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
+ return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
+ }
+ return {};
+}
+
+static void processCallOp(CallOpInterface callOp, Operation *module,
+ RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+ DenseSet<Value> &liveSet) {
+ if (!la.getSolverConfig().isInterprocedural())
+ return;
+
+ Operation *callableOp = callOp.resolveCallable();
+ auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
+ if (!funcOp || !funcOp.isPublic()) {
+ return;
+ }
+
+ LDBG() << "processCallOp to a public function: " << funcOp.getName();
+ // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
+ SmallVector<Value> arguments(funcOp.getArguments());
+ BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
+ nonLiveArgs = nonLiveArgs.flip();
+
+ if (nonLiveArgs.count() > 0) {
+ LDBG() << funcOp.getName() << " contains NonLive arguments";
+ // The number of operands in the call op may not match the number of
+ // arguments in the func op.
+ SmallVector<OpOperand *> callOpOperands =
+ operandsToOpOperands(callOp.getArgOperands());
+
+ for (int index : nonLiveArgs.set_bits()) {
+ OpOperand *operand = callOpOperands[index];
+ Value oldVal = operand->get();
+ if (Value dummy = createDummyArgument(callOp, oldVal)) {
+ callOp->setOperand(operand->getOperandNumber(), dummy);
+ nonLiveSet.insert(oldVal);
+ } else {
+ liveSet.insert(oldVal);
+ }
+ }
+ }
+}
+
/// Process a region branch operation `regionBranchOp` using the liveness
/// information in `la`. The processing involves two scenarios:
///
@@ -411,12 +482,14 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
+ DenseSet<Value> &liveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing region branch op: "
<< OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
// Mark live results of `regionBranchOp` in `liveResults`.
auto markLiveResults = [&](BitVector &liveResults) {
- liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
+ liveResults =
+ markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
};
// Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
@@ -425,7 +498,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
SmallVector<Value> arguments(region.front().getArguments());
- BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
+ BitVector regionLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
liveArgs[®ion] = regionLiveArgs;
}
};
@@ -619,7 +692,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// attributed to something else.
// Do (1') and (2').
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
- !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
+ !hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) {
cl.operations.push_back(regionBranchOp.getOperation());
return;
}
@@ -698,7 +771,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
- RDVFinalCleanupList &cl) {
+ DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
LDBG() << "Processing branch op: " << *branchOp;
unsigned numSuccessors = branchOp->getNumSuccessors();
@@ -716,7 +789,7 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
// Do (2)
BitVector successorNonLive =
- markLives(operandValues, nonLiveSet, la).flip();
+ markLives(operandValues, nonLiveSet, liveSet, la).flip();
collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
successorNonLive);
@@ -876,26 +949,29 @@ void RemoveDeadValues::runOnOperation() {
// Tracks values eligible for erasure - complements liveness analysis to
// identify "droppable" values.
DenseSet<Value> deadVals;
+ // mark outgoing arguments to a public function LIVE. We also propagate
+ // liveness backward.
+ DenseSet<Value> liveVals;
// Maintains a list of Ops, values, branches, etc., slated for cleanup at the
// end of this pass.
RDVFinalCleanupList finalCleanupList;
- module->walk([&](Operation *op) {
+ module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
- processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
+ processFuncOp(funcOp, module, la, deadVals, liveVals, finalCleanupList);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
- processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
+ processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
+ finalCleanupList);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
- processBranchOp(branchOp, la, deadVals, finalCleanupList);
+ processBranchOp(branchOp, la, deadVals, liveVals, finalCleanupList);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
} else if (isa<CallOpInterface>(op)) {
- // Nothing to do because this op is associated with a function op and gets
- // cleaned when the latter is cleaned.
+ processCallOp(cast<CallOpInterface>(op), module, la, deadVals, liveVals);
} else {
- processSimpleOp(op, la, deadVals, finalCleanupList);
+ processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
}
});
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 56449469dc29f..ebb76a8835ceb 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -569,6 +569,24 @@ module @return_void_with_unused_argument {
call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
return %unused : memref<4xi32>
}
+
+ // the function signature is immutable because it is public.
+ func.func public @immutable_fn_with_unused_argument(%arg0: i32, %arg1: memref<4xf32>) -> () {
+ return
+ }
+
+ // CHECK-LABEL: func.func @main2
+ // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
+ // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
+ // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
+ func.func @main2() -> () {
+ %one = arith.constant 1 : i32
+ %scalar = arith.addi %one, %one: i32
+ %mem = memref.alloc() : memref<4xf32>
+
+ call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
+ return
+ }
}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/162038
More information about the Mlir-commits
mailing list