[Mlir-commits] [mlir] 5a10f20 - [mlir][transform] Add region to ApplyPatternsOp

Matthias Springer llvmlistbot at llvm.org
Tue Jun 6 00:16:05 PDT 2023


Author: Matthias Springer
Date: 2023-06-06T09:09:41+02:00
New Revision: 5a10f207cc3714195281b6db11ec3f0fe9110228

URL: https://github.com/llvm/llvm-project/commit/5a10f207cc3714195281b6db11ec3f0fe9110228
DIFF: https://github.com/llvm/llvm-project/commit/5a10f207cc3714195281b6db11ec3f0fe9110228.diff

LOG: [mlir][transform] Add region to ApplyPatternsOp

Patterns should be selected by adding ops that implement `PatternDescriptorOpInterface` to the region of `apply_pattern` ops. Such ops can have operands, allowing for pattern parameterization. The existing way of selecting patterns from the PatternRegistry is deprecated.

Differential Revision: https://reviews.llvm.org/D152167

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/ops-invalid.mlir
    mlir/test/Dialect/Transform/test-pattern-application.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index c1c4387d1587e..0a9a6c18ad883 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -228,4 +228,25 @@ def FindPayloadReplacementOpInterface
   ];
 }
 
+def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
+  let description = [{
+    This interface should be implemented by ops that select patterns of a
+    `transform.apply_patterns` op. It provides a method to populate a rewrite
+    pattern set with patterns.
+  }];
+
+  let cppNamespace = "::mlir::transform";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate patterns into the given pattern set.
+      }],
+      /*returnType=*/"void",
+      /*name=*/"populatePatterns",
+      /*arguments=*/(ins "RewritePatternSet &":$patterns)
+    >,
+  ];
+}
+
 #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index b6740502d2bd1..5b7e6ca8ca14e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -16,6 +16,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/FunctionInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/RegionKindInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Dialect/Transform/IR/MatchInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformAttrs.td"
@@ -128,17 +129,20 @@ def AnnotateOp : TransformDialectOp<"annotate",
 
 def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
     [TransformOpInterface, TransformEachOpTrait,
-     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]
+        # GraphRegionNoTerminator.traits> {
   let summary = "Greedily applies patterns to the body of the targeted op";
   let description = [{
     This transform greedily applies the specified patterns to the body of the
     targeted op until a fixpoint was reached. Patterns are not applied to the
     targeted op itself.
 
-    Only patterns that were registered in the transform dialect's
-    `PatternRegistry` are available. Additional patterns can be registered as
-    part of transform dialect extensions. "canonicalization" is a special set
-    of patterns that refers to all canonicalization patterns of all loaded
+    The patterns that should be applied are specified in the graph region of
+    this op. They must implement the `PatternDescriptorOpInterface`.
+
+    (Deprecated) In addition, patterns that were registered in the transform
+    dialect's `PatternRegistry` are available. "canonicalization" is a special
+    set of patterns that refers to all canonicalization patterns of all loaded
     dialects.
 
     This transform only reads the target handle and modifies the payload. If a
@@ -160,7 +164,9 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
     TransformHandleTypeInterface:$target, ArrayAttr:$patterns,
     DefaultValuedAttr<BoolAttr, "true">:$fail_on_payload_replacement_not_found);
   let results = (outs);
-  let assemblyFormat = "$patterns `to` $target attr-dict `:` type($target)";
+  let regions = (region MaxSizedRegion<1>:$region);
+
+  let assemblyFormat = "$patterns `to` $target $region attr-dict `:` type($target)";
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
@@ -171,6 +177,17 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
   }];
 }
 
