[Mlir-commits] [mlir] [mlir][TD] Allow op printing flags as `transform.print` attrs (PR #86846)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 27 11:13:50 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

<details>
<summary>Changes</summary>

Introduce 3 new optional attributes to the `transform.print` ops:
* `assume_verified`
* `use_local_scope`
* `skip_regions`

The primary motivation is to allow printing on large inputs that otherwise take forever to print and verify. For the full context, see this IREE issue: https://github.com/openxla/iree/issues/16901.

Also add some tests and fix the op description.

---
Full diff: https://github.com/llvm/llvm-project/pull/86846.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+16-3) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+13-2) 
- (modified) mlir/test/Dialect/Transform/ops.mlir (+6-4) 
- (added) mlir/test/Dialect/Transform/test-interpreter-printing.mlir (+48) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 9caa7632c177de..df42677979a817 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -1098,15 +1098,28 @@ def PrintOp : TransformDialectOp<"print",
      MatchOpInterface]> {
   let summary = "Dump each payload op";
   let description = [{
-    This op dumps each payload op that is associated with the `target` operand
-    to stderr. It also prints the `name` string attribute. If no target is
+    Prints each payload op that is associated with the `target` operand to
+    `stdout`. It also prints the `name` string attribute. If no target is
     specified, the top-level op is dumped.
 
     This op is useful for printf-style debugging.
+
+    Supported printing flag attributes:
+    * `assume_verified` -- skips verification when the unit attribute is
+      specified. This improves performace but may lead to crashes and
+      unexpected behavior when the printed payload op is invalid.
+    * `use_local_scope` -- prints in local scope when the unit attribute is
+      specified. This improves performance but may not be identical to
+      printing within the full module.
+    * `skip_regions` -- does not print regions of operations when the unit
+      attribute is specified.
   }];
 
   let arguments = (ins Optional<TransformHandleTypeInterface>:$target,
-                       OptionalAttr<StrAttr>:$name);
+                       OptionalAttr<StrAttr>:$name,
+                       OptionalAttr<UnitAttr>:$assume_verified,
+                       OptionalAttr<UnitAttr>:$use_local_scope,
+                       OptionalAttr<UnitAttr>:$skip_regions);
   let results = (outs);
 
   let builders = [
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index abd557a508a103..7573125cbe2475 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/CallInterfaces.h"
@@ -2620,8 +2621,18 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
   }
 
   llvm::outs() << "]]]\n";
-  for (Operation *target : state.getPayloadOps(getTarget()))
-    llvm::outs() << *target << "\n";
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    OpPrintingFlags printFlags;
+    if (getAssumeVerified().value_or(false))
+      printFlags.assumeVerified();
+    if (getUseLocalScope().value_or(false))
+      printFlags.skipRegions();
+    if (getSkipRegions().value_or(false))
+      printFlags.skipRegions();
+
+    target->print(llvm::outs(), printFlags);
+    llvm::outs() << "\n";
+  }
 
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index a718d6a9e9fd90..ecef7e181e9039 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -86,16 +86,18 @@ transform.sequence failures(propagate) {
 }
 
 // CHECK: transform.sequence
-// CHECK: print
-// CHECK: print
-// CHECK: print
-// CHECK: print
+// CHECK-COUNT-9: print
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
   transform.print %arg0 : !transform.any_op
   transform.print
   transform.print %arg0 {name = "test"} : !transform.any_op
   transform.print {name = "test"}
+  transform.print {name = "test", assume_verified}
+  transform.print %arg0 {assume_verified} : !transform.any_op
+  transform.print %arg0 {use_local_scope} : !transform.any_op
+  transform.print %arg0 {skip_regions} : !transform.any_op
+  transform.print %arg0 {assume_verified, use_local_scope, skip_regions} : !transform.any_op
 }
 
 // CHECK: transform.sequence
diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
new file mode 100644
index 00000000000000..34e914bf5a76cf
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func @nested_ops() {
+  "test.foo"() ({
+    "test.foo"() ({
+      "test.qux"() ({
+        "test.baz"() ({
+          "test.bar"() : () -> ()
+        }) : () -> ()
+      }) : () -> ()
+    }) : () -> ()
+  }) : () -> ()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %baz = transform.structured.match ops{["test.baz"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+    // CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
+    // CHECK-NEXT:  module {
+    transform.print {name = "START"}
+    // CHECK{LITERAL}: [[[ IR printer: ]]]
+    // CHECK-NEXT:      "test.baz"() ({
+    // CHECK-NEXT:        "test.bar"() : () -> ()
+    // CHECK-NEXT:       }) : () -> ()
+    transform.print %baz : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: Baz ]]]
+    // CHECK-NEXT:      "test.baz"() ({
+    transform.print %baz {name = "Baz"} : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: No verify ]]]
+    // CHECK-NEXT:      "test.baz"() ({
+    transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: Local scope ]]]
+    // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
+    transform.print %baz {name = "Local scope", use_local_scope} : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: No region ]]]
+    // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
+    transform.print %baz {name = "No region", skip_regions} : !transform.any_op
+
+    // CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]]
+    transform.print {name = "END"}
+    transform.yield
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/86846


More information about the Mlir-commits mailing list