[Mlir-commits] [mlir] [mlir] Introduce OpAsmAttrInterface for pretty-print (PR #124721)

Hongren Zheng llvmlistbot at llvm.org
Tue Jan 28 00:34:42 PST 2025


https://github.com/ZenithalHourlyRate created https://github.com/llvm/llvm-project/pull/124721

See https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792 for detailed introduction.

This PR adds

* Definition of `OpAsmAttrInterface`
* Integration of `OpAsmAttrInterface` with `AsmPrinter`

In https://github.com/llvm/llvm-project/pull/121187#discussion_r1931472250 I mentioned splitting them into two PRs, but I realized that a PR with only definition of `OpAsmAttrInterface` is hard to test as it requires a custom Dialect with `OpAsmDialectInterface` to hook with `AsmPrinter`, so I just put them together to have a e2e test.

Cc @River707 @jpienaar @ftynse for review.

>From fb0eeaa63e7570a1cb5790e0621254b8b8017a41 Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Tue, 28 Jan 2025 08:24:57 +0000
Subject: [PATCH] [mlir] Introduce OpAsmAttrInterface for pretty-print

---
 mlir/include/mlir/IR/CMakeLists.txt           |  2 ++
 mlir/include/mlir/IR/OpAsmInterface.td        | 22 ++++++++++++
 mlir/include/mlir/IR/OpImplementation.h       |  1 +
 mlir/lib/IR/AsmPrinter.cpp                    | 35 ++++++++++++++-----
 mlir/test/IR/op-asm-interface.mlir            | 15 ++++++++
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    | 10 ++++++
 mlir/test/lib/Dialect/Test/TestAttributes.cpp | 13 ++++++-
 7 files changed, 88 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 0c7937dfd69e55..846547ff131e3a 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -2,6 +2,8 @@ add_mlir_interface(SymbolInterfaces)
 add_mlir_interface(RegionKindInterface)
 
 set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
+mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
+mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs)
 mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls)
 mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs)
 mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls)
diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
index 34c830a12856fa..c3e84bccc5dee5 100644
--- a/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -130,6 +130,28 @@ def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// OpAsmAttrInterface
+//===----------------------------------------------------------------------===//
+
+def OpAsmAttrInterface : AttrInterface<"OpAsmAttrInterface"> {
+  let description = [{
+    This interface provides hooks to interact with the AsmPrinter and AsmParser
+    classes.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+        Get a name to use when generating an alias for this attribute.
+      }],
+      "::mlir::OpAsmDialectInterface::AliasResult", "getAlias",
+      (ins "::llvm::raw_ostream&":$os), "",
+      "return ::mlir::OpAsmDialectInterface::AliasResult::NoAlias;"
+    >,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // ResourceHandleParameter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index d9c925a9c56e6c..24e74e636fe51b 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1824,6 +1824,7 @@ ParseResult parseDimensionList(OpAsmParser &parser,
 //===--------------------------------------------------------------------===//
 
 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
+#include "mlir/IR/OpAsmAttrInterface.h.inc"
 #include "mlir/IR/OpAsmOpInterface.h.inc"
 #include "mlir/IR/OpAsmTypeInterface.h.inc"
 
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index fa4a1b4b72b024..cfac0a0d2f6474 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -125,6 +125,7 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
 //===----------------------------------------------------------------------===//
 
 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
+#include "mlir/IR/OpAsmAttrInterface.cpp.inc"
 #include "mlir/IR/OpAsmOpInterface.cpp.inc"
 #include "mlir/IR/OpAsmTypeInterface.cpp.inc"
 
@@ -1159,15 +1160,31 @@ template <typename T>
 void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
                                      bool canBeDeferred) {
   SmallString<32> nameBuffer;
-  for (const auto &interface : interfaces) {
-    OpAsmDialectInterface::AliasResult result =
-        interface.getAlias(symbol, aliasOS);
-    if (result == OpAsmDialectInterface::AliasResult::NoAlias)
-      continue;
-    nameBuffer = std::move(aliasBuffer);
-    assert(!nameBuffer.empty() && "expected valid alias name");
-    if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
-      break;
+
+  OpAsmDialectInterface::AliasResult symbolInterfaceResult =
+      OpAsmDialectInterface::AliasResult::NoAlias;
+  if constexpr (std::is_base_of_v<Attribute, T>) {
+    if (auto symbolInterface = mlir::dyn_cast<OpAsmAttrInterface>(symbol)) {
+      symbolInterfaceResult = symbolInterface.getAlias(aliasOS);
+      if (symbolInterfaceResult !=
+          OpAsmDialectInterface::AliasResult::NoAlias) {
+        nameBuffer = std::move(aliasBuffer);
+        assert(!nameBuffer.empty() && "expected valid alias name");
+      }
+    }
+  }
+
+  if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::FinalAlias) {
+    for (const auto &interface : interfaces) {
+      OpAsmDialectInterface::AliasResult result =
+          interface.getAlias(symbol, aliasOS);
+      if (result == OpAsmDialectInterface::AliasResult::NoAlias)
+        continue;
+      nameBuffer = std::move(aliasBuffer);
+      assert(!nameBuffer.empty() && "expected valid alias name");
+      if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
+        break;
+    }
   }
 
   if (nameBuffer.empty())
diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
index a9c199e3dc9736..5aa9bc7c5acaee 100644
--- a/mlir/test/IR/op-asm-interface.mlir
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -22,3 +22,18 @@ func.func @block_argument_name_from_op_asm_type_interface() {
   }
   return
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmAttrInterface
+//===----------------------------------------------------------------------===//
+
+// CHECK: #op_asm_attr_interface_test
+#attr = #test.op_asm_attr_interface<value = "test">
+
+func.func @test_op_asm_attr_interface() {
+  // CHECK-LABEL: @test_op_asm_attr_interface
+  %1 = "test.result_name_from_type"() {attr = #attr} : () -> !test.op_asm_type_interface
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 0fd272f85d39bd..4b809c1c0a7656 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -395,4 +395,14 @@ def TestCustomLocationAttr : Test_LocAttr<"TestCustomLocation"> {
   let assemblyFormat = "`<` $file `*` $line `>`";
 }
 
+// Test OpAsmAttrInterface.
+def TestOpAsmAttrInterfaceAttr : Test_Attr<"TestOpAsmAttrInterface",
+    [DeclareAttrInterfaceMethods<OpAsmAttrInterface, ["getAlias"]>]> {
+  let mnemonic = "op_asm_attr_interface";
+  let parameters = (ins "mlir::StringAttr":$value);
+  let assemblyFormat = [{
+    `<` struct(params) `>`
+  }];
+}
+
 #endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index e09ea109061648..7c467308386f1f 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -67,7 +67,7 @@ void CompoundAAttr::print(AsmPrinter &printer) const {
 //===----------------------------------------------------------------------===//
 
 Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) {
-  if (parser.parseLess()){
+  if (parser.parseLess()) {
     return Attribute();
   }
   SmallVector<int64_t> shape;
@@ -316,6 +316,17 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestOpAsmAttrInterfaceAttr
+//===----------------------------------------------------------------------===//
+
+::mlir::OpAsmDialectInterface::AliasResult
+TestOpAsmAttrInterfaceAttr::getAlias(::llvm::raw_ostream &os) const {
+  os << "op_asm_attr_interface_";
+  os << getValue().getValue();
+  return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias;
+}
+
 //===----------------------------------------------------------------------===//
 // Tablegen Generated Definitions
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list