[Mlir-commits] [mlir] [mlir][TD] Allow op printing flags as `transform.print` attrs (PR #86846)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 27 11:13:50 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
Introduce 3 new optional attributes to the `transform.print` ops:
* `assume_verified`
* `use_local_scope`
* `skip_regions`
The primary motivation is to allow printing on large inputs that otherwise take forever to print and verify. For the full context, see this IREE issue: https://github.com/openxla/iree/issues/16901.
Also add some tests and fix the op description.
---
Full diff: https://github.com/llvm/llvm-project/pull/86846.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+16-3)
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+13-2)
- (modified) mlir/test/Dialect/Transform/ops.mlir (+6-4)
- (added) mlir/test/Dialect/Transform/test-interpreter-printing.mlir (+48)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 9caa7632c177de..df42677979a817 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -1098,15 +1098,28 @@ def PrintOp : TransformDialectOp<"print",
MatchOpInterface]> {
let summary = "Dump each payload op";
let description = [{
- This op dumps each payload op that is associated with the `target` operand
- to stderr. It also prints the `name` string attribute. If no target is
+ Prints each payload op that is associated with the `target` operand to
+ `stdout`. It also prints the `name` string attribute. If no target is
specified, the top-level op is dumped.
This op is useful for printf-style debugging.
+
+ Supported printing flag attributes:
+ * `assume_verified` -- skips verification when the unit attribute is
+ specified. This improves performace but may lead to crashes and
+ unexpected behavior when the printed payload op is invalid.
+ * `use_local_scope` -- prints in local scope when the unit attribute is
+ specified. This improves performance but may not be identical to
+ printing within the full module.
+ * `skip_regions` -- does not print regions of operations when the unit
+ attribute is specified.
}];
let arguments = (ins Optional<TransformHandleTypeInterface>:$target,
- OptionalAttr<StrAttr>:$name);
+ OptionalAttr<StrAttr>:$name,
+ OptionalAttr<UnitAttr>:$assume_verified,
+ OptionalAttr<UnitAttr>:$use_local_scope,
+ OptionalAttr<UnitAttr>:$skip_regions);
let results = (outs);
let builders = [
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index abd557a508a103..7573125cbe2475 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/CallInterfaces.h"
@@ -2620,8 +2621,18 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
}
llvm::outs() << "]]]\n";
- for (Operation *target : state.getPayloadOps(getTarget()))
- llvm::outs() << *target << "\n";
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ OpPrintingFlags printFlags;
+ if (getAssumeVerified().value_or(false))
+ printFlags.assumeVerified();
+ if (getUseLocalScope().value_or(false))
+ printFlags.skipRegions();
+ if (getSkipRegions().value_or(false))
+ printFlags.skipRegions();
+
+ target->print(llvm::outs(), printFlags);
+ llvm::outs() << "\n";
+ }
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index a718d6a9e9fd90..ecef7e181e9039 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -86,16 +86,18 @@ transform.sequence failures(propagate) {
}
// CHECK: transform.sequence
-// CHECK: print
-// CHECK: print
-// CHECK: print
-// CHECK: print
+// CHECK-COUNT-9: print
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
transform.print %arg0 : !transform.any_op
transform.print
transform.print %arg0 {name = "test"} : !transform.any_op
transform.print {name = "test"}
+ transform.print {name = "test", assume_verified}
+ transform.print %arg0 {assume_verified} : !transform.any_op
+ transform.print %arg0 {use_local_scope} : !transform.any_op
+ transform.print %arg0 {skip_regions} : !transform.any_op
+ transform.print %arg0 {assume_verified, use_local_scope, skip_regions} : !transform.any_op
}
// CHECK: transform.sequence
diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
new file mode 100644
index 00000000000000..34e914bf5a76cf
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func @nested_ops() {
+ "test.foo"() ({
+ "test.foo"() ({
+ "test.qux"() ({
+ "test.baz"() ({
+ "test.bar"() : () -> ()
+ }) : () -> ()
+ }) : () -> ()
+ }) : () -> ()
+ }) : () -> ()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ %baz = transform.structured.match ops{["test.baz"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+ // CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
+ // CHECK-NEXT: module {
+ transform.print {name = "START"}
+ // CHECK{LITERAL}: [[[ IR printer: ]]]
+ // CHECK-NEXT: "test.baz"() ({
+ // CHECK-NEXT: "test.bar"() : () -> ()
+ // CHECK-NEXT: }) : () -> ()
+ transform.print %baz : !transform.any_op
+
+ // CHECK{LITERAL}: [[[ IR printer: Baz ]]]
+ // CHECK-NEXT: "test.baz"() ({
+ transform.print %baz {name = "Baz"} : !transform.any_op
+
+ // CHECK{LITERAL}: [[[ IR printer: No verify ]]]
+ // CHECK-NEXT: "test.baz"() ({
+ transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
+
+ // CHECK{LITERAL}: [[[ IR printer: Local scope ]]]
+ // CHECK-NEXT: "test.baz"() ({...}) : () -> ()
+ transform.print %baz {name = "Local scope", use_local_scope} : !transform.any_op
+
+ // CHECK{LITERAL}: [[[ IR printer: No region ]]]
+ // CHECK-NEXT: "test.baz"() ({...}) : () -> ()
+ transform.print %baz {name = "No region", skip_regions} : !transform.any_op
+
+ // CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]]
+ transform.print {name = "END"}
+ transform.yield
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/86846
More information about the Mlir-commits
mailing list