[Mlir-commits] [mlir] [mlir] Add OpAsmTypeInterface for pretty-print (PR #121187)

Hongren Zheng llvmlistbot at llvm.org
Mon Jan 13 19:59:07 PST 2025


https://github.com/ZenithalHourlyRate updated https://github.com/llvm/llvm-project/pull/121187

>From b5ae5e42a05290a05d659e5ee53710dd965ba88e Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Sun, 22 Dec 2024 11:20:22 +0000
Subject: [PATCH] [mlir] Add OpAsmTypeInterface for pretty-print

---
 mlir/include/mlir/IR/CMakeLists.txt        |  9 +++++-
 mlir/include/mlir/IR/OpAsmInterface.td     | 23 +++++++++++++++-
 mlir/include/mlir/IR/OpImplementation.h    | 11 ++++++--
 mlir/lib/IR/AsmPrinter.cpp                 |  3 +-
 mlir/test/IR/op-asm-interface.mlir         | 24 ++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp  | 32 ++++++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td      | 15 ++++++++++
 mlir/test/lib/Dialect/Test/TestTypeDefs.td |  5 ++++
 mlir/test/lib/Dialect/Test/TestTypes.cpp   |  5 ++++
 9 files changed, 121 insertions(+), 6 deletions(-)
 create mode 100644 mlir/test/IR/op-asm-interface.mlir

diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index b741eb18d47916..0c7937dfd69e55 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,7 +1,14 @@
-add_mlir_interface(OpAsmInterface)
 add_mlir_interface(SymbolInterfaces)
 add_mlir_interface(RegionKindInterface)
 
+set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
+mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs)
+mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls)
+mlir_tablegen(OpAsmTypeInterface.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIROpAsmInterfaceIncGen)
+add_dependencies(mlir-generic-headers MLIROpAsmInterfaceIncGen)
+
 set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
 mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
 mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
index 98b5095ff2d665..f7384feeef0bd2 100644
--- a/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -23,7 +23,7 @@ include "mlir/IR/OpBase.td"
 
 def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
   let description = [{
-    This interface provides hooks to interact with the AsmPrinter and AsmParser
+    This op interface provides hooks to interact with the AsmPrinter and AsmParser
     classes.
   }];
   let cppNamespace = "::mlir";
@@ -109,6 +109,27 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// OpAsmTypeInterface
+//===----------------------------------------------------------------------===//
+
+def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
+  let description = [{
+    This type interface provides hooks to interact with the AsmPrinter and AsmParser
+    classes.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+        Get a name to use when printing value of this type.
+      }],
+      "void", "getAsmName",
+      (ins "::mlir::OpAsmSetNameFn":$setNameFn), "", ";"
+    >,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // ResourceHandleParameter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 6c1ff4d0e5e6b9..d9c925a9c56e6c 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -734,7 +734,7 @@ class AsmParser {
   virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
   virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;
 
