[Mlir-commits] [mlir] [mlir] Allow trailing digit for alias in AsmPrinter (PR #127993)
Hongren Zheng
llvmlistbot at llvm.org
Thu Feb 20 03:54:32 PST 2025
https://github.com/ZenithalHourlyRate created https://github.com/llvm/llvm-project/pull/127993
When generating aliases from `OpAsm{Dialect,Type,Attr}Interface`, the result would be sanitized and if the alias provided by the interface has a trailing digit, AsmPrinter would attach an underscore to it to presumably prevent confliction.
#### Motivation
There are two reasons to motivate the change from the old behavior to the proposed behavior
1. If the type/attribute can generate unique alias from its content, then the extra trailing underscore added by AsmPrinter will be strange
```mlir
func.func @add(%ct: !ct_L0_) -> !ct_L0_
%ct_0 = bgv.add %ct, %ct : (!ct_L0_, !ct_L0_) -> !ct_L0_
%ct_1 = bgv.add %ct_0, %ct_0 : (!ct_L0_, !ct_L0_) -> !ct_L0_
%ct_2 = bgv.add %ct_1, %ct_1 : (!ct_L0_, !ct_L0_) -> !ct_L0_
return %ct_2 : !ct_L0_
}
```
Which aesthetically would be better if we have `(!ct_L0, !ct_L0) -> !ct_L0`
2. The Value name behavior is that, for the first instance, use no suffix `_N`, which can be similarly applied to alias name. See the IR above where the first one is called `%ct` and others are called `%ct_N`. See `uniqueValueName` for detail.
#### Conflict detection
```mlir
!test.type<a = 3> // suggest !name0
!test.type<a = 4> // suggest !name0
!test.another<b = 3> // suggest !name0_
!test.another<b = 4> // suggest !name0_
```
The conflict detection is based on `nameCounts` in `initializeAliases`, where
In the original way, the first two will get sanitized to `!name0_` and `initializeAlias` can assign unique id `0, 1, 2, 3` to them.
In the current way, the `initializeAlias` know that `!name0` will be printed as `!name0_N` so it will pretend `!name0` is sanitized to `!name0_` when building `nameCounts` and it will get assigned unique id, see the test for example.
>From d95141853a6dcd9fad806331926ae161d498b20e Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Thu, 20 Feb 2025 11:09:42 +0000
Subject: [PATCH] [mlir] Allow trailing digit for alias in AsmPrinter
---
mlir/lib/IR/AsmPrinter.cpp | 30 ++++++++++---------
mlir/test/IR/print-attr-type-aliases.mlir | 21 +++++++------
mlir/test/IR/recursive-type.mlir | 4 +--
.../Dialect/Test/TestDialectInterfaces.cpp | 12 +++++---
4 files changed, 36 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 1f22d4f37a813..8044a1c8507e8 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -547,8 +547,11 @@ class SymbolAlias {
/// Print this alias to the given stream.
void print(raw_ostream &os) const {
os << (isType ? "!" : "#") << name;
- if (suffixIndex)
+ if (suffixIndex) {
+ if (isdigit(name.back()))
+ os << '_';
os << suffixIndex;
+ }
}
/// Returns true if this is a type alias.
@@ -1020,8 +1023,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
/// the string needs to be modified in any way, the provided buffer is used to
/// store the new copy,
static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
- StringRef allowedPunctChars = "$._-",
- bool allowTrailingDigit = true) {
+ StringRef allowedPunctChars = "$._-") {
assert(!name.empty() && "Shouldn't have an empty name here");
auto validChar = [&](char ch) {
@@ -1048,14 +1050,6 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
return buffer;
}
- // If the name ends with a trailing digit, add a '_' to avoid potential
- // conflicts with autogenerated ID's.
- if (!allowTrailingDigit && isdigit(name.back())) {
- copyNameToBuffer();
- buffer.push_back('_');
- return buffer;
- }
-
// Check to see that the name consists of only valid identifier characters.
for (char ch : name) {
if (!validChar(ch)) {
@@ -1084,7 +1078,16 @@ void AliasInitializer::initializeAliases(
if (!aliasInfo.alias)
continue;
StringRef alias = *aliasInfo.alias;
- unsigned nameIndex = nameCounts[alias]++;
+ unsigned nameIndex;
+ // If the alias ends with a digit, we need to pretend as if it has trailing
+ // underscore to get a unique nameIndex.
+ if (isdigit(alias.back())) {
+ SmallString<16> aliasBuffer(alias);
+ aliasBuffer.push_back('_');
+ nameIndex = nameCounts[aliasBuffer]++;
+ } else {
+ nameIndex = nameCounts[alias]++;
+ }
symbolToAlias.insert(
{symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
aliasInfo.canBeDeferred)});
@@ -1191,8 +1194,7 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
SmallString<16> tempBuffer;
StringRef name =
- sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
- /*allowTrailingDigit=*/false);
+ sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-");
name = name.copy(aliasAllocator);
alias = InProgressAliasInfo(name);
}
diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index e878d862076c9..97cda270e2c2e 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -5,7 +5,7 @@
// CHECK-DAG: #test2Ealias = "alias_test:dot_in_name"
"test.op"() {alias_test = "alias_test:dot_in_name"} : () -> ()
-// CHECK-DAG: #test_alias0_ = "alias_test:trailing_digit"
+// CHECK-DAG: #test_alias0 = "alias_test:trailing_digit"
"test.op"() {alias_test = "alias_test:trailing_digit"} : () -> ()
// CHECK-DAG: #_0_test_alias = "alias_test:prefixed_digit"
@@ -14,12 +14,11 @@
// CHECK-DAG: #_25test = "alias_test:prefixed_symbol"
"test.op"() {alias_test = "alias_test:prefixed_symbol"} : () -> ()
-// CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_a"
-// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:sanitize_conflict_b"
-"test.op"() {alias_test = ["alias_test:sanitize_conflict_a", "alias_test:sanitize_conflict_b"]} : () -> ()
-
-// CHECK-DAG: !tuple = tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
-"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
+// CHECK-DAG: #test_alias_conflict0 = "alias_test:trailing_digit_conflict_a"
+// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:trailing_digit_conflict_b"
+// CHECK-DAG: #test_alias_conflict0_2 = "alias_test:trailing_digit_conflict_c"
+// CHECK-DAG: #test_alias_conflict0_3 = "alias_test:trailing_digit_conflict_d"
+"test.op"() {alias_test = ["alias_test:trailing_digit_conflict_a", "alias_test:trailing_digit_conflict_b", "alias_test:trailing_digit_conflict_c", "alias_test:trailing_digit_conflict_d"]} : () -> ()
// CHECK-DAG: !test_tuple = tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
@@ -28,8 +27,8 @@
// CHECK-DAG: tensor<32xf32, #test_encoding>
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
-// CHECK-DAG: !test_ui8_ = !test.int<unsigned, 8>
-// CHECK-DAG: tensor<32x!test_ui8_>
+// CHECK-DAG: !test_ui8 = !test.int<unsigned, 8>
+// CHECK-DAG: tensor<32x!test_ui8>
"test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
// CHECK-DAG: #[[LOC_NESTED:.+]] = loc("nested")
@@ -47,8 +46,8 @@
// -----
// Ensure self type parameters get considered for aliases.
-// CHECK: !test_ui8_ = !test.int<unsigned, 8>
-// CHECK: #test.attr_with_self_type_param : !test_ui8_
+// CHECK: !test_ui8 = !test.int<unsigned, 8>
+// CHECK: #test.attr_with_self_type_param : !test_ui8
"test.op"() {alias_test = #test.attr_with_self_type_param : !test.int<unsigned, 8> } : () -> ()
// -----
diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir
index 42aecb41d998d..b8111d9601e48 100644
--- a/mlir/test/IR/recursive-type.mlir
+++ b/mlir/test/IR/recursive-type.mlir
@@ -4,8 +4,8 @@
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
-// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
-// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
+// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5>
+// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4>
// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 64add8cef3698..065692b98f219 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -187,12 +187,16 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
StringSwitch<std::optional<StringRef>>(strAttr.getValue())
.Case("alias_test:dot_in_name", StringRef("test.alias"))
.Case("alias_test:trailing_digit", StringRef("test_alias0"))
- .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
- .Case("alias_test:prefixed_symbol", StringRef("%test"))
- .Case("alias_test:sanitize_conflict_a",
+ .Case("alias_test:trailing_digit_conflict_a",
+ StringRef("test_alias_conflict0"))
+ .Case("alias_test:trailing_digit_conflict_b",
StringRef("test_alias_conflict0"))
- .Case("alias_test:sanitize_conflict_b",
+ .Case("alias_test:trailing_digit_conflict_c",
StringRef("test_alias_conflict0_"))
+ .Case("alias_test:trailing_digit_conflict_d",
+ StringRef("test_alias_conflict0_"))
+ .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
+ .Case("alias_test:prefixed_symbol", StringRef("%test"))
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
.Default(std::nullopt);
if (!aliasName)
More information about the Mlir-commits
mailing list