[Mlir-commits] [mlir] [mlir][TD] Allow op printing flags as `transform.print` attrs (PR #86846)

Jakub Kuderski llvmlistbot at llvm.org
Sat Mar 30 20:18:28 PDT 2024


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/86846

>From 9b8b00d97c122fcbf3279a168751db1f8fbb83b4 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 27 Mar 2024 14:09:09 -0400
Subject: [PATCH 1/4] [mlir][TD] Allow op printing flags as `transform.print`
 attrs

Introduce 3 new optional attributes to the `transform.print` ops:
* `assume_verified`
* `use_local_scope`
* `skip_regions`

The primary motivation is to allow printing on large inputs that
otherwise take forever to print and verify. For the full context, see
this IREE issue: https://github.com/openxla/iree/issues/16901.

Also add some tests and fix the op description.
---
 .../mlir/Dialect/Transform/IR/TransformOps.td | 19 ++++++--
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 15 +++++-
 mlir/test/Dialect/Transform/ops.mlir          | 10 ++--
 .../Transform/test-interpreter-printing.mlir  | 48 +++++++++++++++++++
 4 files changed, 83 insertions(+), 9 deletions(-)
 create mode 100644 mlir/test/Dialect/Transform/test-interpreter-printing.mlir

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index bf1a8016cd9df6..21c9595860d4c5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -1098,15 +1098,28 @@ def PrintOp : TransformDialectOp<"print",
      MatchOpInterface]> {
   let summary = "Dump each payload op";
   let description = [{
-    This op dumps each payload op that is associated with the `target` operand
-    to stderr. It also prints the `name` string attribute. If no target is
+    Prints each payload op that is associated with the `target` operand to
+    `stdout`. It also prints the `name` string attribute. If no target is
     specified, the top-level op is dumped.
 
     This op is useful for printf-style debugging.
+
+    Supported printing flag attributes:
+    * `assume_verified` -- skips verification when the unit attribute is
+      specified. This improves performace but may lead to crashes and
+      unexpected behavior when the printed payload op is invalid.
+    * `use_local_scope` -- prints in local scope when the unit attribute is
+      specified. This improves performance but may not be identical to
+      printing within the full module.
+    * `skip_regions` -- does not print regions of operations when the unit
+      attribute is specified.
   }];
 
   let arguments = (ins Optional<TransformHandleTypeInterface>:$target,
-                       OptionalAttr<StrAttr>:$name);
+                       OptionalAttr<StrAttr>:$name,
+                       OptionalAttr<UnitAttr>:$assume_verified,
+                       OptionalAttr<UnitAttr>:$use_local_scope,
+                       OptionalAttr<UnitAttr>:$skip_regions);
   let results = (outs);
 
   let builders = [
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 578b2492bbab46..cd26bdcc5a492f 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/CallInterfaces.h"
@@ -2633,8 +2634,18 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
   }
 
   llvm::outs() << "]]]\n";
-  for (Operation *target : state.getPayloadOps(getTarget()))
-    llvm::outs() << *target << "\n";
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    OpPrintingFlags printFlags;
+    if (getAssumeVerified().value_or(false))
+      printFlags.assumeVerified();
+    if (getUseLocalScope().value_or(false))
+      printFlags.skipRegions();
+    if (getSkipRegions().value_or(false))
+      printFlags.skipRegions();
+
+    target->print(llvm::outs(), printFlags);
+    llvm::outs() << "\n";
+  }
 
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index a718d6a9e9fd90..ecef7e181e9039 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -86,16 +86,18 @@ transform.sequence failures(propagate) {
 }
 
 // CHECK: transform.sequence
-// CHECK: print
-// CHECK: print
-// CHECK: print
-// CHECK: print
+// CHECK-COUNT-9: print
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
   transform.print %arg0 : !transform.any_op
   transform.print
   transform.print %arg0 {name = "test"} : !transform.any_op
   transform.print {name = "test"}
+  transform.print {name = "test", assume_verified}
+  transform.print %arg0 {assume_verified} : !transform.any_op
+  transform.print %arg0 {use_local_scope} : !transform.any_op
+  transform.print %arg0 {skip_regions} : !transform.any_op
+  transform.print %arg0 {assume_verified, use_local_scope, skip_regions} : !transform.any_op
 }
 
 // CHECK: transform.sequence
diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
new file mode 100644
index 00000000000000..34e914bf5a76cf
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func @nested_ops() {
+  "test.foo"() ({
+    "test.foo"() ({
+      "test.qux"() ({
+        "test.baz"() ({
+          "test.bar"() : () -> ()
+        }) : () -> ()
+      }) : () -> ()
+    }) : () -> ()
+  }) : () -> ()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %baz = transform.structured.match ops{["test.baz"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+    // CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
+    // CHECK-NEXT:  module {
+    transform.print {name = "START"}
+    // CHECK{LITERAL}: [[[ IR printer: ]]]
+    // CHECK-NEXT:      "test.baz"() ({
+    // CHECK-NEXT:        "test.bar"() : () -> ()
+    // CHECK-NEXT:       }) : () -> ()
+    transform.print %baz : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: Baz ]]]
+    // CHECK-NEXT:      "test.baz"() ({
+    transform.print %baz {name = "Baz"} : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: No verify ]]]
+    // CHECK-NEXT:      "test.baz"() ({
+    transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: Local scope ]]]
+    // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
+    transform.print %baz {name = "Local scope", use_local_scope} : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: No region ]]]
+    // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
+    transform.print %baz {name = "No region", skip_regions} : !transform.any_op
+
+    // CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]]
+    transform.print {name = "END"}
+    transform.yield
+  }
+}

>From 81231765799f79e70c94eea9bc6d20ca6961e27b Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 27 Mar 2024 14:15:09 -0400
Subject: [PATCH 2/4] Fixup

---
 mlir/lib/Dialect/Transform/IR/TransformOps.cpp             | 2 +-
 mlir/test/Dialect/Transform/test-interpreter-printing.mlir | 3 ++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index cd26bdcc5a492f..cfbdd13a517779 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2639,7 +2639,7 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
     if (getAssumeVerified().value_or(false))
       printFlags.assumeVerified();
     if (getUseLocalScope().value_or(false))
-      printFlags.skipRegions();
+      printFlags.useLocalScope();
     if (getSkipRegions().value_or(false))
       printFlags.skipRegions();
 
diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
index 34e914bf5a76cf..24fc2281c5abfa 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -34,7 +34,8 @@ module attributes {transform.with_named_sequence} {
     transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
 
     // CHECK{LITERAL}: [[[ IR printer: Local scope ]]]
-    // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
+    // CHECK-NEXT:      "test.baz"() ({
+    // CHECK-NEXT:        "test.bar"() : () -> ()
     transform.print %baz {name = "Local scope", use_local_scope} : !transform.any_op
 
     // CHECK{LITERAL}: [[[ IR printer: No region ]]]

>From 4b5257ec8181105f7085f45d0bb87733f566a019 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 27 Mar 2024 15:44:15 -0400
Subject: [PATCH 3/4] Address comments

---
 .../Transform/test-interpreter-printing.mlir     | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
index 24fc2281c5abfa..07c14ce8f2db34 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect | FileCheck %s
 
 func.func @nested_ops() {
   "test.foo"() ({
@@ -19,6 +19,7 @@ module attributes {transform.with_named_sequence} {
     // CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
     // CHECK-NEXT:  module {
     transform.print {name = "START"}
+
     // CHECK{LITERAL}: [[[ IR printer: ]]]
     // CHECK-NEXT:      "test.baz"() ({
     // CHECK-NEXT:        "test.bar"() : () -> ()
@@ -29,19 +30,22 @@ module attributes {transform.with_named_sequence} {
     // CHECK-NEXT:      "test.baz"() ({
     transform.print %baz {name = "Baz"} : !transform.any_op
 
+    // CHECK{LITERAL}: [[[ IR printer: No region ]]]
+    // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
+    transform.print %baz {name = "No region", skip_regions} : !transform.any_op
+
+    // This is difficult to test properly, only check that this prints the op
+    // and does not crash.
     // CHECK{LITERAL}: [[[ IR printer: No verify ]]]
     // CHECK-NEXT:      "test.baz"() ({
     transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
 
+    // This is difficult to test properly, only check that this prints the op
+    // and does not crash.
     // CHECK{LITERAL}: [[[ IR printer: Local scope ]]]
     // CHECK-NEXT:      "test.baz"() ({
-    // CHECK-NEXT:        "test.bar"() : () -> ()
     transform.print %baz {name = "Local scope", use_local_scope} : !transform.any_op
 
-    // CHECK{LITERAL}: [[[ IR printer: No region ]]]
-    // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
-    transform.print %baz {name = "No region", skip_regions} : !transform.any_op
-
     // CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]]
     transform.print {name = "END"}
     transform.yield

>From 83ca76b1acc94e014bdc48c0d232e83ba2229a68 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 30 Mar 2024 23:13:32 -0400
Subject: [PATCH 4/4] Add local scope and verification tests

---
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 20 +++++++------
 .../Transform/test-interpreter-printing.mlir  | 29 ++++++++++++-------
 2 files changed, 29 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index cfbdd13a517779..043a4ebf894d5f 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2628,21 +2628,23 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
   if (getName().has_value())
     llvm::outs() << *getName() << " ";
 
+  OpPrintingFlags printFlags;
+  if (getAssumeVerified().value_or(false))
+    printFlags.assumeVerified();
+  if (getUseLocalScope().value_or(false))
+    printFlags.useLocalScope();
+  if (getSkipRegions().value_or(false))
+    printFlags.skipRegions();
+
   if (!getTarget()) {
-    llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n";
+    llvm::outs() << "top-level ]]]\n";
+    state.getTopLevel()->print(llvm::outs(), printFlags);
+    llvm::outs() << "\n";
     return DiagnosedSilenceableFailure::success();
   }
 
   llvm::outs() << "]]]\n";
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    OpPrintingFlags printFlags;
-    if (getAssumeVerified().value_or(false))
-      printFlags.assumeVerified();
-    if (getUseLocalScope().value_or(false))
-      printFlags.useLocalScope();
-    if (getSkipRegions().value_or(false))
-      printFlags.skipRegions();
-
     target->print(llvm::outs(), printFlags);
     llvm::outs() << "\n";
   }
diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
index 07c14ce8f2db34..d6a4bd36214084 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -1,9 +1,13 @@
-// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter --allow-unregistered-dialect --verify-diagnostics | FileCheck %s
+
+// RUN: mlir-opt %s --transform-interpreter --allow-unregistered-dialect --verify-diagnostics \
+// RUN:   --mlir-print-debuginfo | FileCheck %s --check-prefix=CHECK-LOC
 
 func.func @nested_ops() {
   "test.foo"() ({
     "test.foo"() ({
       "test.qux"() ({
+        // expected-error @below{{fail_to_verify is set}}
         "test.baz"() ({
           "test.bar"() : () -> ()
         }) : () -> ()
@@ -14,12 +18,21 @@ func.func @nested_ops() {
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
-    %baz = transform.structured.match ops{["test.baz"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-
     // CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
     // CHECK-NEXT:  module {
+    // CHECK-LOC-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
+    // CHECK-LOC-NEXT:  #{{.+}} = loc(
+    // CHECK-LOC-NEXT:  module {
     transform.print {name = "START"}
 
+    // CHECK{LITERAL}: [[[ IR printer: Local scope top-level ]]]
+    // CHECK-NEXT:      module {
+    // CHECK-LOC{LITERAL}: [[[ IR printer: Local scope top-level ]]]
+    // CHECK-LOC-NEXT:      module {
+    transform.print {name = "Local scope", use_local_scope}
+
+    %baz = transform.structured.match ops{["test.baz"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
     // CHECK{LITERAL}: [[[ IR printer: ]]]
     // CHECK-NEXT:      "test.baz"() ({
     // CHECK-NEXT:        "test.bar"() : () -> ()
@@ -34,18 +47,12 @@ module attributes {transform.with_named_sequence} {
     // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
     transform.print %baz {name = "No region", skip_regions} : !transform.any_op
 
-    // This is difficult to test properly, only check that this prints the op
-    // and does not crash.
     // CHECK{LITERAL}: [[[ IR printer: No verify ]]]
     // CHECK-NEXT:      "test.baz"() ({
+    // CHECK-NEXT:        transform.test_dummy_payload_op  {fail_to_verify} : () -> ()
+    transform.test_produce_invalid_ir %baz : !transform.any_op
     transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
 
-    // This is difficult to test properly, only check that this prints the op
-    // and does not crash.
-    // CHECK{LITERAL}: [[[ IR printer: Local scope ]]]
-    // CHECK-NEXT:      "test.baz"() ({
-    transform.print %baz {name = "Local scope", use_local_scope} : !transform.any_op
-
     // CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]]
     transform.print {name = "END"}
     transform.yield



More information about the Mlir-commits mailing list