- private:
+private:
   template <typename IntT, typename ParseFn>
   OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
                                                    ParseFn &&parseFn) {
@@ -756,7 +756,7 @@ class AsmParser {
     return success();
   }
 
- public:
+public:
   template <typename IntT>
   OptionalParseResult parseOptionalInteger(IntT &result) {
     return parseOptionalIntegerAndCheck(
@@ -1727,6 +1727,10 @@ class OpAsmParser : public AsmParser {
 // Dialect OpAsm interface.
 //===--------------------------------------------------------------------===//
 
+/// A functor used to set the name of the result. See 'getAsmResultNames' below
+/// for more details.
+using OpAsmSetNameFn = function_ref<void(StringRef)>;
+
 /// A functor used to set the name of the start of a result group of an
 /// operation. See 'getAsmResultNames' below for more details.
 using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
@@ -1820,7 +1824,8 @@ ParseResult parseDimensionList(OpAsmParser &parser,
 //===--------------------------------------------------------------------===//
 
 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
-#include "mlir/IR/OpAsmInterface.h.inc"
+#include "mlir/IR/OpAsmOpInterface.h.inc"
+#include "mlir/IR/OpAsmTypeInterface.h.inc"
 
 namespace llvm {
 template <>
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index c603db450cbdd0..fa4a1b4b72b024 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -125,7 +125,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
 //===----------------------------------------------------------------------===//
 
 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
-#include "mlir/IR/OpAsmInterface.cpp.inc"
+#include "mlir/IR/OpAsmOpInterface.cpp.inc"
+#include "mlir/IR/OpAsmTypeInterface.cpp.inc"
 
 LogicalResult
 OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
new file mode 100644
index 00000000000000..a9c199e3dc9736
--- /dev/null
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmOpInterface
+//===----------------------------------------------------------------------===//
+
+func.func @result_name_from_op_asm_type_interface() {
+  // CHECK-LABEL: @result_name_from_op_asm_type_interface
+  // CHECK: %op_asm_type_interface
+  %0 = "test.result_name_from_type"() : () -> !test.op_asm_type_interface
+  return
+}
+
+// -----
+
+func.func @block_argument_name_from_op_asm_type_interface() {
+  // CHECK-LABEL: @block_argument_name_from_op_asm_type_interface
+  // CHECK: ^bb0(%op_asm_type_interface
+  test.block_argument_name_from_type {
+    ^bb0(%arg0: !test.op_asm_type_interface):
+      "test.terminator"() : ()->()
+  }
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index c6be26d0a44d9a..f6b8a0005f2854 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -506,6 +506,38 @@ void CustomResultsNameOp::getAsmResultNames(
         setNameFn(getResult(i), str.getValue());
 }
 
+//===----------------------------------------------------------------------===//
+// ResultNameFromTypeOp
+//===----------------------------------------------------------------------===//
+
+void ResultNameFromTypeOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  auto result = getResult();
+  auto setResultNameFn = [&](::llvm::StringRef name) {
+    setNameFn(result, name);
+  };
+  auto opAsmTypeInterface =
+      ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
+  opAsmTypeInterface.getAsmName(setResultNameFn);
+}
+
+//===----------------------------------------------------------------------===//
+// BlockArgumentNameFromTypeOp
+//===----------------------------------------------------------------------===//
+
+void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames(
+    ::mlir::Region &region, ::mlir::OpAsmSetValueNameFn setNameFn) {
+  for (auto &block : region) {
+    for (auto arg : block.getArguments()) {
+      if (auto opAsmTypeInterface =
+              ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
+        auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
+        opAsmTypeInterface.getAsmName(setArgNameFn);
+      }
+    }
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // ResultTypeWithTraitOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0b1f22b3ee9323..f37573c1351cec 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -924,6 +924,21 @@ def CustomResultsNameOp
   let results = (outs Variadic<AnyInteger>:$r);
 }
 
+// This is used to test OpAsmTypeInterface::getAsmName for op result name,
+def ResultNameFromTypeOp
+ : TEST_Op<"result_name_from_type",
+           [DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+  let results = (outs AnyType:$r);
+}
+
+// This is used to test OpAsmTypeInterface::getAsmName for block argument,
+def BlockArgumentNameFromTypeOp
+  : TEST_Op<"block_argument_name_from_type",
+      [DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
+  let regions = (region AnyRegion:$body);
+  let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
 // This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
 // operations nested in a region under this op will drop the "test." dialect
 // prefix.
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 60108ac86d1edd..6335701786ecc6 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -398,4 +398,9 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
   let assemblyFormat = "`<` $param `>`";
 }
 
+def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
+    [DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName"]>]> {
+  let mnemonic = "op_asm_type_interface";
+}
+
 #endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index b822e019e09d24..1ae7ac472d989e 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -532,3 +532,8 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
   }
   printer << ">";
 }
+
+void TestTypeOpAsmTypeInterfaceType::getAsmName(
+    OpAsmSetNameFn setNameFn) const {
+  setNameFn("op_asm_type_interface");
+}



More information about the Mlir-commits mailing list