[Mlir-commits] [mlir] [mlir] Add the concept of ASM dialect aliases (PR #86033)

Fabian Mora llvmlistbot at llvm.org
Wed Mar 20 16:19:56 PDT 2024


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/86033

>From e32dad6f351faa1a41efc47d3fa956486fd86d0c Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 20 Mar 2024 23:19:34 +0000
Subject: [PATCH] [mlir] Add the concept of ASM dialect aliases

ASM dialect aliases provide a mechanism to provide arbitrary aliases to types
and attributes in pretty form.

To use these aliases, users must complete several steps. For printing an alias,
they need to:
 - Implement the method `OpAsmDialectInterface::getAlias` and return
   `AliasResult::DialectAlias` for the aliased types instances.
 - Implement the method `OpAsmDialectInterface::printDialectAlias`, printing the
   alias however the user sees fit.

For parsing an alias, the steps are:
 - Implement `OpAsmDialectInterface::parseDialectAlias` for the aliased types.

Users also must attach the interface `OpAsmDialectInterface` to the dialect
creating the alias.

An example of this mechanism was added to the tests, specifically:
`"test_dialect_alias:..."` alias to `#test.test_string<...>`
`tensor<3x!test.int<...>>` alias to `!test.tensor_int3<test.int<...>>`

This change is needed to alias "!llvm.ptr" with "ptr".
---
 mlir/include/mlir/IR/OpImplementation.h       |  28 ++++-
 mlir/lib/AsmParser/DialectSymbolParser.cpp    |  12 ++
 mlir/lib/IR/AsmPrinter.cpp                    | 104 ++++++++++++++++--
 mlir/test/IR/print-attr-type-aliases.mlir     |  15 +++
 .../Dialect/Test/TestDialectInterfaces.cpp    |  67 +++++++++++
 5 files changed, 216 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5333d7446df5ca..5c3d93f8e0cff6 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1715,7 +1715,10 @@ class OpAsmDialectInterface
     OverridableAlias,
     /// An alias was provided and it should be used
     /// (no other hooks will be checked).
-    FinalAlias
+    FinalAlias,
+    /// A dialect alias was provided and it will be used
+    /// (no other hooks will be checked).
+    DialectAlias
   };
 
   /// Hooks for getting an alias identifier alias for a given symbol, that is
@@ -1729,6 +1732,29 @@ class OpAsmDialectInterface
     return AliasResult::NoAlias;
   }
 
