[Mlir-commits] [mlir] 37b26bf - [mlir] transform.apply_patterns support more config options (#88484)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 17 05:24:54 PDT 2024


Author: Oleksandr "Alex" Zinenko
Date: 2024-04-17T14:24:51+02:00
New Revision: 37b26bf48b9894ed0c13fd1aede23472660fb75e

URL: https://github.com/llvm/llvm-project/commit/37b26bf48b9894ed0c13fd1aede23472660fb75e
DIFF: https://github.com/llvm/llvm-project/commit/37b26bf48b9894ed0c13fd1aede23472660fb75e.diff

LOG: [mlir] transform.apply_patterns support more config options (#88484)

Greedy rewrite driver has options to control the number of rewrites
applies. Expose those via the corresponding transform op.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-pattern-application.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 21c9595860d4c5..fbac1ffb621fd2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -331,7 +331,10 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
   }];
 
   let arguments = (ins
-    TransformHandleTypeInterface:$target, UnitAttr:$apply_cse);
+    TransformHandleTypeInterface:$target,
+    UnitAttr:$apply_cse,
+    DefaultValuedAttr<I64Attr, "static_cast<uint64_t>(-1)">:$max_iterations,
+    DefaultValuedAttr<I64Attr, "static_cast<uint64_t>(-1)">:$max_num_rewrites);
   let results = (outs);
   let regions = (region MaxSizedRegion<1>:$patterns);
 

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index dc19022219e5b2..53f958caa0bdb7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -396,6 +396,13 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
       static_cast<RewriterBase::Listener *>(rewriter.getListener());
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
 
+  config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
+                             ? GreedyRewriteConfig::kNoLimit
+                             : getMaxIterations();
+  config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
+                              ? GreedyRewriteConfig::kNoLimit
+                              : getMaxNumRewrites();
+
   // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
   // was requested, apply the greedy pattern rewrite only once. (The greedy
   // pattern rewrite driver already iterates to a fixpoint internally.)

diff  --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index fa8a555af92188..f78b4b6f6798c5 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -26,6 +26,36 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @limited_updates
+func.func @limited_updates() {
+  "test.container"() ({
+    // Only one is replaced.
+    // CHECK: "test.foo"() {replace_with_new_op = "test.foo"}
+    // CHECK: "test.foo"() : ()
+    %0 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32)
+    %1 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32)
+  }) : () -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // Pattern application will fail because of the upper limit, wrap in
+    // sequence to suppress the error message.
+    transform.sequence %arg0 : !transform.any_op failures(suppress) {
+    ^bb0(%arg1: !transform.any_op):
+      %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+      %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+      transform.apply_patterns to %0 {
+        transform.apply_patterns.transform.test_patterns
+      }  {max_num_rewrites = 1} : !transform.any_op
+    }
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @replacement_op_not_found() {
   "test.container"() ({
     // expected-note @below {{[0] replaced op}}


        


More information about the Mlir-commits mailing list