[Mlir-commits] [mlir] [mlir] Add the concept of ASM dialect aliases (PR #86033)
Fabian Mora
llvmlistbot at llvm.org
Wed Mar 27 03:59:32 PDT 2024
https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/86033
>From 162f62816b3b2d7e948bccd8f334056200292de7 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 20 Mar 2024 23:19:34 +0000
Subject: [PATCH 1/2] [mlir] Add the concept of ASM dialect aliases
ASM dialect aliases provide a mechanism to provide arbitrary aliases to types
and attributes in pretty form.
To use these aliases, users must complete several steps. For printing an alias,
they need to:
- Implement the method `OpAsmDialectInterface::getAlias` and return
`AliasResult::DialectAlias` for the aliased types instances.
- Implement the method `OpAsmDialectInterface::printDialectAlias`, printing the
alias however the user sees fit.
For parsing an alias, the steps are:
- Implement `OpAsmDialectInterface::parseDialectAlias` for the aliased types.
Users also must attach the interface `OpAsmDialectInterface` to the dialect
creating the alias.
An example of this mechanism was added to the tests, specifically:
`"test_dialect_alias:..."` alias to `#test.test_string<...>`
`tensor<3x!test.int<...>>` alias to `!test.tensor_int3<test.int<...>>`
This change is needed to alias "!llvm.ptr" with "ptr".
---
mlir/include/mlir/IR/OpImplementation.h | 28 ++++-
mlir/lib/AsmParser/DialectSymbolParser.cpp | 12 ++
mlir/lib/IR/AsmPrinter.cpp | 104 ++++++++++++++++--
mlir/test/IR/print-attr-type-aliases.mlir | 15 +++
.../Dialect/Test/TestDialectInterfaces.cpp | 67 +++++++++++
5 files changed, 216 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5333d7446df5ca..5c3d93f8e0cff6 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1715,7 +1715,10 @@ class OpAsmDialectInterface
OverridableAlias,
/// An alias was provided and it should be used
/// (no other hooks will be checked).
- FinalAlias
+ FinalAlias,
+ /// A dialect alias was provided and it will be used
+ /// (no other hooks will be checked).
+ DialectAlias
};
/// Hooks for getting an alias identifier alias for a given symbol, that is
@@ -1729,6 +1732,29 @@ class OpAsmDialectInterface
return AliasResult::NoAlias;
}
+ /// Hooks for parsing a dialect alias. The method returns success if the
+ /// dialect has an alias for the symbol, otherwise it must return failure.
+ /// If there was an error during parsing, this method should return success
+ /// and set the attribute to null.
+ virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
+ Attribute &attr, Type type) const {
+ return failure();
+ }
+ virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
+ Type &type) const {
+ return failure();
+ }
+ /// Hooks for printing a dialect alias.
+ virtual void printDialectAlias(DialectAsmPrinter &printer,
+ Attribute attr) const {
+ llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
+ "dialect aliases");
+ }
+ virtual void printDialectAlias(DialectAsmPrinter &printer, Type type) const {
+ llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
+ "dialect aliases");
+ }
+
//===--------------------------------------------------------------------===//
// Resources
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 80cce7e6ae43f5..9261ef2fb3eb95 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -269,6 +269,12 @@ Attribute Parser::parseExtendedAttr(Type type) {
// Parse the attribute.
CustomDialectAsmParser customParser(symbolData, *this);
+ if (auto iface = dyn_cast<OpAsmDialectInterface>(dialect)) {
+ Attribute attr{};
+ if (succeeded(iface->parseDialectAlias(customParser, attr, type)))
+ return attr;
+ resetToken(symbolData.data());
+ }
Attribute attr = dialect->parseAttribute(customParser, attrType);
resetToken(curLexerPos);
return attr;
@@ -310,6 +316,12 @@ Type Parser::parseExtendedType() {
// Parse the type.
CustomDialectAsmParser customParser(symbolData, *this);
+ if (auto iface = dyn_cast<OpAsmDialectInterface>(dialect)) {
+ Type type{};
+ if (succeeded(iface->parseDialectAlias(customParser, type)))
+ return type;
+ resetToken(symbolData.data());
+ }
Type type = dialect->parseType(customParser);
resetToken(curLexerPos);
return type;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 456cf6a2c27783..7aabc360517ac0 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -542,7 +542,9 @@ class AliasInitializer {
aliasOS(aliasBuffer) {}
void initialize(Operation *op, const OpPrintingFlags &printerFlags,
- llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias);
+ llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias,
+ llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+ &attrTypeToDialectAlias);
/// Visit the given attribute to see if it has an alias. `canBeDeferred` is
/// set to true if the originator of this attribute can resolve the alias
@@ -570,6 +572,10 @@ class AliasInitializer {
InProgressAliasInfo(StringRef alias, bool isType, bool canBeDeferred)
: alias(alias), aliasDepth(1), isType(isType),
canBeDeferred(canBeDeferred) {}
+ InProgressAliasInfo(const OpAsmDialectInterface *aliasDialect, bool isType,
+ bool canBeDeferred)
+ : alias(std::nullopt), aliasDepth(1), isType(isType),
+ canBeDeferred(canBeDeferred), aliasDialect(aliasDialect) {}
bool operator<(const InProgressAliasInfo &rhs) const {
// Order first by depth, then by attr/type kind, and then by name.
@@ -577,6 +583,8 @@ class AliasInitializer {
return aliasDepth < rhs.aliasDepth;
if (isType != rhs.isType)
return isType;
+ if (aliasDialect != rhs.aliasDialect)
+ return aliasDialect < rhs.aliasDialect;
return alias < rhs.alias;
}
@@ -592,6 +600,8 @@ class AliasInitializer {
bool canBeDeferred : 1;
/// Indices for child aliases.
SmallVector<size_t> childIndices;
+ /// Dialect interface used to print the alias.
+ const OpAsmDialectInterface *aliasDialect{};
};
/// Visit the given attribute or type to see if it has an alias.
@@ -617,7 +627,9 @@ class AliasInitializer {
/// symbol to a given alias.
static void initializeAliases(
llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
- llvm::MapVector<const void *, SymbolAlias> &symbolToAlias);
+ llvm::MapVector<const void *, SymbolAlias> &symbolToAlias,
+ llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+ &attrTypeToDialectAlias);
/// The set of asm interfaces within the context.
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
@@ -1027,7 +1039,9 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
/// symbol to a given alias.
void AliasInitializer::initializeAliases(
llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
- llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) {
+ llvm::MapVector<const void *, SymbolAlias> &symbolToAlias,
+ llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+ &attrTypeToDialectAlias) {
SmallVector<std::pair<const void *, InProgressAliasInfo>, 0>
unprocessedAliases = visitedSymbols.takeVector();
llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) {
@@ -1036,8 +1050,12 @@ void AliasInitializer::initializeAliases(
llvm::StringMap<unsigned> nameCounts;
for (auto &[symbol, aliasInfo] : unprocessedAliases) {
- if (!aliasInfo.alias)
+ if (!aliasInfo.alias && !aliasInfo.aliasDialect)
continue;
+ if (aliasInfo.aliasDialect) {
+ attrTypeToDialectAlias.insert({symbol, aliasInfo.aliasDialect});
+ continue;
+ }
StringRef alias = *aliasInfo.alias;
unsigned nameIndex = nameCounts[alias]++;
symbolToAlias.insert(
@@ -1048,7 +1066,9 @@ void AliasInitializer::initializeAliases(
void AliasInitializer::initialize(
Operation *op, const OpPrintingFlags &printerFlags,
- llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
+ llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias,
+ llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+ &attrTypeToDialectAlias) {
// Use a dummy printer when walking the IR so that we can collect the
// attributes/types that will actually be used during printing when
// considering aliases.
@@ -1056,7 +1076,7 @@ void AliasInitializer::initialize(
aliasPrinter.printCustomOrGenericOp(op);
// Initialize the aliases.
- initializeAliases(aliases, attrTypeToAlias);
+ initializeAliases(aliases, attrTypeToAlias, attrTypeToDialectAlias);
}
template <typename T, typename... PrintArgs>
@@ -1113,9 +1133,14 @@ template <typename T>
void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
bool canBeDeferred) {
SmallString<32> nameBuffer;
+ const OpAsmDialectInterface *dialectAlias = nullptr;
for (const auto &interface : interfaces) {
OpAsmDialectInterface::AliasResult result =
interface.getAlias(symbol, aliasOS);
+ if (result == OpAsmDialectInterface::AliasResult::DialectAlias) {
+ dialectAlias = &interface;
+ break;
+ }
if (result == OpAsmDialectInterface::AliasResult::NoAlias)
continue;
nameBuffer = std::move(aliasBuffer);
@@ -1123,6 +1148,11 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
break;
}
+ if (dialectAlias) {
+ alias = InProgressAliasInfo(
+ dialectAlias, /*isType=*/std::is_base_of_v<Type, T>, canBeDeferred);
+ return;
+ }
if (nameBuffer.empty())
return;
@@ -1157,6 +1187,13 @@ class AliasState {
/// Returns success if an alias was printed, failure otherwise.
LogicalResult getAlias(Type ty, raw_ostream &os) const;
+ /// Get a dialect alias for the given attribute if it has one or return
+ /// nullptr.
+ const OpAsmDialectInterface *getDialectAlias(Attribute attr) const;
+
+ /// Get a dialect alias for the given type if it has one or return nullptr.
+ const OpAsmDialectInterface *getDialectAlias(Type ty) const;
+
/// Print all of the referenced aliases that can not be resolved in a deferred
/// manner.
void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
@@ -1177,6 +1214,10 @@ class AliasState {
/// Mapping between attribute/type and alias.
llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
+ /// Mapping between attribute/type and alias dialect interfaces.
+ llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+ attrTypeToDialectAlias;
+
/// An allocator used for alias names.
llvm::BumpPtrAllocator aliasAllocator;
};
@@ -1186,7 +1227,8 @@ void AliasState::initialize(
Operation *op, const OpPrintingFlags &printerFlags,
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
AliasInitializer initializer(interfaces, aliasAllocator);
- initializer.initialize(op, printerFlags, attrTypeToAlias);
+ initializer.initialize(op, printerFlags, attrTypeToAlias,
+ attrTypeToDialectAlias);
}
LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
@@ -1206,6 +1248,20 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
return success();
}
+const OpAsmDialectInterface *AliasState::getDialectAlias(Attribute attr) const {
+ auto it = attrTypeToDialectAlias.find(attr.getAsOpaquePointer());
+ if (it == attrTypeToDialectAlias.end())
+ return nullptr;
+ return it->second;
+}
+
+const OpAsmDialectInterface *AliasState::getDialectAlias(Type ty) const {
+ auto it = attrTypeToDialectAlias.find(ty.getAsOpaquePointer());
+ if (it == attrTypeToDialectAlias.end())
+ return nullptr;
+ return it->second;
+}
+
void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
bool isDeferred) {
auto filterFn = [=](const auto &aliasIt) {
@@ -2189,11 +2245,41 @@ static void printElidedElementsAttr(raw_ostream &os) {
}
LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
- return state.getAliasState().getAlias(attr, os);
+ if (succeeded(state.getAliasState().getAlias(attr, os)))
+ return success();
+ const OpAsmDialectInterface *iface =
+ state.getAliasState().getDialectAlias(attr);
+ if (!iface)
+ return failure();
+ // Ask the dialect to serialize the attribute to a string.
+ std::string attrName;
+ {
+ llvm::raw_string_ostream attrNameStr(attrName);
+ Impl subPrinter(attrNameStr, state);
+ DialectAsmPrinter printer(subPrinter);
+ iface->printDialectAlias(printer, attr);
+ }
+ printDialectSymbol(os, "#", iface->getDialect()->getNamespace(), attrName);
+ return success();
}
LogicalResult AsmPrinter::Impl::printAlias(Type type) {
- return state.getAliasState().getAlias(type, os);
+ if (succeeded(state.getAliasState().getAlias(type, os)))
+ return success();
+ const OpAsmDialectInterface *iface =
+ state.getAliasState().getDialectAlias(type);
+ if (!iface)
+ return failure();
+ // Ask the dialect to serialize the type to a string.
+ std::string typeName;
+ {
+ llvm::raw_string_ostream typeNameStr(typeName);
+ Impl subPrinter(typeNameStr, state);
+ DialectAsmPrinter printer(subPrinter);
+ iface->printDialectAlias(printer, type);
+ }
+ printDialectSymbol(os, "!", iface->getDialect()->getNamespace(), typeName);
+ return success();
}
void AsmPrinter::Impl::printAttribute(Attribute attr,
diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index 162eacd0022832..6065573072e638 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -21,6 +21,12 @@
// CHECK-DAG: !test_tuple = tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
+// CHECK-DAG: #test.test_string<"hello">, #test.test_string<" world">
+"test.op"() {alias_test = ["test_dialect_alias:hello", "test_dialect_alias: world"]} : () -> ()
+
+// CHECK-DAG: #test.test_string<"hello">, #test.test_string<" world">
+"test.op"() {alias_test = [#test.test_string<"hello">, #test.test_string<" world">]} : () -> ()
+
// CHECK-DAG: #test_encoding = "alias_test:tensor_encoding"
// CHECK-DAG: tensor<32xf32, #test_encoding>
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
@@ -29,6 +35,15 @@
// CHECK-DAG: tensor<32x!test_ui8_>
"test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
+// CHECK-DAG: !test.tensor_int3<!test_ui8_>
+"test.op"() : () -> tensor<3x!test.int<unsigned, 8>>
+
+// CHECK-DAG: !test.tensor_int3<!test.int<signed, 8>>
+"test.op"() : () -> !test.tensor_int3<!test.int<signed, 8>>
+
+// CHECK-DAG: tensor<3xi3>
+"test.op"() : () -> !test.tensor_int3<i3>
+
// CHECK-DAG: #loc = loc("nested")
// CHECK-DAG: #loc1 = loc("test.mlir":10:8)
// CHECK-DAG: #loc2 = loc(fused<#loc>[#loc1])
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 66578b246afab1..8ef3cb82fe159e 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -193,6 +193,11 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
StringRef("test_alias_conflict0_"))
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
.Default(std::nullopt);
+
+ // Create a dialect alias for strings starting with "test_dialect_alias:"
+ if (strAttr.getValue().starts_with("test_dialect_alias:"))
+ return AliasResult::DialectAlias;
+
if (!aliasName)
return AliasResult::NoAlias;
@@ -200,6 +205,33 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
+ void printDialectAlias(DialectAsmPrinter &printer,
+ Attribute attr) const override {
+ if (StringAttr strAttr = dyn_cast<StringAttr>(attr)) {
+ // Drop "test_dialect_alias:" from the front of the string
+ StringRef value = strAttr.getValue();
+ value.consume_front("test_dialect_alias:");
+ printer << "test_string<\"" << value << "\">";
+ }
+ }
+
+ LogicalResult parseDialectAlias(DialectAsmParser &parser, Attribute &attr,
+ Type type) const override {
+ return AsmParser::KeywordSwitch<LogicalResult>(parser)
+ // Alias !test.test_string<"..."> to StringAttr
+ .Case("test_string",
+ [&](llvm::StringRef, llvm::SMLoc) {
+ std::string str;
+ if (parser.parseLess() || parser.parseString(&str) ||
+ parser.parseGreater())
+ return success();
+ attr = parser.getBuilder().getStringAttr("test_dialect_alias:" +
+ str);
+ return success();
+ })
+ .Default([&](StringRef keyword, SMLoc) { return failure(); });
+ }
+
AliasResult getAlias(Type type, raw_ostream &os) const final {
if (auto tupleType = dyn_cast<TupleType>(type)) {
if (tupleType.size() > 0 &&
@@ -229,9 +261,44 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
os << recAliasType.getName();
return AliasResult::FinalAlias;
}
+ // Create a dialect alias for tensor<3x!test.int<...>>
+ if (auto tensorTy = dyn_cast<TensorType>(type);
+ tensorTy && isa<TestIntegerType>(tensorTy.getElementType()) &&
+ tensorTy.hasRank()) {
+ ArrayRef<int64_t> shape = tensorTy.getShape();
+ if (shape.size() == 1 && shape[0] == 3)
+ return AliasResult::DialectAlias;
+ }
return AliasResult::NoAlias;
}
+ void printDialectAlias(DialectAsmPrinter &printer, Type type) const override {
+ if (auto tensorTy = dyn_cast<TensorType>(type);
+ tensorTy && isa<TestIntegerType>(tensorTy.getElementType()) &&
+ tensorTy.hasRank()) {
+ // Alias tensor<3x!test.int<...>> to !test.tensor_int3<!test.int<...>>
+ ArrayRef<int64_t> shape = tensorTy.getShape();
+ if (shape.size() == 1 && shape[0] == 3)
+ printer << "tensor_int3" << "<" << tensorTy.getElementType() << ">";
+ }
+ }
+
+ LogicalResult parseDialectAlias(DialectAsmParser &parser,
+ Type &type) const override {
+ return AsmParser::KeywordSwitch<LogicalResult>(parser)
+ // Alias !test.tensor_int3<IntType> to tensor<3xIntType>
+ .Case("tensor_int3",
+ [&](llvm::StringRef, llvm::SMLoc) {
+ if (parser.parseLess() || parser.parseType(type) ||
+ parser.parseGreater())
+ type = nullptr;
+ if (isa<TestIntegerType>(type) || isa<IntegerType>(type))
+ type = RankedTensorType::get({3}, type);
+ return success();
+ })
+ .Default([&](StringRef keyword, SMLoc) { return failure(); });
+ }
+
//===------------------------------------------------------------------===//
// Resources
//===------------------------------------------------------------------===//
>From 2a52522e57cdf8bb6927ac6bc7b711259c9a0b7f Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 27 Mar 2024 10:58:59 +0000
Subject: [PATCH 2/2] Add AttrAsmAliasAttrInterface and
TypeAsmAliasTypeInterface
---
mlir/include/mlir/IR/AsmInterfaces.h | 19 ++++++
mlir/include/mlir/IR/AsmInterfaces.td | 60 +++++++++++++++++++
mlir/include/mlir/IR/CMakeLists.txt | 7 +++
mlir/include/mlir/IR/OpImplementation.h | 18 +++---
mlir/lib/IR/AsmInterfaces.cpp | 19 ++++++
mlir/lib/IR/AsmPrinter.cpp | 23 ++++---
mlir/lib/IR/CMakeLists.txt | 2 +
.../Dialect/Test/TestDialectInterfaces.cpp | 11 +++-
8 files changed, 139 insertions(+), 20 deletions(-)
create mode 100644 mlir/include/mlir/IR/AsmInterfaces.h
create mode 100644 mlir/include/mlir/IR/AsmInterfaces.td
create mode 100644 mlir/lib/IR/AsmInterfaces.cpp
diff --git a/mlir/include/mlir/IR/AsmInterfaces.h b/mlir/include/mlir/IR/AsmInterfaces.h
new file mode 100644
index 00000000000000..00c4ecf58867d3
--- /dev/null
+++ b/mlir/include/mlir/IR/AsmInterfaces.h
@@ -0,0 +1,19 @@
+//===- AsmInterfaces.h - Asm Interfaces -------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_ASMINTERFACES_H
+#define MLIR_IR_ASMINTERFACES_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Types.h"
+
+#include "mlir/IR/AsmAttrInterfaces.h.inc"
+
+#include "mlir/IR/AsmTypeInterfaces.h.inc"
+
+#endif // MLIR_IR_ASMINTERFACES_H
\ No newline at end of file
diff --git a/mlir/include/mlir/IR/AsmInterfaces.td b/mlir/include/mlir/IR/AsmInterfaces.td
new file mode 100644
index 00000000000000..88f8d37a86f68c
--- /dev/null
+++ b/mlir/include/mlir/IR/AsmInterfaces.td
@@ -0,0 +1,60 @@
+//===- AsmInterfaces.td - Asm Interfaces -------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains interfaces and other utilities for interacting with the
+// AsmParser and AsmPrinter.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_ASMINTERFACES_TD
+#define MLIR_IR_ASMINTERFACES_TD
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// AttrAsmAliasAttrInterface
+//===----------------------------------------------------------------------===//
+
+def AttrAsmAliasAttrInterface : AttrInterface<"AttrAsmAliasAttrInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface allows aliasing an attribute between dialects, allowing
+ custom printing of an attribute by an external dialect.
+ }];
+ let methods = [
+ InterfaceMethod<[{
+ Returns the dialect responsible for printing and parsing the attribute
+ instance.
+ }],
+ "Dialect*", "getAliasDialect", (ins), [{}], [{}]
+ >
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// TypeAsmAliasTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TypeAsmAliasTypeInterface : TypeInterface<"TypeAsmAliasTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface allows aliasing a type between dialects, allowing custom
+ printing of a type by an external dialect.
+ }];
+ let methods = [
+ InterfaceMethod<[{
+ Returns the dialect responsible for printing and parsing the type
+ instance.
+ }],
+ "Dialect*", "getAliasDialect", (ins), [{}], [{}]
+ >
+ ];
+}
+
+#endif // MLIR_IR_ASMINTERFACES_TD
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 04a57d26a068d5..d10a1bec682bc2 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -2,6 +2,13 @@ add_mlir_interface(OpAsmInterface)
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)
+set(LLVM_TARGET_DEFINITIONS AsmInterfaces.td)
+mlir_tablegen(AsmAttrInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(AsmAttrInterfaces.cpp.inc -gen-attr-interface-defs)
+mlir_tablegen(AsmTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(AsmTypeInterfaces.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIRAsmInterfacesIncGen)
+
set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5c3d93f8e0cff6..5176618578d872 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1744,15 +1744,15 @@ class OpAsmDialectInterface
Type &type) const {
return failure();
}
- /// Hooks for printing a dialect alias.
- virtual void printDialectAlias(DialectAsmPrinter &printer,
- Attribute attr) const {
- llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
- "dialect aliases");
- }
- virtual void printDialectAlias(DialectAsmPrinter &printer, Type type) const {
- llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
- "dialect aliases");
+ /// Hooks for printing a dialect alias. The method returns success if the
+ /// dialect has an alias for the symbol, otherwise it must return failure.
+ virtual LogicalResult printDialectAlias(DialectAsmPrinter &printer,
+ Attribute attr) const {
+ return failure();
+ }
+ virtual LogicalResult printDialectAlias(DialectAsmPrinter &printer,
+ Type type) const {
+ return failure();
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmInterfaces.cpp b/mlir/lib/IR/AsmInterfaces.cpp
new file mode 100644
index 00000000000000..009701912cf2d3
--- /dev/null
+++ b/mlir/lib/IR/AsmInterfaces.cpp
@@ -0,0 +1,19 @@
+//===- AsmInterfaces.cpp --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AsmInterfaces.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+/// Tablegen Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AsmAttrInterfaces.cpp.inc"
+
+#include "mlir/IR/AsmTypeInterfaces.cpp.inc"
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 7aabc360517ac0..cd837cbd2b1b40 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/AsmInterfaces.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -2257,7 +2258,8 @@ LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
llvm::raw_string_ostream attrNameStr(attrName);
Impl subPrinter(attrNameStr, state);
DialectAsmPrinter printer(subPrinter);
- iface->printDialectAlias(printer, attr);
+ if (failed(iface->printDialectAlias(printer, attr)))
+ return failure();
}
printDialectSymbol(os, "#", iface->getDialect()->getNamespace(), attrName);
return success();
@@ -2276,7 +2278,8 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) {
llvm::raw_string_ostream typeNameStr(typeName);
Impl subPrinter(typeNameStr, state);
DialectAsmPrinter printer(subPrinter);
- iface->printDialectAlias(printer, type);
+ if (failed(iface->printDialectAlias(printer, type)))
+ return failure();
}
printDialectSymbol(os, "!", iface->getDialect()->getNamespace(), typeName);
return success();
@@ -2787,7 +2790,9 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
}
void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
- auto &dialect = attr.getDialect();
+ Dialect *dialect = &attr.getDialect();
+ if (auto iface = dyn_cast<AttrAsmAliasAttrInterface>(attr))
+ dialect = iface.getAliasDialect();
// Ask the dialect to serialize the attribute to a string.
std::string attrName;
@@ -2795,13 +2800,15 @@ void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
llvm::raw_string_ostream attrNameStr(attrName);
Impl subPrinter(attrNameStr, state);
DialectAsmPrinter printer(subPrinter);
- dialect.printAttribute(attr, printer);
+ dialect->printAttribute(attr, printer);
}
- printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
+ printDialectSymbol(os, "#", dialect->getNamespace(), attrName);
}
void AsmPrinter::Impl::printDialectType(Type type) {
- auto &dialect = type.getDialect();
+ Dialect *dialect = &type.getDialect();
+ if (auto iface = dyn_cast<TypeAsmAliasTypeInterface>(type))
+ dialect = iface.getAliasDialect();
// Ask the dialect to serialize the type to a string.
std::string typeName;
@@ -2809,9 +2816,9 @@ void AsmPrinter::Impl::printDialectType(Type type) {
llvm::raw_string_ostream typeNameStr(typeName);
Impl subPrinter(typeNameStr, state);
DialectAsmPrinter printer(subPrinter);
- dialect.printType(type, printer);
+ dialect->printType(type, printer);
}
- printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
+ printDialectSymbol(os, "!", dialect->getNamespace(), typeName);
}
void AsmPrinter::Impl::printEscapedString(StringRef str) {
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index c38ce6c058a006..0fd71dfa5c4499 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -7,6 +7,7 @@ endif()
add_mlir_library(MLIRIR
AffineExpr.cpp
AffineMap.cpp
+ AsmInterfaces.cpp
AsmPrinter.cpp
Attributes.cpp
AttrTypeSubElements.cpp
@@ -48,6 +49,7 @@ add_mlir_library(MLIRIR
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
DEPENDS
+ MLIRAsmInterfacesIncGen
MLIRBuiltinAttributesIncGen
MLIRBuiltinAttributeInterfacesIncGen
MLIRBuiltinDialectBytecodeIncGen
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 8ef3cb82fe159e..db131207c19c8d 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -205,14 +205,16 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
- void printDialectAlias(DialectAsmPrinter &printer,
- Attribute attr) const override {
+ LogicalResult printDialectAlias(DialectAsmPrinter &printer,
+ Attribute attr) const override {
if (StringAttr strAttr = dyn_cast<StringAttr>(attr)) {
// Drop "test_dialect_alias:" from the front of the string
StringRef value = strAttr.getValue();
value.consume_front("test_dialect_alias:");
printer << "test_string<\"" << value << "\">";
+ return success();
}
+ return failure();
}
LogicalResult parseDialectAlias(DialectAsmParser &parser, Attribute &attr,
@@ -272,7 +274,8 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::NoAlias;
}
- void printDialectAlias(DialectAsmPrinter &printer, Type type) const override {
+ LogicalResult printDialectAlias(DialectAsmPrinter &printer,
+ Type type) const override {
if (auto tensorTy = dyn_cast<TensorType>(type);
tensorTy && isa<TestIntegerType>(tensorTy.getElementType()) &&
tensorTy.hasRank()) {
@@ -280,7 +283,9 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
ArrayRef<int64_t> shape = tensorTy.getShape();
if (shape.size() == 1 && shape[0] == 3)
printer << "tensor_int3" << "<" << tensorTy.getElementType() << ">";
+ return success();
}
+ return failure();
}
LogicalResult parseDialectAlias(DialectAsmParser &parser,
More information about the Mlir-commits
mailing list