[Mlir-commits] [mlir] 2c95ede - [mlir][transform] Add `transform.apply_cse` op
Matthias Springer
llvmlistbot at llvm.org
Sun Jul 2 23:55:24 PDT 2023
Author: Matthias Springer
Date: 2023-07-03T08:50:50+02:00
New Revision: 2c95ede4d1831e9bf5a5f7075150e92d2c99b7d2
URL: https://github.com/llvm/llvm-project/commit/2c95ede4d1831e9bf5a5f7075150e92d2c99b7d2
DIFF: https://github.com/llvm/llvm-project/commit/2c95ede4d1831e9bf5a5f7075150e92d2c99b7d2.diff
LOG: [mlir][transform] Add `transform.apply_cse` op
This op applies CSE to the targeted op. This is similar to `transform.apply_registered_pass "cse"`, but it retains handles in the body of the targeted op.
Differential Revision: https://reviews.llvm.org/D154099
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index cf49c451286d83..a4426977c84c66 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -127,6 +127,35 @@ def AnnotateOp : TransformDialectOp<"annotate",
"`:` type($target) (`,` type($param)^)?";
}
+def ApplyCommonSubexpressionEliminationOp : TransformDialectOp<"apply_cse",
+ [TransformOpInterface, TransformEachOpTrait,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let summary = "Eliminate common subexpressions in the body of the target op";
+ let description = [{
+ This transform applies common subexpression elimination (CSE) to the body
+ of the targeted op.
+
+ This transform reads the target handle and modifies the payload. Existing
+ handles to operations inside of the targeted op are retained and updated if
+ necessary. Note that this can lead to situations where a handle, that was
+ previously mapped to multiple distinct (but equivalent) operations, is now
+ mapped to the same operation multiple times.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+ let assemblyFormat = "`to` $target attr-dict `:` type($target)";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
[TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 1d89ff45f5dcc5..3f6fb7173015d8 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -14,12 +14,14 @@
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/CSE.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
@@ -257,6 +259,32 @@ void transform::AnnotateOp::getEffects(
modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// ApplyCommonSubexpressionEliminationOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *target,
+ ApplyToEachResultList &results, transform::TransformState &state) {
+ // Make sure that this transform is not applied to itself. Modifying the
+ // transform IR while it is being interpreted is generally dangerous.
+ DiagnosedSilenceableFailure payloadCheck =
+ ensurePayloadIsSeparateFromTransform(*this, target);
+ if (!payloadCheck.succeeded())
+ return payloadCheck;
+
+ DominanceInfo domInfo;
+ mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTarget(), effects);
+ transform::modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// ApplyPatternsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 4ab9d65e6475b7..ba216f6ee74db0 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1727,3 +1727,66 @@ transform.sequence failures(propagate) {
test_notify_payload_op_replaced %0, %1 : (!transform.any_op, !transform.any_op) -> ()
test_print_remark_at_operand %0, "updated handle" : !transform.any_op
}
+
+// -----
+
+// CHECK-LABEL: func @test_apply_cse()
+// CHECK: %[[const:.*]] = arith.constant 0 : index
+// CHECK: %[[ex1:.*]] = scf.execute_region -> index {
+// CHECK: scf.yield %[[const]]
+// CHECK: }
+// CHECK: %[[ex2:.*]] = scf.execute_region -> index {
+// CHECK: scf.yield %[[const]]
+// CHECK: }
+// CHECK: return %[[const]], %[[ex1]], %[[ex2]]
+func.func @test_apply_cse() -> (index, index, index) {
+ // expected-remark @below{{eliminated 1}}
+ // expected-remark @below{{eliminated 2}}
+ %0 = arith.constant 0 : index
+ %1 = scf.execute_region -> index {
+ %2 = arith.constant 0 : index
+ scf.yield %2 : index
+ } {first}
+ %3 = scf.execute_region -> index {
+ %4 = arith.constant 0 : index
+ scf.yield %4 : index
+ } {second}
+ return %0, %1, %3 : index, index, index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %first = transform.structured.match attributes{first} in %0 : (!transform.any_op) -> !transform.any_op
+ %elim_first = transform.structured.match ops{["arith.constant"]} in %first : (!transform.any_op) -> !transform.any_op
+ %second = transform.structured.match attributes{first} in %0 : (!transform.any_op) -> !transform.any_op
+ %elim_second = transform.structured.match ops{["arith.constant"]} in %first : (!transform.any_op) -> !transform.any_op
+
+ // There are 3 arith.constant ops.
+ %all = transform.structured.match ops{["arith.constant"]} in %0 : (!transform.any_op) -> !transform.any_op
+ // expected-remark @below{{3}}
+ test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+ // "deduplicate" has no effect because these are 3
diff erent ops.
+ %merged_before = transform.merge_handles deduplicate %all : !transform.any_op
+ // expected-remark @below{{3}}
+ test_print_number_of_associated_payload_ir_ops %merged_before : !transform.any_op
+
+ // Apply CSE.
+ transform.apply_cse to %0 : !transform.any_op
+
+ // The handle is still mapped to 3 arith.constant ops.
+ // expected-remark @below{{3}}
+ test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+ // But they are all the same op.
+ %merged_after = transform.merge_handles deduplicate %all : !transform.any_op
+ // expected-remark @below{{1}}
+ test_print_number_of_associated_payload_ir_ops %merged_after : !transform.any_op
+
+ // The other handles were also updated.
+ test_print_remark_at_operand %elim_first, "eliminated 1" : !transform.any_op
+ // expected-remark @below{{1}}
+ test_print_number_of_associated_payload_ir_ops %elim_first : !transform.any_op
+ test_print_remark_at_operand %elim_second, "eliminated 2" : !transform.any_op
+ // expected-remark @below{{1}}
+ test_print_number_of_associated_payload_ir_ops %elim_second : !transform.any_op
+}
More information about the Mlir-commits
mailing list