[Mlir-commits] [mlir] [mlir][Transforms] Add dead code elimination pass (PR #106258)

Matthias Springer llvmlistbot at llvm.org
Tue Aug 27 10:51:40 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/106258

In the absence of a dedicated DCE pass, MLIR users sometimes resort to the canonicalizer pass to remove dead IR. The canonicalizer pass is quite expensive to run. This PR adds a lightweight dead code elimination pass that removes dead operation and dead blocks.

The pass performs 3 walks over the input IR.
1. Visit all operations in reverse dominance order and remove dead ops.
2. Visit all blocks in forward dominance order. This walk enumerates only reachable blocks. Collect all reachable blocks. (This walk could be combined with Step 1, but that would require an extension of the walker API.)
3. Visit all blocks and erase the ones that were not visited in Step 2.


>From f84ba23986f5c5dd84259ec1a44f604b7ad02608 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 27 Aug 2024 19:29:54 +0200
Subject: [PATCH] [mlir][Transforms] Add dead code elimination pass

---
 mlir/include/mlir/Transforms/Passes.h         |   4 +
 mlir/include/mlir/Transforms/Passes.td        |  14 ++
 mlir/lib/Transforms/CMakeLists.txt            |   1 +
 mlir/lib/Transforms/DeadCodeElimination.cpp   |  75 ++++++++++
 .../Transforms/dead-code-elimination.mlir     | 130 ++++++++++++++++++
 5 files changed, 224 insertions(+)
 create mode 100644 mlir/lib/Transforms/DeadCodeElimination.cpp
 create mode 100644 mlir/test/Transforms/dead-code-elimination.mlir

diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 8e4a43c3f24586..d03b879f405af7 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -33,6 +33,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_CANONICALIZER
 #define GEN_PASS_DECL_CONTROLFLOWSINK
 #define GEN_PASS_DECL_CSEPASS
+#define GEN_PASS_DECL_DEADCODEELIMINATION
 #define GEN_PASS_DECL_INLINER
 #define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
 #define GEN_PASS_DECL_MEM2REG
@@ -111,6 +112,9 @@ std::unique_ptr<Pass>
 createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,
                   std::function<void(OpPassManager &)> defaultPipelineBuilder);
 
+/// Creates an optimization pass to remove dead operations and blocks.
+std::unique_ptr<Pass> createDeadCodeEliminationPass();
+
 /// Creates an optimization pass to remove dead values.
 std::unique_ptr<Pass> createRemoveDeadValuesPass();
 
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 000d9f697618e6..b10b1be74d0fcb 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -93,6 +93,20 @@ def CSE : Pass<"cse"> {
   ];
 }
 
