[Mlir-commits] [mlir] 1b390f5 - [mlir][transform] Simplify TrackingListener test case
Matthias Springer
llvmlistbot at llvm.org
Fri Jun 9 03:03:30 PDT 2023
Author: Matthias Springer
Date: 2023-06-09T12:03:19+02:00
New Revision: 1b390f5e75b6309a0e4e6952c883ee35b2baa121
URL: https://github.com/llvm/llvm-project/commit/1b390f5e75b6309a0e4e6952c883ee35b2baa121
DIFF: https://github.com/llvm/llvm-project/commit/1b390f5e75b6309a0e4e6952c883ee35b2baa121.diff
LOG: [mlir][transform] Simplify TrackingListener test case
Use the default TrackingListener. No need to set up a derived listener just for the test case. This revision is in preparation of a future change that adds a TrackingRewriter infrastructure.
Differential Revision: https://reviews.llvm.org/D152446
Added:
Modified:
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 932b2cb011350..b2d7e7a4bdb5e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1601,19 +1601,20 @@ module attributes { transform.with_named_sequence } {
// -----
// CHECK-LABEL: func @test_tracked_rewrite() {
-// CHECK-NEXT: "test.update_mapping"() {original_op = "test.replace_me"}
-// CHECK-NEXT: "test.drop_mapping"() {original_op = "test.replace_me"}
-// CHECK-NEXT: "test.update_mapping"() {original_op = "test.replace_me"}
+// CHECK-NEXT: transform.test_dummy_payload_op {new_op} : () -> i1
+// CHECK-NEXT: transform.test_dummy_payload_op {new_op} : () -> i1
+// CHECK-NEXT: return
// CHECK-NEXT: }
func.func @test_tracked_rewrite() {
- %0 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1)
- %1 = "test.replace_me"() {replacement = "test.drop_mapping"} : () -> (i1)
- %2 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1)
+ %0 = transform.test_dummy_payload_op {replace_me} : () -> (i1)
+ %1 = transform.test_dummy_payload_op {erase_me} : () -> (i1)
+ %2 = transform.test_dummy_payload_op {replace_me} : () -> (i1)
+ func.return
}
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
- %0 = transform.structured.match ops{["test.replace_me"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["transform.test_dummy_payload_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-remark @below {{2 iterations}}
transform.test_tracked_rewrite %0 : (!transform.any_op) -> ()
// One replacement op (test.drop_mapping) is dropped from the mapping.
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 0c3697d1171ff..835cbb3ae2d5e 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -687,33 +687,16 @@ void mlir::test::TestTrackedRewriteOp::getEffects(
transform::modifiesPayload(effects);
}
-namespace {
-/// A TrackingListener for test cases. When the replacement op is
-/// "test.update_mapping", it is considered as a replacement op in the transform
-/// state mapping. Otherwise, it is not and the original op is simply removed
-/// from the mapping.
-class TestTrackingListener : public transform::TrackingListener {
- using transform::TrackingListener::TrackingListener;
-
-protected:
- FailureOr<Operation *>
- findReplacementOp(Operation *op, ValueRange newValues) const override {
- if (newValues.size() != 1)
- return failure();
- Operation *replacement = newValues[0].getDefiningOp();
- if (!replacement)
- return failure();
- if (replacement->getName().getStringRef() != "test.update_mapping")
- return failure();
- return replacement;
- }
-};
-} // namespace
+void mlir::test::TestDummyPayloadOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ for (OpResult result : getResults())
+ transform::producesHandle(result, effects);
+}
DiagnosedSilenceableFailure
mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
- TestTrackingListener listener(state, *this);
+ transform::ErrorCheckingTrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
int64_t numIterations = 0;
@@ -721,19 +704,23 @@ mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results,
// loop body. Replacement ops are not enumerated.
for (Operation *op : state.getPayloadOps(getIn())) {
++numIterations;
- rewriter.setInsertionPointToEnd(op->getBlock());
+ (void)op;
// Erase all payload ops. The outer loop should have only one iteration.
for (Operation *op : state.getPayloadOps(getIn())) {
- if (op->getName().getStringRef() != "test.replace_me")
+ rewriter.setInsertionPoint(op);
+ if (op->hasAttr("erase_me")) {
+ rewriter.eraseOp(op);
continue;
- auto replacementName = op->getAttrOfType<StringAttr>("replacement");
- if (!replacementName)
+ }
+ if (!op->hasAttr("replace_me")) {
continue;
+ }
+
SmallVector<NamedAttribute> attributes;
- attributes.emplace_back(rewriter.getStringAttr("original_op"),
- op->getName().getIdentifier());
- OperationState opState(op->getLoc(), replacementName,
+ attributes.emplace_back(rewriter.getStringAttr("new_op"),
+ rewriter.getUnitAttr());
+ OperationState opState(op->getLoc(), op->getName().getIdentifier(),
/*operands=*/ValueRange(),
/*types=*/op->getResultTypes(), attributes);
Operation *newOp = rewriter.create(opState);
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index f7a6120666b8d..85b0440277dc1 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -467,6 +467,29 @@ def TestRequiredMemoryEffectsOp
let cppNamespace = "::mlir::test";
}
+// This op is used as a payload op. It must be a registered op, so that it can
+// be created with "RewriterBase::replaceOpWithNewOp" (needed for a test case).
+// Since only TransformOpInterface can be injected into the transform dialect,
+// this op implements the interface, even though it is not used as a transform
+// op.
+def TestDummyPayloadOp
+ : Op<Transform_Dialect, "test_dummy_payload_op",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface]> {
+ let arguments = (ins Variadic<AnyType>:$args);
+ let results = (outs Variadic<AnyType>:$outs);
+ let assemblyFormat = "$args attr-dict `:` functional-type(operands, results)";
+ let cppNamespace = "::mlir::test";
+
+ let extraClassDeclaration = [{
+ DiagnosedSilenceableFailure apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ llvm_unreachable("op should not be used as a transform");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ }];
+}
+
def TestTrackedRewriteOp
: Op<Transform_Dialect, "test_tracked_rewrite",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
More information about the Mlir-commits
mailing list