[Mlir-commits] [mlir] [mlir] Add dialect hooks for registering custom type and attribute alias printers (PR #173091)
Fabian Mora
llvmlistbot at llvm.org
Mon Dec 22 07:30:48 PST 2025
https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/173091
>From 797c073c0fe38438b4a223d87c7e468ce59ee865 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Wed, 17 Dec 2025 16:35:45 +0000
Subject: [PATCH 1/2] [mlir] Add dialect hooks for registering custom type and
attribute alias printers
This patch introduces a mechanism for dialects to register custom alias printers
for types and attributes via the `OpAsmDialectInterface`. This allows dialects
to provide alternative printed representations for types and attributes based
on their TypeID, including types/attributes from other dialects.
The new `registerAttrAliasPrinter` and `registerTypeAliasPrinter` virtual
methods accept callbacks that register printers for specific TypeIDs. When
printing, these custom printers are invoked in registration order, and the
first one to produce output is used.
The precedence for alias resolution is:
1. Explicit type/attribute aliases returned by `getAlias`
2. Dialect-specific alias printers registered via the new hooks
3. Default type/attribute printers
Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
---
mlir/include/mlir/IR/OpImplementation.h | 37 +++
mlir/lib/IR/AsmPrinter.cpp | 215 +++++++++++++++++-
.../IR/print-attr-type-dialect-aliases.mlir | 14 ++
.../Dialect/Test/TestDialectInterfaces.cpp | 48 ++++
4 files changed, 310 insertions(+), 4 deletions(-)
create mode 100644 mlir/test/IR/print-attr-type-dialect-aliases.mlir
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index d70aa346eaa1f..a6d39c4d17a8d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -17,6 +17,8 @@
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/OpAsmSupport.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/TypeID.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/SMLoc.h"
#include <optional>
@@ -169,6 +171,9 @@ class AsmPrinter {
if (succeeded(printAlias(attrOrType)))
return;
+ if (succeeded(printDialectAlias(attrOrType, /*printStripped=*/true)))
+ return;
+
raw_ostream &os = getStream();
uint64_t posPrior = os.tell();
attrOrType.print(*this);
@@ -218,6 +223,14 @@ class AsmPrinter {
/// be printed.
virtual LogicalResult printAlias(Type type);
+ /// Print the alias for the given attribute, return failure if no alias could
+ /// be printed.
+ virtual LogicalResult printDialectAlias(Attribute attr, bool printStripped);
+
+ /// Print the alias for the given type, return failure if no alias could
+ /// be printed.
+ virtual LogicalResult printDialectAlias(Type type, bool printStripped);
+
/// Print the given string as a keyword, or a quoted and escaped string if it
/// has any special or non-printable characters in it.
virtual void printKeywordOrString(StringRef keyword);
@@ -1799,6 +1812,30 @@ class OpAsmDialectInterface
return AliasResult::NoAlias;
}
+ /// Hooks for registering alias printers for types and attributes. These
+ /// printers are invoked when printing types or attributes of the given
+ /// TypeID. Printers are invoked in the order they are registered, and the
+ /// first one to print an alias is used.
+ /// The precedence of these printers is as follow:
+ /// 1. The type and attribute aliases returned by `getAlias`.
+ /// 2. Dialect-specific alias printers registered here.
+ /// 3. The type and attribute printers.
+ /// The boolean argument to the printer indicates whether the stripped form
+ /// of the type or attribute is being printed.
+ /// NOTE: This mechanism caches the printed object, therefore the printer
+ /// must always produce the same output for the same input.
+ using AttributeAliasPrinter =
+ llvm::function_ref<void(Attribute, AsmPrinter &, bool)>;
+ using InsertAttrAliasPrinter =
+ llvm::function_ref<void(TypeID, AttributeAliasPrinter)>;
+ virtual void registerAttrAliasPrinter(InsertAttrAliasPrinter insertFn) const {
+ }
+ using TypeAliasPrinter = llvm::function_ref<void(Type, AsmPrinter &, bool)>;
+ using InsertTypeAliasPrinter =
+ llvm::function_ref<void(TypeID, TypeAliasPrinter)>;
+ virtual void registerTypeAliasPrinter(InsertTypeAliasPrinter insertFn) const {
+ }
+
//===--------------------------------------------------------------------===//
// Resources
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 7d991cea6c468..86fa754b3bb54 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -32,13 +32,16 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/PointerIntPair.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Endian.h"
@@ -47,6 +50,7 @@
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/Threading.h"
#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
#include <type_traits>
#include <optional>
@@ -414,6 +418,7 @@ class AsmPrinter::Impl {
public:
Impl(raw_ostream &os, AsmStateImpl &state);
explicit Impl(Impl &other) : Impl(other.os, other.state) {}
+ explicit Impl(raw_ostream &os, Impl &other) : Impl(os, other.state) {}
/// Returns the output stream of the printer.
raw_ostream &getStream() { return os; }
@@ -447,6 +452,10 @@ class AsmPrinter::Impl {
/// be printed.
LogicalResult printAlias(Attribute attr);
+ /// Print the dialect alias for the given attribute, return failure if no
+ /// alias could be printed.
+ LogicalResult printDialectAlias(Attribute attr, bool printStripped);
+
/// Print the given type or an alias.
void printType(Type type);
/// Print the given type.
@@ -456,6 +465,10 @@ class AsmPrinter::Impl {
/// be printed.
LogicalResult printAlias(Type type);
+ /// Print the dialect alias for the given type, return failure if no alias
+ /// could be printed.
+ LogicalResult printDialectAlias(Type type, bool printStripped);
+
/// Print the given location to the stream. If `allowAlias` is true, this
/// allows for the internal location to use an attribute alias.
void printLocation(LocationAttr loc, bool allowAlias = false);
@@ -812,6 +825,12 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
initializer.visit(type);
return success();
}
+ LogicalResult printDialectAlias(Attribute attr, bool printStripped) override {
+ return failure();
+ }
+ LogicalResult printDialectAlias(Type type, bool printStripped) override {
+ return failure();
+ }
/// Consider the given location to be printed for an alias.
void printOptionalLocationSpecifier(Location loc) override {
@@ -991,6 +1010,11 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
return success();
}
+ LogicalResult printDialectAlias(Attribute, bool) override {
+ return failure();
+ }
+ LogicalResult printDialectAlias(Type, bool) override { return failure(); }
+
/// Record the alias result of a child element.
void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) {
childIndices.push_back(aliasDepthAndIndex.second);
@@ -1251,7 +1275,10 @@ namespace {
/// This class manages the state for type and attribute aliases.
class AliasState {
public:
- // Initialize the internal aliases.
+ /// Initialize the alias state for custom dialect aliases.
+ AliasState(DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
+
+ /// Initialize the internal aliases.
void
initialize(Operation *op, const OpPrintingFlags &printerFlags,
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
@@ -1275,20 +1302,99 @@ class AliasState {
printAliases(p, newLine, /*isDeferred=*/true);
}
+ /// Get an attribute alias if it exists. Returns the alias if found,
+ /// or a default constructed pair otherwise.
+ std::pair<const Dialect *, StringRef>
+ getAttrAlias(AsmPrinter::Impl &p, Attribute attr, bool printStripped) {
+ return getAlias(p, attr.getTypeID(), attr.getAsOpaquePointer(),
+ printStripped);
+ }
+
+ /// Get a type alias if it exists. Returns the alias if found,
+ /// or a default constructed pair otherwise.
+ std::pair<const Dialect *, StringRef>
+ getTypeAlias(AsmPrinter::Impl &p, Type type, bool printStripped) {
+ return getAlias(p, type.getTypeID(), type.getAsOpaquePointer(),
+ printStripped);
+ }
+
private:
+ using TypeIDPrinter =
+ std::tuple<TypeID, const Dialect *,
+ std::function<void(const void *, AsmPrinter &, bool)>>;
+ using PrinterIterator = SmallVectorImpl<TypeIDPrinter>::iterator;
+
/// Print all of the referenced aliases that support the provided resolution
/// behavior.
void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
bool isDeferred);
+ /// Comparison function for TypeIDPrinter.
+ static bool comparePrinters(const TypeIDPrinter &lhs,
+ const TypeIDPrinter &rhs);
+
+ /// Find custom printers for the given TypeID.
+ llvm::iterator_range<PrinterIterator> findPrinters(TypeID typeID);
+
+ /// Get an attribute or type alias if it exists. Returns the alias if found,
+ /// or a default constructed pair otherwise.
+ std::pair<const Dialect *, StringRef> getAlias(AsmPrinter::Impl &p,
+ TypeID typeID,
+ const void *opaqueAttrType,
+ bool printStripped);
+
/// Mapping between attribute/type and alias.
llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
/// An allocator used for alias names.
llvm::BumpPtrAllocator aliasAllocator;
+
+ /// Mapping between attribute/type ID and custom printers for them.
+ SmallVector<TypeIDPrinter> attrTypePrinters;
+
+ /// Cache for custom printed attributes/types.
+ DenseMap<llvm::PointerIntPair<const void *, 1, bool>,
+ std::pair<const Dialect *, std::string>>
+ printCache;
};
} // namespace
+AliasState::AliasState(
+ DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
+ // Collect all of the custom alias printers.
+ for (const OpAsmDialectInterface &interface : interfaces) {
+ auto insertAliasAttrFn =
+ [&](TypeID typeID,
+ OpAsmDialectInterface::AttributeAliasPrinter printer) {
+ if (!printer)
+ return;
+ attrTypePrinters.emplace_back(
+ typeID, interface.getDialect(),
+ [printer](const void *attr, AsmPrinter &p, bool printStripped) {
+ printer(Attribute::getFromOpaquePointer(attr), p,
+ printStripped);
+ });
+ };
+ auto insertAliasTypeFn =
+ [&](TypeID typeID, OpAsmDialectInterface::TypeAliasPrinter printer) {
+ if (!printer)
+ return;
+ attrTypePrinters.emplace_back(
+ typeID, interface.getDialect(),
+ [printer](const void *attr, AsmPrinter &p, bool printStripped) {
+ printer(Type::getFromOpaquePointer(attr), p, printStripped);
+ });
+ };
+ interface.registerAttrAliasPrinter(insertAliasAttrFn);
+ interface.registerTypeAliasPrinter(insertAliasTypeFn);
+ }
+
+ // Sort the printers by TypeID for efficient lookup.
+ // Stable sort guarantees that the order of registration is preserved.
+ std::stable_sort(attrTypePrinters.begin(), attrTypePrinters.end(),
+ comparePrinters);
+}
+
void AliasState::initialize(
Operation *op, const OpPrintingFlags &printerFlags,
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
@@ -1315,6 +1421,24 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
return success();
}
+bool AliasState::comparePrinters(const TypeIDPrinter &lhs,
+ const TypeIDPrinter &rhs) {
+ return std::get<0>(lhs).getAsOpaquePointer() <
+ std::get<0>(rhs).getAsOpaquePointer();
+}
+
+llvm::iterator_range<AliasState::PrinterIterator>
+AliasState::findPrinters(TypeID typeID) {
+ TypeIDPrinter key = std::make_tuple(
+ typeID, /*unused*/ nullptr,
+ /*unused*/ std::function<void(const void *, AsmPrinter &, bool)>());
+ PrinterIterator lb = std::lower_bound(
+ attrTypePrinters.begin(), attrTypePrinters.end(), key, comparePrinters);
+ PrinterIterator ub = std::upper_bound(
+ attrTypePrinters.begin(), attrTypePrinters.end(), key, comparePrinters);
+ return llvm::make_range(lb, ub);
+}
+
void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
bool isDeferred) {
auto filterFn = [=](const auto &aliasIt) {
@@ -1342,6 +1466,40 @@ void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
}
}
+std::pair<const Dialect *, StringRef>
+AliasState::getAlias(AsmPrinter::Impl &p, TypeID typeID,
+ const void *opaqueAttrType, bool printStripped) {
+ llvm::PointerIntPair<const void *, 1, bool> key(opaqueAttrType,
+ printStripped);
+ // Check the cache first.
+ if (auto it = printCache.find(key); it != printCache.end())
+ return it->second;
+
+ // Try to get the alias using custom printers.
+ std::string buffer;
+ llvm::raw_string_ostream os(buffer);
+ AsmPrinter::Impl printImpl(os, p);
+ DialectAsmPrinter printer(printImpl);
+ for (const auto &printInfo : findPrinters(typeID)) {
+ // Invoke the printer.
+ std::get<2>(printInfo)(opaqueAttrType, printer, printStripped);
+
+ // Trim any whitespace.
+ if (StringRef str = StringRef(buffer).trim(); str != buffer)
+ buffer = str.str();
+
+ // If we printed something, cache and return.
+ if (!buffer.empty()) {
+ StringRef alias = (printCache[key] = std::make_pair(
+ std::get<1>(printInfo), std::move(buffer)))
+ .second;
+ return std::make_pair(std::get<1>(printInfo), alias);
+ }
+ buffer.clear();
+ }
+ return {nullptr, StringRef()};
+}
+
//===----------------------------------------------------------------------===//
// SSANameState
//===----------------------------------------------------------------------===//
@@ -1948,11 +2106,13 @@ class AsmStateImpl {
public:
explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
AsmState::LocationMap *locationMap)
- : interfaces(op->getContext()), nameState(op, printerFlags),
- printerFlags(printerFlags), locationMap(locationMap) {}
+ : interfaces(op->getContext()), aliasState(interfaces),
+ nameState(op, printerFlags), printerFlags(printerFlags),
+ locationMap(locationMap) {}
explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
AsmState::LocationMap *locationMap)
- : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {}
+ : interfaces(ctx), aliasState(interfaces), printerFlags(printerFlags),
+ locationMap(locationMap) {}
/// Initialize the alias state to enable the printing of aliases.
void initializeAliases(Operation *op) {
@@ -2377,6 +2537,38 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) {
return state.getAliasState().getAlias(type, os);
}
+LogicalResult AsmPrinter::Impl::printDialectAlias(Attribute attr,
+ bool printStripped) {
+ // Check to see if there is a dialect alias for this attribute.
+ auto [aliasDialect, alias] =
+ state.getAliasState().getAttrAlias(*this, attr, printStripped);
+ if (aliasDialect && !alias.empty()) {
+ if (printStripped) {
+ os << alias;
+ return success();
+ }
+ printDialectSymbol(os, "!", aliasDialect->getNamespace(), alias);
+ return success();
+ }
+ return failure();
+}
+
+LogicalResult AsmPrinter::Impl::printDialectAlias(Type type,
+ bool printStripped) {
+ // Check to see if there is a dialect alias for this type.
+ auto [aliasDialect, alias] =
+ state.getAliasState().getTypeAlias(*this, type, printStripped);
+ if (aliasDialect && !alias.empty()) {
+ if (printStripped) {
+ os << alias;
+ return success();
+ }
+ printDialectSymbol(os, "!", aliasDialect->getNamespace(), alias);
+ return success();
+ }
+ return failure();
+}
+
void AsmPrinter::Impl::printAttribute(Attribute attr,
AttrTypeElision typeElision) {
if (!attr) {
@@ -2387,6 +2579,8 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
// Try to print an alias for this attribute.
if (succeeded(printAlias(attr)))
return;
+ if (succeeded(printDialectAlias(attr, /*printStripped=*/false)))
+ return;
return printAttributeImpl(attr, typeElision);
}
void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
@@ -2715,6 +2909,8 @@ void AsmPrinter::Impl::printType(Type type) {
// Try to print an alias for this type.
if (succeeded(printAlias(type)))
return;
+ if (succeeded(printDialectAlias(type, /*printStripped=*/false)))
+ return;
return printTypeImpl(type);
}
@@ -2987,6 +3183,17 @@ LogicalResult AsmPrinter::printAlias(Type type) {
return impl->printAlias(type);
}
+LogicalResult AsmPrinter::printDialectAlias(Attribute attr,
+ bool printStripped) {
+ assert(impl && "expected AsmPrinter::printDialectAlias to be overriden");
+ return impl->printAlias(attr);
+}
+
+LogicalResult AsmPrinter::printDialectAlias(Type type, bool printStripped) {
+ assert(impl && "expected AsmPrinter::printDialectAlias to be overriden");
+ return impl->printAlias(type);
+}
+
void AsmPrinter::printAttributeWithoutType(Attribute attr) {
assert(impl &&
"expected AsmPrinter::printAttributeWithoutType to be overriden");
diff --git a/mlir/test/IR/print-attr-type-dialect-aliases.mlir b/mlir/test/IR/print-attr-type-dialect-aliases.mlir
new file mode 100644
index 0000000000000..95adb49d0a59a
--- /dev/null
+++ b/mlir/test/IR/print-attr-type-dialect-aliases.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// Check that attr and type aliases are properly printed.
+
+// CHECK: {types = [!test.tuple_i7_from_attr, !test.tuple_i6_from_attr, tuple<!test.int<signed, 5>>]} : () -> (!test.tuple_i7, !test.tuple_i6, tuple<!test.int<signed, 5>>)
+"test.op"() {types = [
+ tuple<!test.int<s, 7>>,
+ tuple<!test.int<s, 6>>,
+ tuple<!test.int<s, 5>>
+ ]} : () -> (
+ tuple<!test.int<s, 7>>,
+ tuple<!test.int<s, 6>>,
+ tuple<!test.int<s, 5>>
+ )
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 3d4aa23ebe78a..8061fadb1da95 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -8,6 +8,8 @@
#include "TestDialect.h"
#include "TestOps.h"
+#include "TestTypes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -262,6 +264,52 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::NoAlias;
}
+ void registerAttrAliasPrinter(InsertAttrAliasPrinter insertFn) const final {
+ insertFn(TypeID::get<TypeAttr>(),
+ [](Attribute attr, AsmPrinter &printer, bool printStripped) {
+ auto tTy = dyn_cast<TupleType>(cast<TypeAttr>(attr).getValue());
+ if (!tTy || tTy.size() != 1)
+ return;
+ if (auto iTy = dyn_cast<TestIntegerType>(tTy.getType(0));
+ !iTy || iTy.getWidth() != 7)
+ return;
+ printer.getStream() << "tuple_i7_from_attr";
+ });
+ insertFn(TypeID::get<TypeAttr>(),
+ [](Attribute attr, AsmPrinter &printer, bool printStripped) {
+ auto tTy = dyn_cast<TupleType>(cast<TypeAttr>(attr).getValue());
+ if (!tTy || tTy.size() != 1)
+ return;
+ if (auto iTy = dyn_cast<TestIntegerType>(tTy.getType(0));
+ !iTy || iTy.getWidth() != 6)
+ return;
+ printer.getStream() << "tuple_i6_from_attr";
+ });
+ }
+
+ void registerTypeAliasPrinter(InsertTypeAliasPrinter insertFn) const final {
+ insertFn(TypeID::get<TupleType>(),
+ [](Type type, AsmPrinter &printer, bool printStripped) {
+ auto tTy = dyn_cast<TupleType>(type);
+ if (!tTy || tTy.size() != 1)
+ return;
+ if (auto iTy = dyn_cast<TestIntegerType>(tTy.getType(0));
+ !iTy || iTy.getWidth() != 7)
+ return;
+ printer.getStream() << "tuple_i7";
+ });
+ insertFn(TypeID::get<TupleType>(),
+ [](Type type, AsmPrinter &printer, bool printStripped) {
+ auto tTy = dyn_cast<TupleType>(type);
+ if (!tTy || tTy.size() != 1)
+ return;
+ if (auto iTy = dyn_cast<TestIntegerType>(tTy.getType(0));
+ !iTy || iTy.getWidth() != 6)
+ return;
+ printer.getStream() << "tuple_i6";
+ });
+ }
+
//===------------------------------------------------------------------===//
// Resources
//===------------------------------------------------------------------===//
>From 9c77705fd2feb737c9cb5827de3f62716360fa89 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Mon, 22 Dec 2025 10:23:43 -0500
Subject: [PATCH 2/2] Dialect alias docs and allow disabling
Signed-off-by: Fabian Mora <fmora.dev at gmail.com>
---
mlir/docs/DefiningDialects/Assembly.md | 56 +++++++++++++++++++
mlir/docs/LangRef.md | 36 +++++++++++-
mlir/include/mlir/IR/OperationSupport.h | 10 ++++
mlir/lib/IR/AsmPrinter.cpp | 25 ++++++++-
.../IR/print-attr-type-dialect-aliases.mlir | 2 +
5 files changed, 127 insertions(+), 2 deletions(-)
diff --git a/mlir/docs/DefiningDialects/Assembly.md b/mlir/docs/DefiningDialects/Assembly.md
index 1c00d5ea9ee9b..52367ab20d2da 100644
--- a/mlir/docs/DefiningDialects/Assembly.md
+++ b/mlir/docs/DefiningDialects/Assembly.md
@@ -4,6 +4,8 @@
## Generating Aliases
+### Named aliases for Types and Attributes
+
`AsmPrinter` can generate aliases for frequently used types and attributes when not printing them in generic form. For example, `!my_dialect.type<a=3,b=4,c=5,d=tuple,e=another_type>` and `#my_dialect.attr<a=3>` can be aliased to `!my_dialect_type` and `#my_dialect_attr`.
There are mainly two ways to hook into the `AsmPrinter`. One is the attribute/type interface and the other is the dialect interface.
@@ -28,6 +30,28 @@ enum class OpAsmAliasResult {
If multiple types/attributes have the same alias from `getAlias` hooks, a number is appended to the alias to avoid conflicts.
+### Dialect aliases
+
+There is another mechanism to generate aliases for types and attributes called Dialect Aliases, see the LangRef section on [Dialect Aliases](../LangRef.md#type-and-attribute-dialect-aliases).
+This mechanism provides more flexibility than named aliases, as the print function has access to the full `AsmPrinter`, and their syntax is allowed to be indistinguishable from dialect types/attributes.
+In other words, a dialect alias has the same capabilities as a full custom parser/printer for types/attributes specified via the methods `MyType::print`, or `MyAttr::print`.
+
+These aliases are specified by implementing `OpAsmDialectInterface`, and registering the printers with `registerTypeAliasPrinter` and `registerAttrAliasPrinter` methods.
+These printers are invoked when printing types or attributes of the given TypeID. Printers are invoked in the order they are registered, and the first one to print an alias is used.
+Further, a dialect alias only takes effect if the dialect providing the alias is already loaded in the context.
+
+The precedence for alias resolution is:
+
+1. Type/attribute named aliases as returned by `OpAsmDialectInterface::getAlias`
+2. Dialect-specific alias printers registered via dialect aliases
+3. Default type/attribute printers
+
+Dialect aliases can be disabled globally via the `--mlir-disable-dialect-aliases` command line option, or programmatically by using `enableDialectAliases(false)` flag in `OpPrintingFlags`.
+
+NOTE: That a dialect alias does not provide a parsing mechanism. To parse a dialect alias, the dialect must implement the parsing logic in its type/attribute parser.
+
+For an example, see [`OpAsmDialectInterface`](#OpAsmDialectInterface).
+
### `OpAsmDialectInterface`
```cpp
@@ -37,6 +61,7 @@ struct MyDialectOpAsmDialectInterface : public OpAsmDialectInterface {
public:
using OpAsmDialectInterface::OpAsmDialectInterface;
+ // Define a named aliases for types.
AliasResult getAlias(Type type, raw_ostream& os) const override {
if (mlir::isa<MyType>(type)) {
os << "my_dialect_type";
@@ -45,6 +70,7 @@ struct MyDialectOpAsmDialectInterface : public OpAsmDialectInterface {
return AliasResult::NoAlias;
}
+ // Define a named aliases for attributes.
AliasResult getAlias(Attribute attr, raw_ostream& os) const override {
if (mlir::isa<MyAttribute>(attr)) {
os << "my_dialect_attr";
@@ -52,6 +78,36 @@ struct MyDialectOpAsmDialectInterface : public OpAsmDialectInterface {
}
return AliasResult::NoAlias;
}
+
+ // Register a dialect alias for a type.
+ void registerTypeAliasPrinter(InsertTypeAliasPrinter insertFn) const final {
+ insertFn(TypeID::get<foo::IntStringPairType>(),
+ [](Type type, AsmPrinter &printer, bool printStripped) {
+ auto pair = cast<foo::IntStringPairType>(type);
+
+ // Don't print the alias if the value is not 42.
+ if (pair.getFirst() != 42)
+ return;
+
+ // Print the alias.
+ printer.getStream() << "the_answer<" << pair.getSecond() << ">";
+ });
+ }
+
+ // Register a dialect alias for an attribute.
+ void registerAttrAliasPrinter(InsertAttrAliasPrinter insertFn) const final {
+ insertFn(TypeID::get<foo::IntStringPairAttr>(),
+ [](Attribute attr, AsmPrinter &printer, bool printStripped) {
+ auto pair = cast<foo::IntStringPairAttr>(attr);
+
+ // Don't print the alias if the value is not 42.
+ if (pair.getFirst() != 42)
+ return;
+
+ // Print the alias.
+ printer.getStream() << "the_answer<" << pair.getSecond() << ">";
+ });
+ }
};
void MyDialect::initialize() {
diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index b1da4b9360592..e995b76170bd0 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -868,7 +868,41 @@ that are directly usable by any other dialect in MLIR. These types cover a range
from primitive integer and floating-point values, attribute dictionaries, dense
multi-dimensional arrays, and more.
-### IR Versioning
+## Type and Attribute Dialect Aliases
+
+Type and Attribute dialect aliases provide a mechanism to override the syntax of
+attributes and types in MLIR. Concretely, they allow a dialect to override the
+syntax of its own entities, or entities in the transitive set of dependent dialects.
+Dialect aliases are owned by the dialect that registers them, and only work when
+that dialect is loaded. Consequently, the syntax provided by this mechanism is
+permitted to be virtually indistinguishable from the syntax of entities owned by a
+dialect, see [dialect attributes](#dialect-attribute-values) and [dialect types](#dialect-types)
+for more details. Nevertheless, this sugaring mechanism is optional and can be disabled.
+
+Dialect aliases are particularly useful when a dialect frequently relies on an
+entity that becomes too verbose in its normal form. Benefits include:
+
+- Readability
+- Parsing and printing speed, as very long strings can be aliased to smaller strings
+
+Compared to attribute and type named aliases, dialect aliases are not required to be
+defined before use. Further, dialect aliases are printed with the dialect prefix
+of the dialect that registered them, even when they alias entities from other dialects.
+
+It is critical to note that dialect aliases do not modify the internal
+representation of the IR, and are only a sugaring of the textual representation.
+In particular, a dialect alias is not a construct in bytecode.
+
+Example:
+
+```mlir
+#foo.my_long_attribute<#bar.some_really_long_attribute<"a_very_long_string_value">>
+// With a dialect alias registered for `#foo.my_long_attribute`, this can be
+// sugared to:
+#foo.bar_attr<"a_very_long_string_value">
+```
+
+## IR Versioning
A dialect can opt-in to handle versioning through the
`BytecodeDialectInterface`. Few hooks are exposed to the dialect to allow
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1ff7c56ddca38..ea6dcf5603fc4 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1230,6 +1230,9 @@ class OpPrintingFlags {
/// Print SSA IDs using their NameLoc, if provided, as prefix.
OpPrintingFlags &printNameLocAsPrefix(bool enable = true);
+ /// Enable the use of dialect aliases when printing attributes or types.
+ OpPrintingFlags &enableDialectAliases(bool enable = true);
+
/// Return if the given ElementsAttr should be elided.
bool shouldElideElementsAttr(ElementsAttr attr) const;
@@ -1273,6 +1276,10 @@ class OpPrintingFlags {
/// IDs
bool shouldUseNameLocAsPrefix() const;
+ /// Return if the printer should print dialect aliases when printing
+ /// attributes or types.
+ bool shouldPrintDialectAliases() const;
+
private:
/// Elide large elements attributes if the number of elements is larger than
/// the upper limit.
@@ -1309,6 +1316,9 @@ class OpPrintingFlags {
/// Print SSA IDs using NameLocs as prefixes
bool useNameLocAsPrefix : 1;
+
+ /// Print dialect attribute or type aliases.
+ bool printDialectAliasesFlag : 1;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 86fa754b3bb54..d1a43a422f5c5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -205,6 +205,11 @@ struct AsmPrinterOptions {
llvm::cl::opt<bool> useNameLocAsPrefix{
"mlir-use-nameloc-as-prefix", llvm::cl::init(false),
llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")};
+
+ llvm::cl::opt<bool> disableDialectAliases{
+ "mlir-disable-dialect-aliases", llvm::cl::init(false),
+ llvm::cl::desc(
+ "Disable printing of dialect aliases for attributes and types")};
};
} // namespace
@@ -223,7 +228,7 @@ OpPrintingFlags::OpPrintingFlags()
printGenericOpFormFlag(false), skipRegionsFlag(false),
assumeVerifiedFlag(false), printLocalScope(false),
printValueUsersFlag(false), printUniqueSSAIDsFlag(false),
- useNameLocAsPrefix(false) {
+ useNameLocAsPrefix(false), printDialectAliasesFlag(true) {
// Initialize based upon command line options, if they are available.
if (!clOptions.isConstructed())
return;
@@ -243,6 +248,7 @@ OpPrintingFlags::OpPrintingFlags()
printValueUsersFlag = clOptions->printValueUsers;
printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
useNameLocAsPrefix = clOptions->useNameLocAsPrefix;
+ printDialectAliasesFlag = !clOptions->disableDialectAliases;
}
/// Enable the elision of large elements attributes, by printing a '...'
@@ -335,6 +341,11 @@ OpPrintingFlags &OpPrintingFlags::printNameLocAsPrefix(bool enable) {
return *this;
}
+OpPrintingFlags &OpPrintingFlags::enableDialectAliases(bool enable) {
+ printDialectAliasesFlag = enable;
+ return *this;
+}
+
/// Return the size limit for printing large ElementsAttr.
std::optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
return elementsAttrElementLimit;
@@ -391,6 +402,12 @@ bool OpPrintingFlags::shouldUseNameLocAsPrefix() const {
return useNameLocAsPrefix;
}
+/// Return if the printer should use dialect aliases when printing attributes or
+/// types.
+bool OpPrintingFlags::shouldPrintDialectAliases() const {
+ return printDialectAliasesFlag;
+}
+
//===----------------------------------------------------------------------===//
// NewLineCounter
//===----------------------------------------------------------------------===//
@@ -2539,6 +2556,9 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) {
LogicalResult AsmPrinter::Impl::printDialectAlias(Attribute attr,
bool printStripped) {
+ if (!state.getPrinterFlags().shouldPrintDialectAliases())
+ return failure();
+
// Check to see if there is a dialect alias for this attribute.
auto [aliasDialect, alias] =
state.getAliasState().getAttrAlias(*this, attr, printStripped);
@@ -2555,6 +2575,9 @@ LogicalResult AsmPrinter::Impl::printDialectAlias(Attribute attr,
LogicalResult AsmPrinter::Impl::printDialectAlias(Type type,
bool printStripped) {
+ if (!state.getPrinterFlags().shouldPrintDialectAliases())
+ return failure();
+
// Check to see if there is a dialect alias for this type.
auto [aliasDialect, alias] =
state.getAliasState().getTypeAlias(*this, type, printStripped);
diff --git a/mlir/test/IR/print-attr-type-dialect-aliases.mlir b/mlir/test/IR/print-attr-type-dialect-aliases.mlir
index 95adb49d0a59a..08446173b66f4 100644
--- a/mlir/test/IR/print-attr-type-dialect-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-dialect-aliases.mlir
@@ -1,8 +1,10 @@
// RUN: mlir-opt %s | FileCheck %s
+// RUN: mlir-opt --mlir-disable-dialect-aliases %s | FileCheck %s --check-prefix=CHECK-NO-ALIASES
// Check that attr and type aliases are properly printed.
// CHECK: {types = [!test.tuple_i7_from_attr, !test.tuple_i6_from_attr, tuple<!test.int<signed, 5>>]} : () -> (!test.tuple_i7, !test.tuple_i6, tuple<!test.int<signed, 5>>)
+// CHECK-NO-ALIASES: {types = [tuple<!test.int<signed, 7>>, tuple<!test.int<signed, 6>>, tuple<!test.int<signed, 5>>]} : () -> (tuple<!test.int<signed, 7>>, tuple<!test.int<signed, 6>>, tuple<!test.int<signed, 5>>)
"test.op"() {types = [
tuple<!test.int<s, 7>>,
tuple<!test.int<s, 6>>,
More information about the Mlir-commits
mailing list