[Mlir-commits] [mlir] [mlir] Add OpAsmTypeInterface for pretty-print (PR #121187)
Hongren Zheng
llvmlistbot at llvm.org
Mon Jan 13 19:59:07 PST 2025
https://github.com/ZenithalHourlyRate updated https://github.com/llvm/llvm-project/pull/121187
>From b5ae5e42a05290a05d659e5ee53710dd965ba88e Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Sun, 22 Dec 2024 11:20:22 +0000
Subject: [PATCH] [mlir] Add OpAsmTypeInterface for pretty-print
---
mlir/include/mlir/IR/CMakeLists.txt | 9 +++++-
mlir/include/mlir/IR/OpAsmInterface.td | 23 +++++++++++++++-
mlir/include/mlir/IR/OpImplementation.h | 11 ++++++--
mlir/lib/IR/AsmPrinter.cpp | 3 +-
mlir/test/IR/op-asm-interface.mlir | 24 ++++++++++++++++
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 32 ++++++++++++++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 15 ++++++++++
mlir/test/lib/Dialect/Test/TestTypeDefs.td | 5 ++++
mlir/test/lib/Dialect/Test/TestTypes.cpp | 5 ++++
9 files changed, 121 insertions(+), 6 deletions(-)
create mode 100644 mlir/test/IR/op-asm-interface.mlir
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index b741eb18d47916..0c7937dfd69e55 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,7 +1,14 @@
-add_mlir_interface(OpAsmInterface)
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)
+set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
+mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs)
+mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls)
+mlir_tablegen(OpAsmTypeInterface.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIROpAsmInterfaceIncGen)
+add_dependencies(mlir-generic-headers MLIROpAsmInterfaceIncGen)
+
set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
index 98b5095ff2d665..f7384feeef0bd2 100644
--- a/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -23,7 +23,7 @@ include "mlir/IR/OpBase.td"
def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
let description = [{
- This interface provides hooks to interact with the AsmPrinter and AsmParser
+ This op interface provides hooks to interact with the AsmPrinter and AsmParser
classes.
}];
let cppNamespace = "::mlir";
@@ -109,6 +109,27 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
];
}
+//===----------------------------------------------------------------------===//
+// OpAsmTypeInterface
+//===----------------------------------------------------------------------===//
+
+def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
+ let description = [{
+ This type interface provides hooks to interact with the AsmPrinter and AsmParser
+ classes.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Get a name to use when printing value of this type.
+ }],
+ "void", "getAsmName",
+ (ins "::mlir::OpAsmSetNameFn":$setNameFn), "", ";"
+ >,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// ResourceHandleParameter
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 6c1ff4d0e5e6b9..d9c925a9c56e6c 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -734,7 +734,7 @@ class AsmParser {
virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;
- private:
+private:
template <typename IntT, typename ParseFn>
OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
ParseFn &&parseFn) {
@@ -756,7 +756,7 @@ class AsmParser {
return success();
}
- public:
+public:
template <typename IntT>
OptionalParseResult parseOptionalInteger(IntT &result) {
return parseOptionalIntegerAndCheck(
@@ -1727,6 +1727,10 @@ class OpAsmParser : public AsmParser {
// Dialect OpAsm interface.
//===--------------------------------------------------------------------===//
+/// A functor used to set the name of the result. See 'getAsmResultNames' below
+/// for more details.
+using OpAsmSetNameFn = function_ref<void(StringRef)>;
+
/// A functor used to set the name of the start of a result group of an
/// operation. See 'getAsmResultNames' below for more details.
using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
@@ -1820,7 +1824,8 @@ ParseResult parseDimensionList(OpAsmParser &parser,
//===--------------------------------------------------------------------===//
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
-#include "mlir/IR/OpAsmInterface.h.inc"
+#include "mlir/IR/OpAsmOpInterface.h.inc"
+#include "mlir/IR/OpAsmTypeInterface.h.inc"
namespace llvm {
template <>
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index c603db450cbdd0..fa4a1b4b72b024 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -125,7 +125,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
//===----------------------------------------------------------------------===//
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
-#include "mlir/IR/OpAsmInterface.cpp.inc"
+#include "mlir/IR/OpAsmOpInterface.cpp.inc"
+#include "mlir/IR/OpAsmTypeInterface.cpp.inc"
LogicalResult
OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
new file mode 100644
index 00000000000000..a9c199e3dc9736
--- /dev/null
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmOpInterface
+//===----------------------------------------------------------------------===//
+
+func.func @result_name_from_op_asm_type_interface() {
+ // CHECK-LABEL: @result_name_from_op_asm_type_interface
+ // CHECK: %op_asm_type_interface
+ %0 = "test.result_name_from_type"() : () -> !test.op_asm_type_interface
+ return
+}
+
+// -----
+
+func.func @block_argument_name_from_op_asm_type_interface() {
+ // CHECK-LABEL: @block_argument_name_from_op_asm_type_interface
+ // CHECK: ^bb0(%op_asm_type_interface
+ test.block_argument_name_from_type {
+ ^bb0(%arg0: !test.op_asm_type_interface):
+ "test.terminator"() : ()->()
+ }
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index c6be26d0a44d9a..f6b8a0005f2854 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -506,6 +506,38 @@ void CustomResultsNameOp::getAsmResultNames(
setNameFn(getResult(i), str.getValue());
}
+//===----------------------------------------------------------------------===//
+// ResultNameFromTypeOp
+//===----------------------------------------------------------------------===//
+
+void ResultNameFromTypeOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ auto result = getResult();
+ auto setResultNameFn = [&](::llvm::StringRef name) {
+ setNameFn(result, name);
+ };
+ auto opAsmTypeInterface =
+ ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
+ opAsmTypeInterface.getAsmName(setResultNameFn);
+}
+
+//===----------------------------------------------------------------------===//
+// BlockArgumentNameFromTypeOp
+//===----------------------------------------------------------------------===//
+
+void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames(
+ ::mlir::Region ®ion, ::mlir::OpAsmSetValueNameFn setNameFn) {
+ for (auto &block : region) {
+ for (auto arg : block.getArguments()) {
+ if (auto opAsmTypeInterface =
+ ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
+ auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
+ opAsmTypeInterface.getAsmName(setArgNameFn);
+ }
+ }
+ }
+}
+
//===----------------------------------------------------------------------===//
// ResultTypeWithTraitOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0b1f22b3ee9323..f37573c1351cec 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -924,6 +924,21 @@ def CustomResultsNameOp
let results = (outs Variadic<AnyInteger>:$r);
}
+// This is used to test OpAsmTypeInterface::getAsmName for op result name,
+def ResultNameFromTypeOp
+ : TEST_Op<"result_name_from_type",
+ [DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+ let results = (outs AnyType:$r);
+}
+
+// This is used to test OpAsmTypeInterface::getAsmName for block argument,
+def BlockArgumentNameFromTypeOp
+ : TEST_Op<"block_argument_name_from_type",
+ [DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
+ let regions = (region AnyRegion:$body);
+ let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
// operations nested in a region under this op will drop the "test." dialect
// prefix.
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 60108ac86d1edd..6335701786ecc6 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -398,4 +398,9 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
let assemblyFormat = "`<` $param `>`";
}
+def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
+ [DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName"]>]> {
+ let mnemonic = "op_asm_type_interface";
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index b822e019e09d24..1ae7ac472d989e 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -532,3 +532,8 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
}
printer << ">";
}
+
+void TestTypeOpAsmTypeInterfaceType::getAsmName(
+ OpAsmSetNameFn setNameFn) const {
+ setNameFn("op_asm_type_interface");
+}
More information about the Mlir-commits
mailing list