[Mlir-commits] [mlir] [mlir] Add the concept of ASM dialect aliases (PR #86033)

Fabian Mora llvmlistbot at llvm.org
Wed Mar 27 03:59:32 PDT 2024


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/86033

>From 162f62816b3b2d7e948bccd8f334056200292de7 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 20 Mar 2024 23:19:34 +0000
Subject: [PATCH 1/2] [mlir] Add the concept of ASM dialect aliases

ASM dialect aliases provide a mechanism to provide arbitrary aliases to types
and attributes in pretty form.

To use these aliases, users must complete several steps. For printing an alias,
they need to:
 - Implement the method `OpAsmDialectInterface::getAlias` and return
   `AliasResult::DialectAlias` for the aliased types instances.
 - Implement the method `OpAsmDialectInterface::printDialectAlias`, printing the
   alias however the user sees fit.

For parsing an alias, the steps are:
 - Implement `OpAsmDialectInterface::parseDialectAlias` for the aliased types.

Users also must attach the interface `OpAsmDialectInterface` to the dialect
creating the alias.

An example of this mechanism was added to the tests, specifically:
`"test_dialect_alias:..."` alias to `#test.test_string<...>`
`tensor<3x!test.int<...>>` alias to `!test.tensor_int3<test.int<...>>`

This change is needed to alias "!llvm.ptr" with "ptr".
---
 mlir/include/mlir/IR/OpImplementation.h       |  28 ++++-
 mlir/lib/AsmParser/DialectSymbolParser.cpp    |  12 ++
 mlir/lib/IR/AsmPrinter.cpp                    | 104 ++++++++++++++++--
 mlir/test/IR/print-attr-type-aliases.mlir     |  15 +++
 .../Dialect/Test/TestDialectInterfaces.cpp    |  67 +++++++++++
 5 files changed, 216 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5333d7446df5ca..5c3d93f8e0cff6 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1715,7 +1715,10 @@ class OpAsmDialectInterface
     OverridableAlias,
     /// An alias was provided and it should be used
     /// (no other hooks will be checked).
-    FinalAlias
+    FinalAlias,
+    /// A dialect alias was provided and it will be used
+    /// (no other hooks will be checked).
+    DialectAlias
   };
 
   /// Hooks for getting an alias identifier alias for a given symbol, that is
@@ -1729,6 +1732,29 @@ class OpAsmDialectInterface
     return AliasResult::NoAlias;
   }
 
