[Mlir-commits] [mlir] 59f59d1 - [mlir] Allow to override type/attr aliases from various hooks

Vladislav Vinogradov llvmlistbot at llvm.org
Fri Aug 6 02:04:51 PDT 2021


Author: Vladislav Vinogradov
Date: 2021-08-06T12:05:31+03:00
New Revision: 59f59d1c621cf6844c41fd92ad32a897fc9d10bd

URL: https://github.com/llvm/llvm-project/commit/59f59d1c621cf6844c41fd92ad32a897fc9d10bd
DIFF: https://github.com/llvm/llvm-project/commit/59f59d1c621cf6844c41fd92ad32a897fc9d10bd.diff

LOG: [mlir] Allow to override type/attr aliases from various hooks

Use new return type for `OpAsmDialectInterface::getAlias`:

* `AliasResult::NoAlias` if an alias was not provided.
* `AliasResult::OverridableAlias` if an alias was provided, but it might be overriden by other hook.
* `AliasResult::FinalAlias` if an alias was provided and it should be used (no other hooks will be checked).

In that case `AsmPrinter` will use either the first alias with `FinalAlias` result or
the last alias with `OverridableAlias` result (it depends on dialect array order).

Used `OverridableAlias` result for `BuiltinOpAsmDialectInterface`.

Use case: provide more informative alias for built-in attributes like `AffineMapAttr`
instead of generic "map<N>".

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D107437

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/BuiltinDialect.cpp
    mlir/test/IR/print-attr-type-aliases.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index a9c96473a3a1e..39108dd91c5cd 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -928,18 +928,29 @@ using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
 class OpAsmDialectInterface
     : public DialectInterface::Base<OpAsmDialectInterface> {
 public:
+  /// Holds the result of `getAlias` hook call.
+  enum class AliasResult {
+    /// The object (type or attribute) is not supported by the hook
+    /// and an alias was not provided.
+    NoAlias,
+    /// An alias was provided, but it might be overriden by other hook.
+    OverridableAlias,
+    /// An alias was provided and it should be used
+    /// (no other hooks will be checked).
+    FinalAlias
+  };
+
   OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
 
   /// Hooks for getting an alias identifier alias for a given symbol, that is
   /// not necessarily a part of this dialect. The identifier is used in place of
   /// the symbol when printing textual IR. These aliases must not contain `.` or
-  /// end with a numeric digit([0-9]+). Returns success if an alias was
-  /// provided, failure otherwise.
-  virtual LogicalResult getAlias(Attribute attr, raw_ostream &os) const {
-    return failure();
+  /// end with a numeric digit([0-9]+).
+  virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
+    return AliasResult::NoAlias;
   }
-  virtual LogicalResult getAlias(Type type, raw_ostream &os) const {
-    return failure();
+  virtual AliasResult getAlias(Type type, raw_ostream &os) const {
+    return AliasResult::NoAlias;
   }
 
   /// Get a special name to use when printing the given operation. See

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index d73ca948538ec..b689225925dbe 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -652,21 +652,28 @@ void AliasInitializer::visit(Type type) {
 template <typename T>
 LogicalResult AliasInitializer::generateAlias(
     T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
-  SmallString<16> tempBuffer;
+  SmallString<32> nameBuffer;
   for (const auto &interface : interfaces) {
-    if (failed(interface.getAlias(symbol, aliasOS)))
+    OpAsmDialectInterface::AliasResult result =
+        interface.getAlias(symbol, aliasOS);
+    if (result == OpAsmDialectInterface::AliasResult::NoAlias)
       continue;
-    StringRef name = aliasOS.str();
-    assert(!name.empty() && "expected valid alias name");
-    name = sanitizeIdentifier(name, tempBuffer, /*allowedPunctChars=*/"$_-",
-                              /*allowTrailingDigit=*/false);
-    name = name.copy(aliasAllocator);
-
-    aliasToSymbol[name].push_back(symbol);
-    aliasBuffer.clear();
-    return success();
+    nameBuffer = std::move(aliasBuffer);
+    assert(!nameBuffer.empty() && "expected valid alias name");
+    if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
+      break;
   }
-  return failure();
+
+  if (nameBuffer.empty())
+    return failure();
+
+  SmallString<16> tempBuffer;
+  StringRef name =
+      sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
+                         /*allowTrailingDigit=*/false);
+  name = name.copy(aliasAllocator);
+  aliasToSymbol[name].push_back(symbol);
+  return success();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index 7e79c1b7c7c49..20514957b5fa6 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -33,30 +33,30 @@ namespace {
 struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
   using OpAsmDialectInterface::OpAsmDialectInterface;
 
-  LogicalResult getAlias(Attribute attr, raw_ostream &os) const override {
+  AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
     if (attr.isa<AffineMapAttr>()) {
       os << "map";
-      return success();
+      return AliasResult::OverridableAlias;
     }
     if (attr.isa<IntegerSetAttr>()) {
       os << "set";
-      return success();
+      return AliasResult::OverridableAlias;
     }
     if (attr.isa<LocationAttr>()) {
       os << "loc";
-      return success();
+      return AliasResult::OverridableAlias;
     }
-    return failure();
+    return AliasResult::NoAlias;
   }
 
-  LogicalResult getAlias(Type type, raw_ostream &os) const final {
+  AliasResult getAlias(Type type, raw_ostream &os) const final {
     if (auto tupleType = type.dyn_cast<TupleType>()) {
       if (tupleType.size() > 16) {
         os << "tuple";
-        return success();
+        return AliasResult::OverridableAlias;
       }
     }
-    return failure();
+    return AliasResult::NoAlias;
   }
 };
 } // end anonymous namespace.

diff  --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index 8b4016d49ad56..5bb408d776800 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -18,6 +18,9 @@
 // CHECK-DAG: !tuple = type tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
 "test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
 
+// CHECK-DAG: !test_tuple = type tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
+"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
+
 // CHECK-DAG: #test_encoding = "alias_test:tensor_encoding"
 // CHECK-DAG: tensor<32xf32, #test_encoding>
 "test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 76a7f41fd186d..e56c2d1a92d0f 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -51,10 +51,10 @@ static_assert(OpTrait::hasSingleBlockImplicitTerminator<
 struct TestOpAsmInterface : public OpAsmDialectInterface {
   using OpAsmDialectInterface::OpAsmDialectInterface;
 
-  LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
+  AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
     StringAttr strAttr = attr.dyn_cast<StringAttr>();
     if (!strAttr)
-      return failure();
+      return AliasResult::NoAlias;
 
     // Check the contents of the string attribute to see what the test alias
     // should be named.
@@ -70,10 +70,23 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
             .Default(llvm::None);
     if (!aliasName)
-      return failure();
+      return AliasResult::NoAlias;
 
     os << *aliasName;
-    return success();
+    return AliasResult::FinalAlias;
+  }
+
+  AliasResult getAlias(Type type, raw_ostream &os) const final {
+    if (auto tupleType = type.dyn_cast<TupleType>()) {
+      if (tupleType.size() > 0 &&
+          llvm::all_of(tupleType.getTypes(), [](Type elemType) {
+            return elemType.isa<SimpleAType>();
+          })) {
+        os << "test_tuple";
+        return AliasResult::FinalAlias;
+      }
+    }
+    return AliasResult::NoAlias;
   }
 
   void getAsmResultNames(Operation *op,


        


More information about the Mlir-commits mailing list