+def DeadCodeElimination : Pass<"dce"> {
+  let summary = "Remove dead operations and blocks";
+  let description = [{
+    This pass eliminates dead operations and blocks.
+
+    Operations are eliminated if they have no users and no side effects. Blocks
+    are eliminated if they are not reachable.
+
+    Note: Graph regions are currently not supported and skipped by this pass.
+  }];
+
+  let constructor = "mlir::createDeadCodeEliminationPass()";
+}
+
 def RemoveDeadValues : Pass<"remove-dead-values"> {
   let summary = "Remove dead values";
   let description = [{
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46a..4b90774f972ced 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRTransforms
   CompositePass.cpp
   ControlFlowSink.cpp
   CSE.cpp
+  DeadCodeElimination.cpp
   GenerateRuntimeVerification.cpp
   InlinerPass.cpp
   LocationSnapshot.cpp
diff --git a/mlir/lib/Transforms/DeadCodeElimination.cpp b/mlir/lib/Transforms/DeadCodeElimination.cpp
new file mode 100644
index 00000000000000..33a12b84daa46f
--- /dev/null
+++ b/mlir/lib/Transforms/DeadCodeElimination.cpp
@@ -0,0 +1,75 @@
+//===- DeadCodeElimination.cpp - Dead Code Elimination --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_DEADCODEELIMINATION
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct DeadCodeElimination
+    : public impl::DeadCodeEliminationBase<DeadCodeElimination> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void DeadCodeElimination::runOnOperation() {
+  Operation *topLevel = getOperation();
+
+  // Visit operations in reverse dominance order. This visits all users before
+  // their definitions. (Also takes into account unstructured control flow
+  // between blocks.)
+  topLevel->walk<WalkOrder::PostOrder,
+                 ReverseDominanceIterator</*NoGraphRegions=*/false>>(
+      [&](Operation *op) {
+        // Do not remove the top-level op.
+        if (op == topLevel)
+          return WalkResult::advance();
+
+        // Do not remove ops from regions that may be graph regions.
+        if (mayBeGraphRegion(*op->getParentRegion()))
+          return WalkResult::advance();
+
+        // Remove dead ops.
+        if (isOpTriviallyDead(op)) {
+          op->erase();
+          return WalkResult::skip();
+        }
+
+        return WalkResult::advance();
+      });
+
+  // ReverseDominanceIterator does not visit unreachable blocks. Erase those in
+  // a second walk. First collect all reachable blocks.
+  // TODO: Extend walker API to provide a callback for both ops and blocks, so
+  // that reachable blocks can be collected in the same walk.
+  DenseSet<Block *> reachableBlocks;
+  topLevel->walk<WalkOrder::PostOrder,
+                 ForwardDominanceIterator</*NoGraphRegions=*/false>>(
+      [&](Block *block) { reachableBlocks.insert(block); });
+  // Erase all blocks that were not visited. These are unreachable and thus
+  // dead.
+  topLevel->walk<WalkOrder::PostOrder>([&](Block *block) {
+    if (!reachableBlocks.contains(block)) {
+      block->dropAllDefinedValueUses();
+      block->erase();
+    }
+  });
+}
+
+std::unique_ptr<Pass> mlir::createDeadCodeEliminationPass() {
+  return std::make_unique<DeadCodeElimination>();
+}
diff --git a/mlir/test/Transforms/dead-code-elimination.mlir b/mlir/test/Transforms/dead-code-elimination.mlir
new file mode 100644
index 00000000000000..67130bb3366d94
--- /dev/null
+++ b/mlir/test/Transforms/dead-code-elimination.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-opt -dead-code-elimination -split-input-file %s
+
+// CHECK-LABEL: func @simple_test(
+//  CHECK-SAME:     %[[arg0:.*]]: i16)
+//  CHECK-NEXT:   %[[c5:.*]] = arith.constant 5 : i16
+//  CHECK-NEXT:   %[[add:.*]] = arith.addi %[[c5]], %[[arg0]]
+//  CHECK-NEXT:   return %[[add]]
+func.func @simple_test(%arg0: i16) -> i16 {
+  %0 = arith.constant 5 : i16
+  %1 = arith.addi %0, %arg0 : i16
+  %2 = arith.addi %1, %1 : i16
+  %3 = arith.addi %2, %1 : i16
+  return %1 : i16
+}
+
+// -----
+
+// CHECK-LABEL: func @eliminate_from_region
+//  CHECK-NEXT:   scf.for {{.*}} {
+//  CHECK-NEXT:     arith.constant
+//  CHECK-NEXT:     "test.print"
+//  CHECK-NEXT:   }
+//  CHECK-NEXT:   return
+func.func @eliminate_from_region(%lb: index, %ub: index, %step: index) {
+  scf.for %iv = %lb to %ub step %step {
+    %0 = arith.constant 5 : i16
+    %1 = arith.constant 10 : i16
+    "test.print"(%0) : (i16) -> ()
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @eliminate_op_with_region
+//  CHECK-NEXT:   return
+func.func @eliminate_op_with_region(%lb: index, %ub: index, %step: index) {
+  %c0 = arith.constant 0 : i16
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%iter = %c0) -> i16 {
+    %0 = arith.constant 5 : i16
+    %added = arith.addi %iter, %0 : i16
+    scf.yield %added : i16
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @unstructured_control_flow(
+//  CHECK-SAME:     %[[arg0:.*]]: i16)
+//  CHECK-NEXT:   %[[c5:.*]] = arith.constant 5 : i16
+//  CHECK-NEXT:   cf.br ^[[bb2:.*]]
+//  CHECK-NEXT: ^[[bb1:.*]]:  // pred
+//  CHECK-NEXT:   cf.br ^[[bb3:.*]]
+//  CHECK-NEXT: ^[[bb2]]:
+//  CHECK-NEXT:   %[[add:.*]] = arith.addi %[[c5]], %[[arg0]]
+//  CHECK-NEXT:   cf.br ^[[bb1]]
+//  CHECK-NEXT: ^[[bb3]]:
+//  CHECK-NEXT:   return %[[add]]
+func.func @unstructured_control_flow(%arg0: i16) -> i16 {
+  %0 = arith.constant 5 : i16
+  cf.br ^bb2
+^bb1:
+  %3 = arith.addi %1, %1 : i16
+  %4 = arith.addi %3, %2 : i16
+  cf.br ^bb3
+^bb2:
+  %1 = arith.addi %0, %arg0 : i16
+  %2 = arith.subi %0, %arg0 : i16
+  cf.br ^bb1
+^bb3:
+  return %1 : i16
+}
+
+// -----
+
+// CHECK-LABEL: func @remove_dead_block()
+//  CHECK-NEXT:   cf.br ^[[bb2:.*]]
+//  CHECK-NEXT: ^[[bb2]]:
+//  CHECK-NEXT:   return
+func.func @remove_dead_block() {
+  cf.br ^bb2
+^bb1:
+  %0 = arith.constant 0 : i16
+  cf.br ^bb2
+^bb2:
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @potentially_side_effecting_op()
+//  CHECK-NEXT:   "test.print"
+//  CHECK-NEXT:   return
+func.func @potentially_side_effecting_op() {
+  "test.print"() : () -> ()
+  return
+}
+
+// -----
+
+// Note: Graph regions are not supported and skipped.
+
+// CHECK-LABEL: test.graph_region {
+//  CHECK-NEXT:   arith.addi
+//  CHECK-NEXT:   arith.constant 5 : i16
+//  CHECK-NEXT:   "test.baz"
+//  CHECK-NEXT: }
+test.graph_region {
+  %1 = arith.addi %0, %0 : i16
+  %0 = arith.constant 5 : i16
+  "test.baz"() : () -> i32
+}
+
+// -----
+
+// CHECK-LABEL: dead_blocks()
+//  CHECK-NEXT:   cf.br ^[[bb3:.*]]
+//  CHECK-NEXT: ^[[bb3]]:
+//  CHECK-NEXT:   return
+func.func @dead_blocks() {
+  cf.br ^bb3
+^bb1:
+  "test.print"() : () -> ()
+  cf.br ^bb2
+^bb2:
+  cf.br ^bb1
+^bb3:
+  return
+}



More information about the Mlir-commits mailing list