[Mlir-commits] [mlir] [mlir] Enable setting alias on AsmState. (PR #153776)
Jacques Pienaar
llvmlistbot at llvm.org
Fri Aug 15 03:08:19 PDT 2025
https://github.com/jpienaar created https://github.com/llvm/llvm-project/pull/153776
This is useful for application output where a more useful name can be given to a type or where type is rather long and cumbersome.
>From f5cf13b8666009048292da112908463cf66fd718 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Fri, 15 Aug 2025 09:27:21 +0000
Subject: [PATCH] [mlir] Enable setting alias on AsmState.
This is useful for application output where a more useful name can be given to a type or where type is rather long and cumbersome.
---
mlir/include/mlir/IR/AsmState.h | 9 +++++++++
mlir/lib/IR/AsmPrinter.cpp | 34 +++++++++++++++++++++++++++++++++
mlir/unittests/IR/TypeTest.cpp | 25 ++++++++++++++++++++++++
3 files changed, 68 insertions(+)
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 5e9311742bd94..bc25d9e05d73c 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -560,6 +560,15 @@ class AsmState {
FallbackAsmResourceMap *map = nullptr);
~AsmState();
+ /// Add an alias for the given attribute. The alias will be used as a
+ /// suggestion when printing. The final alias may be modified to resolve
+ /// conflicts.
+ void addAlias(Attribute attr, StringRef alias);
+
+ /// Add an alias for the given type. The alias will be used as a suggestion
+ /// when printing. The final alias may be modified to resolve conflicts.
+ void addAlias(Type type, StringRef alias);
+
/// Get the printer flags.
const OpPrintingFlags &getPrinterFlags() const;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index de52fbd3f215c..775afc02427d5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -558,6 +558,9 @@ class SymbolAlias {
}
}
+ /// Returns the alias name.
+ StringRef getAliasName() const { return name; }
+
/// Returns true if this is a type alias.
bool isTypeAlias() const { return isType; }
@@ -1139,6 +1142,17 @@ void AliasInitializer::initializeAliases(
void AliasInitializer::initialize(
Operation *op, const OpPrintingFlags &printerFlags,
llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
+ // Pre-populate the alias list with the user-provided aliases.
+ for (auto &it : attrTypeToAlias) {
+ SymbolAlias &symbolAlias = it.second;
+ InProgressAliasInfo info(symbolAlias.getAliasName());
+ info.isType = symbolAlias.isTypeAlias();
+ info.canBeDeferred = symbolAlias.canBeDeferred();
+ info.aliasDepth = 0;
+ aliases.insert({it.first, info});
+ }
+ attrTypeToAlias.clear();
+
// Use a dummy printer when walking the IR so that we can collect the
// attributes/types that will actually be used during printing when
// considering aliases.
@@ -1272,6 +1286,9 @@ class AliasState {
printAliases(p, newLine, /*isDeferred=*/true);
}
+ /// Add an alias for the given symbol. All aliases are non-deferred.
+ void addAlias(const void *symbol, bool isType, StringRef alias);
+
private:
/// Print all of the referenced aliases that support the provided resolution
/// behavior.
@@ -1286,6 +1303,14 @@ class AliasState {
};
} // namespace
+void AliasState::addAlias(const void *symbol, bool isType, StringRef alias) {
+ StringRef name = alias.copy(aliasAllocator);
+ // We don't know if this alias is unique. Set suffix index to 0 and defer
+ // conflict resolution to aliases initialization.
+ attrTypeToAlias.insert(
+ {symbol, SymbolAlias(name, 0, isType, /*isDeferrable=*/false)});
+}
+
void AliasState::initialize(
Operation *op, const OpPrintingFlags &printerFlags,
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
@@ -2093,6 +2118,15 @@ AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
}
AsmState::~AsmState() = default;
+void AsmState::addAlias(Attribute attr, StringRef alias) {
+ impl->getAliasState().addAlias(attr.getAsOpaquePointer(), /*isType=*/false,
+ alias);
+}
+void AsmState::addAlias(Type type, StringRef alias) {
+ impl->getAliasState().addAlias(type.getAsOpaquePointer(), /*isType=*/true,
+ alias);
+}
+
const OpPrintingFlags &AsmState::getPrinterFlags() const {
return impl->getPrinterFlags();
}
diff --git a/mlir/unittests/IR/TypeTest.cpp b/mlir/unittests/IR/TypeTest.cpp
index 30f6642a9ca71..4222c18f30742 100644
--- a/mlir/unittests/IR/TypeTest.cpp
+++ b/mlir/unittests/IR/TypeTest.cpp
@@ -6,14 +6,19 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
using namespace mlir;
+using testing::HasSubstr;
+
/// Mock implementations of a Type hierarchy
struct LeafType;
@@ -69,3 +74,23 @@ TEST(Type, Casting) {
EXPECT_EQ(8u, cast<IntegerType>(intTy).getWidth());
}
+
+TEST(Type, UserAlias) {
+ MLIRContext ctx;
+ ctx.allowUnregisteredDialects();
+
+ Type intTy = IntegerType::get(&ctx, 8);
+ AsmState state(&ctx);
+ state.addAlias(intTy, "test.alias");
+ Operation *op = Operation::create(
+ UnknownLoc::get(&ctx), OperationName("test.op", &ctx), TypeRange(intTy),
+ ValueRange(), NamedAttrList(), OpaqueProperties(nullptr), BlockRange(),
+ /*numRegions=*/0);
+
+ std::string str;
+ llvm::raw_string_ostream os(str);
+ op->print(os, state);
+ EXPECT_THAT(str, HasSubstr("!test.alias = i8"));
+ EXPECT_THAT(str, HasSubstr("\"test.op\"() : () -> !test.alias\n"));
+ op->erase();
+}
More information about the Mlir-commits
mailing list