[Mlir-commits] [mlir] Prevent invalid IR from being passed outside of RemoveDeadValues (PR #121079)
Renat Idrisov
llvmlistbot at llvm.org
Tue Dec 24 18:39:55 PST 2024
https://github.com/parsifal-47 created https://github.com/llvm/llvm-project/pull/121079
This is a follow-up for https://github.com/llvm/llvm-project/pull/119110 and a fix for https://github.com/llvm/llvm-project/issues/118450
RemoveDeadValues used to delete Values and analyzing the IR at the same time, because of that, `isMemoryEffectFree` got invalid IR with half-deleted linalg.generic operation. This PR separates analysis and cleanup to prevent such situation.
Thank you!
>From 2db7609b851d249e16a96cc40cdc7f710575b0bc Mon Sep 17 00:00:00 2001
From: Renat Idrisov <parsifal-47 at users.noreply.github.com>
Date: Wed, 25 Dec 2024 02:30:22 +0000
Subject: [PATCH] Prevent invalid IR from being passed outside of
RemoveDeadValues
---
mlir/lib/Transforms/RemoveDeadValues.cpp | 206 +++++++++++++------
mlir/test/Transforms/remove-dead-values.mlir | 26 +++
2 files changed, 172 insertions(+), 60 deletions(-)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 3429008b2f241a..5d4ec66d6905a4 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -72,15 +72,54 @@ using namespace mlir::dataflow;
namespace {
+// Set of structures below to be filled with operations and arguments to erase.
+// This is done to separate analysis and tree modification phases,
+// otherwise analysis is operating on half-deleted tree which is incorrect.
+
+struct CleanupFunction {
+ FunctionOpInterface funcOp;
+ BitVector nonLiveArgs;
+ BitVector nonLiveRets;
+};
+
+struct CleanupOperands {
+ Operation *op;
+ BitVector nonLiveOperands;
+};
+
+struct CleanupResults {
+ Operation *op;
+ BitVector nonLiveResults;
+};
+
+struct CleanupBlockArgs {
+ Block *b;
+ BitVector nonLiveArgs;
+};
+
+struct CleanupSuccessorOperands {
+ BranchOpInterface branch;
+ unsigned index;
+ BitVector nonLiveOperands;
+};
+
+struct CleanupList {
+ SmallVector<Operation *> operations;
+ SmallVector<Value> values;
+ SmallVector<CleanupFunction> functions;
+ SmallVector<CleanupOperands> operands;
+ SmallVector<CleanupResults> results;
+ SmallVector<CleanupBlockArgs> blocks;
+ SmallVector<CleanupSuccessorOperands> successorOperands;
+};
+
// Some helper functions...
/// Return true iff at least one value in `values` is live, given the liveness
/// information in `la`.
-static bool hasLive(ValueRange values, RunLivenessAnalysis &la) {
+static bool hasLive(ValueRange values, const DenseSet<Value> &deletionSet, RunLivenessAnalysis &la) {
for (Value value : values) {
- // If there is a null value, it implies that it was dropped during the
- // execution of this pass, implying that it was non-live.
- if (!value)
+ if (deletionSet.contains(value))
continue;
const Liveness *liveness = la.getLiveness(value);
@@ -92,11 +131,11 @@ static bool hasLive(ValueRange values, RunLivenessAnalysis &la) {
/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
/// i-th value in `values` is live, given the liveness information in `la`.
-static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) {
+static BitVector markLives(ValueRange values, const DenseSet<Value> &deletionSet, RunLivenessAnalysis &la) {
BitVector lives(values.size(), true);
for (auto [index, value] : llvm::enumerate(values)) {
- if (!value) {
+ if (deletionSet.contains(value)) {
lives.reset(index);
continue;
}
@@ -115,6 +154,18 @@ static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) {
return lives;
}
+// DeletionSet is used to track the Values that are scheduled for removal
+void updateDeletionSet(DenseSet<Value> &deletionSet, ValueRange range, const BitVector &nonLive) {
+ for (auto [index, result] : llvm::enumerate(range)) {
+ if (!nonLive[index]) continue;
+ deletionSet.insert(result);
+ }
+}
+
+void updateDeletionSet(DenseSet<Value> &deletionSet, Operation *op, const BitVector &nonLive) {
+ updateDeletionSet(deletionSet, op->getResults(), nonLive);
+}
+
/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
/// is 1.
static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
@@ -174,43 +225,43 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
/// It is assumed that `op` is simple. Here, a simple op is one which isn't a
/// function-like op, a call-like op, a region branch op, a branch op, a region
/// branch terminator op, or return-like.
-static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) {
- if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la))
+static void cleanSimpleOp(CleanupList &cl, DenseSet<Value> &deletionSet, Operation *op, RunLivenessAnalysis &la) {
+ if (!isMemoryEffectFree(op) || hasLive(op->getResults(), deletionSet, la))
return;
- op->dropAllUses();
- op->erase();
+ cl.operations.push_back(op);
+ updateDeletionSet(deletionSet, op, BitVector(op->getNumResults(), true));
}
/// Clean a function-like op `funcOp`, given the liveness information in `la`
/// and the IR in `module`. Here, cleaning means:
/// (1) Dropping the uses of its unnecessary (non-live) arguments,
-/// (2) Erasing these arguments,
-/// (3) Erasing their corresponding operands from its callers,
+/// (2) Erasing their corresponding operands from its callers,
+/// (3) Erasing these arguments,
/// (4) Erasing its unnecessary terminator operands (return values that are
/// non-live across all callers),
/// (5) Dropping the uses of these return values from its callers, AND
/// (6) Erasing these return values
/// iff it is not public or external.
-static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
+static void cleanFuncOp(CleanupList &cl, DenseSet<Value> &deletionSet,
+ FunctionOpInterface funcOp, Operation *module,
RunLivenessAnalysis &la) {
if (funcOp.isPublic() || funcOp.isExternal())
return;
// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(funcOp.getArguments());
- BitVector nonLiveArgs = markLives(arguments, la);
+ BitVector nonLiveArgs = markLives(arguments, deletionSet, la);
nonLiveArgs = nonLiveArgs.flip();
// Do (1).
for (auto [index, arg] : llvm::enumerate(arguments))
- if (arg && nonLiveArgs[index])
- arg.dropAllUses();
+ if (arg && nonLiveArgs[index]) {
+ cl.values.push_back(arg);
+ deletionSet.insert(arg);
+ }
// Do (2).
- funcOp.eraseArguments(nonLiveArgs);
-
- // Do (3).
SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
@@ -222,7 +273,7 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
for (int index : nonLiveArgs.set_bits())
nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
- callOp->eraseOperands(nonLiveCallOperands);
+ cl.operands.push_back({callOp, nonLiveCallOperands});
}
// Get the list of unnecessary terminator operands (return values that are
@@ -253,26 +304,27 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
- BitVector liveCallRets = markLives(callOp->getResults(), la);
+ BitVector liveCallRets = markLives(callOp->getResults(), deletionSet, la);
nonLiveRets &= liveCallRets.flip();
}
- // Do (4).
+ // Do (3).
// Note that in the absence of control flow ops forcing the control to go from
// the entry (first) block to the other blocks, the control never reaches any
// block other than the entry block, because every block has a terminator.
for (Block &block : funcOp.getBlocks()) {
Operation *returnOp = block.getTerminator();
if (returnOp && returnOp->getNumOperands() == numReturns)
- returnOp->eraseOperands(nonLiveRets);
+ cl.operands.push_back({returnOp, nonLiveRets});
}
- funcOp.eraseResults(nonLiveRets);
+ cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
// Do (5) and (6).
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
- dropUsesAndEraseResults(callOp, nonLiveRets);
+ cl.results.push_back({callOp, nonLiveRets});
+ updateDeletionSet(deletionSet, callOp, nonLiveRets);
}
}
@@ -297,18 +349,19 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
/// It is important to note that values in this op flow from operands and
/// terminator operands (successor operands) to arguments and results (successor
/// inputs).
-static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
+static void cleanRegionBranchOp(CleanupList &cl, DenseSet<Value> &deletionSet,
+ RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la) {
// Mark live results of `regionBranchOp` in `liveResults`.
auto markLiveResults = [&](BitVector &liveResults) {
- liveResults = markLives(regionBranchOp->getResults(), la);
+ liveResults = markLives(regionBranchOp->getResults(), deletionSet, la);
};
// Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
for (Region ®ion : regionBranchOp->getRegions()) {
SmallVector<Value> arguments(region.front().getArguments());
- BitVector regionLiveArgs = markLives(arguments, la);
+ BitVector regionLiveArgs = markLives(arguments, deletionSet, la);
liveArgs[®ion] = regionLiveArgs;
}
};
@@ -497,9 +550,8 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// It could never be live because of this op but its liveness could have been
// attributed to something else.
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
- !hasLive(regionBranchOp->getResults(), la)) {
- regionBranchOp->dropAllUses();
- regionBranchOp->erase();
+ !hasLive(regionBranchOp->getResults(), deletionSet, la)) {
+ cl.operations.push_back(regionBranchOp.getOperation());
return;
}
@@ -538,29 +590,27 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
terminatorOperandsToKeep);
// Do (1).
- regionBranchOp->eraseOperands(operandsToKeep.flip());
+ cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
// Do (2.a) and (2.b).
for (Region ®ion : regionBranchOp->getRegions()) {
assert(!region.empty() && "expected a non-empty region in an op "
"implementing `RegionBranchOpInterface`");
- for (auto [index, arg] : llvm::enumerate(region.front().getArguments())) {
- if (argsToKeep[®ion][index])
- continue;
- if (arg)
- arg.dropAllUses();
- }
- region.front().eraseArguments(argsToKeep[®ion].flip());
+ BitVector argsToRemove = argsToKeep[®ion].flip();
+ cl.blocks.push_back({®ion.front(), argsToRemove});
+ updateDeletionSet(deletionSet, region.front().getArguments(), argsToRemove);
}
// Do (2.c).
for (Region ®ion : regionBranchOp->getRegions()) {
Operation *terminator = region.front().getTerminator();
- terminator->eraseOperands(terminatorOperandsToKeep[terminator].flip());
+ cl.operands.push_back({terminator, terminatorOperandsToKeep[terminator].flip()});
}
// Do (3) and (4).
- dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip());
+ BitVector resultsToRemove = resultsToKeep.flip();
+ updateDeletionSet(deletionSet, regionBranchOp.getOperation(), resultsToRemove);
+ cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
}
// 1. Iterate over each successor block of the given BranchOpInterface
@@ -572,7 +622,8 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// c. Mark each operand as live or dead based on the analysis.
// 3. Remove dead operands from the branch operation and arguments accordingly
-static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
+static void cleanBranchOp(CleanupList &cl, DenseSet<Value> &deletionSet,
+ BranchOpInterface branchOp, RunLivenessAnalysis &la) {
unsigned numSuccessors = branchOp->getNumSuccessors();
// Do (1)
@@ -588,22 +639,53 @@ static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
operandValues.push_back(successorOperands[operandIdx]);
}
- BitVector successorLiveOperands = markLives(operandValues, la);
-
// Do (3)
- for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) {
- if (!successorLiveOperands[argIdx]) {
- if (successorBlock->getNumArguments() < successorOperands.size()) {
- // if block was cleaned through a different code path
- // we only need to remove operands from the invokation
- successorOperands.erase(argIdx);
- continue;
- }
+ BitVector successorNonLive = markLives(operandValues, deletionSet, la).flip();
+ updateDeletionSet(deletionSet, successorBlock->getArguments(), successorNonLive);
+ cl.blocks.push_back({successorBlock, successorNonLive});
+ cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
+ }
+}
+
+void cleanup(CleanupList &cl) {
+ for (auto &op: cl.operations) {
+ op->dropAllUses();
+ op->erase();
+ }
+
+ for (auto &v: cl.values) {
+ v.dropAllUses();
+ }
+
+ for (auto &f: cl.functions) {
+ f.funcOp.eraseArguments(f.nonLiveArgs);
+ f.funcOp.eraseResults(f.nonLiveRets);
+ }
+
+ for (auto &o: cl.operands) {
+ o.op->eraseOperands(o.nonLiveOperands); }
+
+ for (auto &r: cl.results) {
+ dropUsesAndEraseResults(r.op, r.nonLiveResults);
+ }
- successorBlock->getArgument(argIdx).dropAllUses();
- successorOperands.erase(argIdx);
- successorBlock->eraseArgument(argIdx);
- }
+ for (auto &b: cl.blocks) {
+ // blocks that are accessed via multiple codepaths processed once
+ if (b.b->getNumArguments() != b.nonLiveArgs.size()) continue;
+ for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
+ if (!b.nonLiveArgs[i]) continue;
+ b.b->getArgument(i).dropAllUses();
+ b.b->eraseArgument(i);
+ }
+ }
+ for (auto &op: cl.successorOperands) {
+ SuccessorOperands successorOperands =
+ op.branch.getSuccessorOperands(op.index);
+ // blocks that are accessed via multiple codepaths processed once
+ if (successorOperands.size() != op.nonLiveOperands.size()) continue;
+ for (int i = successorOperands.size() - 1; i >= 0; --i) {
+ if (!op.nonLiveOperands[i]) continue;
+ successorOperands.erase(i);
}
}
}
@@ -616,14 +698,16 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
void RemoveDeadValues::runOnOperation() {
auto &la = getAnalysis<RunLivenessAnalysis>();
Operation *module = getOperation();
+ DenseSet<Value> deletionSet;
+ CleanupList cl;
module->walk([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
- cleanFuncOp(funcOp, module, la);
+ cleanFuncOp(cl, deletionSet, funcOp, module, la);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
- cleanRegionBranchOp(regionBranchOp, la);
+ cleanRegionBranchOp(cl, deletionSet, regionBranchOp, la);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
- cleanBranchOp(branchOp, la);
+ cleanBranchOp(cl, deletionSet, branchOp, la);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
@@ -631,9 +715,11 @@ void RemoveDeadValues::runOnOperation() {
// Nothing to do because this op is associated with a function op and gets
// cleaned when the latter is cleaned.
} else {
- cleanSimpleOp(op, la);
+ cleanSimpleOp(cl, deletionSet, op, la);
}
});
+
+ cleanup(cl);
}
std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 9273ac01e7ccec..fe7bcbc7c490b6 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -73,6 +73,32 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
// -----
+// Checking that the arguments of linalg.generic are properly handled
+// All code below is removed as a result of the pass
+//
+#map = affine_map<(d0, d1, d2) -> (0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+module {
+ func.func @main() {
+ %cst_3 = arith.constant dense<54> : tensor<1x25x13xi32>
+ %cst_7 = arith.constant dense<11> : tensor<1x25x13xi32>
+ // CHECK-NOT: arith.constant
+ %0 = tensor.empty() : tensor<1x25x13xi32>
+ // CHECK-NOT: tensor
+ %1 = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_3, %cst_7 : tensor<1x25x13xi32>, tensor<1x25x13xi32>) outs(%0 : tensor<1x25x13xi32>) {
+ // CHECK-NOT: linalg.generic
+ ^bb0(%in: i32, %in_15: i32, %out: i32):
+ %29 = arith.xori %in, %in_15 : i32
+ // CHECK-NOT: arith.xori
+ linalg.yield %29 : i32
+ // CHECK-NOT: linalg.yield
+ } -> tensor<1x25x13xi32>
+ return
+ }
+}
+
+// -----
+
// Note that this cleanup cannot be done by the `canonicalize` pass.
//
// CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() {
More information about the Mlir-commits
mailing list