[Mlir-commits] [mlir] [mlir] Integrate OpAsmTypeInterface with AsmPrinter (PR #124700)

Hongren Zheng llvmlistbot at llvm.org
Mon Jan 27 22:49:53 PST 2025


https://github.com/ZenithalHourlyRate updated https://github.com/llvm/llvm-project/pull/124700

>From 0b8231328888a5418fb373f5a9750df55d7c7c91 Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Tue, 28 Jan 2025 05:32:20 +0000
Subject: [PATCH] [mlir] Integrate OpAsmTypeInterface with AsmPrinter

---
 mlir/lib/IR/AsmPrinter.cpp            | 35 ++++++++++++++++++++++++++
 mlir/test/IR/op-asm-interface.mlir    | 36 +++++++++++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td | 19 ++++++++++++++
 3 files changed, 90 insertions(+)

diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index fa4a1b4b72b024..5103a9b51c74c1 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1536,10 +1536,13 @@ StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
 } // namespace
 
 void SSANameState::numberValuesInRegion(Region &region) {
+  // indicate whether OpAsmOpInterface set a name.
+  bool opAsmOpInterfaceUsed = false;
   auto setBlockArgNameFn = [&](Value arg, StringRef name) {
     assert(!valueIDs.count(arg) && "arg numbered multiple times");
     assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
            "arg not defined in current region");
+    opAsmOpInterfaceUsed = true;
     if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
       name = maybeGetValueNameFromLoc(arg, name);
     setValueName(arg, name);
@@ -1549,6 +1552,18 @@ void SSANameState::numberValuesInRegion(Region &region) {
     if (Operation *op = region.getParentOp()) {
       if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
         asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+      if (!opAsmOpInterfaceUsed) {
+        // If the OpAsmOpInterface didn't set a name, get name from the type.
+        for (auto arg : region.getArguments()) {
+          if (auto typeInterface =
+                  mlir::dyn_cast<OpAsmTypeInterface>(arg.getType())) {
+            auto setNameFn = [&](StringRef name) {
+              setBlockArgNameFn(arg, name);
+            };
+            typeInterface.getAsmName(setNameFn);
+          }
+        }
+      }
     }
   }
 
@@ -1598,9 +1613,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
 void SSANameState::numberValuesInOp(Operation &op) {
   // Function used to set the special result names for the operation.
   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
+  // indicating whether OpAsmOpInterface set a name.
+  bool opAsmOpInterfaceUsed = false;
   auto setResultNameFn = [&](Value result, StringRef name) {
     assert(!valueIDs.count(result) && "result numbered multiple times");
     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
+    opAsmOpInterfaceUsed = true;
     if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
       name = maybeGetValueNameFromLoc(result, name);
     setValueName(result, name);
@@ -1630,6 +1648,23 @@ void SSANameState::numberValuesInOp(Operation &op) {
       asmInterface.getAsmBlockNames(setBlockNameFn);
       asmInterface.getAsmResultNames(setResultNameFn);
     }
+    if (!opAsmOpInterfaceUsed) {
+      // If the OpAsmOpInterface didn't set a name, and
+      // all results have OpAsmTypeInterface, get names from types.
+      bool allHaveOpAsmTypeInterface =
+          llvm::all_of(op.getResultTypes(), [&](Type type) {
+            return mlir::isa<OpAsmTypeInterface>(type);
+          });
+      if (allHaveOpAsmTypeInterface) {
+        for (auto result : op.getResults()) {
+          auto typeInterface = mlir::cast<OpAsmTypeInterface>(result.getType());
+          auto setNameFn = [&](StringRef name) {
+            setResultNameFn(result, name);
+          };
+          typeInterface.getAsmName(setNameFn);
+        }
+      }
+    }
   }
 
   unsigned numResults = op.getNumResults();
diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
index a9c199e3dc9736..fe73750ba0edf5 100644
--- a/mlir/test/IR/op-asm-interface.mlir
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -22,3 +22,39 @@ func.func @block_argument_name_from_op_asm_type_interface() {
   }
   return
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmTypeInterface
+//===----------------------------------------------------------------------===//
+
+func.func @result_name_from_op_asm_type_interface_asmprinter() {
+  // CHECK-LABEL: @result_name_from_op_asm_type_interface_asmprinter
+  // CHECK: %op_asm_type_interface
+  %0 = "test.result_name_from_type_interface"() : () -> !test.op_asm_type_interface
+  return
+}
+
+// -----
+
+// i1 does not have OpAsmTypeInterface, should not get named.
+func.func @result_name_from_op_asm_type_interface_not_all() {
+  // CHECK-LABEL: @result_name_from_op_asm_type_interface_not_all
+  // CHECK-NOT: %op_asm_type_interface
+  // CHECK: %0:2
+  %0:2 = "test.result_name_from_type_interface"() : () -> (!test.op_asm_type_interface, i1)
+  return
+}
+
+// -----
+
+func.func @block_argument_name_from_op_asm_type_interface_asmprinter() {
+  // CHECK-LABEL: @block_argument_name_from_op_asm_type_interface_asmprinter
+  // CHECK: ^bb0(%op_asm_type_interface
+  test.block_argument_name_from_type_interface {
+    ^bb0(%arg0: !test.op_asm_type_interface):
+      "test.terminator"() : ()->()
+  }
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index f37573c1351cec..4880b6697055be 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -939,6 +939,25 @@ def BlockArgumentNameFromTypeOp
   let assemblyFormat = "regions attr-dict-with-keyword";
 }
 
+// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
+// for op result name when OpAsmOpInterface::getAsmResultNames is the default implementation
+// i.e. does nothing.
+def ResultNameFromTypeInterfaceOp
+ : TEST_Op<"result_name_from_type_interface",
+           [OpAsmOpInterface]> {
+  let results = (outs Variadic<AnyType>:$r);
+}
+
+// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
+// for block argument name when OpAsmOpInterface::getAsmBlockArgumentNames is the default implementation
+// i.e. does nothing.
+def BlockArgumentNameFromTypeInterfaceOp
+  : TEST_Op<"block_argument_name_from_type_interface",
+      [OpAsmOpInterface]> {
+  let regions = (region AnyRegion:$body);
+  let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
 // This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
 // operations nested in a region under this op will drop the "test." dialect
 // prefix.



More information about the Mlir-commits mailing list