[Mlir-commits] [mlir] 5a10f20 - [mlir][transform] Add region to ApplyPatternsOp
Matthias Springer
llvmlistbot at llvm.org
Tue Jun 6 00:16:05 PDT 2023
Author: Matthias Springer
Date: 2023-06-06T09:09:41+02:00
New Revision: 5a10f207cc3714195281b6db11ec3f0fe9110228
URL: https://github.com/llvm/llvm-project/commit/5a10f207cc3714195281b6db11ec3f0fe9110228
DIFF: https://github.com/llvm/llvm-project/commit/5a10f207cc3714195281b6db11ec3f0fe9110228.diff
LOG: [mlir][transform] Add region to ApplyPatternsOp
Patterns should be selected by adding ops that implement `PatternDescriptorOpInterface` to the region of `apply_pattern` ops. Such ops can have operands, allowing for pattern parameterization. The existing way of selecting patterns from the PatternRegistry is deprecated.
Differential Revision: https://reviews.llvm.org/D152167
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/Dialect/Transform/test-pattern-application.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index c1c4387d1587e..0a9a6c18ad883 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -228,4 +228,25 @@ def FindPayloadReplacementOpInterface
];
}
+def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
+ let description = [{
+ This interface should be implemented by ops that select patterns of a
+ `transform.apply_patterns` op. It provides a method to populate a rewrite
+ pattern set with patterns.
+ }];
+
+ let cppNamespace = "::mlir::transform";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate patterns into the given pattern set.
+ }],
+ /*returnType=*/"void",
+ /*name=*/"populatePatterns",
+ /*arguments=*/(ins "RewritePatternSet &":$patterns)
+ >,
+ ];
+}
+
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index b6740502d2bd1..5b7e6ca8ca14e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -16,6 +16,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/RegionKindInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Dialect/Transform/IR/MatchInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
@@ -128,17 +129,20 @@ def AnnotateOp : TransformDialectOp<"annotate",
def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
[TransformOpInterface, TransformEachOpTrait,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]
+ # GraphRegionNoTerminator.traits> {
let summary = "Greedily applies patterns to the body of the targeted op";
let description = [{
This transform greedily applies the specified patterns to the body of the
targeted op until a fixpoint was reached. Patterns are not applied to the
targeted op itself.
- Only patterns that were registered in the transform dialect's
- `PatternRegistry` are available. Additional patterns can be registered as
- part of transform dialect extensions. "canonicalization" is a special set
- of patterns that refers to all canonicalization patterns of all loaded
+ The patterns that should be applied are specified in the graph region of
+ this op. They must implement the `PatternDescriptorOpInterface`.
+
+ (Deprecated) In addition, patterns that were registered in the transform
+ dialect's `PatternRegistry` are available. "canonicalization" is a special
+ set of patterns that refers to all canonicalization patterns of all loaded
dialects.
This transform only reads the target handle and modifies the payload. If a
@@ -160,7 +164,9 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
TransformHandleTypeInterface:$target, ArrayAttr:$patterns,
DefaultValuedAttr<BoolAttr, "true">:$fail_on_payload_replacement_not_found);
let results = (outs);
- let assemblyFormat = "$patterns `to` $target attr-dict `:` type($target)";
+ let regions = (region MaxSizedRegion<1>:$region);
+
+ let assemblyFormat = "$patterns `to` $target $region attr-dict `:` type($target)";
let hasVerifier = 1;
let extraClassDeclaration = [{
@@ -171,6 +177,17 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
}];
}
+def ApplyCanonicalizationPatternsOp
+ : TransformDialectOp<"apply_patterns.canonicalization",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let summary = "Populates canonicalization patterns";
+ let description = [{
+ This op populates all canonicalization patterns of all loaded dialects in
+ an `apply_patterns` transform.
+ }];
+ let assemblyFormat = "attr-dict";
+}
+
def CastOp : TransformDialectOp<"cast",
[TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<CastOpInterface>,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 20bed31c34203..d4c4327392c28 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -27,12 +27,15 @@ void transform::detail::checkImplementsTransformOpInterface(
RegisteredOperationName opName =
*RegisteredOperationName::lookup(name, context);
assert((opName.hasInterface<TransformOpInterface>() ||
+ opName.hasInterface<PatternDescriptorOpInterface>() ||
opName.hasTrait<OpTrait::IsTerminator>()) &&
"non-terminator ops injected into the transform dialect must "
- "implement TransformOpInterface");
- assert(opName.hasInterface<MemoryEffectOpInterface>() &&
- "ops injected into the transform dialect must implement "
- "MemoryEffectsOpInterface");
+ "implement TransformOpInterface or PatternDescriptorOpInterface");
+ if (!opName.hasInterface<PatternDescriptorOpInterface>()) {
+ assert(opName.hasInterface<MemoryEffectOpInterface>() &&
+ "ops injected into the transform dialect must implement "
+ "MemoryEffectsOpInterface");
+ }
}
void transform::detail::checkImplementsTransformHandleTypeInterface(
@@ -57,16 +60,6 @@ void transform::TransformDialect::initialize() {
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
>();
initializeTypes();
-
- // Register all canonicalization patterns.
- getOrCreateExtraData<transform::PatternRegistry>().registerPatterns(
- "canonicalization", [](RewritePatternSet &patterns) {
- MLIRContext *ctx = patterns.getContext();
- for (Dialect *dialect : ctx->getLoadedDialects())
- dialect->getCanonicalizationPatterns(patterns);
- for (RegisteredOperationName op : ctx->getRegisteredOperations())
- op.getCanonicalizationPatterns(patterns, ctx);
- });
}
Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 987c8489703c5..2dff9a903a261 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -445,6 +445,12 @@ transform::ApplyPatternsOp::applyToOne(Operation *target,
->getExtraData<transform::PatternRegistry>();
for (Attribute attr : getPatterns())
registry.populatePatterns(attr.cast<StringAttr>(), patterns);
+ if (!getRegion().empty()) {
+ for (Operation &op : getRegion().front()) {
+ cast<transform::PatternDescriptorOpInterface>(&op).populatePatterns(
+ patterns);
+ }
+ }
// Configure the GreedyPatternRewriteDriver.
ErrorCheckingTrackingListener listener(state, *this);
@@ -491,6 +497,17 @@ LogicalResult transform::ApplyPatternsOp::verify() {
if (!registry.hasPatterns(strAttr))
return emitOpError() << "patterns not registered: " << strAttr.strref();
}
+ if (!getRegion().empty()) {
+ for (Operation &op : getRegion().front()) {
+ if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
+ InFlightDiagnostic diag = emitOpError()
+ << "expected children ops to implement "
+ "PatternDescriptorOpInterface";
+ diag.attachNote(op.getLoc()) << "op without interface";
+ return diag;
+ }
+ }
+ }
return success();
}
@@ -500,6 +517,19 @@ void transform::ApplyPatternsOp::getEffects(
transform::modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// ApplyCanonicalizationPatternsOp
+//===----------------------------------------------------------------------===//
+
+void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ MLIRContext *ctx = patterns.getContext();
+ for (Dialect *dialect : ctx->getLoadedDialects())
+ dialect->getCanonicalizationPatterns(patterns);
+ for (RegisteredOperationName op : ctx->getRegisteredOperations())
+ op.getCanonicalizationPatterns(patterns, ctx);
+}
+
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 6436c7d860c37..427b1b174acbb 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -678,7 +678,7 @@ module attributes { transform.with_named_sequence } {
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{patterns not registered: transform.invalid_pattern_identifier}}
- transform.apply_patterns ["transform.invalid_pattern_identifier"] to %arg0 : !transform.any_op
+ transform.apply_patterns ["transform.invalid_pattern_identifier"] to %arg0 {} : !transform.any_op
}
// -----
@@ -686,5 +686,15 @@ transform.sequence failures(propagate) {
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{expected "patterns" to be an array of strings}}
- transform.apply_patterns [3, 9] to %arg0 : !transform.any_op
+ transform.apply_patterns [3, 9] to %arg0 {} : !transform.any_op
+}
+
+// -----
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{expected children ops to implement PatternDescriptorOpInterface}}
+ transform.apply_patterns [] to %arg0 {
+ // expected-note @below {{op without interface}}
+ transform.named_sequence @foo()
+ } : !transform.any_op
}
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index c51543e6be4c4..55bb083eb9833 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -15,7 +15,31 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+ transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op
+ // Add an attribute to %1, which is now mapped to a new op.
+ transform.annotate %1 "annotated" : !transform.any_op
+}
+
+// -----
+
+// CHECK-LABEL: func @update_tracked_op_mapping_region()
+// CHECK: "test.container"() ({
+// CHECK: %0 = "test.foo"() {annotated} : () -> i32
+// CHECK: }) : () -> ()
+func.func @update_tracked_op_mapping_region() {
+ "test.container"() ({
+ %0 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32)
+ }) : () -> ()
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns [] to %0 {
+ transform.apply_patterns.transform.test_patterns
+ } : !transform.any_op
// Add an attribute to %1, which is now mapped to a new op.
transform.annotate %1 "annotated" : !transform.any_op
}
@@ -36,7 +60,7 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{tracking listener failed to find replacement op}}
- transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+ transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op
// %1 must be used in some way. If no replacement payload op could be found,
// an error is thrown only if the handle is not dead.
transform.annotate %1 "annotated" : !transform.any_op
@@ -60,7 +84,7 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// No error because %1 is dead.
- transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+ transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op
}
// -----
@@ -80,7 +104,7 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns ["transform.test"] to %0 {fail_on_payload_replacement_not_found = false}: !transform.any_op
+ transform.apply_patterns ["transform.test"] to %0 {} {fail_on_payload_replacement_not_found = false}: !transform.any_op
transform.annotate %1 "annotated" : !transform.any_op
}
@@ -95,8 +119,8 @@ func.func @patterns_apply_only_to_target_body() {
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
- %0 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+%0 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op
}
// -----
@@ -118,7 +142,7 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["test.erase_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.test_print_remark_at_operand %1, "matched op" : !transform.any_op
- transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+ transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op
transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op
}
@@ -138,6 +162,8 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["tensor.dim"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns ["canonicalization"] to %1 : !transform.any_op
+ transform.apply_patterns [] to %1 {
+ transform.apply_patterns.canonicalization
+ } : !transform.any_op
transform.test_print_remark_at_operand %0, "op was replaced" : !transform.any_op
}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 9af4c53cb1c86..9243723947b6d 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -776,7 +776,14 @@ class EraseOp : public RewritePattern {
return success();
}
};
+} // namespace
+void mlir::test::ApplyTestPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
+}
+
+namespace {
void populateTestPatterns(RewritePatternSet &patterns) {
patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index c02e2d97663d1..f7a6120666b8d 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -59,7 +59,6 @@ def TestProduceValueHandleToSelfOperand
let results = (outs TransformValueHandleTypeInterface:$out);
let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)";
let cppNamespace = "::mlir::test";
-
}
def TestProduceValueHandleToResult
@@ -478,4 +477,13 @@ def TestTrackedRewriteOp
let cppNamespace = "::mlir::test";
}
+def ApplyTestPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.transform.test_patterns",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let arguments = (ins);
+ let results = (outs);
+ let assemblyFormat = "attr-dict";
+ let cppNamespace = "::mlir::test";
+}
+
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
More information about the Mlir-commits
mailing list