+  /// Hooks for parsing a dialect alias. The method returns success if the
+  /// dialect has an alias for the symbol, otherwise it must return failure.
+  /// If there was an error during parsing, this method should return success
+  /// and set the attribute to null.
+  virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
+                                          Attribute &attr, Type type) const {
+    return failure();
+  }
+  virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
+                                          Type &type) const {
+    return failure();
+  }
+  /// Hooks for printing a dialect alias.
+  virtual void printDialectAlias(DialectAsmPrinter &printer,
+                                 Attribute attr) const {
+    llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
+                     "dialect aliases");
+  }
+  virtual void printDialectAlias(DialectAsmPrinter &printer, Type type) const {
+    llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
+                     "dialect aliases");
+  }
+
   //===--------------------------------------------------------------------===//
   // Resources
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 80cce7e6ae43f5..9261ef2fb3eb95 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -269,6 +269,12 @@ Attribute Parser::parseExtendedAttr(Type type) {
 
           // Parse the attribute.
           CustomDialectAsmParser customParser(symbolData, *this);
+          if (auto iface = dyn_cast<OpAsmDialectInterface>(dialect)) {
+            Attribute attr{};
+            if (succeeded(iface->parseDialectAlias(customParser, attr, type)))
+              return attr;
+            resetToken(symbolData.data());
+          }
           Attribute attr = dialect->parseAttribute(customParser, attrType);
           resetToken(curLexerPos);
           return attr;
@@ -310,6 +316,12 @@ Type Parser::parseExtendedType() {
 
           // Parse the type.
           CustomDialectAsmParser customParser(symbolData, *this);
+          if (auto iface = dyn_cast<OpAsmDialectInterface>(dialect)) {
+            Type type{};
+            if (succeeded(iface->parseDialectAlias(customParser, type)))
+              return type;
+            resetToken(symbolData.data());
+          }
           Type type = dialect->parseType(customParser);
           resetToken(curLexerPos);
           return type;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 456cf6a2c27783..7aabc360517ac0 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -542,7 +542,9 @@ class AliasInitializer {
         aliasOS(aliasBuffer) {}
 
   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
-                  llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias);
+                  llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias,
+                  llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+                      &attrTypeToDialectAlias);
 
   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
   /// set to true if the originator of this attribute can resolve the alias
@@ -570,6 +572,10 @@ class AliasInitializer {
     InProgressAliasInfo(StringRef alias, bool isType, bool canBeDeferred)
         : alias(alias), aliasDepth(1), isType(isType),
           canBeDeferred(canBeDeferred) {}
+    InProgressAliasInfo(const OpAsmDialectInterface *aliasDialect, bool isType,
+                        bool canBeDeferred)
+        : alias(std::nullopt), aliasDepth(1), isType(isType),
+          canBeDeferred(canBeDeferred), aliasDialect(aliasDialect) {}
 
     bool operator<(const InProgressAliasInfo &rhs) const {
       // Order first by depth, then by attr/type kind, and then by name.
@@ -577,6 +583,8 @@ class AliasInitializer {
         return aliasDepth < rhs.aliasDepth;
       if (isType != rhs.isType)
         return isType;
+      if (aliasDialect != rhs.aliasDialect)
+        return aliasDialect < rhs.aliasDialect;
       return alias < rhs.alias;
     }
 
@@ -592,6 +600,8 @@ class AliasInitializer {
     bool canBeDeferred : 1;
     /// Indices for child aliases.
     SmallVector<size_t> childIndices;
+    /// Dialect interface used to print the alias.
+    const OpAsmDialectInterface *aliasDialect{};
   };
 
   /// Visit the given attribute or type to see if it has an alias.
@@ -617,7 +627,9 @@ class AliasInitializer {
   /// symbol to a given alias.
   static void initializeAliases(
       llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
-      llvm::MapVector<const void *, SymbolAlias> &symbolToAlias);
+      llvm::MapVector<const void *, SymbolAlias> &symbolToAlias,
+      llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+          &attrTypeToDialectAlias);
 
   /// The set of asm interfaces within the context.
   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
@@ -1027,7 +1039,9 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
 /// symbol to a given alias.
 void AliasInitializer::initializeAliases(
     llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
-    llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) {
+    llvm::MapVector<const void *, SymbolAlias> &symbolToAlias,
+    llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+        &attrTypeToDialectAlias) {
   SmallVector<std::pair<const void *, InProgressAliasInfo>, 0>
       unprocessedAliases = visitedSymbols.takeVector();
   llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) {
@@ -1036,8 +1050,12 @@ void AliasInitializer::initializeAliases(
 
   llvm::StringMap<unsigned> nameCounts;
   for (auto &[symbol, aliasInfo] : unprocessedAliases) {
-    if (!aliasInfo.alias)
+    if (!aliasInfo.alias && !aliasInfo.aliasDialect)
       continue;
+    if (aliasInfo.aliasDialect) {
+      attrTypeToDialectAlias.insert({symbol, aliasInfo.aliasDialect});
+      continue;
+    }
     StringRef alias = *aliasInfo.alias;
     unsigned nameIndex = nameCounts[alias]++;
     symbolToAlias.insert(
@@ -1048,7 +1066,9 @@ void AliasInitializer::initializeAliases(
 
 void AliasInitializer::initialize(
     Operation *op, const OpPrintingFlags &printerFlags,
-    llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
+    llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias,
+    llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+        &attrTypeToDialectAlias) {
   // Use a dummy printer when walking the IR so that we can collect the
   // attributes/types that will actually be used during printing when
   // considering aliases.
@@ -1056,7 +1076,7 @@ void AliasInitializer::initialize(
   aliasPrinter.printCustomOrGenericOp(op);
 
   // Initialize the aliases.
-  initializeAliases(aliases, attrTypeToAlias);
+  initializeAliases(aliases, attrTypeToAlias, attrTypeToDialectAlias);
 }
 
 template <typename T, typename... PrintArgs>
@@ -1113,9 +1133,14 @@ template <typename T>
 void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
                                      bool canBeDeferred) {
   SmallString<32> nameBuffer;
+  const OpAsmDialectInterface *dialectAlias = nullptr;
   for (const auto &interface : interfaces) {
     OpAsmDialectInterface::AliasResult result =
         interface.getAlias(symbol, aliasOS);
+    if (result == OpAsmDialectInterface::AliasResult::DialectAlias) {
+      dialectAlias = &interface;
+      break;
+    }
     if (result == OpAsmDialectInterface::AliasResult::NoAlias)
       continue;
     nameBuffer = std::move(aliasBuffer);
@@ -1123,6 +1148,11 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
     if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
       break;
   }
+  if (dialectAlias) {
+    alias = InProgressAliasInfo(
+        dialectAlias, /*isType=*/std::is_base_of_v<Type, T>, canBeDeferred);
+    return;
+  }
 
   if (nameBuffer.empty())
     return;
@@ -1157,6 +1187,13 @@ class AliasState {
   /// Returns success if an alias was printed, failure otherwise.
   LogicalResult getAlias(Type ty, raw_ostream &os) const;
 
+  /// Get a dialect alias for the given attribute if it has one or return
+  /// nullptr.
+  const OpAsmDialectInterface *getDialectAlias(Attribute attr) const;
+
+  /// Get a dialect alias for the given type if it has one or return nullptr.
+  const OpAsmDialectInterface *getDialectAlias(Type ty) const;
+
   /// Print all of the referenced aliases that can not be resolved in a deferred
   /// manner.
   void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
@@ -1177,6 +1214,10 @@ class AliasState {
   /// Mapping between attribute/type and alias.
   llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
 
+  /// Mapping between attribute/type and alias dialect interfaces.
+  llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+      attrTypeToDialectAlias;
+
   /// An allocator used for alias names.
   llvm::BumpPtrAllocator aliasAllocator;
 };
@@ -1186,7 +1227,8 @@ void AliasState::initialize(
     Operation *op, const OpPrintingFlags &printerFlags,
     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
   AliasInitializer initializer(interfaces, aliasAllocator);
-  initializer.initialize(op, printerFlags, attrTypeToAlias);
+  initializer.initialize(op, printerFlags, attrTypeToAlias,
+                         attrTypeToDialectAlias);
 }
 
 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
@@ -1206,6 +1248,20 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
   return success();
 }
 
+const OpAsmDialectInterface *AliasState::getDialectAlias(Attribute attr) const {
+  auto it = attrTypeToDialectAlias.find(attr.getAsOpaquePointer());
+  if (it == attrTypeToDialectAlias.end())
+    return nullptr;
+  return it->second;
+}
+
+const OpAsmDialectInterface *AliasState::getDialectAlias(Type ty) const {
+  auto it = attrTypeToDialectAlias.find(ty.getAsOpaquePointer());
+  if (it == attrTypeToDialectAlias.end())
+    return nullptr;
+  return it->second;
+}
+
 void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
                               bool isDeferred) {
   auto filterFn = [=](const auto &aliasIt) {
@@ -2189,11 +2245,41 @@ static void printElidedElementsAttr(raw_ostream &os) {
 }
 
 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
-  return state.getAliasState().getAlias(attr, os);
+  if (succeeded(state.getAliasState().getAlias(attr, os)))
+    return success();
+  const OpAsmDialectInterface *iface =
+      state.getAliasState().getDialectAlias(attr);
+  if (!iface)
+    return failure();
+  // Ask the dialect to serialize the attribute to a string.
+  std::string attrName;
+  {
+    llvm::raw_string_ostream attrNameStr(attrName);
+    Impl subPrinter(attrNameStr, state);
+    DialectAsmPrinter printer(subPrinter);
+    iface->printDialectAlias(printer, attr);
+  }
+  printDialectSymbol(os, "#", iface->getDialect()->getNamespace(), attrName);
+  return success();
 }
 
 LogicalResult AsmPrinter::Impl::printAlias(Type type) {
-  return state.getAliasState().getAlias(type, os);
+  if (succeeded(state.getAliasState().getAlias(type, os)))
+    return success();
+  const OpAsmDialectInterface *iface =
+      state.getAliasState().getDialectAlias(type);
+  if (!iface)
+    return failure();
+  // Ask the dialect to serialize the type to a string.
+  std::string typeName;
+  {
+    llvm::raw_string_ostream typeNameStr(typeName);
+    Impl subPrinter(typeNameStr, state);
+    DialectAsmPrinter printer(subPrinter);
+    iface->printDialectAlias(printer, type);
+  }
+  printDialectSymbol(os, "!", iface->getDialect()->getNamespace(), typeName);
+  return success();
 }
 
 void AsmPrinter::Impl::printAttribute(Attribute attr,
diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index 162eacd0022832..6065573072e638 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -21,6 +21,12 @@
 // CHECK-DAG: !test_tuple = tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
 "test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
 
+// CHECK-DAG: #test.test_string<"hello">, #test.test_string<" world">
+"test.op"() {alias_test = ["test_dialect_alias:hello", "test_dialect_alias: world"]} : () -> ()
+
+// CHECK-DAG: #test.test_string<"hello">, #test.test_string<" world">
+"test.op"() {alias_test = [#test.test_string<"hello">, #test.test_string<" world">]} : () -> ()
+
 // CHECK-DAG: #test_encoding = "alias_test:tensor_encoding"
 // CHECK-DAG: tensor<32xf32, #test_encoding>
 "test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
@@ -29,6 +35,15 @@
 // CHECK-DAG: tensor<32x!test_ui8_>
 "test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
 
+// CHECK-DAG: !test.tensor_int3<!test_ui8_>
+"test.op"() : () -> tensor<3x!test.int<unsigned, 8>>
+
+// CHECK-DAG: !test.tensor_int3<!test.int<signed, 8>>
+"test.op"() : () -> !test.tensor_int3<!test.int<signed, 8>>
+
+// CHECK-DAG: tensor<3xi3>
+"test.op"() : () -> !test.tensor_int3<i3>
+
 // CHECK-DAG: #loc = loc("nested")
 // CHECK-DAG: #loc1 = loc("test.mlir":10:8)
 // CHECK-DAG: #loc2 = loc(fused<#loc>[#loc1])
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 66578b246afab1..8ef3cb82fe159e 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -193,6 +193,11 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
                   StringRef("test_alias_conflict0_"))
             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
             .Default(std::nullopt);
+
+    // Create a dialect alias for strings starting with "test_dialect_alias:"
+    if (strAttr.getValue().starts_with("test_dialect_alias:"))
+      return AliasResult::DialectAlias;
+
     if (!aliasName)
       return AliasResult::NoAlias;
 
@@ -200,6 +205,33 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
     return AliasResult::FinalAlias;
   }
 
+  void printDialectAlias(DialectAsmPrinter &printer,
+                         Attribute attr) const override {
+    if (StringAttr strAttr = dyn_cast<StringAttr>(attr)) {
+      // Drop "test_dialect_alias:" from the front of the string
+      StringRef value = strAttr.getValue();
+      value.consume_front("test_dialect_alias:");
+      printer << "test_string<\"" << value << "\">";
+    }
+  }
+
+  LogicalResult parseDialectAlias(DialectAsmParser &parser, Attribute &attr,
+                                  Type type) const override {
+    return AsmParser::KeywordSwitch<LogicalResult>(parser)
+        // Alias !test.test_string<"..."> to StringAttr
+        .Case("test_string",
+              [&](llvm::StringRef, llvm::SMLoc) {
+                std::string str;
+                if (parser.parseLess() || parser.parseString(&str) ||
+                    parser.parseGreater())
+                  return success();
+                attr = parser.getBuilder().getStringAttr("test_dialect_alias:" +
+                                                         str);
+                return success();
+              })
+        .Default([&](StringRef keyword, SMLoc) { return failure(); });
+  }
+
   AliasResult getAlias(Type type, raw_ostream &os) const final {
     if (auto tupleType = dyn_cast<TupleType>(type)) {
       if (tupleType.size() > 0 &&
@@ -229,9 +261,44 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
       os << recAliasType.getName();
       return AliasResult::FinalAlias;
     }
+    // Create a dialect alias for tensor<3x!test.int<...>>
+    if (auto tensorTy = dyn_cast<TensorType>(type);
+        tensorTy && isa<TestIntegerType>(tensorTy.getElementType()) &&
+        tensorTy.hasRank()) {
+      ArrayRef<int64_t> shape = tensorTy.getShape();
+      if (shape.size() == 1 && shape[0] == 3)
+        return AliasResult::DialectAlias;
+    }
     return AliasResult::NoAlias;
   }
 
+  void printDialectAlias(DialectAsmPrinter &printer, Type type) const override {
+    if (auto tensorTy = dyn_cast<TensorType>(type);
+        tensorTy && isa<TestIntegerType>(tensorTy.getElementType()) &&
+        tensorTy.hasRank()) {
+      // Alias tensor<3x!test.int<...>> to !test.tensor_int3<!test.int<...>>
+      ArrayRef<int64_t> shape = tensorTy.getShape();
+      if (shape.size() == 1 && shape[0] == 3)
+        printer << "tensor_int3" << "<" << tensorTy.getElementType() << ">";
+    }
+  }
+
+  LogicalResult parseDialectAlias(DialectAsmParser &parser,
+                                  Type &type) const override {
+    return AsmParser::KeywordSwitch<LogicalResult>(parser)
+        // Alias !test.tensor_int3<IntType> to tensor<3xIntType>
+        .Case("tensor_int3",
+              [&](llvm::StringRef, llvm::SMLoc) {
+                if (parser.parseLess() || parser.parseType(type) ||
+                    parser.parseGreater())
+                  type = nullptr;
+                if (isa<TestIntegerType>(type) || isa<IntegerType>(type))
+                  type = RankedTensorType::get({3}, type);
+                return success();
+              })
+        .Default([&](StringRef keyword, SMLoc) { return failure(); });
+  }
+
   //===------------------------------------------------------------------===//
   // Resources
   //===------------------------------------------------------------------===//



More information about the Mlir-commits mailing list