+def ApplyCanonicalizationPatternsOp
+    : TransformDialectOp<"apply_patterns.canonicalization",
+        [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let summary = "Populates canonicalization patterns";
+  let description = [{
+    This op populates all canonicalization patterns of all loaded dialects in
+    an `apply_patterns` transform.
+  }];
+  let assemblyFormat = "attr-dict";
+}
+
 def CastOp : TransformDialectOp<"cast",
     [TransformOpInterface, TransformEachOpTrait,
      DeclareOpInterfaceMethods<CastOpInterface>,

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 20bed31c34203..d4c4327392c28 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -27,12 +27,15 @@ void transform::detail::checkImplementsTransformOpInterface(
   RegisteredOperationName opName =
       *RegisteredOperationName::lookup(name, context);
   assert((opName.hasInterface<TransformOpInterface>() ||
+          opName.hasInterface<PatternDescriptorOpInterface>() ||
           opName.hasTrait<OpTrait::IsTerminator>()) &&
          "non-terminator ops injected into the transform dialect must "
-         "implement TransformOpInterface");
-  assert(opName.hasInterface<MemoryEffectOpInterface>() &&
-         "ops injected into the transform dialect must implement "
-         "MemoryEffectsOpInterface");
+         "implement TransformOpInterface or PatternDescriptorOpInterface");
+  if (!opName.hasInterface<PatternDescriptorOpInterface>()) {
+    assert(opName.hasInterface<MemoryEffectOpInterface>() &&
+           "ops injected into the transform dialect must implement "
+           "MemoryEffectsOpInterface");
+  }
 }
 
 void transform::detail::checkImplementsTransformHandleTypeInterface(
@@ -57,16 +60,6 @@ void transform::TransformDialect::initialize() {
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
       >();
   initializeTypes();
-
-  // Register all canonicalization patterns.
-  getOrCreateExtraData<transform::PatternRegistry>().registerPatterns(
-      "canonicalization", [](RewritePatternSet &patterns) {
-        MLIRContext *ctx = patterns.getContext();
-        for (Dialect *dialect : ctx->getLoadedDialects())
-          dialect->getCanonicalizationPatterns(patterns);
-        for (RegisteredOperationName op : ctx->getRegisteredOperations())
-          op.getCanonicalizationPatterns(patterns, ctx);
-      });
 }
 
 Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 987c8489703c5..2dff9a903a261 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -445,6 +445,12 @@ transform::ApplyPatternsOp::applyToOne(Operation *target,
                              ->getExtraData<transform::PatternRegistry>();
   for (Attribute attr : getPatterns())
     registry.populatePatterns(attr.cast<StringAttr>(), patterns);
+  if (!getRegion().empty()) {
+    for (Operation &op : getRegion().front()) {
+      cast<transform::PatternDescriptorOpInterface>(&op).populatePatterns(
+          patterns);
+    }
+  }
 
   // Configure the GreedyPatternRewriteDriver.
   ErrorCheckingTrackingListener listener(state, *this);
@@ -491,6 +497,17 @@ LogicalResult transform::ApplyPatternsOp::verify() {
     if (!registry.hasPatterns(strAttr))
       return emitOpError() << "patterns not registered: " << strAttr.strref();
   }
+  if (!getRegion().empty()) {
+    for (Operation &op : getRegion().front()) {
+      if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
+        InFlightDiagnostic diag = emitOpError()
+                                  << "expected children ops to implement "
+                                     "PatternDescriptorOpInterface";
+        diag.attachNote(op.getLoc()) << "op without interface";
+        return diag;
+      }
+    }
+  }
   return success();
 }
 
@@ -500,6 +517,19 @@ void transform::ApplyPatternsOp::getEffects(
   transform::modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// ApplyCanonicalizationPatternsOp
+//===----------------------------------------------------------------------===//
+
+void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *ctx = patterns.getContext();
+  for (Dialect *dialect : ctx->getLoadedDialects())
+    dialect->getCanonicalizationPatterns(patterns);
+  for (RegisteredOperationName op : ctx->getRegisteredOperations())
+    op.getCanonicalizationPatterns(patterns, ctx);
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 6436c7d860c37..427b1b174acbb 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -678,7 +678,7 @@ module attributes { transform.with_named_sequence } {
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
   // expected-error @below {{patterns not registered: transform.invalid_pattern_identifier}}
-  transform.apply_patterns ["transform.invalid_pattern_identifier"] to %arg0 : !transform.any_op
+  transform.apply_patterns ["transform.invalid_pattern_identifier"] to %arg0 {} : !transform.any_op
 }
 
 // -----
@@ -686,5 +686,15 @@ transform.sequence failures(propagate) {
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
   // expected-error @below {{expected "patterns" to be an array of strings}}
-  transform.apply_patterns [3, 9] to %arg0 : !transform.any_op
+  transform.apply_patterns [3, 9] to %arg0 {} : !transform.any_op
+}
+
+// -----
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{expected children ops to implement PatternDescriptorOpInterface}}
+  transform.apply_patterns [] to %arg0 {
+    // expected-note @below {{op without interface}}
+    transform.named_sequence @foo()
+  } : !transform.any_op
 }

