[Mlir-commits] [mlir] [mlir] Allow trailing digit for alias in AsmPrinter (PR #127993)

Hongren Zheng llvmlistbot at llvm.org
Thu Feb 20 11:18:07 PST 2025


https://github.com/ZenithalHourlyRate updated https://github.com/llvm/llvm-project/pull/127993

>From cf2b899bc394cc563f9688f848a67b552b12e332 Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Thu, 20 Feb 2025 11:09:42 +0000
Subject: [PATCH] [mlir] Allow trailing digit for alias in AsmPrinter

---
 mlir/lib/IR/AsmPrinter.cpp                    | 57 ++++++++++++++-----
 mlir/test/IR/print-attr-type-aliases.mlir     | 24 ++++----
 mlir/test/IR/recursive-type.mlir              |  4 +-
 .../Dialect/Test/TestDialectInterfaces.cpp    | 18 ++++--
 4 files changed, 72 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 1f22d4f37a813..b00032febda11 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -547,8 +547,11 @@ class SymbolAlias {
   /// Print this alias to the given stream.
   void print(raw_ostream &os) const {
     os << (isType ? "!" : "#") << name;
-    if (suffixIndex)
+    if (suffixIndex) {
+      if (isdigit(name.back()))
+        os << '_';
       os << suffixIndex;
+    }
   }
 
   /// Returns true if this is a type alias.
@@ -1020,8 +1023,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
 /// the string needs to be modified in any way, the provided buffer is used to
 /// store the new copy,
 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
-                                    StringRef allowedPunctChars = "$._-",
-                                    bool allowTrailingDigit = true) {
+                                    StringRef allowedPunctChars = "$._-") {
   assert(!name.empty() && "Shouldn't have an empty name here");
 
   auto validChar = [&](char ch) {
@@ -1048,14 +1050,6 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
     return buffer;
   }
 
-  // If the name ends with a trailing digit, add a '_' to avoid potential
-  // conflicts with autogenerated ID's.
-  if (!allowTrailingDigit && isdigit(name.back())) {
-    copyNameToBuffer();
-    buffer.push_back('_');
-    return buffer;
-  }
-
   // Check to see that the name consists of only valid identifier characters.
   for (char ch : name) {
     if (!validChar(ch)) {
@@ -1084,7 +1078,43 @@ void AliasInitializer::initializeAliases(
     if (!aliasInfo.alias)
       continue;
     StringRef alias = *aliasInfo.alias;
-    unsigned nameIndex = nameCounts[alias]++;
+    unsigned nameIndex;
+    // If the alias ends with a digit, we need to pretend as if it has trailing
+    // underscore to get a unique nameIndex.
+    if (isdigit(alias.back())) {
+      SmallString<16> aliasBuffer(alias);
+      aliasBuffer.push_back('_');
+      // Check if it is safe to use the alias as is.
+      if (nameCounts[aliasBuffer] == 0) {
+        aliasBuffer.pop_back();
+        // Get the prefix stripped of trailing digits. Prefixes ending with a
+        // digit will never generate prefix + digits instead it will generate
+        // prefix or prefix + '_' + digits
+        int numOfDigits = 0;
+        while (!aliasBuffer.empty() && isdigit(aliasBuffer.back())) {
+          numOfDigits++;
+          aliasBuffer.pop_back();
+        }
+        unsigned trailingNumber =
+            std::stoi(std::string(alias.take_back(numOfDigits)));
+        // Check if the prefix auto-generated this alias. If so, directly
+        // starting from _1 to avoid confliction.
+        bool autoGenerated = nameCounts[aliasBuffer] > trailingNumber;
+        // Reinitialize the aliasBuffer to the original alias.
+        aliasBuffer = alias;
+        aliasBuffer.push_back('_');
+        if (autoGenerated) {
+          nameIndex = 1;
+          nameCounts[aliasBuffer] = 2;
+        } else {
+          nameIndex = nameCounts[aliasBuffer]++;
+        }
+      } else {
+        nameIndex = nameCounts[aliasBuffer]++;
+      }
+    } else {
+      nameIndex = nameCounts[alias]++;
+    }
     symbolToAlias.insert(
         {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
                              aliasInfo.canBeDeferred)});
@@ -1191,8 +1221,7 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
 
   SmallString<16> tempBuffer;
   StringRef name =
-      sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
-                         /*allowTrailingDigit=*/false);
+      sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-");
   name = name.copy(aliasAllocator);
   alias = InProgressAliasInfo(name);
 }
diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index e878d862076c9..8772dd4db5c71 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -5,7 +5,7 @@
 // CHECK-DAG: #test2Ealias = "alias_test:dot_in_name"
 "test.op"() {alias_test = "alias_test:dot_in_name"} : () -> ()
 
-// CHECK-DAG: #test_alias0_ = "alias_test:trailing_digit"
+// CHECK-DAG: #test_alias0 = "alias_test:trailing_digit"
 "test.op"() {alias_test = "alias_test:trailing_digit"} : () -> ()
 
 // CHECK-DAG: #_0_test_alias = "alias_test:prefixed_digit"
@@ -14,12 +14,14 @@
 // CHECK-DAG: #_25test = "alias_test:prefixed_symbol"
 "test.op"() {alias_test = "alias_test:prefixed_symbol"} : () -> ()
 
-// CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_a"
-// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:sanitize_conflict_b"
-"test.op"() {alias_test = ["alias_test:sanitize_conflict_a", "alias_test:sanitize_conflict_b"]} : () -> ()
-
-// CHECK-DAG: !tuple = tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
-"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
+// CHECK-DAG: #test_alias_conflict0 = "alias_test:trailing_digit_conflict_a"
+// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:trailing_digit_conflict_b"
+// CHECK-DAG: #test_alias_conflict0_2 = "alias_test:trailing_digit_conflict_c"
+// CHECK-DAG: #test_alias_conflict0_1_1 = "alias_test:trailing_digit_conflict_d"
+// CHECK-DAG: #test_alias_conflict0_1_2 = "alias_test:trailing_digit_conflict_e"
+// CHECK-DAG: #test_alias_conflict0_1_3 = "alias_test:trailing_digit_conflict_f"
+// CHECK-DAG: #test_alias_conflict0_1_1_1 = "alias_test:trailing_digit_conflict_g"
+"test.op"() {alias_test = ["alias_test:trailing_digit_conflict_a", "alias_test:trailing_digit_conflict_b", "alias_test:trailing_digit_conflict_c", "alias_test:trailing_digit_conflict_d", "alias_test:trailing_digit_conflict_e", "alias_test:trailing_digit_conflict_f", "alias_test:trailing_digit_conflict_g"]} : () -> ()
 
 // CHECK-DAG: !test_tuple = tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
 "test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
@@ -28,8 +30,8 @@
 // CHECK-DAG: tensor<32xf32, #test_encoding>
 "test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
 
-// CHECK-DAG: !test_ui8_ = !test.int<unsigned, 8>
-// CHECK-DAG: tensor<32x!test_ui8_>
+// CHECK-DAG: !test_ui8 = !test.int<unsigned, 8>
+// CHECK-DAG: tensor<32x!test_ui8>
 "test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
 
 // CHECK-DAG: #[[LOC_NESTED:.+]] = loc("nested")
@@ -47,8 +49,8 @@
 // -----
 
 // Ensure self type parameters get considered for aliases.
-// CHECK: !test_ui8_ = !test.int<unsigned, 8>
-// CHECK: #test.attr_with_self_type_param : !test_ui8_
+// CHECK: !test_ui8 = !test.int<unsigned, 8>
+// CHECK: #test.attr_with_self_type_param : !test_ui8
 "test.op"() {alias_test = #test.attr_with_self_type_param : !test.int<unsigned, 8> } : () -> ()
 
 // -----
diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir
index 42aecb41d998d..b8111d9601e48 100644
--- a/mlir/test/IR/recursive-type.mlir
+++ b/mlir/test/IR/recursive-type.mlir
@@ -4,8 +4,8 @@
 // CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
 // CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
 // CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
-// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
-// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
+// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5>
+// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4>
 
 // CHECK-LABEL: @roundtrip
 func.func @roundtrip() {
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 64add8cef3698..b5668c5494e0f 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -187,12 +187,22 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
         StringSwitch<std::optional<StringRef>>(strAttr.getValue())
             .Case("alias_test:dot_in_name", StringRef("test.alias"))
             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
-            .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
-            .Case("alias_test:prefixed_symbol", StringRef("%test"))
-            .Case("alias_test:sanitize_conflict_a",
+            .Case("alias_test:trailing_digit_conflict_a",
+                  StringRef("test_alias_conflict0"))
+            .Case("alias_test:trailing_digit_conflict_b",
                   StringRef("test_alias_conflict0"))
-            .Case("alias_test:sanitize_conflict_b",
+            .Case("alias_test:trailing_digit_conflict_c",
                   StringRef("test_alias_conflict0_"))
+            .Case("alias_test:trailing_digit_conflict_d",
+                  StringRef("test_alias_conflict0_1"))
+            .Case("alias_test:trailing_digit_conflict_e",
+                  StringRef("test_alias_conflict0_1"))
+            .Case("alias_test:trailing_digit_conflict_f",
+                  StringRef("test_alias_conflict0_1_"))
+            .Case("alias_test:trailing_digit_conflict_g",
+                  StringRef("test_alias_conflict0_1_1"))
+            .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
+            .Case("alias_test:prefixed_symbol", StringRef("%test"))
             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
             .Default(std::nullopt);
     if (!aliasName)



More information about the Mlir-commits mailing list