[Mlir-commits] [mlir] [mlir][Transforms] Add dead code elimination pass (PR #106258)
Matthias Springer
llvmlistbot at llvm.org
Wed Aug 28 11:35:07 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/106258
>From b6d5e8f35687a320500ca499f35076db4137dddf Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 27 Aug 2024 19:29:54 +0200
Subject: [PATCH] [mlir][Transforms] Add dead code elimination pass
---
mlir/include/mlir/Transforms/Passes.h | 4 +
mlir/include/mlir/Transforms/Passes.td | 14 ++
mlir/lib/Transforms/CMakeLists.txt | 1 +
mlir/lib/Transforms/DeadCodeElimination.cpp | 63 +++++++++
.../Transforms/dead-code-elimination.mlir | 130 ++++++++++++++++++
5 files changed, 212 insertions(+)
create mode 100644 mlir/lib/Transforms/DeadCodeElimination.cpp
create mode 100644 mlir/test/Transforms/dead-code-elimination.mlir
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 8e4a43c3f24586..d03b879f405af7 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -33,6 +33,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_CANONICALIZER
#define GEN_PASS_DECL_CONTROLFLOWSINK
#define GEN_PASS_DECL_CSEPASS
+#define GEN_PASS_DECL_DEADCODEELIMINATION
#define GEN_PASS_DECL_INLINER
#define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
#define GEN_PASS_DECL_MEM2REG
@@ -111,6 +112,9 @@ std::unique_ptr<Pass>
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,
std::function<void(OpPassManager &)> defaultPipelineBuilder);
+/// Creates an optimization pass to remove dead operations and blocks.
+std::unique_ptr<Pass> createDeadCodeEliminationPass();
+
/// Creates an optimization pass to remove dead values.
std::unique_ptr<Pass> createRemoveDeadValuesPass();
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 000d9f697618e6..b10b1be74d0fcb 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -93,6 +93,20 @@ def CSE : Pass<"cse"> {
];
}
+def DeadCodeElimination : Pass<"dce"> {
+ let summary = "Remove dead operations and blocks";
+ let description = [{
+ This pass eliminates dead operations and blocks.
+
+ Operations are eliminated if they have no users and no side effects. Blocks
+ are eliminated if they are not reachable.
+
+ Note: Graph regions are currently not supported and skipped by this pass.
+ }];
+
+ let constructor = "mlir::createDeadCodeEliminationPass()";
+}
+
def RemoveDeadValues : Pass<"remove-dead-values"> {
let summary = "Remove dead values";
let description = [{
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46a..4b90774f972ced 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRTransforms
CompositePass.cpp
ControlFlowSink.cpp
CSE.cpp
+ DeadCodeElimination.cpp
GenerateRuntimeVerification.cpp
InlinerPass.cpp
LocationSnapshot.cpp
diff --git a/mlir/lib/Transforms/DeadCodeElimination.cpp b/mlir/lib/Transforms/DeadCodeElimination.cpp
new file mode 100644
index 00000000000000..417fb110dfbfe9
--- /dev/null
+++ b/mlir/lib/Transforms/DeadCodeElimination.cpp
@@ -0,0 +1,63 @@
+//===- DeadCodeElimination.cpp - Dead Code Elimination --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_DEADCODEELIMINATION
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct DeadCodeElimination
+ : public impl::DeadCodeEliminationBase<DeadCodeElimination> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void DeadCodeElimination::runOnOperation() {
+ Operation *topLevel = getOperation();
+
+ // Visit operations in reverse dominance order. This visits all users before
+ // their definitions. (Also takes into account unstructured control flow
+ // between blocks.)
+ topLevel->walk<WalkOrder::PostOrder,
+ ReverseDominanceIterator</*NoGraphRegions=*/false>>(
+ [&](Operation *op) {
+ // Do not remove the top-level op.
+ if (op == topLevel)
+ return WalkResult::advance();
+
+ // Do not remove ops from regions that may be graph regions.
+ if (mayBeGraphRegion(*op->getParentRegion()))
+ return WalkResult::advance();
+
+ // Remove dead ops.
+ if (isOpTriviallyDead(op)) {
+ op->erase();
+ return WalkResult::skip();
+ }
+
+ return WalkResult::advance();
+ });
+
+ // Erase all unreachable blocks.
+ IRRewriter rewriter(getOperation()->getContext());
+ (void)eraseUnreachableBlocks(rewriter, topLevel->getRegions());
+}
+
+std::unique_ptr<Pass> mlir::createDeadCodeEliminationPass() {
+ return std::make_unique<DeadCodeElimination>();
+}
diff --git a/mlir/test/Transforms/dead-code-elimination.mlir b/mlir/test/Transforms/dead-code-elimination.mlir
new file mode 100644
index 00000000000000..35bed2ae3ba11d
--- /dev/null
+++ b/mlir/test/Transforms/dead-code-elimination.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-opt -dce -split-input-file %s
+
+// CHECK-LABEL: func @simple_test(
+// CHECK-SAME: %[[arg0:.*]]: i16)
+// CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : i16
+// CHECK-NEXT: %[[add:.*]] = arith.addi %[[c5]], %[[arg0]]
+// CHECK-NEXT: return %[[add]]
+func.func @simple_test(%arg0: i16) -> i16 {
+ %0 = arith.constant 5 : i16
+ %1 = arith.addi %0, %arg0 : i16
+ %2 = arith.addi %1, %1 : i16
+ %3 = arith.addi %2, %1 : i16
+ return %1 : i16
+}
+
+// -----
+
+// CHECK-LABEL: func @eliminate_from_region
+// CHECK-NEXT: scf.for {{.*}} {
+// CHECK-NEXT: arith.constant
+// CHECK-NEXT: "test.print"
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+func.func @eliminate_from_region(%lb: index, %ub: index, %step: index) {
+ scf.for %iv = %lb to %ub step %step {
+ %0 = arith.constant 5 : i16
+ %1 = arith.constant 10 : i16
+ "test.print"(%0) : (i16) -> ()
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @eliminate_op_with_region
+// CHECK-NEXT: return
+func.func @eliminate_op_with_region(%lb: index, %ub: index, %step: index) {
+ %c0 = arith.constant 0 : i16
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%iter = %c0) -> i16 {
+ %0 = arith.constant 5 : i16
+ %added = arith.addi %iter, %0 : i16
+ scf.yield %added : i16
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @unstructured_control_flow(
+// CHECK-SAME: %[[arg0:.*]]: i16)
+// CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : i16
+// CHECK-NEXT: cf.br ^[[bb2:.*]]
+// CHECK-NEXT: ^[[bb1:.*]]: // pred
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb2]]:
+// CHECK-NEXT: %[[add:.*]] = arith.addi %[[c5]], %[[arg0]]
+// CHECK-NEXT: cf.br ^[[bb1]]
+// CHECK-NEXT: ^[[bb3]]:
+// CHECK-NEXT: return %[[add]]
+func.func @unstructured_control_flow(%arg0: i16) -> i16 {
+ %0 = arith.constant 5 : i16
+ cf.br ^bb2
+^bb1:
+ %3 = arith.addi %1, %1 : i16
+ %4 = arith.addi %3, %2 : i16
+ cf.br ^bb3
+^bb2:
+ %1 = arith.addi %0, %arg0 : i16
+ %2 = arith.subi %0, %arg0 : i16
+ cf.br ^bb1
+^bb3:
+ return %1 : i16
+}
+
+// -----
+
+// CHECK-LABEL: func @remove_dead_block()
+// CHECK-NEXT: cf.br ^[[bb2:.*]]
+// CHECK-NEXT: ^[[bb2]]:
+// CHECK-NEXT: return
+func.func @remove_dead_block() {
+ cf.br ^bb2
+^bb1:
+ %0 = arith.constant 0 : i16
+ cf.br ^bb2
+^bb2:
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @potentially_side_effecting_op()
+// CHECK-NEXT: "test.print"
+// CHECK-NEXT: return
+func.func @potentially_side_effecting_op() {
+ "test.print"() : () -> ()
+ return
+}
+
+// -----
+
+// Note: Graph regions are not supported and skipped.
+
+// CHECK-LABEL: test.graph_region {
+// CHECK-NEXT: arith.addi
+// CHECK-NEXT: arith.constant 5 : i16
+// CHECK-NEXT: "test.baz"
+// CHECK-NEXT: }
+test.graph_region {
+ %1 = arith.addi %0, %0 : i16
+ %0 = arith.constant 5 : i16
+ "test.baz"() : () -> i32
+}
+
+// -----
+
+// CHECK-LABEL: dead_blocks()
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]:
+// CHECK-NEXT: return
+func.func @dead_blocks() {
+ cf.br ^bb3
+^bb1:
+ "test.print"() : () -> ()
+ cf.br ^bb2
+^bb2:
+ cf.br ^bb1
+^bb3:
+ return
+}
More information about the Mlir-commits
mailing list