[Mlir-commits] [mlir] [mlir] Add getAlias for OpAsmTypeInterface (PR #126364)
Hongren Zheng
llvmlistbot at llvm.org
Tue Feb 18 10:25:43 PST 2025
https://github.com/ZenithalHourlyRate updated https://github.com/llvm/llvm-project/pull/126364
>From 70b64caa4d95f179b9b3d4362a1c4f1bbae0a262 Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Sat, 8 Feb 2025 09:55:11 +0000
Subject: [PATCH] [mlir] Add getAlias for OpAsmTypeInterface
---
mlir/include/mlir/IR/OpAsmInterface.td | 7 +++++++
mlir/lib/IR/AsmPrinter.cpp | 15 +++++++--------
mlir/test/IR/op-asm-interface.mlir | 10 ++++++++++
mlir/test/lib/Dialect/Test/TestTypeDefs.td | 2 +-
mlir/test/lib/Dialect/Test/TestTypes.cpp | 6 ++++++
5 files changed, 31 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
index c3e84bccc5dee..1bd8eb04714c5 100644
--- a/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -127,6 +127,13 @@ def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
"void", "getAsmName",
(ins "::mlir::OpAsmSetNameFn":$setNameFn), "", ";"
>,
+ InterfaceMethod<[{
+ Get a name to use when generating an alias for this type.
+ }],
+ "::mlir::OpAsmDialectInterface::AliasResult", "getAlias",
+ (ins "::llvm::raw_ostream&":$os), "",
+ "return ::mlir::OpAsmDialectInterface::AliasResult::NoAlias;"
+ >,
];
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index cc578eae3ee36..1f22d4f37a813 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1163,14 +1163,13 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
OpAsmDialectInterface::AliasResult symbolInterfaceResult =
OpAsmDialectInterface::AliasResult::NoAlias;
- if constexpr (std::is_base_of_v<Attribute, T>) {
- if (auto symbolInterface = dyn_cast<OpAsmAttrInterface>(symbol)) {
- symbolInterfaceResult = symbolInterface.getAlias(aliasOS);
- if (symbolInterfaceResult !=
- OpAsmDialectInterface::AliasResult::NoAlias) {
- nameBuffer = std::move(aliasBuffer);
- assert(!nameBuffer.empty() && "expected valid alias name");
- }
+ using InterfaceT = std::conditional_t<std::is_base_of_v<Attribute, T>,
+ OpAsmAttrInterface, OpAsmTypeInterface>;
+ if (auto symbolInterface = dyn_cast<InterfaceT>(symbol)) {
+ symbolInterfaceResult = symbolInterface.getAlias(aliasOS);
+ if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::NoAlias) {
+ nameBuffer = std::move(aliasBuffer);
+ assert(!nameBuffer.empty() && "expected valid alias name");
}
}
diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
index 44a6e7afece03..086dc7da421c2 100644
--- a/mlir/test/IR/op-asm-interface.mlir
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -61,6 +61,16 @@ func.func @block_argument_name_from_op_asm_type_interface_asmprinter() {
// -----
+// CHECK: !op_asm_type_interface_type =
+!type = !test.op_asm_type_interface
+
+func.func @alias_from_op_asm_type_interface() {
+ %0 = "test.result_name_from_type"() : () -> !type
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// Test OpAsmAttrInterface
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 6335701786ecc..c048f8b654ec2 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -399,7 +399,7 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
}
def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
- [DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName"]>]> {
+ [DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName", "getAlias"]>]> {
let mnemonic = "op_asm_type_interface";
}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 1ae7ac472d989..0c237440834ef 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -537,3 +537,9 @@ void TestTypeOpAsmTypeInterfaceType::getAsmName(
OpAsmSetNameFn setNameFn) const {
setNameFn("op_asm_type_interface");
}
+
+::mlir::OpAsmDialectInterface::AliasResult
+TestTypeOpAsmTypeInterfaceType::getAlias(::llvm::raw_ostream &os) const {
+ os << "op_asm_type_interface_type";
+ return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias;
+}
More information about the Mlir-commits
mailing list