[Mlir-commits] [mlir] [mlir] Enable setting alias on AsmState. (PR #153776)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 15 03:08:51 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Jacques Pienaar (jpienaar)

<details>
<summary>Changes</summary>

This is useful for application output where a more useful name can be given to a type or where type is rather long and cumbersome.

---
Full diff: https://github.com/llvm/llvm-project/pull/153776.diff


3 Files Affected:

- (modified) mlir/include/mlir/IR/AsmState.h (+9) 
- (modified) mlir/lib/IR/AsmPrinter.cpp (+34) 
- (modified) mlir/unittests/IR/TypeTest.cpp (+25) 


``````````diff
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 5e9311742bd94..bc25d9e05d73c 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -560,6 +560,15 @@ class AsmState {
            FallbackAsmResourceMap *map = nullptr);
   ~AsmState();
 
+  /// Add an alias for the given attribute. The alias will be used as a
+  /// suggestion when printing. The final alias may be modified to resolve
+  /// conflicts.
+  void addAlias(Attribute attr, StringRef alias);
+
+  /// Add an alias for the given type. The alias will be used as a suggestion
+  /// when printing. The final alias may be modified to resolve conflicts.
+  void addAlias(Type type, StringRef alias);
+
   /// Get the printer flags.
   const OpPrintingFlags &getPrinterFlags() const;
 
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index de52fbd3f215c..775afc02427d5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -558,6 +558,9 @@ class SymbolAlias {
     }
   }
 
+  /// Returns the alias name.
+  StringRef getAliasName() const { return name; }
+
   /// Returns true if this is a type alias.
   bool isTypeAlias() const { return isType; }
 
@@ -1139,6 +1142,17 @@ void AliasInitializer::initializeAliases(
 void AliasInitializer::initialize(
     Operation *op, const OpPrintingFlags &printerFlags,
     llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
+  // Pre-populate the alias list with the user-provided aliases.
+  for (auto &it : attrTypeToAlias) {
+    SymbolAlias &symbolAlias = it.second;
+    InProgressAliasInfo info(symbolAlias.getAliasName());
+    info.isType = symbolAlias.isTypeAlias();
+    info.canBeDeferred = symbolAlias.canBeDeferred();
+    info.aliasDepth = 0;
+    aliases.insert({it.first, info});
+  }
+  attrTypeToAlias.clear();
+
   // Use a dummy printer when walking the IR so that we can collect the
   // attributes/types that will actually be used during printing when
   // considering aliases.
@@ -1272,6 +1286,9 @@ class AliasState {
     printAliases(p, newLine, /*isDeferred=*/true);
   }
 
+  /// Add an alias for the given symbol. All aliases are non-deferred.
+  void addAlias(const void *symbol, bool isType, StringRef alias);
+
 private:
   /// Print all of the referenced aliases that support the provided resolution
   /// behavior.
@@ -1286,6 +1303,14 @@ class AliasState {
 };
 } // namespace
 
+void AliasState::addAlias(const void *symbol, bool isType, StringRef alias) {
+  StringRef name = alias.copy(aliasAllocator);
+  // We don't know if this alias is unique. Set suffix index to 0 and defer
+  // conflict resolution to aliases initialization.
+  attrTypeToAlias.insert(
+      {symbol, SymbolAlias(name, 0, isType, /*isDeferrable=*/false)});
+}
+
 void AliasState::initialize(
     Operation *op, const OpPrintingFlags &printerFlags,
     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
@@ -2093,6 +2118,15 @@ AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
 }
 AsmState::~AsmState() = default;
 
+void AsmState::addAlias(Attribute attr, StringRef alias) {
+  impl->getAliasState().addAlias(attr.getAsOpaquePointer(), /*isType=*/false,
+                                 alias);
+}
+void AsmState::addAlias(Type type, StringRef alias) {
+  impl->getAliasState().addAlias(type.getAsOpaquePointer(), /*isType=*/true,
+                                 alias);
+}
+
 const OpPrintingFlags &AsmState::getPrinterFlags() const {
   return impl->getPrinterFlags();
 }
diff --git a/mlir/unittests/IR/TypeTest.cpp b/mlir/unittests/IR/TypeTest.cpp
index 30f6642a9ca71..4222c18f30742 100644
--- a/mlir/unittests/IR/TypeTest.cpp
+++ b/mlir/unittests/IR/TypeTest.cpp
@@ -6,14 +6,19 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/AsmState.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
 using namespace mlir;
 
+using testing::HasSubstr;
+
 /// Mock implementations of a Type hierarchy
 struct LeafType;
 
@@ -69,3 +74,23 @@ TEST(Type, Casting) {
 
   EXPECT_EQ(8u, cast<IntegerType>(intTy).getWidth());
 }
+
+TEST(Type, UserAlias) {
+  MLIRContext ctx;
+  ctx.allowUnregisteredDialects();
+
+  Type intTy = IntegerType::get(&ctx, 8);
+  AsmState state(&ctx);
+  state.addAlias(intTy, "test.alias");
+  Operation *op = Operation::create(
+      UnknownLoc::get(&ctx), OperationName("test.op", &ctx), TypeRange(intTy),
+      ValueRange(), NamedAttrList(), OpaqueProperties(nullptr), BlockRange(),
+      /*numRegions=*/0);
+
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  op->print(os, state);
+  EXPECT_THAT(str, HasSubstr("!test.alias = i8"));
+  EXPECT_THAT(str, HasSubstr("\"test.op\"() : () -> !test.alias\n"));
+  op->erase();
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/153776


More information about the Mlir-commits mailing list