[Mlir-commits] [mlir] [MLIR] Removing dead values for branches (PR #117501)
Renat Idrisov
llvmlistbot at llvm.org
Mon Nov 25 14:33:15 PST 2024
https://github.com/parsifal-47 updated https://github.com/llvm/llvm-project/pull/117501
>From 437db73aa38da618dc10ef1a113b98034944bfdc Mon Sep 17 00:00:00 2001
From: Renat Idrisov <parsifal-47 at users.noreply.github.com>
Date: Sun, 24 Nov 2024 18:26:41 +0000
Subject: [PATCH 1/4] Removing dead values for branches
---
mlir/lib/Transforms/RemoveDeadValues.cpp | 55 ++++++++++++++------
mlir/test/Transforms/remove-dead-values.mlir | 23 ++++++--
2 files changed, 60 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 0aa9dcb36681b3..638726e1212772 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -563,6 +563,44 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip());
}
+// 1. Iterate over each successor block of the given BranchOpInterface
+// operation.
+// 2. For each successor block:
+// a. Retrieve the operands passed to the successor.
+// b. Use the provided liveness analysis (`RunLivenessAnalysis`) to determine
+// which
+// operands are live in the successor block.
+// c. Mark each operand as live or dead based on the analysis.
+// 3. Remove dead operands from the branch operation and arguments accordingly
+
+static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
+ unsigned numSuccessors = branchOp->getNumSuccessors();
+
+ // Do (1)
+ for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
+ Block *successorBlock = branchOp->getSuccessor(succIdx);
+
+ // Do (2)
+ SuccessorOperands successorOperands =
+ branchOp.getSuccessorOperands(succIdx);
+ SmallVector<Value> operandValues;
+ for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
+ ++operandIdx) {
+ operandValues.push_back(successorOperands[operandIdx]);
+ }
+
+ BitVector successorLiveOperands = markLives(operandValues, la);
+
+ // Do (3)
+ for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) {
+ if (!successorLiveOperands[argIdx]) {
+ successorOperands.erase(argIdx);
+ successorBlock->eraseArgument(argIdx);
+ }
+ }
+ }
+}
+
struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
void runOnOperation() override;
};
@@ -572,26 +610,13 @@ void RemoveDeadValues::runOnOperation() {
auto &la = getAnalysis<RunLivenessAnalysis>();
Operation *module = getOperation();
- // The removal of non-live values is performed iff there are no branch ops,
- // and all symbol user ops present in the IR are call-like.
- WalkResult acceptableIR = module->walk([&](Operation *op) {
- if (op == module)
- return WalkResult::advance();
- if (isa<BranchOpInterface>(op)) {
- op->emitError() << "cannot optimize an IR with branch ops\n";
- return WalkResult::interrupt();
- }
- return WalkResult::advance();
- });
-
- if (acceptableIR.wasInterrupted())
- return signalPassFailure();
-
module->walk([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
cleanFuncOp(funcOp, module, la);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
cleanRegionBranchOp(regionBranchOp, la);
+ } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
+ cleanBranchOp(branchOp, la);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 826f6159a36b67..fda7ef3fe673e4 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -28,15 +28,32 @@ module @named_module_acceptable {
// -----
-// The IR remains untouched because of the presence of a branch op `cf.cond_br`.
+// The IR is optimized regardless of the presence of a branch op `cf.cond_br`.
//
-func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) {
+func.func @acceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) {
%non_live = arith.constant 0 : i32
- // expected-error @+1 {{cannot optimize an IR with branch ops}}
+ // CHECK-NOT: non_live
cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32)
^bb1(%non_live_0 : i32):
+ // CHECK-NOT: non_live_0
cf.br ^bb3
^bb2(%non_live_1 : i32):
+ // CHECK-NOT: non_live_1
+ cf.br ^bb3
+^bb3:
+ return
+}
+
+// -----
+
+// Arguments of unconditional branch op `cf.br` are properly removed.
+//
+func.func @acceptable_ir_has_cleanable_simple_op_with_unconditional_branch_op(%arg0: i1) {
+ %non_live = arith.constant 0 : i32
+ // CHECK-NOT: non_live
+ cf.br ^bb1(%non_live : i32)
+^bb1(%non_live_1 : i32):
+ // CHECK-NOT: non_live_1
cf.br ^bb3
^bb3:
return
>From 1663984b655856e91bdda9e120a42b0c318c6954 Mon Sep 17 00:00:00 2001
From: Renat Idrisov <parsifal-47 at users.noreply.github.com>
Date: Sun, 24 Nov 2024 23:00:37 +0000
Subject: [PATCH 2/4] Adding a test with scf.for and iter_args
---
mlir/test/Transforms/remove-dead-values.mlir | 19 +++++++++++++++++++
1 file changed, 19 insertions(+)
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index fda7ef3fe673e4..07136640732195 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -61,6 +61,25 @@ func.func @acceptable_ir_has_cleanable_simple_op_with_unconditional_branch_op(%a
// -----
+// Checking that iter_args are properly handled
+//
+func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %non_live = arith.constant 0 : index
+ // CHECK-NOT: non_live
+ %result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) {
+ %new_live = arith.addi %live_arg, %i : index
+ // CHECK-NOT: non_live_arg
+ scf.yield %new_live, %non_live_arg : index, index
+ }
+ // CHECK-NOT: result_non_live
+ return %result : index
+}
+
+// -----
+
// Note that this cleanup cannot be done by the `canonicalize` pass.
//
// CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() {
>From 4617d02432e0d8f04e859d47829eca2bdba2ecf5 Mon Sep 17 00:00:00 2001
From: Renat Idrisov <parsifal-47 at users.noreply.github.com>
Date: Mon, 25 Nov 2024 21:02:35 +0000
Subject: [PATCH 3/4] Addressing review comment
---
mlir/test/Transforms/remove-dead-values.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 07136640732195..62c575aceeb4da 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -71,7 +71,7 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
// CHECK-NOT: non_live
%result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) {
%new_live = arith.addi %live_arg, %i : index
- // CHECK-NOT: non_live_arg
+ // CHECK: scf.for %[[ARG_0:.*]] = %c0 to %c10 step %c1 iter_args(%[[ARG_1:.*]] = %arg0)
scf.yield %new_live, %non_live_arg : index, index
}
// CHECK-NOT: result_non_live
>From 473fc46afbc8d9542354ab34842e1ea5d5529dec Mon Sep 17 00:00:00 2001
From: Renat Idrisov <parsifal-47 at users.noreply.github.com>
Date: Mon, 25 Nov 2024 22:32:38 +0000
Subject: [PATCH 4/4] Addressing Code Review Feedback
---
mlir/lib/Transforms/RemoveDeadValues.cpp | 18 +++++++++++++++++-
mlir/test/Transforms/remove-dead-values.mlir | 5 +++--
2 files changed, 20 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 638726e1212772..0e43263cf2fe80 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -165,6 +165,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
return opOperands;
}
+// Check if any of the operations implements BranchOpInterface
+template <typename UserRange>
+static bool anyBranchUsers(const UserRange &users) {
+ for (auto user : users) {
+ if (auto subBranchOp = dyn_cast<BranchOpInterface>(user)) {
+ return true;
+ }
+ }
+ return false;
+}
+
/// Clean a simple op `op`, given the liveness analysis information in `la`.
/// Here, cleaning means:
/// (1) Dropping all its uses, AND
@@ -175,7 +186,8 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
/// symbol op, a symbol-user op, a region branch op, a branch op, a region
/// branch terminator op, or return-like.
static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) {
- if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la))
+ if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la) ||
+ anyBranchUsers(op->getUsers()))
return;
op->dropAllUses();
@@ -594,6 +606,10 @@ static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
// Do (3)
for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) {
if (!successorLiveOperands[argIdx]) {
+ if (anyBranchUsers(successorBlock->getArgument(argIdx).getUsers())) {
+ continue;
+ }
+
successorOperands.erase(argIdx);
successorBlock->eraseArgument(argIdx);
}
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 62c575aceeb4da..5f7e518c40649e 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -54,8 +54,9 @@ func.func @acceptable_ir_has_cleanable_simple_op_with_unconditional_branch_op(%a
cf.br ^bb1(%non_live : i32)
^bb1(%non_live_1 : i32):
// CHECK-NOT: non_live_1
- cf.br ^bb3
-^bb3:
+ cf.br ^bb3(%non_live_1 : i32)
+ // CHECK-NOT: non_live_2
+^bb3(%non_live_2 : i32):
return
}
More information about the Mlir-commits
mailing list