[Mlir-commits] [mlir] [mlir] Add OpAsmTypeInterface for pretty-print (PR #121187)

Hongren Zheng llvmlistbot at llvm.org
Fri Dec 27 00:11:01 PST 2024


https://github.com/ZenithalHourlyRate created https://github.com/llvm/llvm-project/pull/121187

See https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792 for detailed introduction.

This PR acts as the first part of it
* Add `OpAsmTypeInterface` and `getAsmName` API for deducing ASM name from type
* Add default impl in `OpAsmOpInterface` to respect this API when available.

The `OpAsmAttrInterface` / hooking into Alias system part should be another PR, using a `getAlias` API.

### Discussion

* Instead of using `StringRef getAsmName()` as the API, I use `void getAsmName(OpAsmSetNameFn)`, as returning StringRef might be unsafe (std::string constructed inside then returned a _ref_; and this aligns with the design of `getAsmResultNames`.
* On the result packing of an op, the current approach is that when not all of the result types are `OpAsmTypeInterface`, then do nothing (old default impl)

### Review 

Cc @j2kun and @Alexanderviand-intel for downstream; Cc @River707 and @joker-eph for relevent commit history; Cc @ftynse for discourse.

>From 0fa78e4194b1c4827fa18b6e8ada33972b97af8d Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Sun, 22 Dec 2024 11:20:22 +0000
Subject: [PATCH] [mlir] Add OpAsmTypeInterface for pretty-print

---
 mlir/include/mlir/IR/CMakeLists.txt        |  9 +++-
 mlir/include/mlir/IR/OpAsmInterface.td     | 55 +++++++++++++++++++++-
 mlir/include/mlir/IR/OpImplementation.h    | 12 +++--
 mlir/lib/IR/AsmPrinter.cpp                 |  3 +-
 mlir/test/IR/op-asm-interface.mlir         | 43 +++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td      | 24 ++++++++++
 mlir/test/lib/Dialect/Test/TestTypeDefs.td |  4 ++
 mlir/test/lib/Dialect/Test/TestTypes.cpp   |  5 ++
 8 files changed, 148 insertions(+), 7 deletions(-)
 create mode 100644 mlir/test/IR/op-asm-interface.mlir

diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index b741eb18d47916..0c7937dfd69e55 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,7 +1,14 @@
-add_mlir_interface(OpAsmInterface)
 add_mlir_interface(SymbolInterfaces)
 add_mlir_interface(RegionKindInterface)
 
+set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
+mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs)
+mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls)
+mlir_tablegen(OpAsmTypeInterface.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIROpAsmInterfaceIncGen)
+add_dependencies(mlir-generic-headers MLIROpAsmInterfaceIncGen)
+
 set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
 mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
 mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
index 98b5095ff2d665..d2dfb60b2ac142 100644
--- a/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -50,10 +50,27 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
         ```mlir
           %first_result, %middle_results:2, %0 = "my.op" ...
         ```
+
+        The default implementation uses `OpAsmTypeInterface` to get the name for
+        each result from its type.
+
+        If not all of the result types have `OpAsmTypeInterface`, the default implementation
+        does nothing, as the packing behavior should be decided by the operation itself.
       }],
       "void", "getAsmResultNames",
       (ins "::mlir::OpAsmSetValueNameFn":$setNameFn),
