[Mlir-commits] [mlir] [mlir] Allow trailing digit for alias in AsmPrinter (PR #127993)
Hongren Zheng
llvmlistbot at llvm.org
Thu Feb 20 11:18:07 PST 2025
https://github.com/ZenithalHourlyRate updated https://github.com/llvm/llvm-project/pull/127993
>From cf2b899bc394cc563f9688f848a67b552b12e332 Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Thu, 20 Feb 2025 11:09:42 +0000
Subject: [PATCH] [mlir] Allow trailing digit for alias in AsmPrinter
---
mlir/lib/IR/AsmPrinter.cpp | 57 ++++++++++++++-----
mlir/test/IR/print-attr-type-aliases.mlir | 24 ++++----
mlir/test/IR/recursive-type.mlir | 4 +-
.../Dialect/Test/TestDialectInterfaces.cpp | 18 ++++--
4 files changed, 72 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 1f22d4f37a813..b00032febda11 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -547,8 +547,11 @@ class SymbolAlias {
/// Print this alias to the given stream.
void print(raw_ostream &os) const {
os << (isType ? "!" : "#") << name;
- if (suffixIndex)
+ if (suffixIndex) {
+ if (isdigit(name.back()))
+ os << '_';
os << suffixIndex;
+ }
}
/// Returns true if this is a type alias.
@@ -1020,8 +1023,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
/// the string needs to be modified in any way, the provided buffer is used to
/// store the new copy,
static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
- StringRef allowedPunctChars = "$._-",
- bool allowTrailingDigit = true) {
+ StringRef allowedPunctChars = "$._-") {
assert(!name.empty() && "Shouldn't have an empty name here");
auto validChar = [&](char ch) {
@@ -1048,14 +1050,6 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
return buffer;
}
- // If the name ends with a trailing digit, add a '_' to avoid potential
- // conflicts with autogenerated ID's.
- if (!allowTrailingDigit && isdigit(name.back())) {
- copyNameToBuffer();
- buffer.push_back('_');
- return buffer;
- }
-
// Check to see that the name consists of only valid identifier characters.
for (char ch : name) {
if (!validChar(ch)) {
@@ -1084,7 +1078,43 @@ void AliasInitializer::initializeAliases(
if (!aliasInfo.alias)
continue;
StringRef alias = *aliasInfo.alias;
- unsigned nameIndex = nameCounts[alias]++;
+ unsigned nameIndex;
+ // If the alias ends with a digit, we need to pretend as if it has trailing
+ // underscore to get a unique nameIndex.
+ if (isdigit(alias.back())) {
+ SmallString<16> aliasBuffer(alias);
+ aliasBuffer.push_back('_');
+ // Check if it is safe to use the alias as is.
+ if (nameCounts[aliasBuffer] == 0) {
+ aliasBuffer.pop_back();
+ // Get the prefix stripped of trailing digits. Prefixes ending with a
+ // digit will never generate prefix + digits instead it will generate
+ // prefix or prefix + '_' + digits
+ int numOfDigits = 0;
+ while (!aliasBuffer.empty() && isdigit(aliasBuffer.back())) {
+ numOfDigits++;
+ aliasBuffer.pop_back();
+ }
+ unsigned trailingNumber =
+ std::stoi(std::string(alias.take_back(numOfDigits)));
+ // Check if the prefix auto-generated this alias. If so, directly
+ // starting from _1 to avoid confliction.
+ bool autoGenerated = nameCounts[aliasBuffer] > trailingNumber;
+ // Reinitialize the aliasBuffer to the original alias.
+ aliasBuffer = alias;
+ aliasBuffer.push_back('_');
+ if (autoGenerated) {
+ nameIndex = 1;
+ nameCounts[aliasBuffer] = 2;
+ } else {
+ nameIndex = nameCounts[aliasBuffer]++;
+ }
+ } else {
+ nameIndex = nameCounts[aliasBuffer]++;
+ }
+ } else {
+ nameIndex = nameCounts[alias]++;
+ }
symbolToAlias.insert(
{symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
aliasInfo.canBeDeferred)});
@@ -1191,8 +1221,7 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
SmallString<16> tempBuffer;
StringRef name =
- sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
- /*allowTrailingDigit=*/false);
+ sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-");
name = name.copy(aliasAllocator);
alias = InProgressAliasInfo(name);
}
diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index e878d862076c9..8772dd4db5c71 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -5,7 +5,7 @@
// CHECK-DAG: #test2Ealias = "alias_test:dot_in_name"
"test.op"() {alias_test = "alias_test:dot_in_name"} : () -> ()
-// CHECK-DAG: #test_alias0_ = "alias_test:trailing_digit"
+// CHECK-DAG: #test_alias0 = "alias_test:trailing_digit"
"test.op"() {alias_test = "alias_test:trailing_digit"} : () -> ()
// CHECK-DAG: #_0_test_alias = "alias_test:prefixed_digit"
@@ -14,12 +14,14 @@
// CHECK-DAG: #_25test = "alias_test:prefixed_symbol"
"test.op"() {alias_test = "alias_test:prefixed_symbol"} : () -> ()
-// CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_a"
-// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:sanitize_conflict_b"
-"test.op"() {alias_test = ["alias_test:sanitize_conflict_a", "alias_test:sanitize_conflict_b"]} : () -> ()
-
-// CHECK-DAG: !tuple = tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
-"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
+// CHECK-DAG: #test_alias_conflict0 = "alias_test:trailing_digit_conflict_a"
+// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:trailing_digit_conflict_b"
+// CHECK-DAG: #test_alias_conflict0_2 = "alias_test:trailing_digit_conflict_c"
+// CHECK-DAG: #test_alias_conflict0_1_1 = "alias_test:trailing_digit_conflict_d"
+// CHECK-DAG: #test_alias_conflict0_1_2 = "alias_test:trailing_digit_conflict_e"
+// CHECK-DAG: #test_alias_conflict0_1_3 = "alias_test:trailing_digit_conflict_f"
+// CHECK-DAG: #test_alias_conflict0_1_1_1 = "alias_test:trailing_digit_conflict_g"
+"test.op"() {alias_test = ["alias_test:trailing_digit_conflict_a", "alias_test:trailing_digit_conflict_b", "alias_test:trailing_digit_conflict_c", "alias_test:trailing_digit_conflict_d", "alias_test:trailing_digit_conflict_e", "alias_test:trailing_digit_conflict_f", "alias_test:trailing_digit_conflict_g"]} : () -> ()
// CHECK-DAG: !test_tuple = tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
@@ -28,8 +30,8 @@
// CHECK-DAG: tensor<32xf32, #test_encoding>
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
-// CHECK-DAG: !test_ui8_ = !test.int<unsigned, 8>
-// CHECK-DAG: tensor<32x!test_ui8_>
+// CHECK-DAG: !test_ui8 = !test.int<unsigned, 8>
+// CHECK-DAG: tensor<32x!test_ui8>
"test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
// CHECK-DAG: #[[LOC_NESTED:.+]] = loc("nested")
@@ -47,8 +49,8 @@
// -----
// Ensure self type parameters get considered for aliases.
-// CHECK: !test_ui8_ = !test.int<unsigned, 8>
-// CHECK: #test.attr_with_self_type_param : !test_ui8_
+// CHECK: !test_ui8 = !test.int<unsigned, 8>
+// CHECK: #test.attr_with_self_type_param : !test_ui8
"test.op"() {alias_test = #test.attr_with_self_type_param : !test.int<unsigned, 8> } : () -> ()
// -----
diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir
index 42aecb41d998d..b8111d9601e48 100644
--- a/mlir/test/IR/recursive-type.mlir
+++ b/mlir/test/IR/recursive-type.mlir
@@ -4,8 +4,8 @@
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
-// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
-// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
+// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5>
+// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4>
// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 64add8cef3698..b5668c5494e0f 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -187,12 +187,22 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
StringSwitch<std::optional<StringRef>>(strAttr.getValue())
.Case("alias_test:dot_in_name", StringRef("test.alias"))
.Case("alias_test:trailing_digit", StringRef("test_alias0"))
- .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
- .Case("alias_test:prefixed_symbol", StringRef("%test"))
- .Case("alias_test:sanitize_conflict_a",
+ .Case("alias_test:trailing_digit_conflict_a",
+ StringRef("test_alias_conflict0"))
+ .Case("alias_test:trailing_digit_conflict_b",
StringRef("test_alias_conflict0"))
- .Case("alias_test:sanitize_conflict_b",
+ .Case("alias_test:trailing_digit_conflict_c",
StringRef("test_alias_conflict0_"))
+ .Case("alias_test:trailing_digit_conflict_d",
+ StringRef("test_alias_conflict0_1"))
+ .Case("alias_test:trailing_digit_conflict_e",
+ StringRef("test_alias_conflict0_1"))
+ .Case("alias_test:trailing_digit_conflict_f",
+ StringRef("test_alias_conflict0_1_"))
+ .Case("alias_test:trailing_digit_conflict_g",
+ StringRef("test_alias_conflict0_1_1"))
+ .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
+ .Case("alias_test:prefixed_symbol", StringRef("%test"))
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
.Default(std::nullopt);
if (!aliasName)
More information about the Mlir-commits
mailing list