[Mlir-commits] [mlir] [MLIR][RemoveDeadValues] Mark arguments of a public function Live (PR #162038)
xin liu
llvmlistbot at llvm.org
Fri Oct 10 00:16:44 PDT 2025
https://github.com/navyxliu updated https://github.com/llvm/llvm-project/pull/162038
>From 52cf8cbbcc8bdd1ae3e9ee5e1dd5f62534f563bb Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at xinliu-pc.lan>
Date: Sun, 21 Sep 2025 23:33:41 -0700
Subject: [PATCH 1/7] [MLIR][RemoveDeadValues] Mark arguments of a public
function Live
This diff also changes traversal order from forward to backward for
region/block/ops. This order guanratees Liveness updates at a callsite
can propagates to the defs of arguments.
```
./bin/llvm-lit -v ../mlir/test/Transforms/remove-dead-values.mlir
```
---
mlir/include/mlir/IR/Visitors.h | 14 ++++++
mlir/lib/Transforms/RemoveDeadValues.cpp | 52 +++++++++++++++++---
mlir/test/Transforms/remove-dead-values.mlir | 18 +++++++
3 files changed, 78 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 893f66ae33deb..5766d262796d6 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -39,6 +39,20 @@ struct ForwardIterator {
}
};
+/// This iterator enumerates the elements in "backward" order.
+struct BackwardIterator {
+ template <typename T>
+ static auto makeIterable(T &range) {
+ if constexpr (std::is_same<T, Operation>()) {
+ /// Make operations iterable: return the list of regions.
+ return llvm::reverse(range.getRegions());
+ } else {
+ /// Regions and block are already iterable.
+ return llvm::reverse(range);
+ }
+ }
+};
+
/// A utility class to encode the current walk stage for "generic" walkers.
/// When walking an operation, we can either choose a Pre/Post order walker
/// which invokes the callback on an operation before/after all its attached
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index e0c65b0e09774..5dad922ff4b69 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -117,9 +117,15 @@ struct RDVFinalCleanupList {
/// Return true iff at least one value in `values` is live, given the liveness
/// information in `la`.
-static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
+static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet, const DenseSet<Value> &liveSet,
+
RunLivenessAnalysis &la) {
for (Value value : values) {
+ if (liveSet.contains(value)) {
+ LDBG() << "Value " << value << " is marked live by CallOp";
+ return true;
+ }
+
if (nonLiveSet.contains(value)) {
LDBG() << "Value " << value << " is already marked non-live (dead)";
continue;
@@ -259,8 +265,9 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
/// - Return-like
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
+ DenseSet<Value> &liveSet,
RDVFinalCleanupList &cl) {
- if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
+ if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
LDBG() << "Simple op is not memory effect free or has live results, "
"preserving it: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -379,6 +386,31 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
}
}
+static void processCallOp(CallOpInterface callOp, Operation *module,
+ RunLivenessAnalysis &la,
+ DenseSet<Value> &liveSet) {
+ auto callable = callOp.getCallableForCallee();
+
+ if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) {
+ Operation *calleeOp = SymbolTable::lookupSymbolIn(module, symbolRef);
+
+ if (auto funcOp = llvm::dyn_cast_or_null<mlir::FunctionOpInterface>(calleeOp)) {
+ // Ensure the outgoing arguments of PUBLIC functions are live
+ // because processFuncOp can not process them.
+ //
+ // Liveness treats the external function as a blackbox.
+ if (funcOp.isPublic()) {
+ for (Value arg: callOp.getArgOperands()) {
+ const Liveness *liveness = la.getLiveness(arg);
+ if (liveness && !liveness->isLive) {
+ liveSet.insert(arg);
+ }
+ }
+ }
+ }
+ }
+}
+
/// Process a region branch operation `regionBranchOp` using the liveness
/// information in `la`. The processing involves two scenarios:
///
@@ -411,6 +443,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
+ DenseSet<Value> &liveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing region branch op: "
<< OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
@@ -619,7 +652,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// attributed to something else.
// Do (1') and (2').
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
- !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
+ !hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) {
cl.operations.push_back(regionBranchOp.getOperation());
return;
}
@@ -876,16 +909,18 @@ void RemoveDeadValues::runOnOperation() {
// Tracks values eligible for erasure - complements liveness analysis to
// identify "droppable" values.
DenseSet<Value> deadVals;
+ // mark outgoing arguments to a public function LIVE.
+ DenseSet<Value> liveVals;
// Maintains a list of Ops, values, branches, etc., slated for cleanup at the
// end of this pass.
RDVFinalCleanupList finalCleanupList;
- module->walk([&](Operation *op) {
+ module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
- processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
+ processRegionBranchOp(regionBranchOp, la, deadVals, liveVals, finalCleanupList);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
processBranchOp(branchOp, la, deadVals, finalCleanupList);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
@@ -894,8 +929,13 @@ void RemoveDeadValues::runOnOperation() {
} else if (isa<CallOpInterface>(op)) {
// Nothing to do because this op is associated with a function op and gets
// cleaned when the latter is cleaned.
+ //
+ // The only exception is public callee. By default, Liveness analysis is inter-procedural.
+ // Unused arguments of a public function nonLive and are propagated to the caller.
+ // processCallOp puts them to liveVals.
+ processCallOp(cast<CallOpInterface>(op), module, la, liveVals);
} else {
- processSimpleOp(op, la, deadVals, finalCleanupList);
+ processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
}
});
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 56449469dc29f..fc857f2989f18 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -569,6 +569,24 @@ module @return_void_with_unused_argument {
call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
return %unused : memref<4xi32>
}
+
+ // the function is immutable because it is public.
+ func.func public @immutable_fn_return_void_with_unused_argument(%arg0: i32, %unused: i32) -> () {
+ %sum = arith.addi %arg0, %arg0 : i32
+ %c0 = arith.constant 0 : index
+ %buf = memref.alloc() : memref<1xi32>
+ memref.store %sum, %buf[%c0] : memref<1xi32>
+ return
+ }
+ // CHECK-LABEL: func.func @main2
+ // CHECK-SAME: (%[[ARG0_MAIN:.*]]: i32)
+ // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
+ // CHECK: call @immutable_fn_return_void_with_unused_argument(%[[ARG0_MAIN]], %[[UNUSED]]) : (i32, i32) -> ()
+ func.func @main2(%arg0: i32) -> () {
+ %zero = arith.constant 0 : i32
+ call @immutable_fn_return_void_with_unused_argument(%arg0, %zero) : (i32, i32) -> ()
+ return
+ }
}
// -----
>From be516b70225f331e51c4075c5e5f267519ef9d6b Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Tue, 23 Sep 2025 09:27:34 -0700
Subject: [PATCH 2/7] Fix formatter issue.
---
mlir/lib/Transforms/RemoveDeadValues.cpp | 26 +++++++++++++-----------
1 file changed, 14 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 5dad922ff4b69..50055ff7788a4 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -117,7 +117,8 @@ struct RDVFinalCleanupList {
/// Return true iff at least one value in `values` is live, given the liveness
/// information in `la`.
-static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet, const DenseSet<Value> &liveSet,
+static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
+ const DenseSet<Value> &liveSet,
RunLivenessAnalysis &la) {
for (Value value : values) {
@@ -265,9 +266,9 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
/// - Return-like
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
- DenseSet<Value> &liveSet,
- RDVFinalCleanupList &cl) {
- if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
+ DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+ if (!isMemoryEffectFree(op) ||
+ hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
LDBG() << "Simple op is not memory effect free or has live results, "
"preserving it: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -387,20 +388,20 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
}
static void processCallOp(CallOpInterface callOp, Operation *module,
- RunLivenessAnalysis &la,
- DenseSet<Value> &liveSet) {
+ RunLivenessAnalysis &la, DenseSet<Value> &liveSet) {
auto callable = callOp.getCallableForCallee();
if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) {
Operation *calleeOp = SymbolTable::lookupSymbolIn(module, symbolRef);
- if (auto funcOp = llvm::dyn_cast_or_null<mlir::FunctionOpInterface>(calleeOp)) {
+ if (auto funcOp =
+ llvm::dyn_cast_or_null<mlir::FunctionOpInterface>(calleeOp)) {
// Ensure the outgoing arguments of PUBLIC functions are live
// because processFuncOp can not process them.
//
// Liveness treats the external function as a blackbox.
if (funcOp.isPublic()) {
- for (Value arg: callOp.getArgOperands()) {
+ for (Value arg : callOp.getArgOperands()) {
const Liveness *liveness = la.getLiveness(arg);
if (liveness && !liveness->isLive) {
liveSet.insert(arg);
@@ -920,7 +921,8 @@ void RemoveDeadValues::runOnOperation() {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
- processRegionBranchOp(regionBranchOp, la, deadVals, liveVals, finalCleanupList);
+ processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
+ finalCleanupList);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
processBranchOp(branchOp, la, deadVals, finalCleanupList);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
@@ -930,9 +932,9 @@ void RemoveDeadValues::runOnOperation() {
// Nothing to do because this op is associated with a function op and gets
// cleaned when the latter is cleaned.
//
- // The only exception is public callee. By default, Liveness analysis is inter-procedural.
- // Unused arguments of a public function nonLive and are propagated to the caller.
- // processCallOp puts them to liveVals.
+ // The only exception is public callee. By default, Liveness analysis is
+ // inter-procedural. Unused arguments of a public function nonLive and are
+ // propagated to the caller. processCallOp puts them to liveVals.
processCallOp(cast<CallOpInterface>(op), module, la, liveVals);
} else {
processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
>From 6c62398cbf139d8384f712500311f11377bfd0a4 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Thu, 2 Oct 2025 22:41:04 -0700
Subject: [PATCH 3/7] update processCallOp.
---
.../mlir/Analysis/DataFlow/LivenessAnalysis.h | 3 +
mlir/lib/Transforms/RemoveDeadValues.cpp | 107 +++++++++++-------
2 files changed, 71 insertions(+), 39 deletions(-)
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index cf1fd6e2d48ca..be7e027b95f64 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -102,6 +102,9 @@ struct RunLivenessAnalysis {
const Liveness *getLiveness(Value val);
+ /// Return the configuration of the solver used for this analysis.
+ const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
+
private:
/// Stores the result of the liveness analysis that was run.
DataFlowSolver solver;
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 50055ff7788a4..621ec7b3827d3 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -33,6 +33,7 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
@@ -119,7 +120,6 @@ struct RDVFinalCleanupList {
/// information in `la`.
static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
const DenseSet<Value> &liveSet,
-
RunLivenessAnalysis &la) {
for (Value value : values) {
if (liveSet.contains(value)) {
@@ -151,7 +151,7 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
/// i-th value in `values` is live, given the liveness information in `la`.
static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
- RunLivenessAnalysis &la) {
+ const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
BitVector lives(values.size(), true);
for (auto [index, value] : llvm::enumerate(values)) {
@@ -161,7 +161,9 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
<< " is already marked non-live (dead) at index " << index;
continue;
}
-
+ if (liveSet.contains(value)) {
+ continue;
+ }
const Liveness *liveness = la.getLiveness(value);
// It is important to note that when `liveness` is null, we can't tell if
// `value` is live or not. So, the safe option is to consider it live. Also,
@@ -267,6 +269,16 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+ for (Value val: op->getResults()) {
+ if (liveSet.contains(val)) {
+ LDBG() << "Simple op is used by a public function, "
+ "preserving it: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ liveSet.insert_range(op->getOperands());
+ return;
+ }
+ }
+
if (!isMemoryEffectFree(op) ||
hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
LDBG() << "Simple op is not memory effect free or has live results, "
@@ -295,7 +307,7 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
/// removal.
/// (6) Marking all its results as non-live values.
static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
- RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+ RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing function op: "
<< OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
@@ -307,7 +319,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(funcOp.getArguments());
- BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
+ BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
nonLiveArgs = nonLiveArgs.flip();
// Do (1).
@@ -360,7 +372,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
- BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
+ BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, liveSet, la);
nonLiveRets &= liveCallRets.flip();
}
@@ -386,29 +398,52 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
}
}
+// create a cheaper value with the same type of oldVal in front of CallOp.
+static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
+ OpBuilder builder(callOp.getOperation());
+ Type type = oldVal.getType();
+
+ // Create zero constant for any supported type
+ if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
+ return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
+ }
+ return {};
+}
static void processCallOp(CallOpInterface callOp, Operation *module,
- RunLivenessAnalysis &la, DenseSet<Value> &liveSet) {
- auto callable = callOp.getCallableForCallee();
-
- if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) {
- Operation *calleeOp = SymbolTable::lookupSymbolIn(module, symbolRef);
-
- if (auto funcOp =
- llvm::dyn_cast_or_null<mlir::FunctionOpInterface>(calleeOp)) {
- // Ensure the outgoing arguments of PUBLIC functions are live
- // because processFuncOp can not process them.
- //
- // Liveness treats the external function as a blackbox.
- if (funcOp.isPublic()) {
- for (Value arg : callOp.getArgOperands()) {
- const Liveness *liveness = la.getLiveness(arg);
- if (liveness && !liveness->isLive) {
- liveSet.insert(arg);
- }
- }
+ RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet) {
+ if (!la.getSolverConfig().isInterprocedural())
+ return;
+
+ Operation *callableOp = callOp.resolveCallable();
+ auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
+ if (!funcOp || !funcOp.isPublic()) {
+ return;
+ }
+ LDBG() << "processCallOp" << funcOp.getName();
+ // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
+ SmallVector<Value> arguments(funcOp.getArguments());
+ BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
+ nonLiveArgs = nonLiveArgs.flip();
+
+ if (nonLiveArgs.count() > 0) {
+ LDBG() << funcOp.getName() << " contains NonLive arguments";
+ // The number of operands in the call op may not match the number of
+ // arguments in the func op.
+ SmallVector<OpOperand *> callOpOperands =
+ operandsToOpOperands(callOp.getArgOperands());
+
+ for (int index : nonLiveArgs.set_bits()) {
+ OpOperand *operand = callOpOperands[index];
+ Value oldVal = operand->get();
+ if (Value dummy = createDummyArgument(callOp, oldVal)) {
+ callOp->setOperand(operand->getOperandNumber(), dummy);
+ nonLiveSet.insert(oldVal);
+ } else {
+ liveSet.insert(oldVal);
}
}
+ LDBG() << "after changed " << callOp;
}
}
@@ -450,7 +485,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
<< OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
// Mark live results of `regionBranchOp` in `liveResults`.
auto markLiveResults = [&](BitVector &liveResults) {
- liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
+ liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
};
// Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
@@ -459,7 +494,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
SmallVector<Value> arguments(region.front().getArguments());
- BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
+ BitVector regionLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
liveArgs[®ion] = regionLiveArgs;
}
};
@@ -731,7 +766,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
/// as well as their corresponding arguments.
static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
- DenseSet<Value> &nonLiveSet,
+ DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing branch op: " << *branchOp;
unsigned numSuccessors = branchOp->getNumSuccessors();
@@ -750,7 +785,7 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
// Do (2)
BitVector successorNonLive =
- markLives(operandValues, nonLiveSet, la).flip();
+ markLives(operandValues, nonLiveSet, liveSet, la).flip();
collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
successorNonLive);
@@ -910,7 +945,7 @@ void RemoveDeadValues::runOnOperation() {
// Tracks values eligible for erasure - complements liveness analysis to
// identify "droppable" values.
DenseSet<Value> deadVals;
- // mark outgoing arguments to a public function LIVE.
+ // mark outgoing arguments to a public function LIVE. We also propagate liveness backward.
DenseSet<Value> liveVals;
// Maintains a list of Ops, values, branches, etc., slated for cleanup at the
@@ -919,23 +954,17 @@ void RemoveDeadValues::runOnOperation() {
module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
- processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
+ processFuncOp(funcOp, module, la, deadVals, liveVals, finalCleanupList);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
finalCleanupList);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
- processBranchOp(branchOp, la, deadVals, finalCleanupList);
+ processBranchOp(branchOp, la, deadVals, liveVals, finalCleanupList);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
} else if (isa<CallOpInterface>(op)) {
- // Nothing to do because this op is associated with a function op and gets
- // cleaned when the latter is cleaned.
- //
- // The only exception is public callee. By default, Liveness analysis is
- // inter-procedural. Unused arguments of a public function nonLive and are
- // propagated to the caller. processCallOp puts them to liveVals.
- processCallOp(cast<CallOpInterface>(op), module, la, liveVals);
+ processCallOp(cast<CallOpInterface>(op), module, la, deadVals, liveVals);
} else {
processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
}
>From 3f4d69f8910007656105a8798b0a5934310df489 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Sun, 5 Oct 2025 20:45:24 -0700
Subject: [PATCH 4/7] Update the lit test and formatter.
---
mlir/lib/Transforms/RemoveDeadValues.cpp | 37 +++++++++++---------
mlir/test/Transforms/remove-dead-values.mlir | 22 ++++++------
2 files changed, 32 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 621ec7b3827d3..ed5c6a8d2ead0 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -119,8 +119,7 @@ struct RDVFinalCleanupList {
/// Return true iff at least one value in `values` is live, given the liveness
/// information in `la`.
static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
- const DenseSet<Value> &liveSet,
- RunLivenessAnalysis &la) {
+ const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
for (Value value : values) {
if (liveSet.contains(value)) {
LDBG() << "Value " << value << " is marked live by CallOp";
@@ -151,7 +150,8 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
/// i-th value in `values` is live, given the liveness information in `la`.
static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
- const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
+ const DenseSet<Value> &liveSet,
+ RunLivenessAnalysis &la) {
BitVector lives(values.size(), true);
for (auto [index, value] : llvm::enumerate(values)) {
@@ -269,7 +269,7 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
- for (Value val: op->getResults()) {
+ for (Value val : op->getResults()) {
if (liveSet.contains(val)) {
LDBG() << "Simple op is used by a public function, "
"preserving it: "
@@ -307,8 +307,8 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
/// removal.
/// (6) Marking all its results as non-live values.
static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
- RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet,
- RDVFinalCleanupList &cl) {
+ RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+ DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
LDBG() << "Processing function op: "
<< OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
if (funcOp.isPublic() || funcOp.isExternal()) {
@@ -372,7 +372,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
- BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, liveSet, la);
+ BitVector liveCallRets =
+ markLives(callOp->getResults(), nonLiveSet, liveSet, la);
nonLiveRets &= liveCallRets.flip();
}
@@ -398,20 +399,22 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
}
}
-// create a cheaper value with the same type of oldVal in front of CallOp.
+
+// Create a cheaper value with the same type of oldVal in front of CallOp.
static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
OpBuilder builder(callOp.getOperation());
Type type = oldVal.getType();
// Create zero constant for any supported type
if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
- return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
+ return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
}
return {};
}
static void processCallOp(CallOpInterface callOp, Operation *module,
- RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet) {
+ RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+ DenseSet<Value> &liveSet) {
if (!la.getSolverConfig().isInterprocedural())
return;
@@ -420,7 +423,8 @@ static void processCallOp(CallOpInterface callOp, Operation *module,
if (!funcOp || !funcOp.isPublic()) {
return;
}
- LDBG() << "processCallOp" << funcOp.getName();
+
+ LDBG() << "processCallOp to a public function: " << funcOp.getName();
// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(funcOp.getArguments());
BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
@@ -443,7 +447,6 @@ static void processCallOp(CallOpInterface callOp, Operation *module,
liveSet.insert(oldVal);
}
}
- LDBG() << "after changed " << callOp;
}
}
@@ -485,7 +488,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
<< OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
// Mark live results of `regionBranchOp` in `liveResults`.
auto markLiveResults = [&](BitVector &liveResults) {
- liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
+ liveResults =
+ markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
};
// Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
@@ -766,8 +770,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
/// as well as their corresponding arguments.
static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
- DenseSet<Value> &nonLiveSet, DenseSet<Value> &liveSet,
- RDVFinalCleanupList &cl) {
+ DenseSet<Value> &nonLiveSet,
+ DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
LDBG() << "Processing branch op: " << *branchOp;
unsigned numSuccessors = branchOp->getNumSuccessors();
@@ -945,7 +949,8 @@ void RemoveDeadValues::runOnOperation() {
// Tracks values eligible for erasure - complements liveness analysis to
// identify "droppable" values.
DenseSet<Value> deadVals;
- // mark outgoing arguments to a public function LIVE. We also propagate liveness backward.
+ // mark outgoing arguments to a public function LIVE. We also propagate
+ // liveness backward.
DenseSet<Value> liveVals;
// Maintains a list of Ops, values, branches, etc., slated for cleanup at the
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index fc857f2989f18..ebb76a8835ceb 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -570,21 +570,21 @@ module @return_void_with_unused_argument {
return %unused : memref<4xi32>
}
- // the function is immutable because it is public.
- func.func public @immutable_fn_return_void_with_unused_argument(%arg0: i32, %unused: i32) -> () {
- %sum = arith.addi %arg0, %arg0 : i32
- %c0 = arith.constant 0 : index
- %buf = memref.alloc() : memref<1xi32>
- memref.store %sum, %buf[%c0] : memref<1xi32>
+ // the function signature is immutable because it is public.
+ func.func public @immutable_fn_with_unused_argument(%arg0: i32, %arg1: memref<4xf32>) -> () {
return
}
+
// CHECK-LABEL: func.func @main2
- // CHECK-SAME: (%[[ARG0_MAIN:.*]]: i32)
+ // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
// CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
- // CHECK: call @immutable_fn_return_void_with_unused_argument(%[[ARG0_MAIN]], %[[UNUSED]]) : (i32, i32) -> ()
- func.func @main2(%arg0: i32) -> () {
- %zero = arith.constant 0 : i32
- call @immutable_fn_return_void_with_unused_argument(%arg0, %zero) : (i32, i32) -> ()
+ // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
+ func.func @main2() -> () {
+ %one = arith.constant 1 : i32
+ %scalar = arith.addi %one, %one: i32
+ %mem = memref.alloc() : memref<4xf32>
+
+ call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
return
}
}
>From 4904fb9ab8162477b13e7b0827696d11a65c9389 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Tue, 7 Oct 2025 20:28:14 -0700
Subject: [PATCH 5/7] Add LIT test contains SCF
This diff also tries to propagate liveness recursively.
---
mlir/include/mlir/IR/Visitors.h | 2 +-
mlir/lib/Transforms/RemoveDeadValues.cpp | 95 ++++++++++++++++----
mlir/test/Transforms/remove-dead-values.mlir | 21 ++++-
3 files changed, 99 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 5766d262796d6..632db0a7b5cd4 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -45,7 +45,7 @@ struct BackwardIterator {
static auto makeIterable(T &range) {
if constexpr (std::is_same<T, Operation>()) {
/// Make operations iterable: return the list of regions.
- return llvm::reverse(range.getRegions());
+ return range.getRegions();
} else {
/// Regions and block are already iterable.
return llvm::reverse(range);
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index ed5c6a8d2ead0..ba2c56e5d9661 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -269,16 +269,6 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
- for (Value val : op->getResults()) {
- if (liveSet.contains(val)) {
- LDBG() << "Simple op is used by a public function, "
- "preserving it: "
- << OpWithFlags(op, OpPrintingFlags().skipRegions());
- liveSet.insert_range(op->getOperands());
- return;
- }
- }
-
if (!isMemoryEffectFree(op) ||
hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
LDBG() << "Simple op is not memory effect free or has live results, "
@@ -412,6 +402,82 @@ static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
return {};
}
+// When you mark a call operand as live, also mark its definition chain, recursively.
+// We handle RegionBranchOpInterface here. I think we should handle BranchOpInterface as well.
+void propagateBackward(Value val, DenseSet<Value> &liveSet) {
+ if (liveSet.contains(val)) return;
+ liveSet.insert(val);
+
+ if (auto defOp = val.getDefiningOp()) {
+ // Mark operands of live results as live
+ for (Value operand : defOp->getOperands()) {
+ propagateBackward(operand, liveSet);
+ }
+
+ // Handle RegionBranchOpInterface specially
+ if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(defOp)) {
+ // If this is a result of a RegionBranchOpInterface, we need to trace back
+ // through the control flow to find the sources that contribute to this result
+
+ OpResult result = cast<OpResult>(val);
+ unsigned resultIndex = result.getResultNumber();
+
+ // Find all possible sources that can contribute to this result
+ // by examining all regions and their terminators
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ if (region.empty()) continue;
+
+ // Get the successors from this region
+ SmallVector<RegionSuccessor> successors;
+ regionBranchOp.getSuccessorRegions(RegionBranchPoint(®ion), successors);
+
+ // Check if any successor can produce this result
+ for (const RegionSuccessor &successor : successors) {
+ if (successor.isParent()) {
+ // This region can return to the parent operation
+ ValueRange successorInputs = successor.getSuccessorInputs();
+ if (resultIndex < successorInputs.size()) {
+ // Find the terminator that contributes to this result
+ Operation *terminator = region.back().getTerminator();
+ if (auto regionBranchTerm =
+ dyn_cast<RegionBranchTerminatorOpInterface>(terminator)) {
+ OperandRange terminatorOperands =
+ regionBranchTerm.getSuccessorOperands(RegionBranchPoint::parent());
+ if (resultIndex < terminatorOperands.size()) {
+ // This terminator operand contributes to our result
+ propagateBackward(terminatorOperands[resultIndex], liveSet);
+ }
+ }
+ }
+ }
+ }
+
+ // Also mark region arguments as live if they might contribute to this result
+ // Find which operand of the parent operation corresponds to region arguments
+ Block &entryBlock = region.front();
+ for (BlockArgument arg : entryBlock.getArguments()) {
+ // Get entry successor operands - these are the operands that flow
+ // from the parent operation to this region
+ SmallVector<RegionSuccessor> entrySuccessors;
+ regionBranchOp.getSuccessorRegions(RegionBranchPoint::parent(), entrySuccessors);
+
+ for (const RegionSuccessor &entrySuccessor : entrySuccessors) {
+ if (entrySuccessor.getSuccessor() == ®ion) {
+ // Get the operands that are forwarded to this region
+ OperandRange entryOperands =
+ regionBranchOp.getEntrySuccessorOperands(RegionBranchPoint::parent());
+ unsigned argIndex = arg.getArgNumber();
+ if (argIndex < entryOperands.size()) {
+ propagateBackward(entryOperands[argIndex], liveSet);
+ }
+ break;
+ }
+ }
+ }
+ }
+ }
+ }
+}
static void processCallOp(CallOpInterface callOp, Operation *module,
RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
DenseSet<Value> &liveSet) {
@@ -439,13 +505,8 @@ static void processCallOp(CallOpInterface callOp, Operation *module,
for (int index : nonLiveArgs.set_bits()) {
OpOperand *operand = callOpOperands[index];
- Value oldVal = operand->get();
- if (Value dummy = createDummyArgument(callOp, oldVal)) {
- callOp->setOperand(operand->getOperandNumber(), dummy);
- nonLiveSet.insert(oldVal);
- } else {
- liveSet.insert(oldVal);
- }
+ LDBG() << "mark operand " << index << " live " << operand->get();
+ propagateBackward(operand->get(), liveSet);
}
}
}
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index ebb76a8835ceb..a60efa45fe943 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -576,8 +576,9 @@ module @return_void_with_unused_argument {
}
// CHECK-LABEL: func.func @main2
+ // CHECK: %[[ONE:.*]] = arith.constant 1 : i32
+ // CHECK: %[[UNUSED:.*]] = arith.addi %[[ONE]], %[[ONE]] : i32
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
- // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
// CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
func.func @main2() -> () {
%one = arith.constant 1 : i32
@@ -587,6 +588,24 @@ module @return_void_with_unused_argument {
call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
return
}
+
+ // CHECK-LABEL: func.func @main3
+ // CHECK: %[[UNUSED:.*]] = scf.if %arg0 -> (i32)
+ // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
+ // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
+ func.func @main3(%arg0: i1) {
+ %0 = scf.if %arg0 -> (i32) {
+ %c1_i32 = arith.constant 1 : i32
+ scf.yield %c1_i32 : i32
+ } else {
+ %c0_i32 = arith.constant 0 : i32
+ scf.yield %c0_i32 : i32
+ }
+ %mem = memref.alloc() : memref<4xf32>
+
+ call @immutable_fn_with_unused_argument(%0, %mem) : (i32, memref<4xf32>) -> ()
+ return
+ }
}
// -----
>From 8f92fee035fd1fdbda438481975770ecff2a1651 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Thu, 9 Oct 2025 20:28:52 -0700
Subject: [PATCH 6/7] [MLIR][RemoveDeadValues] Privatize public function
beforehand.
---
.../mlir/Analysis/DataFlow/LivenessAnalysis.h | 7 +-
mlir/include/mlir/IR/Visitors.h | 14 -
.../Analysis/DataFlow/LivenessAnalysis.cpp | 1 +
mlir/lib/Transforms/RemoveDeadValues.cpp | 330 +++++++++---------
mlir/test/Transforms/remove-dead-values.mlir | 51 +--
5 files changed, 211 insertions(+), 192 deletions(-)
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index be7e027b95f64..7bf3ada5847e6 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -24,6 +24,7 @@
#define MLIR_ANALYSIS_DATAFLOW_LIVENESSANALYSIS_H
#include <mlir/Analysis/DataFlow/SparseAnalysis.h>
+#include <mlir/Pass/AnalysisManager.h>
#include <optional>
namespace mlir::dataflow {
@@ -101,13 +102,17 @@ struct RunLivenessAnalysis {
RunLivenessAnalysis(Operation *op);
const Liveness *getLiveness(Value val);
-
+ // This only mark Liveness results are stale.
+ void invalidate() { valid = false; }
/// Return the configuration of the solver used for this analysis.
const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
+ /// The function is called by analysis_impl::isInvalidated.
+ bool isInvalidated(AnalysisManager::PreservedAnalyses&) const { return !valid; }
private:
/// Stores the result of the liveness analysis that was run.
DataFlowSolver solver;
+ bool valid {true};
};
} // end namespace mlir::dataflow
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 632db0a7b5cd4..893f66ae33deb 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -39,20 +39,6 @@ struct ForwardIterator {
}
};
-/// This iterator enumerates the elements in "backward" order.
-struct BackwardIterator {
- template <typename T>
- static auto makeIterable(T &range) {
- if constexpr (std::is_same<T, Operation>()) {
- /// Make operations iterable: return the list of regions.
- return range.getRegions();
- } else {
- /// Regions and block are already iterable.
- return llvm::reverse(range);
- }
- }
-};
-
/// A utility class to encode the current walk stage for "generic" walkers.
/// When walking an operation, we can either choose a Pre/Post order walker
/// which invokes the callback on an operation before/after all its attached
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index d705d8d4c7819..943c60bda9de6 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -325,5 +325,6 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
}
const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
+ assert(valid && "getLiveness called after invalidate");
return solver.lookupState<Liveness>(val);
}
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index ba2c56e5d9661..05d49d8beec39 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -33,7 +33,6 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
@@ -47,6 +46,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
@@ -119,13 +119,8 @@ struct RDVFinalCleanupList {
/// Return true iff at least one value in `values` is live, given the liveness
/// information in `la`.
static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
- const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
+ RunLivenessAnalysis &la) {
for (Value value : values) {
- if (liveSet.contains(value)) {
- LDBG() << "Value " << value << " is marked live by CallOp";
- return true;
- }
-
if (nonLiveSet.contains(value)) {
LDBG() << "Value " << value << " is already marked non-live (dead)";
continue;
@@ -150,7 +145,6 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
/// i-th value in `values` is live, given the liveness information in `la`.
static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
- const DenseSet<Value> &liveSet,
RunLivenessAnalysis &la) {
BitVector lives(values.size(), true);
@@ -161,9 +155,7 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
<< " is already marked non-live (dead) at index " << index;
continue;
}
- if (liveSet.contains(value)) {
- continue;
- }
+
const Liveness *liveness = la.getLiveness(value);
// It is important to note that when `liveness` is null, we can't tell if
// `value` is live or not. So, the safe option is to consider it live. Also,
@@ -268,9 +260,8 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
/// - Return-like
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
- DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
- if (!isMemoryEffectFree(op) ||
- hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
+ RDVFinalCleanupList &cl) {
+ if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
LDBG() << "Simple op is not memory effect free or has live results, "
"preserving it: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -298,7 +289,7 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
/// (6) Marking all its results as non-live values.
static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
- DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+ RDVFinalCleanupList &cl) {
LDBG() << "Processing function op: "
<< OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
if (funcOp.isPublic() || funcOp.isExternal()) {
@@ -309,7 +300,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(funcOp.getArguments());
- BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
+ BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
nonLiveArgs = nonLiveArgs.flip();
// Do (1).
@@ -362,8 +353,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
- BitVector liveCallRets =
- markLives(callOp->getResults(), nonLiveSet, liveSet, la);
+ BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
nonLiveRets &= liveCallRets.flip();
}
@@ -390,127 +380,6 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
}
}
-// Create a cheaper value with the same type of oldVal in front of CallOp.
-static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
- OpBuilder builder(callOp.getOperation());
- Type type = oldVal.getType();
-
- // Create zero constant for any supported type
- if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
- return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
- }
- return {};
-}
-
-// When you mark a call operand as live, also mark its definition chain, recursively.
-// We handle RegionBranchOpInterface here. I think we should handle BranchOpInterface as well.
-void propagateBackward(Value val, DenseSet<Value> &liveSet) {
- if (liveSet.contains(val)) return;
- liveSet.insert(val);
-
- if (auto defOp = val.getDefiningOp()) {
- // Mark operands of live results as live
- for (Value operand : defOp->getOperands()) {
- propagateBackward(operand, liveSet);
- }
-
- // Handle RegionBranchOpInterface specially
- if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(defOp)) {
- // If this is a result of a RegionBranchOpInterface, we need to trace back
- // through the control flow to find the sources that contribute to this result
-
- OpResult result = cast<OpResult>(val);
- unsigned resultIndex = result.getResultNumber();
-
- // Find all possible sources that can contribute to this result
- // by examining all regions and their terminators
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty()) continue;
-
- // Get the successors from this region
- SmallVector<RegionSuccessor> successors;
- regionBranchOp.getSuccessorRegions(RegionBranchPoint(®ion), successors);
-
- // Check if any successor can produce this result
- for (const RegionSuccessor &successor : successors) {
- if (successor.isParent()) {
- // This region can return to the parent operation
- ValueRange successorInputs = successor.getSuccessorInputs();
- if (resultIndex < successorInputs.size()) {
- // Find the terminator that contributes to this result
- Operation *terminator = region.back().getTerminator();
- if (auto regionBranchTerm =
- dyn_cast<RegionBranchTerminatorOpInterface>(terminator)) {
- OperandRange terminatorOperands =
- regionBranchTerm.getSuccessorOperands(RegionBranchPoint::parent());
- if (resultIndex < terminatorOperands.size()) {
- // This terminator operand contributes to our result
- propagateBackward(terminatorOperands[resultIndex], liveSet);
- }
- }
- }
- }
- }
-
- // Also mark region arguments as live if they might contribute to this result
- // Find which operand of the parent operation corresponds to region arguments
- Block &entryBlock = region.front();
- for (BlockArgument arg : entryBlock.getArguments()) {
- // Get entry successor operands - these are the operands that flow
- // from the parent operation to this region
- SmallVector<RegionSuccessor> entrySuccessors;
- regionBranchOp.getSuccessorRegions(RegionBranchPoint::parent(), entrySuccessors);
-
- for (const RegionSuccessor &entrySuccessor : entrySuccessors) {
- if (entrySuccessor.getSuccessor() == ®ion) {
- // Get the operands that are forwarded to this region
- OperandRange entryOperands =
- regionBranchOp.getEntrySuccessorOperands(RegionBranchPoint::parent());
- unsigned argIndex = arg.getArgNumber();
- if (argIndex < entryOperands.size()) {
- propagateBackward(entryOperands[argIndex], liveSet);
- }
- break;
- }
- }
- }
- }
- }
- }
-}
-static void processCallOp(CallOpInterface callOp, Operation *module,
- RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
- DenseSet<Value> &liveSet) {
- if (!la.getSolverConfig().isInterprocedural())
- return;
-
- Operation *callableOp = callOp.resolveCallable();
- auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
- if (!funcOp || !funcOp.isPublic()) {
- return;
- }
-
- LDBG() << "processCallOp to a public function: " << funcOp.getName();
- // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
- SmallVector<Value> arguments(funcOp.getArguments());
- BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
- nonLiveArgs = nonLiveArgs.flip();
-
- if (nonLiveArgs.count() > 0) {
- LDBG() << funcOp.getName() << " contains NonLive arguments";
- // The number of operands in the call op may not match the number of
- // arguments in the func op.
- SmallVector<OpOperand *> callOpOperands =
- operandsToOpOperands(callOp.getArgOperands());
-
- for (int index : nonLiveArgs.set_bits()) {
- OpOperand *operand = callOpOperands[index];
- LDBG() << "mark operand " << index << " live " << operand->get();
- propagateBackward(operand->get(), liveSet);
- }
- }
-}
-
/// Process a region branch operation `regionBranchOp` using the liveness
/// information in `la`. The processing involves two scenarios:
///
@@ -543,14 +412,12 @@ static void processCallOp(CallOpInterface callOp, Operation *module,
static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
- DenseSet<Value> &liveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing region branch op: "
<< OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
// Mark live results of `regionBranchOp` in `liveResults`.
auto markLiveResults = [&](BitVector &liveResults) {
- liveResults =
- markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
+ liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
};
// Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
@@ -559,7 +426,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
SmallVector<Value> arguments(region.front().getArguments());
- BitVector regionLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
+ BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
liveArgs[®ion] = regionLiveArgs;
}
};
@@ -753,7 +620,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// attributed to something else.
// Do (1') and (2').
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
- !hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) {
+ !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
cl.operations.push_back(regionBranchOp.getOperation());
return;
}
@@ -832,7 +699,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
- DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+ RDVFinalCleanupList &cl) {
LDBG() << "Processing branch op: " << *branchOp;
unsigned numSuccessors = branchOp->getNumSuccessors();
@@ -850,7 +717,7 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
// Do (2)
BitVector successorNonLive =
- markLives(operandValues, nonLiveSet, liveSet, la).flip();
+ markLives(operandValues, nonLiveSet, la).flip();
collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
successorNonLive);
@@ -1003,36 +870,185 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
};
} // namespace
+/// If the target of CallOP is a public function and at least one argument is NonLive,
+/// privatize the function.
+/// Our strategy here is separation interface and implementation. eg.
+///
+/// public void foo(int unused){...}
+/// =>
+/// public void foo(int unused) { // old function, interface
+/// return __foo_impl(unused);
+/// }
+///
+/// private void __foo_impl(int unused) { // the new private function, or implementation.
+/// ... // the function body of the original function.
+/// }
+///
+/// Returns true if any IR changes were made, false otherwise.
+static bool processCallOp(CallOpInterface callOp, Operation *module,
+ RunLivenessAnalysis &la) {
+ Operation *callableOp = callOp.resolveCallable();
+ auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
+ if (!funcOp || !funcOp.isPublic()) {
+ return false;
+ }
+
+ LDBG() << "Processing callOp " << callOp
+ << " target is a public function: " << funcOp.getOperation()->getName();
+
+ // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
+ SmallVector<Value> arguments(callOp.getArgOperands());
+ BitVector nonLiveArgs = markLives(arguments, DenseSet<Value>(), la);
+ nonLiveArgs = nonLiveArgs.flip();
+
+ if (nonLiveArgs.count() > 0) {
+ auto moduleOp = cast<ModuleOp>(module);
+ OpBuilder rewriter(moduleOp.getContext());
+
+ // Clone function and create private version
+ FunctionOpInterface clonedFunc = cast<FunctionOpInterface>(funcOp.clone());
+
+ // Set visibility = 'private' and a new name for the cloned function
+ SymbolTable::setSymbolVisibility(clonedFunc,
+ SymbolTable::Visibility::Private);
+ std::string newName = "__" + funcOp.getName().str() + "_privatized";
+ clonedFunc.setName(newName);
+
+ // Insert the cloned function into the module
+ rewriter.setInsertionPointAfter(funcOp);
+ rewriter.insert(clonedFunc);
+
+ // Replace ALL callsites of the original function to call the cloned function directly
+ LogicalResult result = SymbolTable::replaceAllSymbolUses(
+ funcOp,
+ clonedFunc.getNameAttr(),
+ moduleOp
+ );
+
+ if (result.failed()) {
+ LDBG() << "Failed to replace all symbol uses for " << funcOp.getName();
+ return false;
+ }
+
+ LDBG() << "Redirected all callsites from " << funcOp.getName()
+ << " to " << newName;
+
+ // Transform the original funcOp into a wrapper that calls the cloned function
+ Region &funcBody = funcOp.getFunctionBody();
+
+ // Clean the original function body
+ funcBody.dropAllReferences();
+ funcBody.getBlocks().clear();
+
+ // Create a new entry block for the wrapper function
+ Block *wrapperBlock = rewriter.createBlock(&funcBody);
+
+ // Add block arguments that match the function signature
+ for (Type argType : funcOp.getArgumentTypes()) {
+ wrapperBlock->addArgument(argType, funcOp.getLoc());
+ }
+
+ // Set insertion point to the new block
+ rewriter.setInsertionPointToStart(wrapperBlock);
+
+ // Clone the original call operation and update its callee
+ Operation *clonedCallOp = callOp->clone();
+
+ // Update the callee symbol reference to point to the new private function
+ if (auto callableOp = dyn_cast<CallOpInterface>(clonedCallOp)) {
+ auto symbolRef = SymbolRefAttr::get(funcOp.getContext(), clonedFunc.getName());
+ callableOp.setCalleeFromCallable(symbolRef);
+ }
+
+ // Set the call arguments to use the wrapper block's arguments
+ clonedCallOp->setOperands(wrapperBlock->getArguments());
+
+ // Insert the cloned call operation
+ rewriter.insert(clonedCallOp);
+
+ // Create return operation of the same type as the original function's return
+ SmallVector<Operation*> returnOps;
+ for (Block &block : clonedFunc.getFunctionBody()) {
+ if (block.getNumSuccessors() > 0)
+ continue;
+
+ Operation *terminator = block.getTerminator();
+ if (terminator && terminator->hasTrait<OpTrait::ReturnLike>()) {
+ returnOps.push_back(terminator);
+ break; // Use first return as template
+ }
+ }
+
+ if (!returnOps.empty()) {
+ Operation *templateReturnOp = returnOps[0];
+ Operation *newReturnOp = templateReturnOp->clone();
+ newReturnOp->setOperands(clonedCallOp->getResults());
+ newReturnOp->setLoc(funcOp.getLoc());
+ rewriter.insert(newReturnOp);
+ }
+
+ LDBG() << "Created wrapper function " << funcOp.getName()
+ << " that calls " << newName;
+
+ return true; // Changes were made
+ }
+
+ return false; // No changes made
+}
+
void RemoveDeadValues::runOnOperation() {
- auto &la = getAnalysis<RunLivenessAnalysis>();
+ AnalysisManager am = getAnalysisManager();
+ RunLivenessAnalysis *la = &am.getAnalysis<RunLivenessAnalysis>();
Operation *module = getOperation();
+ // Only privatize public funciton if liveness analysis is inter-procedural.
+ if (la->getSolverConfig().isInterprocedural()) {
+ bool changed = false;
+ module->walk([&](CallOpInterface callOp) {
+ if (processCallOp(callOp, module, *la)) {
+ changed = true;
+ }
+ });
+
+ if (changed) {
+ LDBG() << "IR has changed, invalidate RunLivenessAnalysis only";
+ auto & pa = getPassState().preservedAnalyses;
+ bool preserved = pa.isPreserved<RunLivenessAnalysis>();
+ la->invalidate();
+ am.invalidate(pa);
+
+ la = &am.getAnalysis<RunLivenessAnalysis>();
+ // if RunLivenessAnalysis was previously preserved, preserved the updated results.
+ if (preserved) {
+ pa.preserve<RunLivenessAnalysis>();
+ }
+ }
+ }
+
+
// Tracks values eligible for erasure - complements liveness analysis to
// identify "droppable" values.
DenseSet<Value> deadVals;
- // mark outgoing arguments to a public function LIVE. We also propagate
- // liveness backward.
- DenseSet<Value> liveVals;
// Maintains a list of Ops, values, branches, etc., slated for cleanup at the
// end of this pass.
RDVFinalCleanupList finalCleanupList;
- module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
+ module->walk([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
- processFuncOp(funcOp, module, la, deadVals, liveVals, finalCleanupList);
+ processFuncOp(funcOp, module, *la, deadVals, finalCleanupList);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
- processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
- finalCleanupList);
+ processRegionBranchOp(regionBranchOp, *la, deadVals, finalCleanupList);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
- processBranchOp(branchOp, la, deadVals, liveVals, finalCleanupList);
+ processBranchOp(branchOp, *la, deadVals, finalCleanupList);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
} else if (isa<CallOpInterface>(op)) {
- processCallOp(cast<CallOpInterface>(op), module, la, deadVals, liveVals);
+ // Nothing to do because this op is associated with a function op and gets
+ // cleaned when the latter is cleaned.
} else {
- processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
+ processSimpleOp(op, *la, deadVals, finalCleanupList);
}
});
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index a60efa45fe943..bcac4638604fa 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -569,31 +569,25 @@ module @return_void_with_unused_argument {
call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
return %unused : memref<4xi32>
}
+}
+// check that public functions with non-live arguments correctly.
+module @public_function_with_nonlive_arguments {
// the function signature is immutable because it is public.
- func.func public @immutable_fn_with_unused_argument(%arg0: i32, %arg1: memref<4xf32>) -> () {
+ func.func public @public_fn_with_unused_argument(%unused: i32) -> () {
return
}
-
- // CHECK-LABEL: func.func @main2
- // CHECK: %[[ONE:.*]] = arith.constant 1 : i32
- // CHECK: %[[UNUSED:.*]] = arith.addi %[[ONE]], %[[ONE]] : i32
- // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
- // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
- func.func @main2() -> () {
- %one = arith.constant 1 : i32
- %scalar = arith.addi %one, %one: i32
- %mem = memref.alloc() : memref<4xf32>
-
- call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
+ // CHECK-LABEL: func.func @main
+ // CHECK: call @__public_fn_with_unused_argument_privatized() : () -> ()
+ func.func @main() -> () {
+ %zero = arith.constant 0 : i32
+ call @public_fn_with_unused_argument(%zero) : (i32) -> ()
return
}
- // CHECK-LABEL: func.func @main3
- // CHECK: %[[UNUSED:.*]] = scf.if %arg0 -> (i32)
- // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
- // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
- func.func @main3(%arg0: i1) {
+ // CHECK-LABEL: func.func @main2
+ // CHECK: call @__public_fn_with_unused_argument_privatized() : () -> ()
+ func.func @main2(%arg0: i1) {
%0 = scf.if %arg0 -> (i32) {
%c1_i32 = arith.constant 1 : i32
scf.yield %c1_i32 : i32
@@ -601,9 +595,26 @@ module @return_void_with_unused_argument {
%c0_i32 = arith.constant 0 : i32
scf.yield %c0_i32 : i32
}
- %mem = memref.alloc() : memref<4xf32>
- call @immutable_fn_with_unused_argument(%0, %mem) : (i32, memref<4xf32>) -> ()
+ call @public_fn_with_unused_argument(%0) : (i32) -> ()
+ return
+ }
+
+ func.func public @fn_return_multiple(%arg0: i32) -> (i32, i32, i32) {
+ %one = arith.constant 1 : i32
+ %two = arith.constant 2 : i32
+ %three = arith.constant 4 : i32
+
+ return %one, %two, %three: i32, i32, i32
+ }
+
+ // CHECK-LABEL: func.func @main3
+ // CHECK: call @__fn_return_multiple_privatized() : () -> (i32, i32, i32)
+ func.func @main3(%arg: i32) -> () {
+ %one = arith.constant 1 : i32
+ %scalar = arith.addi %arg, %one: i32
+
+ call @fn_return_multiple(%scalar) : (i32) -> (i32, i32, i32)
return
}
}
>From 443dff206e45b8aecb1f78e86ab9a80d1a815a2d Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Fri, 10 Oct 2025 00:14:25 -0700
Subject: [PATCH 7/7] Update format.
---
.../mlir/Analysis/DataFlow/LivenessAnalysis.h | 8 ++-
mlir/lib/Transforms/RemoveDeadValues.cpp | 70 ++++++++-----------
2 files changed, 36 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index 7bf3ada5847e6..20162e6586600 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -102,17 +102,19 @@ struct RunLivenessAnalysis {
RunLivenessAnalysis(Operation *op);
const Liveness *getLiveness(Value val);
- // This only mark Liveness results are stale.
+ // This only remarks that Liveness results are stale.
void invalidate() { valid = false; }
/// Return the configuration of the solver used for this analysis.
const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
/// The function is called by analysis_impl::isInvalidated.
- bool isInvalidated(AnalysisManager::PreservedAnalyses&) const { return !valid; }
+ bool isInvalidated(AnalysisManager::PreservedAnalyses &) const {
+ return !valid;
+ }
private:
/// Stores the result of the liveness analysis that was run.
DataFlowSolver solver;
- bool valid {true};
+ bool valid{true};
};
} // end namespace mlir::dataflow
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 05d49d8beec39..0b0d2c1485c1d 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -870,18 +870,20 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
};
} // namespace
-/// If the target of CallOP is a public function and at least one argument is NonLive,
-/// privatize the function.
-/// Our strategy here is separation interface and implementation. eg.
+/// If the target of CallOp is a public function and at least one argument is
+/// NonLive, privatize the function. Our strategy here is separation interface
+/// and implementation. eg.
///
/// public void foo(int unused){...}
/// =>
/// public void foo(int unused) { // old function, interface
-/// return __foo_impl(unused);
+/// return __foo_privatized(unused);
/// }
///
-/// private void __foo_impl(int unused) { // the new private function, or implementation.
-/// ... // the function body of the original function.
+/// private void __foo_privatized(int unused) { // the new private function, or
+/// implementation.
+/// ... // the function body of the
+/// original function.
/// }
///
/// Returns true if any IR changes were made, false otherwise.
@@ -893,8 +895,8 @@ static bool processCallOp(CallOpInterface callOp, Operation *module,
return false;
}
- LDBG() << "Processing callOp " << callOp
- << " target is a public function: " << funcOp.getOperation()->getName();
+ LDBG() << "Processing callOp " << callOp << " target is a public function: "
+ << funcOp.getOperation()->getName();
// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(callOp.getArgOperands());
@@ -910,7 +912,7 @@ static bool processCallOp(CallOpInterface callOp, Operation *module,
// Set visibility = 'private' and a new name for the cloned function
SymbolTable::setSymbolVisibility(clonedFunc,
- SymbolTable::Visibility::Private);
+ SymbolTable::Visibility::Private);
std::string newName = "__" + funcOp.getName().str() + "_privatized";
clonedFunc.setName(newName);
@@ -918,22 +920,21 @@ static bool processCallOp(CallOpInterface callOp, Operation *module,
rewriter.setInsertionPointAfter(funcOp);
rewriter.insert(clonedFunc);
- // Replace ALL callsites of the original function to call the cloned function directly
+ // Replace ALL callsites of the original function to call the cloned
+ // function directly
LogicalResult result = SymbolTable::replaceAllSymbolUses(
- funcOp,
- clonedFunc.getNameAttr(),
- moduleOp
- );
+ funcOp, clonedFunc.getNameAttr(), moduleOp);
if (result.failed()) {
LDBG() << "Failed to replace all symbol uses for " << funcOp.getName();
return false;
}
- LDBG() << "Redirected all callsites from " << funcOp.getName()
- << " to " << newName;
+ LDBG() << "Redirected all callsites from " << funcOp.getName() << " to "
+ << newName;
- // Transform the original funcOp into a wrapper that calls the cloned function
+ // Transform the original funcOp into a wrapper that calls the cloned
+ // function
Region &funcBody = funcOp.getFunctionBody();
// Clean the original function body
@@ -952,44 +953,35 @@ static bool processCallOp(CallOpInterface callOp, Operation *module,
rewriter.setInsertionPointToStart(wrapperBlock);
// Clone the original call operation and update its callee
- Operation *clonedCallOp = callOp->clone();
-
+ auto clonedCallOp = cast<CallOpInterface>(callOp->clone());
// Update the callee symbol reference to point to the new private function
- if (auto callableOp = dyn_cast<CallOpInterface>(clonedCallOp)) {
- auto symbolRef = SymbolRefAttr::get(funcOp.getContext(), clonedFunc.getName());
- callableOp.setCalleeFromCallable(symbolRef);
- }
-
+ auto symbolRef =
+ SymbolRefAttr::get(funcOp.getContext(), clonedFunc.getName());
+ clonedCallOp.setCalleeFromCallable(symbolRef);
// Set the call arguments to use the wrapper block's arguments
clonedCallOp->setOperands(wrapperBlock->getArguments());
-
- // Insert the cloned call operation
rewriter.insert(clonedCallOp);
- // Create return operation of the same type as the original function's return
- SmallVector<Operation*> returnOps;
+ // Create return operation of the same type as the original function's
+ // return
+ Operation *returnOp = nullptr;
for (Block &block : clonedFunc.getFunctionBody()) {
if (block.getNumSuccessors() > 0)
continue;
Operation *terminator = block.getTerminator();
if (terminator && terminator->hasTrait<OpTrait::ReturnLike>()) {
- returnOps.push_back(terminator);
+ returnOp = terminator;
break; // Use first return as template
}
}
- if (!returnOps.empty()) {
- Operation *templateReturnOp = returnOps[0];
- Operation *newReturnOp = templateReturnOp->clone();
+ if (returnOp) {
+ Operation *newReturnOp = returnOp->clone();
newReturnOp->setOperands(clonedCallOp->getResults());
newReturnOp->setLoc(funcOp.getLoc());
rewriter.insert(newReturnOp);
}
-
- LDBG() << "Created wrapper function " << funcOp.getName()
- << " that calls " << newName;
-
return true; // Changes were made
}
@@ -1012,20 +1004,20 @@ void RemoveDeadValues::runOnOperation() {
if (changed) {
LDBG() << "IR has changed, invalidate RunLivenessAnalysis only";
- auto & pa = getPassState().preservedAnalyses;
+ auto &pa = getPassState().preservedAnalyses;
bool preserved = pa.isPreserved<RunLivenessAnalysis>();
la->invalidate();
am.invalidate(pa);
la = &am.getAnalysis<RunLivenessAnalysis>();
- // if RunLivenessAnalysis was previously preserved, preserved the updated results.
+ // If RunLivenessAnalysis was previously preserved, preserved the updated
+ // results.
if (preserved) {
pa.preserve<RunLivenessAnalysis>();
}
}
}
-
// Tracks values eligible for erasure - complements liveness analysis to
// identify "droppable" values.
DenseSet<Value> deadVals;
More information about the Mlir-commits
mailing list