-      "", "return;"
+      "", [{
+        bool hasOpAsmTypeInterface = llvm::all_of($_op->getResults(), [&](Value result) {
+          return ::mlir::isa<::mlir::OpAsmTypeInterface>(result.getType());
+        });
+        if (!hasOpAsmTypeInterface)
+          return;
+        for (auto result : $_op->getResults()) {
+          auto setResultNameFn = [&](StringRef name) { setNameFn(result, name); };
+          auto opAsmTypeInterface = ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
+          opAsmTypeInterface.getAsmName(setResultNameFn);
+        }
+      }]
     >,
     InterfaceMethod<[{
         Get a special name to use when printing the block arguments for a region
@@ -64,7 +81,16 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
         "::mlir::Region&":$region,
         "::mlir::OpAsmSetValueNameFn":$setNameFn
       ),
-      "", "return;"
+      "", [{
+        for (auto &block : region) {
+          for (auto arg : block.getArguments()) {
+            if (auto opAsmTypeInterface = ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
+              auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
+              opAsmTypeInterface.getAsmName(setArgNameFn);
+            }
+          }
+        }
+      }]
     >,
     InterfaceMethod<[{
         Get the name to use for a given block inside a region attached to this
@@ -109,6 +135,31 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// OpAsmTypeInterface
+//===----------------------------------------------------------------------===//
+
+def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
+  let description = [{
+    This interface provides hooks to interact with the AsmPrinter and AsmParser
+    classes.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+        Get a special name to use when printing value of this type.
+
+        For example, the default implementation of OpAsmOpInterface
+        will respect this method when printing the results of an operation
+        and/or block argument of it.
+      }],
+      "void", "getAsmName",
+      (ins "::mlir::OpAsmSetNameFn":$setNameFn), "", ";"
+    >,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // ResourceHandleParameter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 6c1ff4d0e5e6b9..83e2758a2e782f 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -734,7 +734,7 @@ class AsmParser {
   virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
   virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;
 
- private:
+private:
   template <typename IntT, typename ParseFn>
   OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
                                                    ParseFn &&parseFn) {
@@ -756,7 +756,7 @@ class AsmParser {
     return success();
   }
 
- public:
+public:
   template <typename IntT>
   OptionalParseResult parseOptionalInteger(IntT &result) {
     return parseOptionalIntegerAndCheck(
@@ -1727,6 +1727,10 @@ class OpAsmParser : public AsmParser {
 // Dialect OpAsm interface.
 //===--------------------------------------------------------------------===//
 
+/// A functor used to set the name of the result. See 'getAsmResultNames' below
+/// for more details.
+using OpAsmSetNameFn = function_ref<void(StringRef)>;
+
 /// A functor used to set the name of the start of a result group of an
 /// operation. See 'getAsmResultNames' below for more details.
 using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
@@ -1820,7 +1824,9 @@ ParseResult parseDimensionList(OpAsmParser &parser,
 //===--------------------------------------------------------------------===//
 
 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
-#include "mlir/IR/OpAsmInterface.h.inc"
+#include "mlir/IR/OpAsmTypeInterface.h.inc"
+// put Attr/Type before Op
+#include "mlir/IR/OpAsmOpInterface.h.inc"
 
 namespace llvm {
 template <>
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6fe96504ae100c..9eddcb1a0872ba 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -125,7 +125,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
 //===----------------------------------------------------------------------===//
 
 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
-#include "mlir/IR/OpAsmInterface.cpp.inc"
+#include "mlir/IR/OpAsmOpInterface.cpp.inc"
+#include "mlir/IR/OpAsmTypeInterface.cpp.inc"
 
 LogicalResult
 OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
new file mode 100644
index 00000000000000..9753fd419cb609
--- /dev/null
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmOpInterface
+//===----------------------------------------------------------------------===//
+
+func.func @result_name_from_op_asm_type_interface() {
+  // CHECK-LABEL: @result_name_from_op_asm_type_interface
+  // CHECK: %op_asm_type_interface
+  %0 = "test.default_result_name"() : () -> !test.op_asm_type_interface
+  return
+}
+
+// -----
+
+func.func @result_name_pack_from_op_asm_type_interface() {
+  // CHECK-LABEL: @result_name_pack_from_op_asm_type_interface
+  // CHECK: %op_asm_type_interface{{.*}}, %op_asm_type_interface{{.*}}
+  // CHECK-NOT: :2
+  %0:2 = "test.default_result_name_packing"() : () -> (!test.op_asm_type_interface, !test.op_asm_type_interface)
+  return
+}
+
+// -----
+
+func.func @result_name_pack_do_nothing() {
+  // CHECK-LABEL: @result_name_pack_do_nothing
+  // CHECK: %0:2
+  %0:2 = "test.default_result_name_packing"() : () -> (i32, !test.op_asm_type_interface)
+  return
+}
+
+// -----
+
+func.func @block_argument_name_from_op_asm_type_interface() {
+  // CHECK-LABEL: @block_argument_name_from_op_asm_type_interface
+  // CHECK: ^bb0(%op_asm_type_interface
+  test.default_block_argument_name {
+    ^bb0(%arg0: !test.op_asm_type_interface):
+      "test.terminator"() : ()->()
+  }
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bafab155eb9d57..72047c0e9aeb9f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -924,6 +924,30 @@ def CustomResultsNameOp
   let results = (outs Variadic<AnyInteger>:$r);
 }
 
+// This is used to test default implementation of OpAsmOpInterface::getAsmResultNames,
+// which uses OpAsmTypeInterface if available.
+def DefaultResultsNameOp
+ : TEST_Op<"default_result_name",
+           [OpAsmOpInterface]> {
+  let results = (outs AnyType:$r);
+}
+
+// This is used to test default implementation of OpAsmOpInterface::getAsmResultNames,
+// when there are multiple results, and not all of their type has OpAsmTypeInterface,
+// it should not set result name from OpAsmTypeInterface.
+def DefaultResultsNamePackingOp
+ : TEST_Op<"default_result_name_packing",
+           [OpAsmOpInterface]> {
+  let results = (outs AnyType:$r, AnyType:$s);
+}
+
+// This is used to test default implementation of OpAsmOpInterface::getAsmBlockArgumentNames,
+def DefaultBlockArgumentNameOp : TEST_Op<"default_block_argument_name",
+                                    [OpAsmOpInterface]> {
+  let regions = (region AnyRegion:$body);
+  let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
 // This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
 // operations nested in a region under this op will drop the "test." dialect
 // prefix.
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 60108ac86d1edd..756552a7ebd63e 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -398,4 +398,8 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
   let assemblyFormat = "`<` $param `>`";
 }
 
+def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface", [DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName"]>]> {
+  let mnemonic = "op_asm_type_interface";
+}
+
 #endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 6e31bb71d04d80..8ab2a16338b116 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -531,3 +531,8 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
   }
   printer << ">";
 }
+
+void TestTypeOpAsmTypeInterfaceType::getAsmName(
+    OpAsmSetNameFn setNameFn) const {
+  setNameFn("op_asm_type_interface");
+}



More information about the Mlir-commits mailing list