[Mlir-commits] [mlir] [mlir] Integrate OpAsmTypeInterface with AsmPrinter (PR #124700)
Hongren Zheng
llvmlistbot at llvm.org
Mon Jan 27 22:49:53 PST 2025
https://github.com/ZenithalHourlyRate updated https://github.com/llvm/llvm-project/pull/124700
>From 0b8231328888a5418fb373f5a9750df55d7c7c91 Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Tue, 28 Jan 2025 05:32:20 +0000
Subject: [PATCH] [mlir] Integrate OpAsmTypeInterface with AsmPrinter
---
mlir/lib/IR/AsmPrinter.cpp | 35 ++++++++++++++++++++++++++
mlir/test/IR/op-asm-interface.mlir | 36 +++++++++++++++++++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 19 ++++++++++++++
3 files changed, 90 insertions(+)
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index fa4a1b4b72b024..5103a9b51c74c1 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1536,10 +1536,13 @@ StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
} // namespace
void SSANameState::numberValuesInRegion(Region ®ion) {
+ // indicate whether OpAsmOpInterface set a name.
+ bool opAsmOpInterfaceUsed = false;
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion &&
"arg not defined in current region");
+ opAsmOpInterfaceUsed = true;
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
name = maybeGetValueNameFromLoc(arg, name);
setValueName(arg, name);
@@ -1549,6 +1552,18 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
if (Operation *op = region.getParentOp()) {
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+ if (!opAsmOpInterfaceUsed) {
+ // If the OpAsmOpInterface didn't set a name, get name from the type.
+ for (auto arg : region.getArguments()) {
+ if (auto typeInterface =
+ mlir::dyn_cast<OpAsmTypeInterface>(arg.getType())) {
+ auto setNameFn = [&](StringRef name) {
+ setBlockArgNameFn(arg, name);
+ };
+ typeInterface.getAsmName(setNameFn);
+ }
+ }
+ }
}
}
@@ -1598,9 +1613,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
void SSANameState::numberValuesInOp(Operation &op) {
// Function used to set the special result names for the operation.
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
+ // indicating whether OpAsmOpInterface set a name.
+ bool opAsmOpInterfaceUsed = false;
auto setResultNameFn = [&](Value result, StringRef name) {
assert(!valueIDs.count(result) && "result numbered multiple times");
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
+ opAsmOpInterfaceUsed = true;
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
name = maybeGetValueNameFromLoc(result, name);
setValueName(result, name);
@@ -1630,6 +1648,23 @@ void SSANameState::numberValuesInOp(Operation &op) {
asmInterface.getAsmBlockNames(setBlockNameFn);
asmInterface.getAsmResultNames(setResultNameFn);
}
+ if (!opAsmOpInterfaceUsed) {
+ // If the OpAsmOpInterface didn't set a name, and
+ // all results have OpAsmTypeInterface, get names from types.
+ bool allHaveOpAsmTypeInterface =
+ llvm::all_of(op.getResultTypes(), [&](Type type) {
+ return mlir::isa<OpAsmTypeInterface>(type);
+ });
+ if (allHaveOpAsmTypeInterface) {
+ for (auto result : op.getResults()) {
+ auto typeInterface = mlir::cast<OpAsmTypeInterface>(result.getType());
+ auto setNameFn = [&](StringRef name) {
+ setResultNameFn(result, name);
+ };
+ typeInterface.getAsmName(setNameFn);
+ }
+ }
+ }
}
unsigned numResults = op.getNumResults();
diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
index a9c199e3dc9736..fe73750ba0edf5 100644
--- a/mlir/test/IR/op-asm-interface.mlir
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -22,3 +22,39 @@ func.func @block_argument_name_from_op_asm_type_interface() {
}
return
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmTypeInterface
+//===----------------------------------------------------------------------===//
+
+func.func @result_name_from_op_asm_type_interface_asmprinter() {
+ // CHECK-LABEL: @result_name_from_op_asm_type_interface_asmprinter
+ // CHECK: %op_asm_type_interface
+ %0 = "test.result_name_from_type_interface"() : () -> !test.op_asm_type_interface
+ return
+}
+
+// -----
+
+// i1 does not have OpAsmTypeInterface, should not get named.
+func.func @result_name_from_op_asm_type_interface_not_all() {
+ // CHECK-LABEL: @result_name_from_op_asm_type_interface_not_all
+ // CHECK-NOT: %op_asm_type_interface
+ // CHECK: %0:2
+ %0:2 = "test.result_name_from_type_interface"() : () -> (!test.op_asm_type_interface, i1)
+ return
+}
+
+// -----
+
+func.func @block_argument_name_from_op_asm_type_interface_asmprinter() {
+ // CHECK-LABEL: @block_argument_name_from_op_asm_type_interface_asmprinter
+ // CHECK: ^bb0(%op_asm_type_interface
+ test.block_argument_name_from_type_interface {
+ ^bb0(%arg0: !test.op_asm_type_interface):
+ "test.terminator"() : ()->()
+ }
+ return
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index f37573c1351cec..4880b6697055be 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -939,6 +939,25 @@ def BlockArgumentNameFromTypeOp
let assemblyFormat = "regions attr-dict-with-keyword";
}
+// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
+// for op result name when OpAsmOpInterface::getAsmResultNames is the default implementation
+// i.e. does nothing.
+def ResultNameFromTypeInterfaceOp
+ : TEST_Op<"result_name_from_type_interface",
+ [OpAsmOpInterface]> {
+ let results = (outs Variadic<AnyType>:$r);
+}
+
+// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
+// for block argument name when OpAsmOpInterface::getAsmBlockArgumentNames is the default implementation
+// i.e. does nothing.
+def BlockArgumentNameFromTypeInterfaceOp
+ : TEST_Op<"block_argument_name_from_type_interface",
+ [OpAsmOpInterface]> {
+ let regions = (region AnyRegion:$body);
+ let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
// operations nested in a region under this op will drop the "test." dialect
// prefix.
More information about the Mlir-commits
mailing list