[Mlir-commits] [mlir] [mlir] transform.apply_patterns support more config options (PR #88484)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Tue Apr 16 03:05:37 PDT 2024
https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/88484
>From 26505d67ef588aa9446cf53d80707c187506a6d9 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Fri, 12 Apr 2024 08:00:03 +0000
Subject: [PATCH] [mlir] transform.apply_patterns support more config options
Greedy rewrite driver has options to control the number of rewrites
applies. Expose those via the corresponding transform op.
---
.../mlir/Dialect/Transform/IR/TransformOps.td | 5 +++-
.../lib/Dialect/Transform/IR/TransformOps.cpp | 7 +++++
.../Transform/test-pattern-application.mlir | 30 +++++++++++++++++++
3 files changed, 41 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 21c9595860d4c5..fbac1ffb621fd2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -331,7 +331,10 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
}];
let arguments = (ins
- TransformHandleTypeInterface:$target, UnitAttr:$apply_cse);
+ TransformHandleTypeInterface:$target,
+ UnitAttr:$apply_cse,
+ DefaultValuedAttr<I64Attr, "static_cast<uint64_t>(-1)">:$max_iterations,
+ DefaultValuedAttr<I64Attr, "static_cast<uint64_t>(-1)">:$max_num_rewrites);
let results = (outs);
let regions = (region MaxSizedRegion<1>:$patterns);
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index dc19022219e5b2..53f958caa0bdb7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -396,6 +396,13 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
static_cast<RewriterBase::Listener *>(rewriter.getListener());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+ config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
+ ? GreedyRewriteConfig::kNoLimit
+ : getMaxIterations();
+ config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
+ ? GreedyRewriteConfig::kNoLimit
+ : getMaxNumRewrites();
+
// Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
// was requested, apply the greedy pattern rewrite only once. (The greedy
// pattern rewrite driver already iterates to a fixpoint internally.)
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index fa8a555af92188..f78b4b6f6798c5 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -26,6 +26,36 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @limited_updates
+func.func @limited_updates() {
+ "test.container"() ({
+ // Only one is replaced.
+ // CHECK: "test.foo"() {replace_with_new_op = "test.foo"}
+ // CHECK: "test.foo"() : ()
+ %0 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32)
+ %1 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32)
+ }) : () -> ()
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // Pattern application will fail because of the upper limit, wrap in
+ // sequence to suppress the error message.
+ transform.sequence %arg0 : !transform.any_op failures(suppress) {
+ ^bb0(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.transform.test_patterns
+ } {max_num_rewrites = 1} : !transform.any_op
+ }
+ transform.yield
+ }
+}
+
+// -----
+
func.func @replacement_op_not_found() {
"test.container"() ({
// expected-note @below {{[0] replaced op}}
More information about the Mlir-commits
mailing list