[Mlir-commits] [mlir] c2d5d34 - [mlir][transform] Add `transform.apply_dce` op

Matthias Springer llvmlistbot at llvm.org
Fri Jul 21 23:30:11 PDT 2023


Author: Matthias Springer
Date: 2023-07-22T08:25:02+02:00
New Revision: c2d5d348a81b60be322d461b70fc0cc2fbee8b73

URL: https://github.com/llvm/llvm-project/commit/c2d5d348a81b60be322d461b70fc0cc2fbee8b73
DIFF: https://github.com/llvm/llvm-project/commit/c2d5d348a81b60be322d461b70fc0cc2fbee8b73.diff

LOG: [mlir][transform] Add `transform.apply_dce` op

Add a transform that eliminates dead operations. This is useful after certain transforms (such as fusion) that create/clone new IR but leave the original IR in place.

Differential Revision: https://reviews.llvm.org/D155954

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 5af2649ae519fc..9e0b7b95d006cc 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -156,6 +156,36 @@ def ApplyCommonSubexpressionEliminationOp : TransformDialectOp<"apply_cse",
   }];
 }
 
+def ApplyDeadCodeEliminationOp : TransformDialectOp<"apply_dce",
+    [TransformOpInterface, TransformEachOpTrait,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let summary = "Eliminate dead operations in the body of the target op";
+  let description = [{
+    This transform applies dead code elimination (DCE) to the body of the
+    targeted op.
+
+    Note: "transform.apply_patterns" with an empty region can also be used to
+    remove dead ops. However, that op applies additional simplifications such as
+    op folding and region simplification.
+
+    This transform reads the target handle and modifies the payload. Note that
+    this transform may silently remove payload ops from handles.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs);
+  let assemblyFormat = "`to` $target attr-dict `:` type($target)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+      ::mlir::transform::TransformRewriter &rewriter,
+      ::mlir::Operation *target,
+      ::mlir::transform::ApplyToEachResultList &results,
+      ::mlir::transform::TransformState &state);
+  }];
+}
+
 def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
     [TransformOpInterface, TransformEachOpTrait,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index c9ecec7659ccba..5327a5f7f2524d 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -287,6 +287,71 @@ void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
   transform::modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// ApplyDeadCodeEliminationOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    ApplyToEachResultList &results, transform::TransformState &state) {
+  // Make sure that this transform is not applied to itself. Modifying the
+  // transform IR while it is being interpreted is generally dangerous.
+  DiagnosedSilenceableFailure payloadCheck =
+      ensurePayloadIsSeparateFromTransform(*this, target);
+  if (!payloadCheck.succeeded())
+    return payloadCheck;
+
+  // Maintain a worklist of potentially dead ops.
+  SetVector<Operation *> worklist;
+
+  // Helper function that adds all defining ops of used values (operands and
+  // operands of nested ops).
+  auto addDefiningOpsToWorklist = [&](Operation *op) {
+    op->walk([&](Operation *op) {
+      for (Value v : op->getOperands())
+        if (Operation *defOp = v.getDefiningOp())
+          if (target->isProperAncestor(defOp))
+            worklist.insert(defOp);
+    });
+  };
+
+  // Helper function that erases an op.
+  auto eraseOp = [&](Operation *op) {
+    // Remove op and nested ops from the worklist.
+    op->walk([&](Operation *op) {
+      auto it = llvm::find(worklist, op);
+      if (it != worklist.end())
+        worklist.erase(it);
+    });
+    rewriter.eraseOp(op);
+  };
+
+  // Initial walk over the IR.
+  target->walk<WalkOrder::PostOrder>([&](Operation *op) {
+    if (op != target && isOpTriviallyDead(op)) {
+      addDefiningOpsToWorklist(op);
+      eraseOp(op);
+    }
+  });
+
+  // Erase all ops that have become dead.
+  while (!worklist.empty()) {
+    Operation *op = worklist.pop_back_val();
+    if (!isOpTriviallyDead(op))
+      continue;
+    addDefiningOpsToWorklist(op);
+    eraseOp(op);
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ApplyDeadCodeEliminationOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // ApplyPatternsOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 442ef625ffde9c..4ebecd8c933925 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1974,3 +1974,26 @@ transform.sequence failures(propagate) {
   %bar = transform.select "test.bar" in %0 : (!transform.any_op) -> !transform.any_op
   test_print_remark_at_operand %bar, "found bar" : !transform.any_op
 }
+
+// -----
+
+// CHECK-LABEL: func @apply_dce(
+//  CHECK-NEXT:   memref.store
+//  CHECK-NEXT:   return
+func.func @apply_dce(%f: f32, %m: memref<5xf32>, %idx: index) {
+  // Two dead ops, interleaved with a non-dead op.
+  %0 = tensor.empty() : tensor<5xf32>
+  memref.store %f, %m[%idx] : memref<5xf32>
+  %1 = tensor.insert %f into %0[%idx] : tensor<5xf32>
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %func_op = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %empty_op = transform.structured.match ops{["tensor.empty"]} in %func_op : (!transform.any_op) -> !transform.any_op
+  transform.apply_dce to %func_op : !transform.any_op
+
+  // expected-remark @below{{0}}
+  test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op
+}


        


More information about the Mlir-commits mailing list