[Mlir-commits] [mlir] [mlir] Allow trailing digit for alias in AsmPrinter (PR #127993)

Hongren Zheng llvmlistbot at llvm.org
Thu Feb 20 03:54:32 PST 2025


https://github.com/ZenithalHourlyRate created https://github.com/llvm/llvm-project/pull/127993

When generating aliases from `OpAsm{Dialect,Type,Attr}Interface`, the result would be sanitized and if the alias provided by the interface has a trailing digit, AsmPrinter would attach an underscore to it to presumably prevent confliction.

#### Motivation

There are two reasons to motivate the change from the old behavior to the proposed behavior

1. If the type/attribute can generate unique alias from its content, then the extra trailing underscore added by AsmPrinter will be strange

```mlir
  func.func @add(%ct: !ct_L0_) -> !ct_L0_
    %ct_0 = bgv.add %ct, %ct : (!ct_L0_, !ct_L0_) -> !ct_L0_
    %ct_1 = bgv.add %ct_0, %ct_0 : (!ct_L0_, !ct_L0_) -> !ct_L0_
    %ct_2 = bgv.add %ct_1, %ct_1 : (!ct_L0_, !ct_L0_) -> !ct_L0_
    return %ct_2 : !ct_L0_
  }
```

Which aesthetically would be better if we have `(!ct_L0, !ct_L0) -> !ct_L0`

2. The Value name behavior is that, for the first instance, use no suffix `_N`, which can be similarly applied to alias name. See the IR above where the first one is called `%ct` and others are called `%ct_N`. See `uniqueValueName` for detail.

#### Conflict detection


```mlir
!test.type<a = 3> // suggest !name0
!test.type<a = 4> // suggest !name0
!test.another<b = 3> // suggest !name0_
!test.another<b = 4> // suggest !name0_
```

The conflict detection is based on `nameCounts` in `initializeAliases`, where 

In the original way, the first two will get sanitized to `!name0_` and `initializeAlias` can assign unique id `0, 1, 2, 3` to them.

In the current way, the `initializeAlias` know that `!name0` will be printed as `!name0_N` so it will pretend `!name0` is sanitized to `!name0_` when building `nameCounts` and it will get assigned unique id, see the test for example.

>From d95141853a6dcd9fad806331926ae161d498b20e Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Thu, 20 Feb 2025 11:09:42 +0000
Subject: [PATCH] [mlir] Allow trailing digit for alias in AsmPrinter

---
 mlir/lib/IR/AsmPrinter.cpp                    | 30 ++++++++++---------
 mlir/test/IR/print-attr-type-aliases.mlir     | 21 +++++++------
 mlir/test/IR/recursive-type.mlir              |  4 +--
 .../Dialect/Test/TestDialectInterfaces.cpp    | 12 +++++---
 4 files changed, 36 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 1f22d4f37a813..8044a1c8507e8 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -547,8 +547,11 @@ class SymbolAlias {
   /// Print this alias to the given stream.
   void print(raw_ostream &os) const {
     os << (isType ? "!" : "#") << name;
-    if (suffixIndex)
+    if (suffixIndex) {
+      if (isdigit(name.back()))
+        os << '_';
       os << suffixIndex;
+    }
   }
 
   /// Returns true if this is a type alias.
@@ -1020,8 +1023,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
 /// the string needs to be modified in any way, the provided buffer is used to
 /// store the new copy,
 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
-                                    StringRef allowedPunctChars = "$._-",
-                                    bool allowTrailingDigit = true) {
+                                    StringRef allowedPunctChars = "$._-") {
   assert(!name.empty() && "Shouldn't have an empty name here");
 
   auto validChar = [&](char ch) {
@@ -1048,14 +1050,6 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
     return buffer;
   }
 
-  // If the name ends with a trailing digit, add a '_' to avoid potential
-  // conflicts with autogenerated ID's.
-  if (!allowTrailingDigit && isdigit(name.back())) {
-    copyNameToBuffer();
-    buffer.push_back('_');
-    return buffer;
-  }
-
   // Check to see that the name consists of only valid identifier characters.
   for (char ch : name) {
     if (!validChar(ch)) {
@@ -1084,7 +1078,16 @@ void AliasInitializer::initializeAliases(
     if (!aliasInfo.alias)
       continue;
     StringRef alias = *aliasInfo.alias;
-    unsigned nameIndex = nameCounts[alias]++;
+    unsigned nameIndex;
+    // If the alias ends with a digit, we need to pretend as if it has trailing
+    // underscore to get a unique nameIndex.
+    if (isdigit(alias.back())) {
+      SmallString<16> aliasBuffer(alias);
+      aliasBuffer.push_back('_');
+      nameIndex = nameCounts[aliasBuffer]++;
+    } else {
+      nameIndex = nameCounts[alias]++;
+    }
     symbolToAlias.insert(
         {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
                              aliasInfo.canBeDeferred)});
@@ -1191,8 +1194,7 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
 
   SmallString<16> tempBuffer;
   StringRef name =
-      sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
-                         /*allowTrailingDigit=*/false);
+      sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-");
   name = name.copy(aliasAllocator);
   alias = InProgressAliasInfo(name);
 }
diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index e878d862076c9..97cda270e2c2e 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -5,7 +5,7 @@
 // CHECK-DAG: #test2Ealias = "alias_test:dot_in_name"
 "test.op"() {alias_test = "alias_test:dot_in_name"} : () -> ()
 
-// CHECK-DAG: #test_alias0_ = "alias_test:trailing_digit"
+// CHECK-DAG: #test_alias0 = "alias_test:trailing_digit"
 "test.op"() {alias_test = "alias_test:trailing_digit"} : () -> ()
 
 // CHECK-DAG: #_0_test_alias = "alias_test:prefixed_digit"
@@ -14,12 +14,11 @@
 // CHECK-DAG: #_25test = "alias_test:prefixed_symbol"
 "test.op"() {alias_test = "alias_test:prefixed_symbol"} : () -> ()
 
-// CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_a"
-// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:sanitize_conflict_b"
-"test.op"() {alias_test = ["alias_test:sanitize_conflict_a", "alias_test:sanitize_conflict_b"]} : () -> ()
-
-// CHECK-DAG: !tuple = tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
-"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
+// CHECK-DAG: #test_alias_conflict0 = "alias_test:trailing_digit_conflict_a"
+// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:trailing_digit_conflict_b"
+// CHECK-DAG: #test_alias_conflict0_2 = "alias_test:trailing_digit_conflict_c"
+// CHECK-DAG: #test_alias_conflict0_3 = "alias_test:trailing_digit_conflict_d"
+"test.op"() {alias_test = ["alias_test:trailing_digit_conflict_a", "alias_test:trailing_digit_conflict_b", "alias_test:trailing_digit_conflict_c", "alias_test:trailing_digit_conflict_d"]} : () -> ()
 
 // CHECK-DAG: !test_tuple = tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
 "test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
@@ -28,8 +27,8 @@
 // CHECK-DAG: tensor<32xf32, #test_encoding>
 "test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
 
-// CHECK-DAG: !test_ui8_ = !test.int<unsigned, 8>
-// CHECK-DAG: tensor<32x!test_ui8_>
+// CHECK-DAG: !test_ui8 = !test.int<unsigned, 8>
+// CHECK-DAG: tensor<32x!test_ui8>
 "test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
 
 // CHECK-DAG: #[[LOC_NESTED:.+]] = loc("nested")
@@ -47,8 +46,8 @@
 // -----
 
 // Ensure self type parameters get considered for aliases.
-// CHECK: !test_ui8_ = !test.int<unsigned, 8>
-// CHECK: #test.attr_with_self_type_param : !test_ui8_
+// CHECK: !test_ui8 = !test.int<unsigned, 8>
+// CHECK: #test.attr_with_self_type_param : !test_ui8
 "test.op"() {alias_test = #test.attr_with_self_type_param : !test.int<unsigned, 8> } : () -> ()
 
 // -----
diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir
index 42aecb41d998d..b8111d9601e48 100644
--- a/mlir/test/IR/recursive-type.mlir
+++ b/mlir/test/IR/recursive-type.mlir
@@ -4,8 +4,8 @@
 // CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
 // CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
 // CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
-// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
-// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
+// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5>
+// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4>
 
 // CHECK-LABEL: @roundtrip
 func.func @roundtrip() {
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 64add8cef3698..065692b98f219 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -187,12 +187,16 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
         StringSwitch<std::optional<StringRef>>(strAttr.getValue())
             .Case("alias_test:dot_in_name", StringRef("test.alias"))
             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
-            .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
-            .Case("alias_test:prefixed_symbol", StringRef("%test"))
-            .Case("alias_test:sanitize_conflict_a",
+            .Case("alias_test:trailing_digit_conflict_a",
+                  StringRef("test_alias_conflict0"))
+            .Case("alias_test:trailing_digit_conflict_b",
                   StringRef("test_alias_conflict0"))
-            .Case("alias_test:sanitize_conflict_b",
+            .Case("alias_test:trailing_digit_conflict_c",
                   StringRef("test_alias_conflict0_"))
+            .Case("alias_test:trailing_digit_conflict_d",
+                  StringRef("test_alias_conflict0_"))
+            .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
+            .Case("alias_test:prefixed_symbol", StringRef("%test"))
             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
             .Default(std::nullopt);
     if (!aliasName)



More information about the Mlir-commits mailing list