[Mlir-commits] [mlir] 1b390f5 - [mlir][transform] Simplify TrackingListener test case

Matthias Springer llvmlistbot at llvm.org
Fri Jun 9 03:03:30 PDT 2023


Author: Matthias Springer
Date: 2023-06-09T12:03:19+02:00
New Revision: 1b390f5e75b6309a0e4e6952c883ee35b2baa121

URL: https://github.com/llvm/llvm-project/commit/1b390f5e75b6309a0e4e6952c883ee35b2baa121
DIFF: https://github.com/llvm/llvm-project/commit/1b390f5e75b6309a0e4e6952c883ee35b2baa121.diff

LOG: [mlir][transform] Simplify TrackingListener test case

Use the default TrackingListener. No need to set up a derived listener just for the test case. This revision is in preparation of a future change that adds a TrackingRewriter infrastructure.

Differential Revision: https://reviews.llvm.org/D152446

Added: 
    

Modified: 
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 932b2cb011350..b2d7e7a4bdb5e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1601,19 +1601,20 @@ module attributes { transform.with_named_sequence } {
 // -----
 
 // CHECK-LABEL: func @test_tracked_rewrite() {
-//  CHECK-NEXT:   "test.update_mapping"() {original_op = "test.replace_me"}
-//  CHECK-NEXT:   "test.drop_mapping"() {original_op = "test.replace_me"}
-//  CHECK-NEXT:   "test.update_mapping"() {original_op = "test.replace_me"}
+//  CHECK-NEXT:   transform.test_dummy_payload_op  {new_op} : () -> i1
+//  CHECK-NEXT:   transform.test_dummy_payload_op  {new_op} : () -> i1
+//  CHECK-NEXT:   return
 //  CHECK-NEXT: }
 func.func @test_tracked_rewrite() {
-  %0 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1)
-  %1 = "test.replace_me"() {replacement = "test.drop_mapping"} : () -> (i1)
-  %2 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1)
+  %0 = transform.test_dummy_payload_op {replace_me} : () -> (i1)
+  %1 = transform.test_dummy_payload_op {erase_me} : () -> (i1)
+  %2 = transform.test_dummy_payload_op {replace_me} : () -> (i1)
+  func.return
 }
 
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
-  %0 = transform.structured.match ops{["test.replace_me"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %0 = transform.structured.match ops{["transform.test_dummy_payload_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   // expected-remark @below {{2 iterations}}
   transform.test_tracked_rewrite %0 : (!transform.any_op) -> ()
   // One replacement op (test.drop_mapping) is dropped from the mapping.

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 0c3697d1171ff..835cbb3ae2d5e 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -687,33 +687,16 @@ void mlir::test::TestTrackedRewriteOp::getEffects(
   transform::modifiesPayload(effects);
 }
 
-namespace {
-/// A TrackingListener for test cases. When the replacement op is
-/// "test.update_mapping", it is considered as a replacement op in the transform
-/// state mapping. Otherwise, it is not and the original op is simply removed
-/// from the mapping.
-class TestTrackingListener : public transform::TrackingListener {
-  using transform::TrackingListener::TrackingListener;
-
-protected:
-  FailureOr<Operation *>
-  findReplacementOp(Operation *op, ValueRange newValues) const override {
-    if (newValues.size() != 1)
-      return failure();
-    Operation *replacement = newValues[0].getDefiningOp();
-    if (!replacement)
-      return failure();
-    if (replacement->getName().getStringRef() != "test.update_mapping")
-      return failure();
-    return replacement;
-  }
-};
-} // namespace
+void mlir::test::TestDummyPayloadOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  for (OpResult result : getResults())
+    transform::producesHandle(result, effects);
+}
 
 DiagnosedSilenceableFailure
 mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results,
                                         transform::TransformState &state) {
-  TestTrackingListener listener(state, *this);
+  transform::ErrorCheckingTrackingListener listener(state, *this);
   IRRewriter rewriter(getContext(), &listener);
   int64_t numIterations = 0;
 
@@ -721,19 +704,23 @@ mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results,
   // loop body. Replacement ops are not enumerated.
   for (Operation *op : state.getPayloadOps(getIn())) {
     ++numIterations;
-    rewriter.setInsertionPointToEnd(op->getBlock());
+    (void)op;
 
     // Erase all payload ops. The outer loop should have only one iteration.
     for (Operation *op : state.getPayloadOps(getIn())) {
-      if (op->getName().getStringRef() != "test.replace_me")
+      rewriter.setInsertionPoint(op);
+      if (op->hasAttr("erase_me")) {
+        rewriter.eraseOp(op);
         continue;
-      auto replacementName = op->getAttrOfType<StringAttr>("replacement");
-      if (!replacementName)
+      }
+      if (!op->hasAttr("replace_me")) {
         continue;
+      }
+
       SmallVector<NamedAttribute> attributes;
-      attributes.emplace_back(rewriter.getStringAttr("original_op"),
-                              op->getName().getIdentifier());
-      OperationState opState(op->getLoc(), replacementName,
+      attributes.emplace_back(rewriter.getStringAttr("new_op"),
+                              rewriter.getUnitAttr());
+      OperationState opState(op->getLoc(), op->getName().getIdentifier(),
                              /*operands=*/ValueRange(),
                              /*types=*/op->getResultTypes(), attributes);
       Operation *newOp = rewriter.create(opState);

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index f7a6120666b8d..85b0440277dc1 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -467,6 +467,29 @@ def TestRequiredMemoryEffectsOp
   let cppNamespace = "::mlir::test";
 }
 
+// This op is used as a payload op. It must be a registered op, so that it can
+// be created with "RewriterBase::replaceOpWithNewOp" (needed for a test case).
+// Since only TransformOpInterface can be injected into the transform dialect,
+// this op implements the interface, even though it is not used as a transform
+// op.
+def TestDummyPayloadOp
+  : Op<Transform_Dialect, "test_dummy_payload_op",
+      [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+       TransformOpInterface]> {
+  let arguments = (ins Variadic<AnyType>:$args);
+  let results = (outs Variadic<AnyType>:$outs);
+  let assemblyFormat = "$args attr-dict `:` functional-type(operands, results)";
+  let cppNamespace = "::mlir::test";
+
+  let extraClassDeclaration = [{
+    DiagnosedSilenceableFailure apply(transform::TransformResults &results,
+                                      transform::TransformState &state) {
+      llvm_unreachable("op should not be used as a transform");
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+  }];
+}
+
 def TestTrackedRewriteOp
   : Op<Transform_Dialect, "test_tracked_rewrite",
       [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,


        


More information about the Mlir-commits mailing list