diff  --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index c51543e6be4c4..55bb083eb9833 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -15,7 +15,31 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+  transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op
+  // Add an attribute to %1, which is now mapped to a new op.
+  transform.annotate %1 "annotated" : !transform.any_op
+}
+
+// -----
+
+// CHECK-LABEL: func @update_tracked_op_mapping_region()
+//       CHECK:   "test.container"() ({
+//       CHECK:     %0 = "test.foo"() {annotated} : () -> i32
+//       CHECK:   }) : () -> ()
+func.func @update_tracked_op_mapping_region() {
+  "test.container"() ({
+    %0 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32)
+  }) : () -> ()
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns [] to %0 {
+    transform.apply_patterns.transform.test_patterns
+  } : !transform.any_op
   // Add an attribute to %1, which is now mapped to a new op.
   transform.annotate %1 "annotated" : !transform.any_op
 }
@@ -36,7 +60,7 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   // expected-error @below {{tracking listener failed to find replacement op}}
-  transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+  transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op
   // %1 must be used in some way. If no replacement payload op could be found,
   // an error is thrown only if the handle is not dead.
   transform.annotate %1 "annotated" : !transform.any_op
@@ -60,7 +84,7 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   // No error because %1 is dead.
-  transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+  transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op
 }
 
 // -----
@@ -80,7 +104,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  transform.apply_patterns ["transform.test"] to %0 {fail_on_payload_replacement_not_found = false}: !transform.any_op
+  transform.apply_patterns ["transform.test"] to %0 {} {fail_on_payload_replacement_not_found = false}: !transform.any_op
   transform.annotate %1 "annotated" : !transform.any_op
 }
 
@@ -95,8 +119,8 @@ func.func @patterns_apply_only_to_target_body() {
 
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
-  %0 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+%0 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op
 }
 
 // -----
@@ -118,7 +142,7 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.structured.match ops{["test.erase_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   transform.test_print_remark_at_operand %1, "matched op" : !transform.any_op
-  transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+  transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op
   transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op
 }
 
@@ -138,6 +162,8 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["tensor.dim"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  transform.apply_patterns ["canonicalization"] to %1 : !transform.any_op
+  transform.apply_patterns [] to %1 {
+    transform.apply_patterns.canonicalization
+  } : !transform.any_op
   transform.test_print_remark_at_operand %0, "op was replaced" : !transform.any_op
 }

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 9af4c53cb1c86..9243723947b6d 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -776,7 +776,14 @@ class EraseOp : public RewritePattern {
     return success();
   }
 };
+} // namespace
 
+void mlir::test::ApplyTestPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
+}
+
+namespace {
 void populateTestPatterns(RewritePatternSet &patterns) {
   patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
 }

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index c02e2d97663d1..f7a6120666b8d 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -59,7 +59,6 @@ def TestProduceValueHandleToSelfOperand
   let results = (outs TransformValueHandleTypeInterface:$out);
   let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)";
   let cppNamespace = "::mlir::test";
-  
 }
 
 def TestProduceValueHandleToResult
@@ -478,4 +477,13 @@ def TestTrackedRewriteOp
   let cppNamespace = "::mlir::test";
 }
 
+def ApplyTestPatternsOp
+  : Op<Transform_Dialect, "apply_patterns.transform.test_patterns",
+      [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let arguments = (ins);
+  let results = (outs);
+  let assemblyFormat = "attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD


        


More information about the Mlir-commits mailing list