[Mlir-commits] [mlir] a8cfa7c - [mlir][TD] Allow op printing flags as `transform.print` attrs (#86846)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 1 09:32:27 PDT 2024


Author: Jakub Kuderski
Date: 2024-04-01T12:32:23-04:00
New Revision: a8cfa7cbdf6cc1a94ed25c90897d2e031f77a5a9

URL: https://github.com/llvm/llvm-project/commit/a8cfa7cbdf6cc1a94ed25c90897d2e031f77a5a9
DIFF: https://github.com/llvm/llvm-project/commit/a8cfa7cbdf6cc1a94ed25c90897d2e031f77a5a9.diff

LOG: [mlir][TD] Allow op printing flags as `transform.print` attrs (#86846)

Introduce 3 new optional attributes to the `transform.print` ops:
* `assume_verified`
* `use_local_scope`
* `skip_regions`

The primary motivation is to allow printing on large inputs that
otherwise take forever to print and verify. For the full context, see
this IREE issue: https://github.com/openxla/iree/issues/16901.

Also add some tests and fix the op description.

Added: 
    mlir/test/Dialect/Transform/test-interpreter-printing.mlir

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index bf1a8016cd9df6..21c9595860d4c5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -1098,15 +1098,28 @@ def PrintOp : TransformDialectOp<"print",
      MatchOpInterface]> {
   let summary = "Dump each payload op";
   let description = [{
-    This op dumps each payload op that is associated with the `target` operand
-    to stderr. It also prints the `name` string attribute. If no target is
+    Prints each payload op that is associated with the `target` operand to
+    `stdout`. It also prints the `name` string attribute. If no target is
     specified, the top-level op is dumped.
 
     This op is useful for printf-style debugging.
+
+    Supported printing flag attributes:
+    * `assume_verified` -- skips verification when the unit attribute is
+      specified. This improves performace but may lead to crashes and
+      unexpected behavior when the printed payload op is invalid.
+    * `use_local_scope` -- prints in local scope when the unit attribute is
+      specified. This improves performance but may not be identical to
+      printing within the full module.
+    * `skip_regions` -- does not print regions of operations when the unit
+      attribute is specified.
   }];
 
   let arguments = (ins Optional<TransformHandleTypeInterface>:$target,
-                       OptionalAttr<StrAttr>:$name);
+                       OptionalAttr<StrAttr>:$name,
+                       OptionalAttr<UnitAttr>:$assume_verified,
+                       OptionalAttr<UnitAttr>:$use_local_scope,
+                       OptionalAttr<UnitAttr>:$skip_regions);
   let results = (outs);
 
   let builders = [

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index c8d06ba157b904..dc19022219e5b2 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/CallInterfaces.h"
@@ -2627,14 +2628,26 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
   if (getName().has_value())
     llvm::outs() << *getName() << " ";
 
+  OpPrintingFlags printFlags;
+  if (getAssumeVerified().value_or(false))
+    printFlags.assumeVerified();
+  if (getUseLocalScope().value_or(false))
+    printFlags.useLocalScope();
+  if (getSkipRegions().value_or(false))
+    printFlags.skipRegions();
+
   if (!getTarget()) {
-    llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n";
+    llvm::outs() << "top-level ]]]\n";
+    state.getTopLevel()->print(llvm::outs(), printFlags);
+    llvm::outs() << "\n";
     return DiagnosedSilenceableFailure::success();
   }
 
   llvm::outs() << "]]]\n";
-  for (Operation *target : state.getPayloadOps(getTarget()))
-    llvm::outs() << *target << "\n";
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    target->print(llvm::outs(), printFlags);
+    llvm::outs() << "\n";
+  }
 
   return DiagnosedSilenceableFailure::success();
 }

diff  --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index a718d6a9e9fd90..ecef7e181e9039 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -86,16 +86,18 @@ transform.sequence failures(propagate) {
 }
 
 // CHECK: transform.sequence
-// CHECK: print
-// CHECK: print
-// CHECK: print
-// CHECK: print
+// CHECK-COUNT-9: print
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
   transform.print %arg0 : !transform.any_op
   transform.print
   transform.print %arg0 {name = "test"} : !transform.any_op
   transform.print {name = "test"}
+  transform.print {name = "test", assume_verified}
+  transform.print %arg0 {assume_verified} : !transform.any_op
+  transform.print %arg0 {use_local_scope} : !transform.any_op
+  transform.print %arg0 {skip_regions} : !transform.any_op
+  transform.print %arg0 {assume_verified, use_local_scope, skip_regions} : !transform.any_op
 }
 
 // CHECK: transform.sequence

diff  --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
new file mode 100644
index 00000000000000..a54c83d2b249eb
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt %s --transform-interpreter --allow-unregistered-dialect --verify-diagnostics | FileCheck %s
+
+// RUN: mlir-opt %s --transform-interpreter --allow-unregistered-dialect --verify-diagnostics \
+// RUN:   --mlir-print-debuginfo | FileCheck %s --check-prefix=CHECK-LOC
+
+func.func @nested_ops() {
+  "test.qux"() ({
+    // expected-error @below{{fail_to_verify is set}}
+    "test.baz"() ({
+      "test.bar"() : () -> ()
+    }) : () -> ()
+  }) : () -> ()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
+    // CHECK-NEXT:  module {
+    // CHECK-LOC-LABEL{LITERAL}: [[[ IR printer: START top-level ]]]
+    // CHECK-LOC-NEXT:  #{{.+}} = loc(
+    // CHECK-LOC-NEXT:  module {
+    transform.print {name = "START"}
+
+    // CHECK{LITERAL}: [[[ IR printer: Local scope top-level ]]]
+    // CHECK-NEXT:      module {
+    // CHECK-LOC{LITERAL}: [[[ IR printer: Local scope top-level ]]]
+    // CHECK-LOC-NEXT:      module {
+    transform.print {name = "Local scope", use_local_scope}
+
+    %baz = transform.structured.match ops{["test.baz"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: ]]]
+    // CHECK-NEXT:      "test.baz"() ({
+    // CHECK-NEXT:        "test.bar"() : () -> ()
+    // CHECK-NEXT:       }) : () -> ()
+    transform.print %baz : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: Baz ]]]
+    // CHECK-NEXT:      "test.baz"() ({
+    transform.print %baz {name = "Baz"} : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: No region ]]]
+    // CHECK-NEXT:      "test.baz"() ({...}) : () -> ()
+    transform.print %baz {name = "No region", skip_regions} : !transform.any_op
+
+    // CHECK{LITERAL}: [[[ IR printer: No verify ]]]
+    // CHECK-NEXT:      "test.baz"() ({
+    // CHECK-NEXT:        transform.test_dummy_payload_op  {fail_to_verify} : () -> ()
+    transform.test_produce_invalid_ir %baz : !transform.any_op
+    transform.print %baz {name = "No verify", assume_verified} : !transform.any_op
+
+    // CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]]
+    transform.print {name = "END"}
+    transform.yield
+  }
+}


        


More information about the Mlir-commits mailing list