[Mlir-commits] [mlir] [mlir][Transforms] Add dead code elimination pass (PR #106258)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 27 10:52:09 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
In the absence of a dedicated DCE pass, MLIR users sometimes resort to the canonicalizer pass to remove dead IR. The canonicalizer pass is quite expensive to run. This PR adds a lightweight dead code elimination pass that removes dead operation and dead blocks.
The pass performs 3 walks over the input IR.
1. Visit all operations in reverse dominance order and remove dead ops.
2. Visit all blocks in forward dominance order. This walk enumerates only reachable blocks. Collect all reachable blocks. (This walk could be combined with Step 1, but that would require an extension of the walker API.)
3. Visit all blocks and erase the ones that were not visited in Step 2.
---
Full diff: https://github.com/llvm/llvm-project/pull/106258.diff
5 Files Affected:
- (modified) mlir/include/mlir/Transforms/Passes.h (+4)
- (modified) mlir/include/mlir/Transforms/Passes.td (+14)
- (modified) mlir/lib/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Transforms/DeadCodeElimination.cpp (+75)
- (added) mlir/test/Transforms/dead-code-elimination.mlir (+130)
``````````diff
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 8e4a43c3f24586..d03b879f405af7 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -33,6 +33,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_CANONICALIZER
#define GEN_PASS_DECL_CONTROLFLOWSINK
#define GEN_PASS_DECL_CSEPASS
+#define GEN_PASS_DECL_DEADCODEELIMINATION
#define GEN_PASS_DECL_INLINER
#define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
#define GEN_PASS_DECL_MEM2REG
@@ -111,6 +112,9 @@ std::unique_ptr<Pass>
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,
std::function<void(OpPassManager &)> defaultPipelineBuilder);
+/// Creates an optimization pass to remove dead operations and blocks.
+std::unique_ptr<Pass> createDeadCodeEliminationPass();
+
/// Creates an optimization pass to remove dead values.
std::unique_ptr<Pass> createRemoveDeadValuesPass();
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 000d9f697618e6..b10b1be74d0fcb 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -93,6 +93,20 @@ def CSE : Pass<"cse"> {
];
}
+def DeadCodeElimination : Pass<"dce"> {
+ let summary = "Remove dead operations and blocks";
+ let description = [{
+ This pass eliminates dead operations and blocks.
+
+ Operations are eliminated if they have no users and no side effects. Blocks
+ are eliminated if they are not reachable.
+
+ Note: Graph regions are currently not supported and skipped by this pass.
+ }];
+
+ let constructor = "mlir::createDeadCodeEliminationPass()";
+}
+
def RemoveDeadValues : Pass<"remove-dead-values"> {
let summary = "Remove dead values";
let description = [{
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46a..4b90774f972ced 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRTransforms
CompositePass.cpp
ControlFlowSink.cpp
CSE.cpp
+ DeadCodeElimination.cpp
GenerateRuntimeVerification.cpp
InlinerPass.cpp
LocationSnapshot.cpp
diff --git a/mlir/lib/Transforms/DeadCodeElimination.cpp b/mlir/lib/Transforms/DeadCodeElimination.cpp
new file mode 100644
index 00000000000000..33a12b84daa46f
--- /dev/null
+++ b/mlir/lib/Transforms/DeadCodeElimination.cpp
@@ -0,0 +1,75 @@
+//===- DeadCodeElimination.cpp - Dead Code Elimination --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_DEADCODEELIMINATION
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct DeadCodeElimination
+ : public impl::DeadCodeEliminationBase<DeadCodeElimination> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void DeadCodeElimination::runOnOperation() {
+ Operation *topLevel = getOperation();
+
+ // Visit operations in reverse dominance order. This visits all users before
+ // their definitions. (Also takes into account unstructured control flow
+ // between blocks.)
+ topLevel->walk<WalkOrder::PostOrder,
+ ReverseDominanceIterator</*NoGraphRegions=*/false>>(
+ [&](Operation *op) {
+ // Do not remove the top-level op.
+ if (op == topLevel)
+ return WalkResult::advance();
+
+ // Do not remove ops from regions that may be graph regions.
+ if (mayBeGraphRegion(*op->getParentRegion()))
+ return WalkResult::advance();
+
+ // Remove dead ops.
+ if (isOpTriviallyDead(op)) {
+ op->erase();
+ return WalkResult::skip();
+ }
+
+ return WalkResult::advance();
+ });
+
+ // ReverseDominanceIterator does not visit unreachable blocks. Erase those in
+ // a second walk. First collect all reachable blocks.
+ // TODO: Extend walker API to provide a callback for both ops and blocks, so
+ // that reachable blocks can be collected in the same walk.
+ DenseSet<Block *> reachableBlocks;
+ topLevel->walk<WalkOrder::PostOrder,
+ ForwardDominanceIterator</*NoGraphRegions=*/false>>(
+ [&](Block *block) { reachableBlocks.insert(block); });
+ // Erase all blocks that were not visited. These are unreachable and thus
+ // dead.
+ topLevel->walk<WalkOrder::PostOrder>([&](Block *block) {
+ if (!reachableBlocks.contains(block)) {
+ block->dropAllDefinedValueUses();
+ block->erase();
+ }
+ });
+}
+
+std::unique_ptr<Pass> mlir::createDeadCodeEliminationPass() {
+ return std::make_unique<DeadCodeElimination>();
+}
diff --git a/mlir/test/Transforms/dead-code-elimination.mlir b/mlir/test/Transforms/dead-code-elimination.mlir
new file mode 100644
index 00000000000000..67130bb3366d94
--- /dev/null
+++ b/mlir/test/Transforms/dead-code-elimination.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-opt -dead-code-elimination -split-input-file %s
+
+// CHECK-LABEL: func @simple_test(
+// CHECK-SAME: %[[arg0:.*]]: i16)
+// CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : i16
+// CHECK-NEXT: %[[add:.*]] = arith.addi %[[c5]], %[[arg0]]
+// CHECK-NEXT: return %[[add]]
+func.func @simple_test(%arg0: i16) -> i16 {
+ %0 = arith.constant 5 : i16
+ %1 = arith.addi %0, %arg0 : i16
+ %2 = arith.addi %1, %1 : i16
+ %3 = arith.addi %2, %1 : i16
+ return %1 : i16
+}
+
+// -----
+
+// CHECK-LABEL: func @eliminate_from_region
+// CHECK-NEXT: scf.for {{.*}} {
+// CHECK-NEXT: arith.constant
+// CHECK-NEXT: "test.print"
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+func.func @eliminate_from_region(%lb: index, %ub: index, %step: index) {
+ scf.for %iv = %lb to %ub step %step {
+ %0 = arith.constant 5 : i16
+ %1 = arith.constant 10 : i16
+ "test.print"(%0) : (i16) -> ()
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @eliminate_op_with_region
+// CHECK-NEXT: return
+func.func @eliminate_op_with_region(%lb: index, %ub: index, %step: index) {
+ %c0 = arith.constant 0 : i16
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%iter = %c0) -> i16 {
+ %0 = arith.constant 5 : i16
+ %added = arith.addi %iter, %0 : i16
+ scf.yield %added : i16
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @unstructured_control_flow(
+// CHECK-SAME: %[[arg0:.*]]: i16)
+// CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : i16
+// CHECK-NEXT: cf.br ^[[bb2:.*]]
+// CHECK-NEXT: ^[[bb1:.*]]: // pred
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb2]]:
+// CHECK-NEXT: %[[add:.*]] = arith.addi %[[c5]], %[[arg0]]
+// CHECK-NEXT: cf.br ^[[bb1]]
+// CHECK-NEXT: ^[[bb3]]:
+// CHECK-NEXT: return %[[add]]
+func.func @unstructured_control_flow(%arg0: i16) -> i16 {
+ %0 = arith.constant 5 : i16
+ cf.br ^bb2
+^bb1:
+ %3 = arith.addi %1, %1 : i16
+ %4 = arith.addi %3, %2 : i16
+ cf.br ^bb3
+^bb2:
+ %1 = arith.addi %0, %arg0 : i16
+ %2 = arith.subi %0, %arg0 : i16
+ cf.br ^bb1
+^bb3:
+ return %1 : i16
+}
+
+// -----
+
+// CHECK-LABEL: func @remove_dead_block()
+// CHECK-NEXT: cf.br ^[[bb2:.*]]
+// CHECK-NEXT: ^[[bb2]]:
+// CHECK-NEXT: return
+func.func @remove_dead_block() {
+ cf.br ^bb2
+^bb1:
+ %0 = arith.constant 0 : i16
+ cf.br ^bb2
+^bb2:
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @potentially_side_effecting_op()
+// CHECK-NEXT: "test.print"
+// CHECK-NEXT: return
+func.func @potentially_side_effecting_op() {
+ "test.print"() : () -> ()
+ return
+}
+
+// -----
+
+// Note: Graph regions are not supported and skipped.
+
+// CHECK-LABEL: test.graph_region {
+// CHECK-NEXT: arith.addi
+// CHECK-NEXT: arith.constant 5 : i16
+// CHECK-NEXT: "test.baz"
+// CHECK-NEXT: }
+test.graph_region {
+ %1 = arith.addi %0, %0 : i16
+ %0 = arith.constant 5 : i16
+ "test.baz"() : () -> i32
+}
+
+// -----
+
+// CHECK-LABEL: dead_blocks()
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]:
+// CHECK-NEXT: return
+func.func @dead_blocks() {
+ cf.br ^bb3
+^bb1:
+ "test.print"() : () -> ()
+ cf.br ^bb2
+^bb2:
+ cf.br ^bb1
+^bb3:
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/106258
More information about the Mlir-commits
mailing list