[Mlir-commits] [mlir] [mlir][TD] Allow op printing flags as `transform.print` attrs (PR #86846)

Jakub Kuderski llvmlistbot at llvm.org
Wed Mar 27 12:44:32 PDT 2024


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/86846

>From ce7913c809644db7afcfd78c31bf34c84dc69075 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 27 Mar 2024 14:09:09 -0400
Subject: [PATCH 1/3] [mlir][TD] Allow op printing flags as `transform.print`
 attrs

Introduce 3 new optional attributes to the `transform.print` ops:
* `assume_verified`
* `use_local_scope`
* `skip_regions`

The primary motivation is to allow printing on large inputs that
otherwise take forever to print and verify. For the full context, see
this IREE issue: https://github.com/openxla/iree/issues/16901.

Also add some tests and fix the op description.
---
 .../mlir/Dialect/Transform/IR/TransformOps.td | 19 ++++++--
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 15 +++++-
 mlir/test/Dialect/Transform/ops.mlir          | 10 ++--
 .../Transform/test-interpreter-printing.mlir  | 48 +++++++++++++++++++
 4 files changed, 83 insertions(+), 9 deletions(-)
 create mode 100644 mlir/test/Dialect/Transform/test-interpreter-printing.mlir

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 9caa7632c177de..df42677979a817 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -1098,15 +1098,28 @@ def PrintOp : TransformDialectOp<"print",
      MatchOpInterface]> {
   let summary = "Dump each payload op";
   let description = [{
-    This op dumps each payload op that is associated with the `target` operand
-    to stderr. It also prints the `name` string attribute. If no target is
+    Prints each payload op that is associated with the `target` operand to
+    `stdout`. It also prints the `name` string attribute. If no target is
     specified, the top-level op is dumped.
 
     This op is useful for printf-style debugging.
+
+    Supported printing flag attributes:
+    * `assume_verified` -- skips verification when the unit attribute is
+      specified. This improves performace but may lead to crashes and
+      unexpected behavior when the printed payload op is invalid.
+    * `use_local_scope` -- prints in local scope when the unit attribute is
+      specified. This improves performance but may not be identical to
+      printing within the full module.
+    * `skip_regions` -- does not print regions of operations when the unit
+      attribute is specified.
   }];
 
   let arguments = (ins Optional<TransformHandleTypeInterface>:$target,
-                       OptionalAttr<StrAttr>:$name);
+                       OptionalAttr<StrAttr>:$name,
+                       OptionalAttr<UnitAttr>:$assume_verified,
+                       OptionalAttr<UnitAttr>:$use_local_scope,
+                       OptionalAttr<UnitAttr>:$skip_regions);
   let results = (outs);
 
   let builders = [
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index abd557a508a103..7573125cbe2475 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/CallInterfaces.h"
@@ -2620,8 +2621,18 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
   }
 
   llvm::outs() << "]]]\n";
-  for (Operation *target : state.getPayloadOps(getTarget()))
-    llvm::outs() << *target << "\n";
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    OpPrintingFlags printFlags;
+    if (getAssumeVerified().value_or(false))
+      printFlags.assumeVerified();
+    if (getUseLocalScope().value_or(false))
+      printFlags.skipRegions();
+    if (getSkipRegions().value_or(false))
+      printFlags.skipRegions();
+
+    target->print(llvm::outs(), printFlags);
+    llvm::outs() << "\n";
+  }
 
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index a718d6a9e9fd90..ecef7e181e9039 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -86,16 +86,18 @@ transform.sequence failures(propagate) {
 }
 
 // CHECK: transform.sequence
-// CHECK: print
-// CHECK: print
-// CHECK: print
-// CHECK: print
+// CHECK-COUNT-9: print
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
   transform.print %arg0 : !transform.any_op
   transform.print
   transform.print %arg0 {name = "test"} : !transform.any_op
   transform.print {name = "test"}
+  transform.print {name = "test", assume_verified}
+  transform.print %arg0 {assume_verified} : !transform.any_op
+  transform.print %arg0 {use_local_scope} : !transform.any_op
+  transform.print %arg0 {skip_regions} : !transform.any_op
+  transform.print %arg0 {assume_verified, use_local_scope, skip_regions} : !transform.any_op
 }
 
 // CHECK: transform.sequence
diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
new file mode 100644
index 00000000000000..34e914bf5a76cf
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func @nested_ops() {
+  "test.foo"() ({
+    "test.foo"() ({
+      "test.qux"() ({
+        "test.baz"() ({
+          "test.bar"() : () -> ()
+        }) : () -> ()
+      }) : () -> ()
+    }) : () -> ()
+  }) : () -> ()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %baz = transform.structured.match ops{["test.baz"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+    // CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
+    // CHECK-NEXT:  module {
+    transform.print {name = "START"}
+    // CHECK{LITERAL}: [[[ IR printer: ]]]
+    // CHECK-NEXT:      "test.baz"() ({
+    // CHECK-NEXT:        "test.bar"() : () -> ()
+    // CHECK-NEXT:       }) : () -> ()
+    transform.print %baz : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: Baz ]]]
+    // CHECK-NEXT:      "test.baz"() ({
+    transform.print %baz {name = "Baz"} : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: No verify ]]]
+    // CHECK-NEXT:      "test.baz"() ({
+    transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: Local scope ]]]
+    // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
+    transform.print %baz {name = "Local scope", use_local_scope} : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: No region ]]]
+    // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
+    transform.print %baz {name = "No region", skip_regions} : !transform.any_op
+
+    // CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]]
+    transform.print {name = "END"}
+    transform.yield
+  }
+}

>From deaa5b24600f33788e777e4fdc253bc6bdfca1de Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 27 Mar 2024 14:15:09 -0400
Subject: [PATCH 2/3] Fixup

---
 mlir/lib/Dialect/Transform/IR/TransformOps.cpp             | 2 +-
 mlir/test/Dialect/Transform/test-interpreter-printing.mlir | 3 ++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7573125cbe2475..4eb4cf366c957f 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2626,7 +2626,7 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
     if (getAssumeVerified().value_or(false))
       printFlags.assumeVerified();
     if (getUseLocalScope().value_or(false))
-      printFlags.skipRegions();
+      printFlags.useLocalScope();
     if (getSkipRegions().value_or(false))
       printFlags.skipRegions();
 
diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
index 34e914bf5a76cf..24fc2281c5abfa 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -34,7 +34,8 @@ module attributes {transform.with_named_sequence} {
     transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
 
     // CHECK{LITERAL}: [[[ IR printer: Local scope ]]]
-    // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
+    // CHECK-NEXT:      "test.baz"() ({
+    // CHECK-NEXT:        "test.bar"() : () -> ()
     transform.print %baz {name = "Local scope", use_local_scope} : !transform.any_op
 
     // CHECK{LITERAL}: [[[ IR printer: No region ]]]

>From 9935c4ed7c888b70a7e6e5fd5c40cd100714990d Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 27 Mar 2024 15:44:15 -0400
Subject: [PATCH 3/3] Address comments

---
 .../Transform/test-interpreter-printing.mlir     | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
index 24fc2281c5abfa..07c14ce8f2db34 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect | FileCheck %s
 
 func.func @nested_ops() {
   "test.foo"() ({
@@ -19,6 +19,7 @@ module attributes {transform.with_named_sequence} {
     // CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
     // CHECK-NEXT:  module {
     transform.print {name = "START"}
+
     // CHECK{LITERAL}: [[[ IR printer: ]]]
     // CHECK-NEXT:      "test.baz"() ({
     // CHECK-NEXT:        "test.bar"() : () -> ()
@@ -29,19 +30,22 @@ module attributes {transform.with_named_sequence} {
     // CHECK-NEXT:      "test.baz"() ({
     transform.print %baz {name = "Baz"} : !transform.any_op
 
+    // CHECK{LITERAL}: [[[ IR printer: No region ]]]
+    // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
+    transform.print %baz {name = "No region", skip_regions} : !transform.any_op
+
+    // This is difficult to test properly, only check that this prints the op
+    // and does not crash.
     // CHECK{LITERAL}: [[[ IR printer: No verify ]]]
     // CHECK-NEXT:      "test.baz"() ({
     transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
 
+    // This is difficult to test properly, only check that this prints the op
+    // and does not crash.
     // CHECK{LITERAL}: [[[ IR printer: Local scope ]]]
     // CHECK-NEXT:      "test.baz"() ({
-    // CHECK-NEXT:        "test.bar"() : () -> ()
     transform.print %baz {name = "Local scope", use_local_scope} : !transform.any_op
 
-    // CHECK{LITERAL}: [[[ IR printer: No region ]]]
-    // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
-    transform.print %baz {name = "No region", skip_regions} : !transform.any_op
-
     // CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]]
     transform.print {name = "END"}
     transform.yield



More information about the Mlir-commits mailing list