[Mlir-commits] [mlir] 055872a - [mlir] Integrate OpAsmTypeInterface with AsmPrinter (#124700)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 18 10:07:21 PST 2025


Author: Hongren Zheng
Date: 2025-02-19T02:07:17+08:00
New Revision: 055872acc28afdd8d29acdbec24f4bd415481d33

URL: https://github.com/llvm/llvm-project/commit/055872acc28afdd8d29acdbec24f4bd415481d33
DIFF: https://github.com/llvm/llvm-project/commit/055872acc28afdd8d29acdbec24f4bd415481d33.diff

LOG: [mlir] Integrate OpAsmTypeInterface with AsmPrinter (#124700)

See
https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792
for detailed introduction.

This is a follow up PR of #121187, by integrating OpAsmTypeInterface
with AsmPrinter. There are a few conditions when OpAsmTypeInterface
comes into play

* There is no OpAsmOpInterface
* Or OpAsmOpInterface::getAsmResultName/getBlockArgumentName does not
invoke `setName` (i.e. the default impl)
* All results have OpAsmTypeInterface (otherwise we can not handle
result grouping behavior)

Cc @River707 @jpienaar @ftynse for review.

Added: 
    

Modified: 
    mlir/lib/IR/AsmPrinter.cpp
    mlir/test/IR/op-asm-interface.mlir
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 0fa97f1f38079..03b8fca0fa7ab 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1536,10 +1536,13 @@ StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
 } // namespace
 
 void SSANameState::numberValuesInRegion(Region &region) {
+  // Indicates whether OpAsmOpInterface set a name.
+  bool opAsmOpInterfaceUsed = false;
   auto setBlockArgNameFn = [&](Value arg, StringRef name) {
     assert(!valueIDs.count(arg) && "arg numbered multiple times");
     assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
            "arg not defined in current region");
+    opAsmOpInterfaceUsed = true;
     if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
       name = maybeGetValueNameFromLoc(arg, name);
     setValueName(arg, name);
@@ -1549,6 +1552,15 @@ void SSANameState::numberValuesInRegion(Region &region) {
     if (Operation *op = region.getParentOp()) {
       if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
         asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+      // If the OpAsmOpInterface didn't set a name, get name from the type.
+      if (!opAsmOpInterfaceUsed) {
+        for (BlockArgument arg : region.getArguments()) {
+          if (auto interface = dyn_cast<OpAsmTypeInterface>(arg.getType())) {
+            interface.getAsmName(
+                [&](StringRef name) { setBlockArgNameFn(arg, name); });
+          }
+        }
+      }
     }
   }
 
@@ -1598,9 +1610,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
 void SSANameState::numberValuesInOp(Operation &op) {
   // Function used to set the special result names for the operation.
   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
+  // Indicates whether OpAsmOpInterface set a name.
+  bool opAsmOpInterfaceUsed = false;
   auto setResultNameFn = [&](Value result, StringRef name) {
     assert(!valueIDs.count(result) && "result numbered multiple times");
     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
+    opAsmOpInterfaceUsed = true;
     if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
       name = maybeGetValueNameFromLoc(result, name);
     setValueName(result, name);
@@ -1630,6 +1645,21 @@ void SSANameState::numberValuesInOp(Operation &op) {
       asmInterface.getAsmBlockNames(setBlockNameFn);
       asmInterface.getAsmResultNames(setResultNameFn);
     }
+    if (!opAsmOpInterfaceUsed) {
+      // If the OpAsmOpInterface didn't set a name, and all results have
+      // OpAsmTypeInterface, get names from types.
+      bool allHaveOpAsmTypeInterface =
+          llvm::all_of(op.getResultTypes(), [&](Type type) {
+            return isa<OpAsmTypeInterface>(type);
+          });
+      if (allHaveOpAsmTypeInterface) {
+        for (OpResult result : op.getResults()) {
+          auto interface = cast<OpAsmTypeInterface>(result.getType());
+          interface.getAsmName(
+              [&](StringRef name) { setResultNameFn(result, name); });
+        }
+      }
+    }
   }
 
   unsigned numResults = op.getNumResults();

diff  --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
index a9c199e3dc973..6cce9572f4fc2 100644
--- a/mlir/test/IR/op-asm-interface.mlir
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -22,3 +22,39 @@ func.func @block_argument_name_from_op_asm_type_interface() {
   }
   return
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmTypeInterface
+//===----------------------------------------------------------------------===//
+
+func.func @result_name_from_op_asm_type_interface_asmprinter() {
+  // CHECK-LABEL: @result_name_from_op_asm_type_interface_asmprinter
+  // CHECK: %op_asm_type_interface
+  %0 = "test.result_name_from_type_interface"() : () -> !test.op_asm_type_interface
+  return
+}
+
+// -----
+
+// i1 does not have OpAsmTypeInterface, should not get named.
+func.func @result_name_from_op_asm_type_interface_not_all() {
+  // CHECK-LABEL: @result_name_from_op_asm_type_interface_not_all
+  // CHECK-NOT: %op_asm_type_interface
+  // CHECK: %0:2
+  %0:2 = "test.result_name_from_type_interface"() : () -> (!test.op_asm_type_interface, i1)
+  return
+}
+
+// -----
+
+func.func @block_argument_name_from_op_asm_type_interface_asmprinter() {
+  // CHECK-LABEL: @block_argument_name_from_op_asm_type_interface_asmprinter
+  // CHECK: ^bb0(%op_asm_type_interface
+  test.block_argument_name_from_type_interface {
+    ^bb0(%arg0: !test.op_asm_type_interface):
+      "test.terminator"() : ()->()
+  }
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2aa0658ab0e5d..cdc1237ec8c5a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -955,6 +955,25 @@ def BlockArgumentNameFromTypeOp
   let assemblyFormat = "regions attr-dict-with-keyword";
 }
 
+// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
+// for op result name when OpAsmOpInterface::getAsmResultNames is the default implementation
+// i.e. does nothing.
+def ResultNameFromTypeInterfaceOp
+ : TEST_Op<"result_name_from_type_interface",
+           [OpAsmOpInterface]> {
+  let results = (outs Variadic<AnyType>:$r);
+}
+
+// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
+// for block argument name when OpAsmOpInterface::getAsmBlockArgumentNames is the default implementation
+// i.e. does nothing.
+def BlockArgumentNameFromTypeInterfaceOp
+  : TEST_Op<"block_argument_name_from_type_interface",
+      [OpAsmOpInterface]> {
+  let regions = (region AnyRegion:$body);
+  let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
 // This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
 // operations nested in a region under this op will drop the "test." dialect
 // prefix.


        


More information about the Mlir-commits mailing list