[Mlir-commits] [mlir] a8cfa7c - [mlir][TD] Allow op printing flags as `transform.print` attrs (#86846)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 1 09:32:27 PDT 2024
Author: Jakub Kuderski
Date: 2024-04-01T12:32:23-04:00
New Revision: a8cfa7cbdf6cc1a94ed25c90897d2e031f77a5a9
URL: https://github.com/llvm/llvm-project/commit/a8cfa7cbdf6cc1a94ed25c90897d2e031f77a5a9
DIFF: https://github.com/llvm/llvm-project/commit/a8cfa7cbdf6cc1a94ed25c90897d2e031f77a5a9.diff
LOG: [mlir][TD] Allow op printing flags as `transform.print` attrs (#86846)
Introduce 3 new optional attributes to the `transform.print` ops:
* `assume_verified`
* `use_local_scope`
* `skip_regions`
The primary motivation is to allow printing on large inputs that
otherwise take forever to print and verify. For the full context, see
this IREE issue: https://github.com/openxla/iree/issues/16901.
Also add some tests and fix the op description.
Added:
mlir/test/Dialect/Transform/test-interpreter-printing.mlir
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index bf1a8016cd9df6..21c9595860d4c5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -1098,15 +1098,28 @@ def PrintOp : TransformDialectOp<"print",
MatchOpInterface]> {
let summary = "Dump each payload op";
let description = [{
- This op dumps each payload op that is associated with the `target` operand
- to stderr. It also prints the `name` string attribute. If no target is
+ Prints each payload op that is associated with the `target` operand to
+ `stdout`. It also prints the `name` string attribute. If no target is
specified, the top-level op is dumped.
This op is useful for printf-style debugging.
+
+ Supported printing flag attributes:
+ * `assume_verified` -- skips verification when the unit attribute is
+ specified. This improves performace but may lead to crashes and
+ unexpected behavior when the printed payload op is invalid.
+ * `use_local_scope` -- prints in local scope when the unit attribute is
+ specified. This improves performance but may not be identical to
+ printing within the full module.
+ * `skip_regions` -- does not print regions of operations when the unit
+ attribute is specified.
}];
let arguments = (ins Optional<TransformHandleTypeInterface>:$target,
- OptionalAttr<StrAttr>:$name);
+ OptionalAttr<StrAttr>:$name,
+ OptionalAttr<UnitAttr>:$assume_verified,
+ OptionalAttr<UnitAttr>:$use_local_scope,
+ OptionalAttr<UnitAttr>:$skip_regions);
let results = (outs);
let builders = [
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index c8d06ba157b904..dc19022219e5b2 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/CallInterfaces.h"
@@ -2627,14 +2628,26 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
if (getName().has_value())
llvm::outs() << *getName() << " ";
+ OpPrintingFlags printFlags;
+ if (getAssumeVerified().value_or(false))
+ printFlags.assumeVerified();
+ if (getUseLocalScope().value_or(false))
+ printFlags.useLocalScope();
+ if (getSkipRegions().value_or(false))
+ printFlags.skipRegions();
+
if (!getTarget()) {
- llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n";
+ llvm::outs() << "top-level ]]]\n";
+ state.getTopLevel()->print(llvm::outs(), printFlags);
+ llvm::outs() << "\n";
return DiagnosedSilenceableFailure::success();
}
llvm::outs() << "]]]\n";
- for (Operation *target : state.getPayloadOps(getTarget()))
- llvm::outs() << *target << "\n";
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ target->print(llvm::outs(), printFlags);
+ llvm::outs() << "\n";
+ }
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index a718d6a9e9fd90..ecef7e181e9039 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -86,16 +86,18 @@ transform.sequence failures(propagate) {
}
// CHECK: transform.sequence
-// CHECK: print
-// CHECK: print
-// CHECK: print
-// CHECK: print
+// CHECK-COUNT-9: print
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
transform.print %arg0 : !transform.any_op
transform.print
transform.print %arg0 {name = "test"} : !transform.any_op
transform.print {name = "test"}
+ transform.print {name = "test", assume_verified}
+ transform.print %arg0 {assume_verified} : !transform.any_op
+ transform.print %arg0 {use_local_scope} : !transform.any_op
+ transform.print %arg0 {skip_regions} : !transform.any_op
+ transform.print %arg0 {assume_verified, use_local_scope, skip_regions} : !transform.any_op
}
// CHECK: transform.sequence
diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
new file mode 100644
index 00000000000000..a54c83d2b249eb
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt %s --transform-interpreter --allow-unregistered-dialect --verify-diagnostics | FileCheck %s
+
+// RUN: mlir-opt %s --transform-interpreter --allow-unregistered-dialect --verify-diagnostics \
+// RUN: --mlir-print-debuginfo | FileCheck %s --check-prefix=CHECK-LOC
+
+func.func @nested_ops() {
+ "test.qux"() ({
+ // expected-error @below{{fail_to_verify is set}}
+ "test.baz"() ({
+ "test.bar"() : () -> ()
+ }) : () -> ()
+ }) : () -> ()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
+ // CHECK-NEXT: module {
+ // CHECK-LOC-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
+ // CHECK-LOC-NEXT: #{{.+}} = loc(
+ // CHECK-LOC-NEXT: module {
+ transform.print {name = "START"}
+
+ // CHECK{LITERAL}: [[[ IR printer: Local scope top-level ]]]
+ // CHECK-NEXT: module {
+ // CHECK-LOC{LITERAL}: [[[ IR printer: Local scope top-level ]]]
+ // CHECK-LOC-NEXT: module {
+ transform.print {name = "Local scope", use_local_scope}
+
+ %baz = transform.structured.match ops{["test.baz"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+ // CHECK{LITERAL}: [[[ IR printer: ]]]
+ // CHECK-NEXT: "test.baz"() ({
+ // CHECK-NEXT: "test.bar"() : () -> ()
+ // CHECK-NEXT: }) : () -> ()
+ transform.print %baz : !transform.any_op
+
+ // CHECK{LITERAL}: [[[ IR printer: Baz ]]]
+ // CHECK-NEXT: "test.baz"() ({
+ transform.print %baz {name = "Baz"} : !transform.any_op
+
+ // CHECK{LITERAL}: [[[ IR printer: No region ]]]
+ // CHECK-NEXT: "test.baz"() ({...}) : () -> ()
+ transform.print %baz {name = "No region", skip_regions} : !transform.any_op
+
+ // CHECK{LITERAL}: [[[ IR printer: No verify ]]]
+ // CHECK-NEXT: "test.baz"() ({
+ // CHECK-NEXT: transform.test_dummy_payload_op {fail_to_verify} : () -> ()
+ transform.test_produce_invalid_ir %baz : !transform.any_op
+ transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
+
+ // CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]]
+ transform.print {name = "END"}
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list