[Mlir-commits] [mlir] c2d5d34 - [mlir][transform] Add `transform.apply_dce` op
Matthias Springer
llvmlistbot at llvm.org
Fri Jul 21 23:30:11 PDT 2023
Author: Matthias Springer
Date: 2023-07-22T08:25:02+02:00
New Revision: c2d5d348a81b60be322d461b70fc0cc2fbee8b73
URL: https://github.com/llvm/llvm-project/commit/c2d5d348a81b60be322d461b70fc0cc2fbee8b73
DIFF: https://github.com/llvm/llvm-project/commit/c2d5d348a81b60be322d461b70fc0cc2fbee8b73.diff
LOG: [mlir][transform] Add `transform.apply_dce` op
Add a transform that eliminates dead operations. This is useful after certain transforms (such as fusion) that create/clone new IR but leave the original IR in place.
Differential Revision: https://reviews.llvm.org/D155954
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 5af2649ae519fc..9e0b7b95d006cc 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -156,6 +156,36 @@ def ApplyCommonSubexpressionEliminationOp : TransformDialectOp<"apply_cse",
}];
}
+def ApplyDeadCodeEliminationOp : TransformDialectOp<"apply_dce",
+ [TransformOpInterface, TransformEachOpTrait,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let summary = "Eliminate dead operations in the body of the target op";
+ let description = [{
+ This transform applies dead code elimination (DCE) to the body of the
+ targeted op.
+
+ Note: "transform.apply_patterns" with an empty region can also be used to
+ remove dead ops. However, that op applies additional simplifications such as
+ op folding and region simplification.
+
+ This transform reads the target handle and modifies the payload. Note that
+ this transform may silently remove payload ops from handles.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+ let assemblyFormat = "`to` $target attr-dict `:` type($target)";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
[TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index c9ecec7659ccba..5327a5f7f2524d 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -287,6 +287,71 @@ void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
transform::modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// ApplyDeadCodeEliminationOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *target,
+ ApplyToEachResultList &results, transform::TransformState &state) {
+ // Make sure that this transform is not applied to itself. Modifying the
+ // transform IR while it is being interpreted is generally dangerous.
+ DiagnosedSilenceableFailure payloadCheck =
+ ensurePayloadIsSeparateFromTransform(*this, target);
+ if (!payloadCheck.succeeded())
+ return payloadCheck;
+
+ // Maintain a worklist of potentially dead ops.
+ SetVector<Operation *> worklist;
+
+ // Helper function that adds all defining ops of used values (operands and
+ // operands of nested ops).
+ auto addDefiningOpsToWorklist = [&](Operation *op) {
+ op->walk([&](Operation *op) {
+ for (Value v : op->getOperands())
+ if (Operation *defOp = v.getDefiningOp())
+ if (target->isProperAncestor(defOp))
+ worklist.insert(defOp);
+ });
+ };
+
+ // Helper function that erases an op.
+ auto eraseOp = [&](Operation *op) {
+ // Remove op and nested ops from the worklist.
+ op->walk([&](Operation *op) {
+ auto it = llvm::find(worklist, op);
+ if (it != worklist.end())
+ worklist.erase(it);
+ });
+ rewriter.eraseOp(op);
+ };
+
+ // Initial walk over the IR.
+ target->walk<WalkOrder::PostOrder>([&](Operation *op) {
+ if (op != target && isOpTriviallyDead(op)) {
+ addDefiningOpsToWorklist(op);
+ eraseOp(op);
+ }
+ });
+
+ // Erase all ops that have become dead.
+ while (!worklist.empty()) {
+ Operation *op = worklist.pop_back_val();
+ if (!isOpTriviallyDead(op))
+ continue;
+ addDefiningOpsToWorklist(op);
+ eraseOp(op);
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ApplyDeadCodeEliminationOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTarget(), effects);
+ transform::modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// ApplyPatternsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 442ef625ffde9c..4ebecd8c933925 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1974,3 +1974,26 @@ transform.sequence failures(propagate) {
%bar = transform.select "test.bar" in %0 : (!transform.any_op) -> !transform.any_op
test_print_remark_at_operand %bar, "found bar" : !transform.any_op
}
+
+// -----
+
+// CHECK-LABEL: func @apply_dce(
+// CHECK-NEXT: memref.store
+// CHECK-NEXT: return
+func.func @apply_dce(%f: f32, %m: memref<5xf32>, %idx: index) {
+ // Two dead ops, interleaved with a non-dead op.
+ %0 = tensor.empty() : tensor<5xf32>
+ memref.store %f, %m[%idx] : memref<5xf32>
+ %1 = tensor.insert %f into %0[%idx] : tensor<5xf32>
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %func_op = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %empty_op = transform.structured.match ops{["tensor.empty"]} in %func_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_dce to %func_op : !transform.any_op
+
+ // expected-remark @below{{0}}
+ test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op
+}
More information about the Mlir-commits
mailing list