[Mlir-commits] [mlir] [mlir][TD] Allow op printing flags as `transform.print` attrs (PR #86846)
Jakub Kuderski
llvmlistbot at llvm.org
Sat Mar 30 20:18:28 PDT 2024
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/86846
>From 9b8b00d97c122fcbf3279a168751db1f8fbb83b4 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 27 Mar 2024 14:09:09 -0400
Subject: [PATCH 1/4] [mlir][TD] Allow op printing flags as `transform.print`
attrs
Introduce 3 new optional attributes to the `transform.print` ops:
* `assume_verified`
* `use_local_scope`
* `skip_regions`
The primary motivation is to allow printing on large inputs that
otherwise take forever to print and verify. For the full context, see
this IREE issue: https://github.com/openxla/iree/issues/16901.
Also add some tests and fix the op description.
---
.../mlir/Dialect/Transform/IR/TransformOps.td | 19 ++++++--
.../lib/Dialect/Transform/IR/TransformOps.cpp | 15 +++++-
mlir/test/Dialect/Transform/ops.mlir | 10 ++--
.../Transform/test-interpreter-printing.mlir | 48 +++++++++++++++++++
4 files changed, 83 insertions(+), 9 deletions(-)
create mode 100644 mlir/test/Dialect/Transform/test-interpreter-printing.mlir
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index bf1a8016cd9df6..21c9595860d4c5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -1098,15 +1098,28 @@ def PrintOp : TransformDialectOp<"print",
MatchOpInterface]> {
let summary = "Dump each payload op";
let description = [{
- This op dumps each payload op that is associated with the `target` operand
- to stderr. It also prints the `name` string attribute. If no target is
+ Prints each payload op that is associated with the `target` operand to
+ `stdout`. It also prints the `name` string attribute. If no target is
specified, the top-level op is dumped.
This op is useful for printf-style debugging.
+
+ Supported printing flag attributes:
+ * `assume_verified` -- skips verification when the unit attribute is
+ specified. This improves performace but may lead to crashes and
+ unexpected behavior when the printed payload op is invalid.
+ * `use_local_scope` -- prints in local scope when the unit attribute is
+ specified. This improves performance but may not be identical to
+ printing within the full module.
+ * `skip_regions` -- does not print regions of operations when the unit
+ attribute is specified.
}];
let arguments = (ins Optional<TransformHandleTypeInterface>:$target,
- OptionalAttr<StrAttr>:$name);
+ OptionalAttr<StrAttr>:$name,
+ OptionalAttr<UnitAttr>:$assume_verified,
+ OptionalAttr<UnitAttr>:$use_local_scope,
+ OptionalAttr<UnitAttr>:$skip_regions);
let results = (outs);
let builders = [
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 578b2492bbab46..cd26bdcc5a492f 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/CallInterfaces.h"
@@ -2633,8 +2634,18 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
}
llvm::outs() << "]]]\n";
- for (Operation *target : state.getPayloadOps(getTarget()))
- llvm::outs() << *target << "\n";
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ OpPrintingFlags printFlags;
+ if (getAssumeVerified().value_or(false))
+ printFlags.assumeVerified();
+ if (getUseLocalScope().value_or(false))
+ printFlags.skipRegions();
+ if (getSkipRegions().value_or(false))
+ printFlags.skipRegions();
+
+ target->print(llvm::outs(), printFlags);
+ llvm::outs() << "\n";
+ }
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index a718d6a9e9fd90..ecef7e181e9039 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -86,16 +86,18 @@ transform.sequence failures(propagate) {
}
// CHECK: transform.sequence
-// CHECK: print
-// CHECK: print
-// CHECK: print
-// CHECK: print
+// CHECK-COUNT-9: print
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
transform.print %arg0 : !transform.any_op
transform.print
transform.print %arg0 {name = "test"} : !transform.any_op
transform.print {name = "test"}
+ transform.print {name = "test", assume_verified}
+ transform.print %arg0 {assume_verified} : !transform.any_op
+ transform.print %arg0 {use_local_scope} : !transform.any_op
+ transform.print %arg0 {skip_regions} : !transform.any_op
+ transform.print %arg0 {assume_verified, use_local_scope, skip_regions} : !transform.any_op
}
// CHECK: transform.sequence
diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
new file mode 100644
index 00000000000000..34e914bf5a76cf
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func @nested_ops() {
+ "test.foo"() ({
+ "test.foo"() ({
+ "test.qux"() ({
+ "test.baz"() ({
+ "test.bar"() : () -> ()
+ }) : () -> ()
+ }) : () -> ()
+ }) : () -> ()
+ }) : () -> ()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ %baz = transform.structured.match ops{["test.baz"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+ // CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
+ // CHECK-NEXT: module {
+ transform.print {name = "START"}
+ // CHECK{LITERAL}: [[[ IR printer: ]]]
+ // CHECK-NEXT: "test.baz"() ({
+ // CHECK-NEXT: "test.bar"() : () -> ()
+ // CHECK-NEXT: }) : () -> ()
+ transform.print %baz : !transform.any_op
+
+ // CHECK{LITERAL}: [[[ IR printer: Baz ]]]
+ // CHECK-NEXT: "test.baz"() ({
+ transform.print %baz {name = "Baz"} : !transform.any_op
+
+ // CHECK{LITERAL}: [[[ IR printer: No verify ]]]
+ // CHECK-NEXT: "test.baz"() ({
+ transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
+
+ // CHECK{LITERAL}: [[[ IR printer: Local scope ]]]
+ // CHECK-NEXT: "test.baz"() ({...}) : () -> ()
+ transform.print %baz {name = "Local scope", use_local_scope} : !transform.any_op
+
+ // CHECK{LITERAL}: [[[ IR printer: No region ]]]
+ // CHECK-NEXT: "test.baz"() ({...}) : () -> ()
+ transform.print %baz {name = "No region", skip_regions} : !transform.any_op
+
+ // CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]]
+ transform.print {name = "END"}
+ transform.yield
+ }
+}
>From 81231765799f79e70c94eea9bc6d20ca6961e27b Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 27 Mar 2024 14:15:09 -0400
Subject: [PATCH 2/4] Fixup
---
mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 2 +-
mlir/test/Dialect/Transform/test-interpreter-printing.mlir | 3 ++-
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index cd26bdcc5a492f..cfbdd13a517779 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2639,7 +2639,7 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
if (getAssumeVerified().value_or(false))
printFlags.assumeVerified();
if (getUseLocalScope().value_or(false))
- printFlags.skipRegions();
+ printFlags.useLocalScope();
if (getSkipRegions().value_or(false))
printFlags.skipRegions();
diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
index 34e914bf5a76cf..24fc2281c5abfa 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -34,7 +34,8 @@ module attributes {transform.with_named_sequence} {
transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
// CHECK{LITERAL}: [[[ IR printer: Local scope ]]]
- // CHECK-NEXT: "test.baz"() ({...}) : () -> ()
+ // CHECK-NEXT: "test.baz"() ({
+ // CHECK-NEXT: "test.bar"() : () -> ()
transform.print %baz {name = "Local scope", use_local_scope} : !transform.any_op
// CHECK{LITERAL}: [[[ IR printer: No region ]]]
>From 4b5257ec8181105f7085f45d0bb87733f566a019 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 27 Mar 2024 15:44:15 -0400
Subject: [PATCH 3/4] Address comments
---
.../Transform/test-interpreter-printing.mlir | 16 ++++++++++------
1 file changed, 10 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
index 24fc2281c5abfa..07c14ce8f2db34 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect | FileCheck %s
func.func @nested_ops() {
"test.foo"() ({
@@ -19,6 +19,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
// CHECK-NEXT: module {
transform.print {name = "START"}
+
// CHECK{LITERAL}: [[[ IR printer: ]]]
// CHECK-NEXT: "test.baz"() ({
// CHECK-NEXT: "test.bar"() : () -> ()
@@ -29,19 +30,22 @@ module attributes {transform.with_named_sequence} {
// CHECK-NEXT: "test.baz"() ({
transform.print %baz {name = "Baz"} : !transform.any_op
+ // CHECK{LITERAL}: [[[ IR printer: No region ]]]
+ // CHECK-NEXT: "test.baz"() ({...}) : () -> ()
+ transform.print %baz {name = "No region", skip_regions} : !transform.any_op
+
+ // This is difficult to test properly, only check that this prints the op
+ // and does not crash.
// CHECK{LITERAL}: [[[ IR printer: No verify ]]]
// CHECK-NEXT: "test.baz"() ({
transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
+ // This is difficult to test properly, only check that this prints the op
+ // and does not crash.
// CHECK{LITERAL}: [[[ IR printer: Local scope ]]]
// CHECK-NEXT: "test.baz"() ({
- // CHECK-NEXT: "test.bar"() : () -> ()
transform.print %baz {name = "Local scope", use_local_scope} : !transform.any_op
- // CHECK{LITERAL}: [[[ IR printer: No region ]]]
- // CHECK-NEXT: "test.baz"() ({...}) : () -> ()
- transform.print %baz {name = "No region", skip_regions} : !transform.any_op
-
// CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]]
transform.print {name = "END"}
transform.yield
>From 83ca76b1acc94e014bdc48c0d232e83ba2229a68 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 30 Mar 2024 23:13:32 -0400
Subject: [PATCH 4/4] Add local scope and verification tests
---
.../lib/Dialect/Transform/IR/TransformOps.cpp | 20 +++++++------
.../Transform/test-interpreter-printing.mlir | 29 ++++++++++++-------
2 files changed, 29 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index cfbdd13a517779..043a4ebf894d5f 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2628,21 +2628,23 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
if (getName().has_value())
llvm::outs() << *getName() << " ";
+ OpPrintingFlags printFlags;
+ if (getAssumeVerified().value_or(false))
+ printFlags.assumeVerified();
+ if (getUseLocalScope().value_or(false))
+ printFlags.useLocalScope();
+ if (getSkipRegions().value_or(false))
+ printFlags.skipRegions();
+
if (!getTarget()) {
- llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n";
+ llvm::outs() << "top-level ]]]\n";
+ state.getTopLevel()->print(llvm::outs(), printFlags);
+ llvm::outs() << "\n";
return DiagnosedSilenceableFailure::success();
}
llvm::outs() << "]]]\n";
for (Operation *target : state.getPayloadOps(getTarget())) {
- OpPrintingFlags printFlags;
- if (getAssumeVerified().value_or(false))
- printFlags.assumeVerified();
- if (getUseLocalScope().value_or(false))
- printFlags.useLocalScope();
- if (getSkipRegions().value_or(false))
- printFlags.skipRegions();
-
target->print(llvm::outs(), printFlags);
llvm::outs() << "\n";
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
index 07c14ce8f2db34..d6a4bd36214084 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -1,9 +1,13 @@
-// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter --allow-unregistered-dialect --verify-diagnostics | FileCheck %s
+
+// RUN: mlir-opt %s --transform-interpreter --allow-unregistered-dialect --verify-diagnostics \
+// RUN: --mlir-print-debuginfo | FileCheck %s --check-prefix=CHECK-LOC
func.func @nested_ops() {
"test.foo"() ({
"test.foo"() ({
"test.qux"() ({
+ // expected-error @below{{fail_to_verify is set}}
"test.baz"() ({
"test.bar"() : () -> ()
}) : () -> ()
@@ -14,12 +18,21 @@ func.func @nested_ops() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
- %baz = transform.structured.match ops{["test.baz"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-
// CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
// CHECK-NEXT: module {
+ // CHECK-LOC-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
+ // CHECK-LOC-NEXT: #{{.+}} = loc(
+ // CHECK-LOC-NEXT: module {
transform.print {name = "START"}
+ // CHECK{LITERAL}: [[[ IR printer: Local scope top-level ]]]
+ // CHECK-NEXT: module {
+ // CHECK-LOC{LITERAL}: [[[ IR printer: Local scope top-level ]]]
+ // CHECK-LOC-NEXT: module {
+ transform.print {name = "Local scope", use_local_scope}
+
+ %baz = transform.structured.match ops{["test.baz"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
// CHECK{LITERAL}: [[[ IR printer: ]]]
// CHECK-NEXT: "test.baz"() ({
// CHECK-NEXT: "test.bar"() : () -> ()
@@ -34,18 +47,12 @@ module attributes {transform.with_named_sequence} {
// CHECK-NEXT: "test.baz"() ({...}) : () -> ()
transform.print %baz {name = "No region", skip_regions} : !transform.any_op
- // This is difficult to test properly, only check that this prints the op
- // and does not crash.
// CHECK{LITERAL}: [[[ IR printer: No verify ]]]
// CHECK-NEXT: "test.baz"() ({
+ // CHECK-NEXT: transform.test_dummy_payload_op {fail_to_verify} : () -> ()
+ transform.test_produce_invalid_ir %baz : !transform.any_op
transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
- // This is difficult to test properly, only check that this prints the op
- // and does not crash.
- // CHECK{LITERAL}: [[[ IR printer: Local scope ]]]
- // CHECK-NEXT: "test.baz"() ({
- transform.print %baz {name = "Local scope", use_local_scope} : !transform.any_op
-
// CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]]
transform.print {name = "END"}
transform.yield
More information about the Mlir-commits
mailing list