[Mlir-commits] [mlir] [mlir] Add OpAsmTypeInterface for pretty-print (PR #121187)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 27 00:11:34 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-ods
Author: Hongren Zheng (ZenithalHourlyRate)
<details>
<summary>Changes</summary>
See https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792 for detailed introduction.
This PR acts as the first part of it
* Add `OpAsmTypeInterface` and `getAsmName` API for deducing ASM name from type
* Add default impl in `OpAsmOpInterface` to respect this API when available.
The `OpAsmAttrInterface` / hooking into Alias system part should be another PR, using a `getAlias` API.
### Discussion
* Instead of using `StringRef getAsmName()` as the API, I use `void getAsmName(OpAsmSetNameFn)`, as returning StringRef might be unsafe (std::string constructed inside then returned a _ref_; and this aligns with the design of `getAsmResultNames`.
* On the result packing of an op, the current approach is that when not all of the result types are `OpAsmTypeInterface`, then do nothing (old default impl)
### Review
Cc @<!-- -->j2kun and @<!-- -->Alexanderviand-intel for downstream; Cc @<!-- -->River707 and @<!-- -->joker-eph for relevent commit history; Cc @<!-- -->ftynse for discourse.
---
Full diff: https://github.com/llvm/llvm-project/pull/121187.diff
8 Files Affected:
- (modified) mlir/include/mlir/IR/CMakeLists.txt (+8-1)
- (modified) mlir/include/mlir/IR/OpAsmInterface.td (+53-2)
- (modified) mlir/include/mlir/IR/OpImplementation.h (+9-3)
- (modified) mlir/lib/IR/AsmPrinter.cpp (+2-1)
- (added) mlir/test/IR/op-asm-interface.mlir (+43)
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+24)
- (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+4)
- (modified) mlir/test/lib/Dialect/Test/TestTypes.cpp (+5)
``````````diff
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index b741eb18d47916..0c7937dfd69e55 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,7 +1,14 @@
-add_mlir_interface(OpAsmInterface)
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)
+set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
+mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs)
+mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls)
+mlir_tablegen(OpAsmTypeInterface.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIROpAsmInterfaceIncGen)
+add_dependencies(mlir-generic-headers MLIROpAsmInterfaceIncGen)
+
set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
index 98b5095ff2d665..d2dfb60b2ac142 100644
--- a/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -50,10 +50,27 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
```mlir
%first_result, %middle_results:2, %0 = "my.op" ...
```
+
+ The default implementation uses `OpAsmTypeInterface` to get the name for
+ each result from its type.
+
+ If not all of the result types have `OpAsmTypeInterface`, the default implementation
+ does nothing, as the packing behavior should be decided by the operation itself.
}],
"void", "getAsmResultNames",
(ins "::mlir::OpAsmSetValueNameFn":$setNameFn),
- "", "return;"
+ "", [{
+ bool hasOpAsmTypeInterface = llvm::all_of($_op->getResults(), [&](Value result) {
+ return ::mlir::isa<::mlir::OpAsmTypeInterface>(result.getType());
+ });
+ if (!hasOpAsmTypeInterface)
+ return;
+ for (auto result : $_op->getResults()) {
+ auto setResultNameFn = [&](StringRef name) { setNameFn(result, name); };
+ auto opAsmTypeInterface = ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
+ opAsmTypeInterface.getAsmName(setResultNameFn);
+ }
+ }]
>,
InterfaceMethod<[{
Get a special name to use when printing the block arguments for a region
@@ -64,7 +81,16 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
"::mlir::Region&":$region,
"::mlir::OpAsmSetValueNameFn":$setNameFn
),
- "", "return;"
+ "", [{
+ for (auto &block : region) {
+ for (auto arg : block.getArguments()) {
+ if (auto opAsmTypeInterface = ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
+ auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
+ opAsmTypeInterface.getAsmName(setArgNameFn);
+ }
+ }
+ }
+ }]
>,
InterfaceMethod<[{
Get the name to use for a given block inside a region attached to this
@@ -109,6 +135,31 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
];
}
+//===----------------------------------------------------------------------===//
+// OpAsmTypeInterface
+//===----------------------------------------------------------------------===//
+
+def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
+ let description = [{
+ This interface provides hooks to interact with the AsmPrinter and AsmParser
+ classes.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Get a special name to use when printing value of this type.
+
+ For example, the default implementation of OpAsmOpInterface
+ will respect this method when printing the results of an operation
+ and/or block argument of it.
+ }],
+ "void", "getAsmName",
+ (ins "::mlir::OpAsmSetNameFn":$setNameFn), "", ";"
+ >,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// ResourceHandleParameter
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 6c1ff4d0e5e6b9..83e2758a2e782f 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -734,7 +734,7 @@ class AsmParser {
virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;
- private:
+private:
template <typename IntT, typename ParseFn>
OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
ParseFn &&parseFn) {
@@ -756,7 +756,7 @@ class AsmParser {
return success();
}
- public:
+public:
template <typename IntT>
OptionalParseResult parseOptionalInteger(IntT &result) {
return parseOptionalIntegerAndCheck(
@@ -1727,6 +1727,10 @@ class OpAsmParser : public AsmParser {
// Dialect OpAsm interface.
//===--------------------------------------------------------------------===//
+/// A functor used to set the name of the result. See 'getAsmResultNames' below
+/// for more details.
+using OpAsmSetNameFn = function_ref<void(StringRef)>;
+
/// A functor used to set the name of the start of a result group of an
/// operation. See 'getAsmResultNames' below for more details.
using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
@@ -1820,7 +1824,9 @@ ParseResult parseDimensionList(OpAsmParser &parser,
//===--------------------------------------------------------------------===//
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
-#include "mlir/IR/OpAsmInterface.h.inc"
+#include "mlir/IR/OpAsmTypeInterface.h.inc"
+// put Attr/Type before Op
+#include "mlir/IR/OpAsmOpInterface.h.inc"
namespace llvm {
template <>
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6fe96504ae100c..9eddcb1a0872ba 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -125,7 +125,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
//===----------------------------------------------------------------------===//
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
-#include "mlir/IR/OpAsmInterface.cpp.inc"
+#include "mlir/IR/OpAsmOpInterface.cpp.inc"
+#include "mlir/IR/OpAsmTypeInterface.cpp.inc"
LogicalResult
OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
new file mode 100644
index 00000000000000..9753fd419cb609
--- /dev/null
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmOpInterface
+//===----------------------------------------------------------------------===//
+
+func.func @result_name_from_op_asm_type_interface() {
+ // CHECK-LABEL: @result_name_from_op_asm_type_interface
+ // CHECK: %op_asm_type_interface
+ %0 = "test.default_result_name"() : () -> !test.op_asm_type_interface
+ return
+}
+
+// -----
+
+func.func @result_name_pack_from_op_asm_type_interface() {
+ // CHECK-LABEL: @result_name_pack_from_op_asm_type_interface
+ // CHECK: %op_asm_type_interface{{.*}}, %op_asm_type_interface{{.*}}
+ // CHECK-NOT: :2
+ %0:2 = "test.default_result_name_packing"() : () -> (!test.op_asm_type_interface, !test.op_asm_type_interface)
+ return
+}
+
+// -----
+
+func.func @result_name_pack_do_nothing() {
+ // CHECK-LABEL: @result_name_pack_do_nothing
+ // CHECK: %0:2
+ %0:2 = "test.default_result_name_packing"() : () -> (i32, !test.op_asm_type_interface)
+ return
+}
+
+// -----
+
+func.func @block_argument_name_from_op_asm_type_interface() {
+ // CHECK-LABEL: @block_argument_name_from_op_asm_type_interface
+ // CHECK: ^bb0(%op_asm_type_interface
+ test.default_block_argument_name {
+ ^bb0(%arg0: !test.op_asm_type_interface):
+ "test.terminator"() : ()->()
+ }
+ return
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bafab155eb9d57..72047c0e9aeb9f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -924,6 +924,30 @@ def CustomResultsNameOp
let results = (outs Variadic<AnyInteger>:$r);
}
+// This is used to test default implementation of OpAsmOpInterface::getAsmResultNames,
+// which uses OpAsmTypeInterface if available.
+def DefaultResultsNameOp
+ : TEST_Op<"default_result_name",
+ [OpAsmOpInterface]> {
+ let results = (outs AnyType:$r);
+}
+
+// This is used to test default implementation of OpAsmOpInterface::getAsmResultNames,
+// when there are multiple results, and not all of their type has OpAsmTypeInterface,
+// it should not set result name from OpAsmTypeInterface.
+def DefaultResultsNamePackingOp
+ : TEST_Op<"default_result_name_packing",
+ [OpAsmOpInterface]> {
+ let results = (outs AnyType:$r, AnyType:$s);
+}
+
+// This is used to test default implementation of OpAsmOpInterface::getAsmBlockArgumentNames,
+def DefaultBlockArgumentNameOp : TEST_Op<"default_block_argument_name",
+ [OpAsmOpInterface]> {
+ let regions = (region AnyRegion:$body);
+ let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
// operations nested in a region under this op will drop the "test." dialect
// prefix.
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 60108ac86d1edd..756552a7ebd63e 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -398,4 +398,8 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
let assemblyFormat = "`<` $param `>`";
}
+def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface", [DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName"]>]> {
+ let mnemonic = "op_asm_type_interface";
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 6e31bb71d04d80..8ab2a16338b116 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -531,3 +531,8 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
}
printer << ">";
}
+
+void TestTypeOpAsmTypeInterfaceType::getAsmName(
+ OpAsmSetNameFn setNameFn) const {
+ setNameFn("op_asm_type_interface");
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/121187
More information about the Mlir-commits
mailing list