[Mlir-commits] [mlir] [mlir] Integrate OpAsmTypeInterface with AsmPrinter (PR #124700)
Hongren Zheng
llvmlistbot at llvm.org
Tue Feb 4 19:38:55 PST 2025
https://github.com/ZenithalHourlyRate updated https://github.com/llvm/llvm-project/pull/124700
>From d6c9bbf62107a801ff7701cff2c75353a87b1119 Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Tue, 28 Jan 2025 05:32:20 +0000
Subject: [PATCH] [mlir] Integrate OpAsmTypeInterface with AsmPrinter
---
mlir/lib/IR/AsmPrinter.cpp | 30 ++++++++++++++++++++++
mlir/test/IR/op-asm-interface.mlir | 36 +++++++++++++++++++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 19 ++++++++++++++
3 files changed, 85 insertions(+)
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index eea4f7fa5c4be11..9f17699be29e5dc 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1536,10 +1536,13 @@ StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
} // namespace
void SSANameState::numberValuesInRegion(Region ®ion) {
+ // Indicates whether OpAsmOpInterface set a name.
+ bool opAsmOpInterfaceUsed = false;
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion &&
"arg not defined in current region");
+ opAsmOpInterfaceUsed = true;
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
name = maybeGetValueNameFromLoc(arg, name);
setValueName(arg, name);
@@ -1549,6 +1552,15 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
if (Operation *op = region.getParentOp()) {
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+ // If the OpAsmOpInterface didn't set a name, get name from the type.
+ if (!opAsmOpInterfaceUsed) {
+ for (BlockArgument arg : region.getArguments()) {
+ if (auto interface = dyn_cast<OpAsmTypeInterface>(arg.getType())) {
+ interface.getAsmName(
+ [&](StringRef name) { setBlockArgNameFn(arg, name); });
+ }
+ }
+ }
}
}
@@ -1598,9 +1610,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
void SSANameState::numberValuesInOp(Operation &op) {
// Function used to set the special result names for the operation.
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
+ // Indicates whether OpAsmOpInterface set a name.
+ bool opAsmOpInterfaceUsed = false;
auto setResultNameFn = [&](Value result, StringRef name) {
assert(!valueIDs.count(result) && "result numbered multiple times");
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
+ opAsmOpInterfaceUsed = true;
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
name = maybeGetValueNameFromLoc(result, name);
setValueName(result, name);
@@ -1630,6 +1645,21 @@ void SSANameState::numberValuesInOp(Operation &op) {
asmInterface.getAsmBlockNames(setBlockNameFn);
asmInterface.getAsmResultNames(setResultNameFn);
}
+ if (!opAsmOpInterfaceUsed) {
+ // If the OpAsmOpInterface didn't set a name, and all results have
+ // OpAsmTypeInterface, get names from types.
+ bool allHaveOpAsmTypeInterface =
+ llvm::all_of(op.getResultTypes(), [&](Type type) {
+ return isa<OpAsmTypeInterface>(type);
+ });
+ if (allHaveOpAsmTypeInterface) {
+ for (OpResult result : op.getResults()) {
+ auto interface = cast<OpAsmTypeInterface>(result.getType());
+ interface.getAsmName(
+ [&](StringRef name) { setResultNameFn(result, name); });
+ }
+ }
+ }
}
unsigned numResults = op.getNumResults();
diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
index a9c199e3dc97364..6cce9572f4fc26d 100644
--- a/mlir/test/IR/op-asm-interface.mlir
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -22,3 +22,39 @@ func.func @block_argument_name_from_op_asm_type_interface() {
}
return
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmTypeInterface
+//===----------------------------------------------------------------------===//
+
+func.func @result_name_from_op_asm_type_interface_asmprinter() {
+ // CHECK-LABEL: @result_name_from_op_asm_type_interface_asmprinter
+ // CHECK: %op_asm_type_interface
+ %0 = "test.result_name_from_type_interface"() : () -> !test.op_asm_type_interface
+ return
+}
+
+// -----
+
+// i1 does not have OpAsmTypeInterface, should not get named.
+func.func @result_name_from_op_asm_type_interface_not_all() {
+ // CHECK-LABEL: @result_name_from_op_asm_type_interface_not_all
+ // CHECK-NOT: %op_asm_type_interface
+ // CHECK: %0:2
+ %0:2 = "test.result_name_from_type_interface"() : () -> (!test.op_asm_type_interface, i1)
+ return
+}
+
+// -----
+
+func.func @block_argument_name_from_op_asm_type_interface_asmprinter() {
+ // CHECK-LABEL: @block_argument_name_from_op_asm_type_interface_asmprinter
+ // CHECK: ^bb0(%op_asm_type_interface
+ test.block_argument_name_from_type_interface {
+ ^bb0(%arg0: !test.op_asm_type_interface):
+ "test.terminator"() : ()->()
+ }
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2aa0658ab0e5d40..cdc1237ec8c5aa8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -955,6 +955,25 @@ def BlockArgumentNameFromTypeOp
let assemblyFormat = "regions attr-dict-with-keyword";
}
+// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
+// for op result name when OpAsmOpInterface::getAsmResultNames is the default implementation
+// i.e. does nothing.
+def ResultNameFromTypeInterfaceOp
+ : TEST_Op<"result_name_from_type_interface",
+ [OpAsmOpInterface]> {
+ let results = (outs Variadic<AnyType>:$r);
+}
+
+// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
+// for block argument name when OpAsmOpInterface::getAsmBlockArgumentNames is the default implementation
+// i.e. does nothing.
+def BlockArgumentNameFromTypeInterfaceOp
+ : TEST_Op<"block_argument_name_from_type_interface",
+ [OpAsmOpInterface]> {
+ let regions = (region AnyRegion:$body);
+ let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
// operations nested in a region under this op will drop the "test." dialect
// prefix.
More information about the Mlir-commits
mailing list