[Mlir-commits] [mlir] 055872a - [mlir] Integrate OpAsmTypeInterface with AsmPrinter (#124700)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 18 10:07:21 PST 2025
Author: Hongren Zheng
Date: 2025-02-19T02:07:17+08:00
New Revision: 055872acc28afdd8d29acdbec24f4bd415481d33
URL: https://github.com/llvm/llvm-project/commit/055872acc28afdd8d29acdbec24f4bd415481d33
DIFF: https://github.com/llvm/llvm-project/commit/055872acc28afdd8d29acdbec24f4bd415481d33.diff
LOG: [mlir] Integrate OpAsmTypeInterface with AsmPrinter (#124700)
See
https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792
for detailed introduction.
This is a follow up PR of #121187, by integrating OpAsmTypeInterface
with AsmPrinter. There are a few conditions when OpAsmTypeInterface
comes into play
* There is no OpAsmOpInterface
* Or OpAsmOpInterface::getAsmResultName/getBlockArgumentName does not
invoke `setName` (i.e. the default impl)
* All results have OpAsmTypeInterface (otherwise we can not handle
result grouping behavior)
Cc @River707 @jpienaar @ftynse for review.
Added:
Modified:
mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/op-asm-interface.mlir
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 0fa97f1f38079..03b8fca0fa7ab 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1536,10 +1536,13 @@ StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
} // namespace
void SSANameState::numberValuesInRegion(Region ®ion) {
+ // Indicates whether OpAsmOpInterface set a name.
+ bool opAsmOpInterfaceUsed = false;
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion &&
"arg not defined in current region");
+ opAsmOpInterfaceUsed = true;
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
name = maybeGetValueNameFromLoc(arg, name);
setValueName(arg, name);
@@ -1549,6 +1552,15 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
if (Operation *op = region.getParentOp()) {
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+ // If the OpAsmOpInterface didn't set a name, get name from the type.
+ if (!opAsmOpInterfaceUsed) {
+ for (BlockArgument arg : region.getArguments()) {
+ if (auto interface = dyn_cast<OpAsmTypeInterface>(arg.getType())) {
+ interface.getAsmName(
+ [&](StringRef name) { setBlockArgNameFn(arg, name); });
+ }
+ }
+ }
}
}
@@ -1598,9 +1610,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
void SSANameState::numberValuesInOp(Operation &op) {
// Function used to set the special result names for the operation.
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
+ // Indicates whether OpAsmOpInterface set a name.
+ bool opAsmOpInterfaceUsed = false;
auto setResultNameFn = [&](Value result, StringRef name) {
assert(!valueIDs.count(result) && "result numbered multiple times");
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
+ opAsmOpInterfaceUsed = true;
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
name = maybeGetValueNameFromLoc(result, name);
setValueName(result, name);
@@ -1630,6 +1645,21 @@ void SSANameState::numberValuesInOp(Operation &op) {
asmInterface.getAsmBlockNames(setBlockNameFn);
asmInterface.getAsmResultNames(setResultNameFn);
}
+ if (!opAsmOpInterfaceUsed) {
+ // If the OpAsmOpInterface didn't set a name, and all results have
+ // OpAsmTypeInterface, get names from types.
+ bool allHaveOpAsmTypeInterface =
+ llvm::all_of(op.getResultTypes(), [&](Type type) {
+ return isa<OpAsmTypeInterface>(type);
+ });
+ if (allHaveOpAsmTypeInterface) {
+ for (OpResult result : op.getResults()) {
+ auto interface = cast<OpAsmTypeInterface>(result.getType());
+ interface.getAsmName(
+ [&](StringRef name) { setResultNameFn(result, name); });
+ }
+ }
+ }
}
unsigned numResults = op.getNumResults();
diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
index a9c199e3dc973..6cce9572f4fc2 100644
--- a/mlir/test/IR/op-asm-interface.mlir
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -22,3 +22,39 @@ func.func @block_argument_name_from_op_asm_type_interface() {
}
return
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmTypeInterface
+//===----------------------------------------------------------------------===//
+
+func.func @result_name_from_op_asm_type_interface_asmprinter() {
+ // CHECK-LABEL: @result_name_from_op_asm_type_interface_asmprinter
+ // CHECK: %op_asm_type_interface
+ %0 = "test.result_name_from_type_interface"() : () -> !test.op_asm_type_interface
+ return
+}
+
+// -----
+
+// i1 does not have OpAsmTypeInterface, should not get named.
+func.func @result_name_from_op_asm_type_interface_not_all() {
+ // CHECK-LABEL: @result_name_from_op_asm_type_interface_not_all
+ // CHECK-NOT: %op_asm_type_interface
+ // CHECK: %0:2
+ %0:2 = "test.result_name_from_type_interface"() : () -> (!test.op_asm_type_interface, i1)
+ return
+}
+
+// -----
+
+func.func @block_argument_name_from_op_asm_type_interface_asmprinter() {
+ // CHECK-LABEL: @block_argument_name_from_op_asm_type_interface_asmprinter
+ // CHECK: ^bb0(%op_asm_type_interface
+ test.block_argument_name_from_type_interface {
+ ^bb0(%arg0: !test.op_asm_type_interface):
+ "test.terminator"() : ()->()
+ }
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2aa0658ab0e5d..cdc1237ec8c5a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -955,6 +955,25 @@ def BlockArgumentNameFromTypeOp
let assemblyFormat = "regions attr-dict-with-keyword";
}
+// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
+// for op result name when OpAsmOpInterface::getAsmResultNames is the default implementation
+// i.e. does nothing.
+def ResultNameFromTypeInterfaceOp
+ : TEST_Op<"result_name_from_type_interface",
+ [OpAsmOpInterface]> {
+ let results = (outs Variadic<AnyType>:$r);
+}
+
+// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
+// for block argument name when OpAsmOpInterface::getAsmBlockArgumentNames is the default implementation
+// i.e. does nothing.
+def BlockArgumentNameFromTypeInterfaceOp
+ : TEST_Op<"block_argument_name_from_type_interface",
+ [OpAsmOpInterface]> {
+ let regions = (region AnyRegion:$body);
+ let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
// operations nested in a region under this op will drop the "test." dialect
// prefix.
More information about the Mlir-commits
mailing list