[Mlir-commits] [mlir] [mlir][Transforms] Add dead code elimination pass (PR #106258)

Matthias Springer llvmlistbot at llvm.org
Wed Aug 28 11:35:07 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/106258

>From b6d5e8f35687a320500ca499f35076db4137dddf Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 27 Aug 2024 19:29:54 +0200
Subject: [PATCH] [mlir][Transforms] Add dead code elimination pass

---
 mlir/include/mlir/Transforms/Passes.h         |   4 +
 mlir/include/mlir/Transforms/Passes.td        |  14 ++
 mlir/lib/Transforms/CMakeLists.txt            |   1 +
 mlir/lib/Transforms/DeadCodeElimination.cpp   |  63 +++++++++
 .../Transforms/dead-code-elimination.mlir     | 130 ++++++++++++++++++
 5 files changed, 212 insertions(+)
 create mode 100644 mlir/lib/Transforms/DeadCodeElimination.cpp
 create mode 100644 mlir/test/Transforms/dead-code-elimination.mlir

diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 8e4a43c3f24586..d03b879f405af7 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -33,6 +33,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_CANONICALIZER
 #define GEN_PASS_DECL_CONTROLFLOWSINK
 #define GEN_PASS_DECL_CSEPASS
+#define GEN_PASS_DECL_DEADCODEELIMINATION
 #define GEN_PASS_DECL_INLINER
 #define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
 #define GEN_PASS_DECL_MEM2REG
@@ -111,6 +112,9 @@ std::unique_ptr<Pass>
 createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,
                   std::function<void(OpPassManager &)> defaultPipelineBuilder);
 
+/// Creates an optimization pass to remove dead operations and blocks.
+std::unique_ptr<Pass> createDeadCodeEliminationPass();
+
 /// Creates an optimization pass to remove dead values.
 std::unique_ptr<Pass> createRemoveDeadValuesPass();
 
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 000d9f697618e6..b10b1be74d0fcb 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -93,6 +93,20 @@ def CSE : Pass<"cse"> {
   ];
 }
 
