[Mlir-commits] [mlir] 2c95ede - [mlir][transform] Add `transform.apply_cse` op

Matthias Springer llvmlistbot at llvm.org
Sun Jul 2 23:55:24 PDT 2023


Author: Matthias Springer
Date: 2023-07-03T08:50:50+02:00
New Revision: 2c95ede4d1831e9bf5a5f7075150e92d2c99b7d2

URL: https://github.com/llvm/llvm-project/commit/2c95ede4d1831e9bf5a5f7075150e92d2c99b7d2
DIFF: https://github.com/llvm/llvm-project/commit/2c95ede4d1831e9bf5a5f7075150e92d2c99b7d2.diff

LOG: [mlir][transform] Add `transform.apply_cse` op

This op applies CSE to the targeted op. This is similar to `transform.apply_registered_pass "cse"`, but it retains handles in the body of the targeted op.

Differential Revision: https://reviews.llvm.org/D154099

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index cf49c451286d83..a4426977c84c66 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -127,6 +127,35 @@ def AnnotateOp : TransformDialectOp<"annotate",
     "`:` type($target) (`,` type($param)^)?";
 }
 
+def ApplyCommonSubexpressionEliminationOp : TransformDialectOp<"apply_cse",
+    [TransformOpInterface, TransformEachOpTrait,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let summary = "Eliminate common subexpressions in the body of the target op";
+  let description = [{
+    This transform applies common subexpression elimination (CSE) to the body
+    of the targeted op.
+
+    This transform reads the target handle and modifies the payload. Existing
+    handles to operations inside of the targeted op are retained and updated if
+    necessary. Note that this can lead to situations where a handle, that was
+    previously mapped to multiple distinct (but equivalent) operations, is now
+    mapped to the same operation multiple times.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs);
+  let assemblyFormat = "`to` $target attr-dict `:` type($target)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+      ::mlir::transform::TransformRewriter &rewriter,
+      ::mlir::Operation *target,
+      ::mlir::transform::ApplyToEachResultList &results,
+      ::mlir::transform::TransformState &state);
+  }];
+}
+
 def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
     [TransformOpInterface, TransformEachOpTrait,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 1d89ff45f5dcc5..3f6fb7173015d8 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -14,12 +14,14 @@
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/CSE.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
@@ -257,6 +259,32 @@ void transform::AnnotateOp::getEffects(
   modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// ApplyCommonSubexpressionEliminationOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    ApplyToEachResultList &results, transform::TransformState &state) {
+  // Make sure that this transform is not applied to itself. Modifying the
+  // transform IR while it is being interpreted is generally dangerous.
+  DiagnosedSilenceableFailure payloadCheck =
+      ensurePayloadIsSeparateFromTransform(*this, target);
+  if (!payloadCheck.succeeded())
+    return payloadCheck;
+
+  DominanceInfo domInfo;
+  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // ApplyPatternsOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 4ab9d65e6475b7..ba216f6ee74db0 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1727,3 +1727,66 @@ transform.sequence failures(propagate) {
   test_notify_payload_op_replaced %0, %1 : (!transform.any_op, !transform.any_op) -> ()
   test_print_remark_at_operand %0, "updated handle" : !transform.any_op
 }
+
+// -----
+
+// CHECK-LABEL: func @test_apply_cse()
+//       CHECK:   %[[const:.*]] = arith.constant 0 : index
+//       CHECK:   %[[ex1:.*]] = scf.execute_region -> index {
+//       CHECK:     scf.yield %[[const]]
+//       CHECK:   }
+//       CHECK:   %[[ex2:.*]] = scf.execute_region -> index {
+//       CHECK:     scf.yield %[[const]]
+//       CHECK:   }
+//       CHECK:   return %[[const]], %[[ex1]], %[[ex2]]
+func.func @test_apply_cse() -> (index, index, index) {
+  // expected-remark @below{{eliminated 1}}
+  // expected-remark @below{{eliminated 2}}
+  %0 = arith.constant 0 : index
+  %1 = scf.execute_region -> index {
+    %2 = arith.constant 0 : index
+    scf.yield %2 : index
+  } {first}
+  %3 = scf.execute_region -> index {
+    %4 = arith.constant 0 : index
+    scf.yield %4 : index
+  } {second}
+  return %0, %1, %3 : index, index, index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %first = transform.structured.match attributes{first} in %0 : (!transform.any_op) -> !transform.any_op
+  %elim_first = transform.structured.match ops{["arith.constant"]} in %first : (!transform.any_op) -> !transform.any_op
+  %second = transform.structured.match attributes{first} in %0 : (!transform.any_op) -> !transform.any_op
+  %elim_second = transform.structured.match ops{["arith.constant"]} in %first : (!transform.any_op) -> !transform.any_op
+
+  // There are 3 arith.constant ops.
+  %all = transform.structured.match ops{["arith.constant"]} in %0 : (!transform.any_op) -> !transform.any_op
+  // expected-remark @below{{3}}
+  test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+  // "deduplicate" has no effect because these are 3 
diff erent ops.
+  %merged_before = transform.merge_handles deduplicate %all : !transform.any_op
+  // expected-remark @below{{3}}
+  test_print_number_of_associated_payload_ir_ops %merged_before : !transform.any_op
+
+  // Apply CSE.
+  transform.apply_cse to %0 : !transform.any_op
+
+  // The handle is still mapped to 3 arith.constant ops.
+  // expected-remark @below{{3}}
+  test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+  // But they are all the same op.
+  %merged_after = transform.merge_handles deduplicate %all : !transform.any_op
+  // expected-remark @below{{1}}
+  test_print_number_of_associated_payload_ir_ops %merged_after : !transform.any_op
+
+  // The other handles were also updated.
+  test_print_remark_at_operand %elim_first, "eliminated 1" : !transform.any_op
+  // expected-remark @below{{1}}
+  test_print_number_of_associated_payload_ir_ops %elim_first : !transform.any_op
+  test_print_remark_at_operand %elim_second, "eliminated 2" : !transform.any_op
+  // expected-remark @below{{1}}
+  test_print_number_of_associated_payload_ir_ops %elim_second : !transform.any_op
+}


        


More information about the Mlir-commits mailing list