+  /// Hooks for parsing a dialect alias. The method returns success if the
+  /// dialect has an alias for the symbol, otherwise it must return failure.
+  /// If there was an error during parsing, this method should return success
+  /// and set the attribute to null.
+  virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
+                                          Attribute &attr, Type type) const {
+    return failure();
+  }
+  virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
+                                          Type &type) const {
+    return failure();
+  }
+  /// Hooks for printing a dialect alias.
+  virtual void printDialectAlias(DialectAsmPrinter &printer,
+                                 Attribute attr) const {
+    llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
+                     "dialect aliases");
+  }
+  virtual void printDialectAlias(DialectAsmPrinter &printer, Type type) const {
+    llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
+                     "dialect aliases");
+  }
+
   //===--------------------------------------------------------------------===//
   // Resources
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 80cce7e6ae43f5..9261ef2fb3eb95 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -269,6 +269,12 @@ Attribute Parser::parseExtendedAttr(Type type) {
 
           // Parse the attribute.
           CustomDialectAsmParser customParser(symbolData, *this);
+          if (auto iface = dyn_cast<OpAsmDialectInterface>(dialect)) {
+            Attribute attr{};
+            if (succeeded(iface->parseDialectAlias(customParser, attr, type)))
+              return attr;
+            resetToken(symbolData.data());
+          }
           Attribute attr = dialect->parseAttribute(customParser, attrType);
           resetToken(curLexerPos);
           return attr;
@@ -310,6 +316,12 @@ Type Parser::parseExtendedType() {
 
           // Parse the type.
           CustomDialectAsmParser customParser(symbolData, *this);
+          if (auto iface = dyn_cast<OpAsmDialectInterface>(dialect)) {
+            Type type{};
+            if (succeeded(iface->parseDialectAlias(customParser, type)))
+              return type;
+            resetToken(symbolData.data());
+          }
           Type type = dialect->parseType(customParser);
           resetToken(curLexerPos);
           return type;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 456cf6a2c27783..7aabc360517ac0 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -542,7 +542,9 @@ class AliasInitializer {
         aliasOS(aliasBuffer) {}
 
   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
-                  llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias);
+                  llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias,
+                  llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+                      &attrTypeToDialectAlias);
 
   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
   /// set to true if the originator of this attribute can resolve the alias
@@ -570,6 +572,10 @@ class AliasInitializer {
     InProgressAliasInfo(StringRef alias, bool isType, bool canBeDeferred)
         : alias(alias), aliasDepth(1), isType(isType),
           canBeDeferred(canBeDeferred) {}
+    InProgressAliasInfo(const OpAsmDialectInterface *aliasDialect, bool isType,
+                        bool canBeDeferred)
+        : alias(std::nullopt), aliasDepth(1), isType(isType),
+          canBeDeferred(canBeDeferred), aliasDialect(aliasDialect) {}
 
     bool operator<(const InProgressAliasInfo &rhs) const {
       // Order first by depth, then by attr/type kind, and then by name.
@@ -577,6 +583,8 @@ class AliasInitializer {
         return aliasDepth < rhs.aliasDepth;
       if (isType != rhs.isType)
         return isType;
+      if (aliasDialect != rhs.aliasDialect)
+        return aliasDialect < rhs.aliasDialect;
       return alias < rhs.alias;
     }
 
@@ -592,6 +600,8 @@ class AliasInitializer {
     bool canBeDeferred : 1;
     /// Indices for child aliases.
     SmallVector<size_t> childIndices;
+    /// Dialect interface used to print the alias.
+    const OpAsmDialectInterface *aliasDialect{};
   };
 
   /// Visit the given attribute or type to see if it has an alias.
@@ -617,7 +627,9 @@ class AliasInitializer {
   /// symbol to a given alias.
   static void initializeAliases(
       llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
-      llvm::MapVector<const void *, SymbolAlias> &symbolToAlias);
+      llvm::MapVector<const void *, SymbolAlias> &symbolToAlias,
+      llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+          &attrTypeToDialectAlias);
 
   /// The set of asm interfaces within the context.
   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
@@ -1027,7 +1039,9 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
 /// symbol to a given alias.
 void AliasInitializer::initializeAliases(
     llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
-    llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) {
+    llvm::MapVector<const void *, SymbolAlias> &symbolToAlias,
+    llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+        &attrTypeToDialectAlias) {
   SmallVector<std::pair<const void *, InProgressAliasInfo>, 0>
       unprocessedAliases = visitedSymbols.takeVector();
   llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) {
@@ -1036,8 +1050,12 @@ void AliasInitializer::initializeAliases(
 
   llvm::StringMap<unsigned> nameCounts;
   for (auto &[symbol, aliasInfo] : unprocessedAliases) {
-    if (!aliasInfo.alias)
+    if (!aliasInfo.alias && !aliasInfo.aliasDialect)
       continue;
+    if (aliasInfo.aliasDialect) {
+      attrTypeToDialectAlias.insert({symbol, aliasInfo.aliasDialect});
+      continue;
+    }
     StringRef alias = *aliasInfo.alias;
     unsigned nameIndex = nameCounts[alias]++;
     symbolToAlias.insert(
@@ -1048,7 +1066,9 @@ void AliasInitializer::initializeAliases(
 
 void AliasInitializer::initialize(
     Operation *op, const OpPrintingFlags &printerFlags,
-    llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
+    llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias,
+    llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+        &attrTypeToDialectAlias) {
   // Use a dummy printer when walking the IR so that we can collect the
   // attributes/types that will actually be used during printing when
   // considering aliases.
@@ -1056,7 +1076,7 @@ void AliasInitializer::initialize(
   aliasPrinter.printCustomOrGenericOp(op);
 
   // Initialize the aliases.
-  initializeAliases(aliases, attrTypeToAlias);
+  initializeAliases(aliases, attrTypeToAlias, attrTypeToDialectAlias);
 }
 
 template <typename T, typename... PrintArgs>
@@ -1113,9 +1133,14 @@ template <typename T>
 void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
                                      bool canBeDeferred) {
   SmallString<32> nameBuffer;
+  const OpAsmDialectInterface *dialectAlias = nullptr;
   for (const auto &interface : interfaces) {
     OpAsmDialectInterface::AliasResult result =
         interface.getAlias(symbol, aliasOS);
+    if (result == OpAsmDialectInterface::AliasResult::DialectAlias) {
+      dialectAlias = &interface;
+      break;
+    }
     if (result == OpAsmDialectInterface::AliasResult::NoAlias)
       continue;
     nameBuffer = std::move(aliasBuffer);
@@ -1123,6 +1148,11 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
     if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
       break;
   }
+  if (dialectAlias) {
+    alias = InProgressAliasInfo(
+        dialectAlias, /*isType=*/std::is_base_of_v<Type, T>, canBeDeferred);
+    return;
+  }
 
   if (nameBuffer.empty())
     return;
@@ -1157,6 +1187,13 @@ class AliasState {
   /// Returns success if an alias was printed, failure otherwise.
   LogicalResult getAlias(Type ty, raw_ostream &os) const;
 
+  /// Get a dialect alias for the given attribute if it has one or return
+  /// nullptr.
+  const OpAsmDialectInterface *getDialectAlias(Attribute attr) const;
+
+  /// Get a dialect alias for the given type if it has one or return nullptr.
+  const OpAsmDialectInterface *getDialectAlias(Type ty) const;
+
   /// Print all of the referenced aliases that can not be resolved in a deferred
   /// manner.
   void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
@@ -1177,6 +1214,10 @@ class AliasState {
   /// Mapping between attribute/type and alias.
   llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
 
+  /// Mapping between attribute/type and alias dialect interfaces.
+  llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+      attrTypeToDialectAlias;
+
   /// An allocator used for alias names.
   llvm::BumpPtrAllocator aliasAllocator;
 };
@@ -1186,7 +1227,8 @@ void AliasState::initialize(
     Operation *op, const OpPrintingFlags &printerFlags,
     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
   AliasInitializer initializer(interfaces, aliasAllocator);
-  initializer.initialize(op, printerFlags, attrTypeToAlias);
+  initializer.initialize(op, printerFlags, attrTypeToAlias,
+                         attrTypeToDialectAlias);
 }
 
 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
@@ -1206,6 +1248,20 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
   return success();
 }
 
+const OpAsmDialectInterface *AliasState::getDialectAlias(Attribute attr) const {
+  auto it = attrTypeToDialectAlias.find(attr.getAsOpaquePointer());
+  if (it == attrTypeToDialectAlias.end())
+    return nullptr;
+  return it->second;
+}
+
+const OpAsmDialectInterface *AliasState::getDialectAlias(Type ty) const {
+  auto it = attrTypeToDialectAlias.find(ty.getAsOpaquePointer());
+  if (it == attrTypeToDialectAlias.end())
+    return nullptr;
+  return it->second;
+}
+
 void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
                               bool isDeferred) {
   auto filterFn = [=](const auto &aliasIt) {
@@ -2189,11 +2245,41 @@ static void printElidedElementsAttr(raw_ostream &os) {
 }
 
 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
-  return state.getAliasState().getAlias(attr, os);
+  if (succeeded(state.getAliasState().getAlias(attr, os)))
+    return success();
+  const OpAsmDialectInterface *iface =
+      state.getAliasState().getDialectAlias(attr);
+  if (!iface)
+    return failure();
+  // Ask the dialect to serialize the attribute to a string.
+  std::string attrName;
+  {
+    llvm::raw_string_ostream attrNameStr(attrName);
+    Impl subPrinter(attrNameStr, state);
+    DialectAsmPrinter printer(subPrinter);
+    iface->printDialectAlias(printer, attr);
+  }
+  printDialectSymbol(os, "#", iface->getDialect()->getNamespace(), attrName);
+  return success();
 }
 
 LogicalResult AsmPrinter::Impl::printAlias(Type type) {
-  return state.getAliasState().getAlias(type, os);
+  if (succeeded(state.getAliasState().getAlias(type, os)))
+    return success();
+  const OpAsmDialectInterface *iface =
+      state.getAliasState().getDialectAlias(type);
+  if (!iface)
+    return failure();
+  // Ask the dialect to serialize the type to a string.
+  std::string typeName;
+  {
+    llvm::raw_string_ostream typeNameStr(typeName);
+    Impl subPrinter(typeNameStr, state);
+    DialectAsmPrinter printer(subPrinter);
+    iface->printDialectAlias(printer, type);
+  }
+  printDialectSymbol(os, "!", iface->getDialect()->getNamespace(), typeName);
+  return success();
 }
 
 void AsmPrinter::Impl::printAttribute(Attribute attr,
diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index 162eacd0022832..6065573072e638 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -21,6 +21,12 @@
 // CHECK-DAG: !test_tuple = tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
 "test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
 
+// CHECK-DAG: #test.test_string<"hello">, #test.test_string<" world">
+"test.op"() {alias_test = ["test_dialect_alias:hello", "test_dialect_alias: world"]} : () -> ()
+
+// CHECK-DAG: #test.test_string<"hello">, #test.test_string<" world">
+"test.op"() {alias_test = [#test.test_string<"hello">, #test.test_string<" world">]} : () -> ()
+
 // CHECK-DAG: #test_encoding = "alias_test:tensor_encoding"
 // CHECK-DAG: tensor<32xf32, #test_encoding>
 "test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
@@ -29,6 +35,15 @@
 // CHECK-DAG: tensor<32x!test_ui8_>
 "test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
 
+// CHECK-DAG: !test.tensor_int3<!test_ui8_>
+"test.op"() : () -> tensor<3x!test.int<unsigned, 8>>
+
+// CHECK-DAG: !test.tensor_int3<!test.int<signed, 8>>
+"test.op"() : () -> !test.tensor_int3<!test.int<signed, 8>>
+
+// CHECK-DAG: tensor<3xi3>
+"test.op"() : () -> !test.tensor_int3<i3>
+
 // CHECK-DAG: #loc = loc("nested")
 // CHECK-DAG: #loc1 = loc("test.mlir":10:8)
 // CHECK-DAG: #loc2 = loc(fused<#loc>[#loc1])
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 66578b246afab1..8ef3cb82fe159e 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -193,6 +193,11 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
                   StringRef("test_alias_conflict0_"))
             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
             .Default(std::nullopt);
+
+    // Create a dialect alias for strings starting with "test_dialect_alias:"
+    if (strAttr.getValue().starts_with("test_dialect_alias:"))
+      return AliasResult::DialectAlias;
+
     if (!aliasName)
       return AliasResult::NoAlias;
 
@@ -200,6 +205,33 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
     return AliasResult::FinalAlias;
   }
 
+  void printDialectAlias(DialectAsmPrinter &printer,
+                         Attribute attr) const override {
+    if (StringAttr strAttr = dyn_cast<StringAttr>(attr)) {
+      // Drop "test_dialect_alias:" from the front of the string
+      StringRef value = strAttr.getValue();
+      value.consume_front("test_dialect_alias:");
+      printer << "test_string<\"" << value << "\">";
+    }
+  }
+
+  LogicalResult parseDialectAlias(DialectAsmParser &parser, Attribute &attr,
+                                  Type type) const override {
+    return AsmParser::KeywordSwitch<LogicalResult>(parser)
+        // Alias !test.test_string<"..."> to StringAttr
+        .Case("test_string",
+              [&](llvm::StringRef, llvm::SMLoc) {
+                std::string str;
+                if (parser.parseLess() || parser.parseString(&str) ||
+                    parser.parseGreater())
+                  return success();
+                attr = parser.getBuilder().getStringAttr("test_dialect_alias:" +
+                                                         str);
+                return success();
+              })
+        .Default([&](StringRef keyword, SMLoc) { return failure(); });
+  }
+
   AliasResult getAlias(Type type, raw_ostream &os) const final {
     if (auto tupleType = dyn_cast<TupleType>(type)) {
       if (tupleType.size() > 0 &&
@@ -229,9 +261,44 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
       os << recAliasType.getName();
       return AliasResult::FinalAlias;
     }
+    // Create a dialect alias for tensor<3x!test.int<...>>
+    if (auto tensorTy = dyn_cast<TensorType>(type);
+        tensorTy && isa<TestIntegerType>(tensorTy.getElementType()) &&
+        tensorTy.hasRank()) {
+      ArrayRef<int64_t> shape = tensorTy.getShape();
+      if (shape.size() == 1 && shape[0] == 3)
+        return AliasResult::DialectAlias;
+    }
     return AliasResult::NoAlias;
   }
 
+  void printDialectAlias(DialectAsmPrinter &printer, Type type) const override {
+    if (auto tensorTy = dyn_cast<TensorType>(type);
+        tensorTy && isa<TestIntegerType>(tensorTy.getElementType()) &&
+        tensorTy.hasRank()) {
+      // Alias tensor<3x!test.int<...>> to !test.tensor_int3<!test.int<...>>
+      ArrayRef<int64_t> shape = tensorTy.getShape();
+      if (shape.size() == 1 && shape[0] == 3)
+        printer << "tensor_int3" << "<" << tensorTy.getElementType() << ">";
+    }
+  }
+
+  LogicalResult parseDialectAlias(DialectAsmParser &parser,
+                                  Type &type) const override {
+    return AsmParser::KeywordSwitch<LogicalResult>(parser)
+        // Alias !test.tensor_int3<IntType> to tensor<3xIntType>
+        .Case("tensor_int3",
+              [&](llvm::StringRef, llvm::SMLoc) {
+                if (parser.parseLess() || parser.parseType(type) ||
+                    parser.parseGreater())
+                  type = nullptr;
+                if (isa<TestIntegerType>(type) || isa<IntegerType>(type))
+                  type = RankedTensorType::get({3}, type);
+                return success();
+              })
+        .Default([&](StringRef keyword, SMLoc) { return failure(); });
+  }
+
   //===------------------------------------------------------------------===//
   // Resources
   //===------------------------------------------------------------------===//

>From 2a52522e57cdf8bb6927ac6bc7b711259c9a0b7f Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 27 Mar 2024 10:58:59 +0000
Subject: [PATCH 2/2] Add AttrAsmAliasAttrInterface and
 TypeAsmAliasTypeInterface

---
 mlir/include/mlir/IR/AsmInterfaces.h          | 19 ++++++
 mlir/include/mlir/IR/AsmInterfaces.td         | 60 +++++++++++++++++++
 mlir/include/mlir/IR/CMakeLists.txt           |  7 +++
 mlir/include/mlir/IR/OpImplementation.h       | 18 +++---
 mlir/lib/IR/AsmInterfaces.cpp                 | 19 ++++++
 mlir/lib/IR/AsmPrinter.cpp                    | 23 ++++---
 mlir/lib/IR/CMakeLists.txt                    |  2 +
 .../Dialect/Test/TestDialectInterfaces.cpp    | 11 +++-
 8 files changed, 139 insertions(+), 20 deletions(-)
 create mode 100644 mlir/include/mlir/IR/AsmInterfaces.h
 create mode 100644 mlir/include/mlir/IR/AsmInterfaces.td
 create mode 100644 mlir/lib/IR/AsmInterfaces.cpp

diff --git a/mlir/include/mlir/IR/AsmInterfaces.h b/mlir/include/mlir/IR/AsmInterfaces.h
new file mode 100644
index 00000000000000..00c4ecf58867d3
--- /dev/null
+++ b/mlir/include/mlir/IR/AsmInterfaces.h
@@ -0,0 +1,19 @@
+//===- AsmInterfaces.h - Asm Interfaces -------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_ASMINTERFACES_H
+#define MLIR_IR_ASMINTERFACES_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Types.h"
+
+#include "mlir/IR/AsmAttrInterfaces.h.inc"
+
+#include "mlir/IR/AsmTypeInterfaces.h.inc"
+
+#endif // MLIR_IR_ASMINTERFACES_H
\ No newline at end of file
diff --git a/mlir/include/mlir/IR/AsmInterfaces.td b/mlir/include/mlir/IR/AsmInterfaces.td
new file mode 100644
index 00000000000000..88f8d37a86f68c
--- /dev/null
+++ b/mlir/include/mlir/IR/AsmInterfaces.td
@@ -0,0 +1,60 @@
+//===- AsmInterfaces.td - Asm Interfaces -------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains interfaces and other utilities for interacting with the
+// AsmParser and AsmPrinter.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_ASMINTERFACES_TD
+#define MLIR_IR_ASMINTERFACES_TD
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// AttrAsmAliasAttrInterface
+//===----------------------------------------------------------------------===//
+
+def AttrAsmAliasAttrInterface : AttrInterface<"AttrAsmAliasAttrInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface allows aliasing an attribute between dialects, allowing
+    custom printing of an attribute by an external dialect.
+  }];
+  let methods = [
+    InterfaceMethod<[{
+        Returns the dialect responsible for printing and parsing the attribute
+        instance.
+      }],
+      "Dialect*", "getAliasDialect", (ins), [{}], [{}]
+    >
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// TypeAsmAliasTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TypeAsmAliasTypeInterface : TypeInterface<"TypeAsmAliasTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface allows aliasing a type between dialects, allowing custom
+    printing of a type by an external dialect.
+  }];
+  let methods = [
+    InterfaceMethod<[{
+        Returns the dialect responsible for printing and parsing the type
+        instance.
+      }],
+      "Dialect*", "getAliasDialect", (ins), [{}], [{}]
+    >
+  ];
+}
+
+#endif // MLIR_IR_ASMINTERFACES_TD
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 04a57d26a068d5..d10a1bec682bc2 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -2,6 +2,13 @@ add_mlir_interface(OpAsmInterface)
 add_mlir_interface(SymbolInterfaces)
 add_mlir_interface(RegionKindInterface)
 
+set(LLVM_TARGET_DEFINITIONS AsmInterfaces.td)
+mlir_tablegen(AsmAttrInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(AsmAttrInterfaces.cpp.inc -gen-attr-interface-defs)
+mlir_tablegen(AsmTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(AsmTypeInterfaces.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIRAsmInterfacesIncGen)
+
 set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
 mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
 mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5c3d93f8e0cff6..5176618578d872 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1744,15 +1744,15 @@ class OpAsmDialectInterface
                                           Type &type) const {
     return failure();
   }
-  /// Hooks for printing a dialect alias.
-  virtual void printDialectAlias(DialectAsmPrinter &printer,
-                                 Attribute attr) const {
-    llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
-                     "dialect aliases");
-  }
-  virtual void printDialectAlias(DialectAsmPrinter &printer, Type type) const {
-    llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
-                     "dialect aliases");
+  /// Hooks for printing a dialect alias. The method returns success if the
+  /// dialect has an alias for the symbol, otherwise it must return failure.
+  virtual LogicalResult printDialectAlias(DialectAsmPrinter &printer,
+                                          Attribute attr) const {
+    return failure();
+  }
+  virtual LogicalResult printDialectAlias(DialectAsmPrinter &printer,
+                                          Type type) const {
+    return failure();
   }
 
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmInterfaces.cpp b/mlir/lib/IR/AsmInterfaces.cpp
new file mode 100644
index 00000000000000..009701912cf2d3
--- /dev/null
+++ b/mlir/lib/IR/AsmInterfaces.cpp
@@ -0,0 +1,19 @@
+//===- AsmInterfaces.cpp --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AsmInterfaces.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+/// Tablegen Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AsmAttrInterfaces.cpp.inc"
+
+#include "mlir/IR/AsmTypeInterfaces.cpp.inc"
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 7aabc360517ac0..cd837cbd2b1b40 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/AsmInterfaces.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -2257,7 +2258,8 @@ LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
     llvm::raw_string_ostream attrNameStr(attrName);
     Impl subPrinter(attrNameStr, state);
     DialectAsmPrinter printer(subPrinter);
-    iface->printDialectAlias(printer, attr);
+    if (failed(iface->printDialectAlias(printer, attr)))
+      return failure();
   }
   printDialectSymbol(os, "#", iface->getDialect()->getNamespace(), attrName);
   return success();
@@ -2276,7 +2278,8 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) {
     llvm::raw_string_ostream typeNameStr(typeName);
     Impl subPrinter(typeNameStr, state);
     DialectAsmPrinter printer(subPrinter);
-    iface->printDialectAlias(printer, type);
+    if (failed(iface->printDialectAlias(printer, type)))
+      return failure();
   }
   printDialectSymbol(os, "!", iface->getDialect()->getNamespace(), typeName);
   return success();
@@ -2787,7 +2790,9 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
 }
 
 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
-  auto &dialect = attr.getDialect();
+  Dialect *dialect = &attr.getDialect();
+  if (auto iface = dyn_cast<AttrAsmAliasAttrInterface>(attr))
+    dialect = iface.getAliasDialect();
 
   // Ask the dialect to serialize the attribute to a string.
   std::string attrName;
@@ -2795,13 +2800,15 @@ void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
     llvm::raw_string_ostream attrNameStr(attrName);
     Impl subPrinter(attrNameStr, state);
     DialectAsmPrinter printer(subPrinter);
-    dialect.printAttribute(attr, printer);
+    dialect->printAttribute(attr, printer);
   }
-  printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
+  printDialectSymbol(os, "#", dialect->getNamespace(), attrName);
 }
 
 void AsmPrinter::Impl::printDialectType(Type type) {
-  auto &dialect = type.getDialect();
+  Dialect *dialect = &type.getDialect();
+  if (auto iface = dyn_cast<TypeAsmAliasTypeInterface>(type))
+    dialect = iface.getAliasDialect();
 
   // Ask the dialect to serialize the type to a string.
   std::string typeName;
@@ -2809,9 +2816,9 @@ void AsmPrinter::Impl::printDialectType(Type type) {
     llvm::raw_string_ostream typeNameStr(typeName);
     Impl subPrinter(typeNameStr, state);
     DialectAsmPrinter printer(subPrinter);
-    dialect.printType(type, printer);
+    dialect->printType(type, printer);
   }
-  printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
+  printDialectSymbol(os, "!", dialect->getNamespace(), typeName);
 }
 
 void AsmPrinter::Impl::printEscapedString(StringRef str) {
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index c38ce6c058a006..0fd71dfa5c4499 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -7,6 +7,7 @@ endif()
 add_mlir_library(MLIRIR
   AffineExpr.cpp
   AffineMap.cpp
+  AsmInterfaces.cpp
   AsmPrinter.cpp
   Attributes.cpp
   AttrTypeSubElements.cpp
@@ -48,6 +49,7 @@ add_mlir_library(MLIRIR
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
 
   DEPENDS
+  MLIRAsmInterfacesIncGen
   MLIRBuiltinAttributesIncGen
   MLIRBuiltinAttributeInterfacesIncGen
   MLIRBuiltinDialectBytecodeIncGen
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 8ef3cb82fe159e..db131207c19c8d 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -205,14 +205,16 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
     return AliasResult::FinalAlias;
   }
 
-  void printDialectAlias(DialectAsmPrinter &printer,
-                         Attribute attr) const override {
+  LogicalResult printDialectAlias(DialectAsmPrinter &printer,
+                                  Attribute attr) const override {
     if (StringAttr strAttr = dyn_cast<StringAttr>(attr)) {
       // Drop "test_dialect_alias:" from the front of the string
       StringRef value = strAttr.getValue();
       value.consume_front("test_dialect_alias:");
       printer << "test_string<\"" << value << "\">";
+      return success();
     }
+    return failure();
   }
 
   LogicalResult parseDialectAlias(DialectAsmParser &parser, Attribute &attr,
@@ -272,7 +274,8 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
     return AliasResult::NoAlias;
   }
 
-  void printDialectAlias(DialectAsmPrinter &printer, Type type) const override {
+  LogicalResult printDialectAlias(DialectAsmPrinter &printer,
+                                  Type type) const override {
     if (auto tensorTy = dyn_cast<TensorType>(type);
         tensorTy && isa<TestIntegerType>(tensorTy.getElementType()) &&
         tensorTy.hasRank()) {
@@ -280,7 +283,9 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
       ArrayRef<int64_t> shape = tensorTy.getShape();
       if (shape.size() == 1 && shape[0] == 3)
         printer << "tensor_int3" << "<" << tensorTy.getElementType() << ">";
+      return success();
     }
+    return failure();
   }
 
   LogicalResult parseDialectAlias(DialectAsmParser &parser,



More information about the Mlir-commits mailing list