+def DeadCodeElimination : Pass<"dce"> {
+  let summary = "Remove dead operations and blocks";
+  let description = [{
+    This pass eliminates dead operations and blocks.
+
+    Operations are eliminated if they have no users and no side effects. Blocks
+    are eliminated if they are not reachable.
+
+    Note: Graph regions are currently not supported and skipped by this pass.
+  }];
+
+  let constructor = "mlir::createDeadCodeEliminationPass()";
+}
+
 def RemoveDeadValues : Pass<"remove-dead-values"> {
   let summary = "Remove dead values";
   let description = [{
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46a..4b90774f972ced 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRTransforms
   CompositePass.cpp
   ControlFlowSink.cpp
   CSE.cpp
+  DeadCodeElimination.cpp
   GenerateRuntimeVerification.cpp
   InlinerPass.cpp
   LocationSnapshot.cpp
diff --git a/mlir/lib/Transforms/DeadCodeElimination.cpp b/mlir/lib/Transforms/DeadCodeElimination.cpp
new file mode 100644
index 00000000000000..417fb110dfbfe9
--- /dev/null
+++ b/mlir/lib/Transforms/DeadCodeElimination.cpp
@@ -0,0 +1,63 @@
+//===- DeadCodeElimination.cpp - Dead Code Elimination --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_DEADCODEELIMINATION
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct DeadCodeElimination
+    : public impl::DeadCodeEliminationBase<DeadCodeElimination> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void DeadCodeElimination::runOnOperation() {
+  Operation *topLevel = getOperation();
+
+  // Visit operations in reverse dominance order. This visits all users before
+  // their definitions. (Also takes into account unstructured control flow
+  // between blocks.)
+  topLevel->walk<WalkOrder::PostOrder,
+                 ReverseDominanceIterator</*NoGraphRegions=*/false>>(
+      [&](Operation *op) {
+        // Do not remove the top-level op.
+        if (op == topLevel)
+          return WalkResult::advance();
+
+        // Do not remove ops from regions that may be graph regions.
+        if (mayBeGraphRegion(*op->getParentRegion()))
+          return WalkResult::advance();
+
+        // Remove dead ops.
+        if (isOpTriviallyDead(op)) {
+          op->erase();
+          return WalkResult::skip();
+        }
+
+        return WalkResult::advance();
+      });
+
+  // Erase all unreachable blocks.
+  IRRewriter rewriter(getOperation()->getContext());
+  (void)eraseUnreachableBlocks(rewriter, topLevel->getRegions());
+}
+
+std::unique_ptr<Pass> mlir::createDeadCodeEliminationPass() {
+  return std::make_unique<DeadCodeElimination>();
+}
diff --git a/mlir/test/Transforms/dead-code-elimination.mlir b/mlir/test/Transforms/dead-code-elimination.mlir
new file mode 100644
index 00000000000000..35bed2ae3ba11d
--- /dev/null
+++ b/mlir/test/Transforms/dead-code-elimination.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-opt -dce -split-input-file %s
+
+// CHECK-LABEL: func @simple_test(
+//  CHECK-SAME:     %[[arg0:.*]]: i16)
+//  CHECK-NEXT:   %[[c5:.*]] = arith.constant 5 : i16
+//  CHECK-NEXT:   %[[add:.*]] = arith.addi %[[c5]], %[[arg0]]
+//  CHECK-NEXT:   return %[[add]]
+func.func @simple_test(%arg0: i16) -> i16 {
+  %0 = arith.constant 5 : i16
+  %1 = arith.addi %0, %arg0 : i16
+  %2 = arith.addi %1, %1 : i16
+  %3 = arith.addi %2, %1 : i16
+  return %1 : i16
+}
+
+// -----
+
+// CHECK-LABEL: func @eliminate_from_region
+//  CHECK-NEXT:   scf.for {{.*}} {
+//  CHECK-NEXT:     arith.constant
+//  CHECK-NEXT:     "test.print"
+//  CHECK-NEXT:   }
+//  CHECK-NEXT:   return
+func.func @eliminate_from_region(%lb: index, %ub: index, %step: index) {
+  scf.for %iv = %lb to %ub step %step {
+    %0 = arith.constant 5 : i16
+    %1 = arith.constant 10 : i16
+    "test.print"(%0) : (i16) -> ()
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @eliminate_op_with_region
+//  CHECK-NEXT:   return
+func.func @eliminate_op_with_region(%lb: index, %ub: index, %step: index) {
+  %c0 = arith.constant 0 : i16
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%iter = %c0) -> i16 {
+    %0 = arith.constant 5 : i16
+    %added = arith.addi %iter, %0 : i16
+    scf.yield %added : i16
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @unstructured_control_flow(
+//  CHECK-SAME:     %[[arg0:.*]]: i16)
+//  CHECK-NEXT:   %[[c5:.*]] = arith.constant 5 : i16
+//  CHECK-NEXT:   cf.br ^[[bb2:.*]]
+//  CHECK-NEXT: ^[[bb1:.*]]:  // pred
+//  CHECK-NEXT:   cf.br ^[[bb3:.*]]
+//  CHECK-NEXT: ^[[bb2]]:
+//  CHECK-NEXT:   %[[add:.*]] = arith.addi %[[c5]], %[[arg0]]
+//  CHECK-NEXT:   cf.br ^[[bb1]]
+//  CHECK-NEXT: ^[[bb3]]:
+//  CHECK-NEXT:   return %[[add]]
+func.func @unstructured_control_flow(%arg0: i16) -> i16 {
+  %0 = arith.constant 5 : i16
+  cf.br ^bb2
+^bb1:
+  %3 = arith.addi %1, %1 : i16
+  %4 = arith.addi %3, %2 : i16
+  cf.br ^bb3
+^bb2:
+  %1 = arith.addi %0, %arg0 : i16
+  %2 = arith.subi %0, %arg0 : i16
+  cf.br ^bb1
+^bb3:
+  return %1 : i16
+}
+
+// -----
+
+// CHECK-LABEL: func @remove_dead_block()
+//  CHECK-NEXT:   cf.br ^[[bb2:.*]]
+//  CHECK-NEXT: ^[[bb2]]:
+//  CHECK-NEXT:   return
+func.func @remove_dead_block() {
+  cf.br ^bb2
+^bb1:
+  %0 = arith.constant 0 : i16
+  cf.br ^bb2
+^bb2:
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @potentially_side_effecting_op()
+//  CHECK-NEXT:   "test.print"
+//  CHECK-NEXT:   return
+func.func @potentially_side_effecting_op() {
+  "test.print"() : () -> ()
+  return
+}
+
+// -----
+
+// Note: Graph regions are not supported and skipped.
+
+// CHECK-LABEL: test.graph_region {
+//  CHECK-NEXT:   arith.addi
+//  CHECK-NEXT:   arith.constant 5 : i16
+//  CHECK-NEXT:   "test.baz"
+//  CHECK-NEXT: }
+test.graph_region {
+  %1 = arith.addi %0, %0 : i16
+  %0 = arith.constant 5 : i16
+  "test.baz"() : () -> i32
+}
+
+// -----
+
+// CHECK-LABEL: dead_blocks()
+//  CHECK-NEXT:   cf.br ^[[bb3:.*]]
+//  CHECK-NEXT: ^[[bb3]]:
+//  CHECK-NEXT:   return
+func.func @dead_blocks() {
+  cf.br ^bb3
+^bb1:
+  "test.print"() : () -> ()
+  cf.br ^bb2
+^bb2:
+  cf.br ^bb1
+^bb3:
+  return
+}



More information about the Mlir-commits mailing list