[llvm] [mlir] [mlir][transforms]-Extract dead code elimination util (PR #124606)

Amir Bishara via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 27 10:34:43 PST 2025


https://github.com/amirBish created https://github.com/llvm/llvm-project/pull/124606

Extract the dead code elimination into an exposed util file to have a simple usage of the cpp API.

>From d2a30d9e5c12e97087a941f7d558a03dc9d87a70 Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Mon, 27 Jan 2025 20:11:33 +0200
Subject: [PATCH] [mlir][transforms]-Extract dead code elimination util

Extract the dead code elimination into an exposed util
file to have a simple usage of the cpp API.

Change-Id: Ic013ba06b778f8d5221bea226131e5343b543f51
---
 mlir/include/mlir/Transforms/DCEUtils.h       | 26 ++++++++
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 44 +------------
 mlir/lib/Transforms/Utils/CMakeLists.txt      |  1 +
 mlir/lib/Transforms/Utils/DCEUtils.cpp        | 62 +++++++++++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  1 +
 5 files changed, 92 insertions(+), 42 deletions(-)
 create mode 100644 mlir/include/mlir/Transforms/DCEUtils.h
 create mode 100644 mlir/lib/Transforms/Utils/DCEUtils.cpp

diff --git a/mlir/include/mlir/Transforms/DCEUtils.h b/mlir/include/mlir/Transforms/DCEUtils.h
new file mode 100644
index 00000000000000..9432a136a811eb
--- /dev/null
+++ b/mlir/include/mlir/Transforms/DCEUtils.h
@@ -0,0 +1,26 @@
+//===- DCE.h - Dead Code Elimination -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares methods for eliminating dead code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_DCE_H_
+#define MLIR_TRANSFORMS_DCE_H_
+
+namespace mlir {
+
+class Operation;
+class RewriterBase;
+
+/// Eliminate dead code within the given `target`.
+void deadCodeElimination(RewriterBase &rewriter, Operation *target);
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_DCE_H_
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 798853a75441a1..79257cf16efed5 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Transforms/CSE.h"
+#include "mlir/Transforms/DCEUtils.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
@@ -314,48 +315,7 @@ DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
   if (!payloadCheck.succeeded())
     return payloadCheck;
 
-  // Maintain a worklist of potentially dead ops.
-  SetVector<Operation *> worklist;
-
-  // Helper function that adds all defining ops of used values (operands and
-  // operands of nested ops).
-  auto addDefiningOpsToWorklist = [&](Operation *op) {
-    op->walk([&](Operation *op) {
-      for (Value v : op->getOperands())
-        if (Operation *defOp = v.getDefiningOp())
-          if (target->isProperAncestor(defOp))
-            worklist.insert(defOp);
-    });
-  };
-
-  // Helper function that erases an op.
-  auto eraseOp = [&](Operation *op) {
-    // Remove op and nested ops from the worklist.
-    op->walk([&](Operation *op) {
-      const auto *it = llvm::find(worklist, op);
-      if (it != worklist.end())
-        worklist.erase(it);
-    });
-    rewriter.eraseOp(op);
-  };
-
-  // Initial walk over the IR.
-  target->walk<WalkOrder::PostOrder>([&](Operation *op) {
-    if (op != target && isOpTriviallyDead(op)) {
-      addDefiningOpsToWorklist(op);
-      eraseOp(op);
-    }
-  });
-
-  // Erase all ops that have become dead.
-  while (!worklist.empty()) {
-    Operation *op = worklist.pop_back_val();
-    if (!isOpTriviallyDead(op))
-      continue;
-    addDefiningOpsToWorklist(op);
-    eraseOp(op);
-  }
-
+  deadCodeElimination(rewriter, target);
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index 72eb34f36cf5f6..4b8c3b185f320b 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_library(MLIRTransformUtils
   CFGToSCF.cpp
   CommutativityUtils.cpp
   ControlFlowSinkUtils.cpp
+  DCEUtils.cpp
   DialectConversion.cpp
   FoldUtils.cpp
   GreedyPatternRewriteDriver.cpp
diff --git a/mlir/lib/Transforms/Utils/DCEUtils.cpp b/mlir/lib/Transforms/Utils/DCEUtils.cpp
new file mode 100644
index 00000000000000..102c04b4f6c734
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/DCEUtils.cpp
@@ -0,0 +1,62 @@
+//===- DCEUtils.cpp - Dead Code Elimination ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation implements method for eliminating dead code.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/DCEUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/SetVector.h"
+
+using namespace mlir;
+
+void mlir::deadCodeElimination(RewriterBase &rewriter, Operation *target) {
+  // Maintain a worklist of potentially dead ops.
+  mlir::SetVector<Operation *> worklist;
+
+  // Helper function that adds all defining ops of used values (operands and
+  // operands of nested ops).
+  auto addDefiningOpsToWorklist = [&](Operation *op) {
+    op->walk([&](Operation *op) {
+      for (Value v : op->getOperands())
+        if (Operation *defOp = v.getDefiningOp())
+          if (target->isProperAncestor(defOp))
+            worklist.insert(defOp);
+    });
+  };
+
+  // Helper function that erases an op.
+  auto eraseOp = [&](Operation *op) {
+    // Remove op and nested ops from the worklist.
+    op->walk([&](Operation *op) {
+      const auto *it = llvm::find(worklist, op);
+      if (it != worklist.end())
+        worklist.erase(it);
+    });
+    rewriter.eraseOp(op);
+  };
+
+  // Initial walk over the IR.
+  target->walk<WalkOrder::PostOrder>([&](Operation *op) {
+    if (op != target && isOpTriviallyDead(op)) {
+      addDefiningOpsToWorklist(op);
+      eraseOp(op);
+    }
+  });
+
+  // Erase all ops that have become dead.
+  while (!worklist.empty()) {
+    Operation *op = worklist.pop_back_val();
+    if (!isOpTriviallyDead(op))
+      continue;
+    addDefiningOpsToWorklist(op);
+    eraseOp(op);
+  }
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 4d44396be98eec..10323cd54c522c 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8022,6 +8022,7 @@ cc_library(
         "include/mlir/Transforms/CFGToSCF.h",
         "include/mlir/Transforms/CommutativityUtils.h",
         "include/mlir/Transforms/ControlFlowSinkUtils.h",
+        "include/mlir/Transforms/DCEUtils.h",
         "include/mlir/Transforms/DialectConversion.h",
         "include/mlir/Transforms/FoldUtils.h",
         "include/mlir/Transforms/GreedyPatternRewriteDriver.h",



More information about the llvm-commits mailing list