[Mlir-commits] [mlir] c6390f1 - [mlir] Fix AsmPrinter for types with sub elements

Vladislav Vinogradov llvmlistbot at llvm.org
Tue Oct 12 02:18:06 PDT 2021


Author: Vladislav Vinogradov
Date: 2021-10-12T12:08:16+03:00
New Revision: c6390f19f20c93687d6366f6e78d6e96f1f0a126

URL: https://github.com/llvm/llvm-project/commit/c6390f19f20c93687d6366f6e78d6e96f1f0a126
DIFF: https://github.com/llvm/llvm-project/commit/c6390f19f20c93687d6366f6e78d6e96f1f0a126.diff

LOG: [mlir] Fix AsmPrinter for types with sub elements

Call `printType(subElemType)` instead of `os << subElemType` for them.
It allows to handle type aliases inside complex types.

As a side effect, fixed `test.int` parsing.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D111536

Added: 
    

Modified: 
    mlir/lib/IR/AsmPrinter.cpp
    mlir/test/IR/print-attr-type-aliases.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestTypes.cpp
    mlir/test/mlir-tblgen/testdialect-typedefs.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 3d9090daaf8b..67f135ff4bc1 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1921,7 +1921,7 @@ void AsmPrinter::Impl::printType(Type type) {
         os << ") -> ";
         ArrayRef<Type> results = funcTy.getResults();
         if (results.size() == 1 && !results[0].isa<FunctionType>()) {
-          os << results[0];
+          printType(results[0]);
         } else {
           os << '(';
           interleaveComma(results, [&](Type ty) { printType(ty); });
@@ -1932,7 +1932,8 @@ void AsmPrinter::Impl::printType(Type type) {
         os << "vector<";
         for (int64_t dim : vectorTy.getShape())
           os << dim << 'x';
-        os << vectorTy.getElementType() << '>';
+        printType(vectorTy.getElementType());
+        os << '>';
       })
       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
         os << "tensor<";
@@ -1943,7 +1944,7 @@ void AsmPrinter::Impl::printType(Type type) {
             os << dim;
           os << 'x';
         }
-        os << tensorTy.getElementType();
+        printType(tensorTy.getElementType());
         // Only print the encoding attribute value if set.
         if (tensorTy.getEncoding()) {
           os << ", ";

diff  --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index 5bb408d77680..286f62c192b9 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -24,3 +24,7 @@
 // CHECK-DAG: #test_encoding = "alias_test:tensor_encoding"
 // CHECK-DAG: tensor<32xf32, #test_encoding>
 "test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
+
+// CHECK-DAG: !test_ui8_ = type !test.int<unsigned, 8>
+// CHECK-DAG: tensor<32x!test_ui8_>
+"test.op"() : () -> tensor<32x!test.int<unsigned, 8>>

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index a0f88452b3d1..658ed0957be3 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -88,6 +88,14 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
         return AliasResult::FinalAlias;
       }
     }
+    if (auto intType = type.dyn_cast<TestIntegerType>()) {
+      if (intType.getSignedness() ==
+              TestIntegerType::SignednessSemantics::Unsigned &&
+          intType.getWidth() == 8) {
+        os << "test_ui8";
+        return AliasResult::FinalAlias;
+      }
+    }
     return AliasResult::NoAlias;
   }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 9d618ce1aad6..7278f52a5865 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -32,14 +32,13 @@ parseSignedness(DialectAsmParser &parser,
   auto loc = parser.getCurrentLocation();
   if (parser.parseKeyword(&signStr))
     return failure();
-  if (signStr.compare_insensitive("u") ||
-      signStr.compare_insensitive("unsigned"))
+  if (signStr.equals_insensitive("u") || signStr.equals_insensitive("unsigned"))
     result = TestIntegerType::SignednessSemantics::Unsigned;
-  else if (signStr.compare_insensitive("s") ||
-           signStr.compare_insensitive("signed"))
+  else if (signStr.equals_insensitive("s") ||
+           signStr.equals_insensitive("signed"))
     result = TestIntegerType::SignednessSemantics::Signed;
-  else if (signStr.compare_insensitive("n") ||
-           signStr.compare_insensitive("none"))
+  else if (signStr.equals_insensitive("n") ||
+           signStr.equals_insensitive("none"))
     result = TestIntegerType::SignednessSemantics::Signless;
   else
     return parser.emitError(loc, "expected signed, unsigned, or none");

diff  --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
index c8500e47b695..783b4be704c7 100644
--- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
+++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
@@ -13,12 +13,12 @@ func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6]>)-> () {
   return
 }
 
-// CHECK: @testInt(%arg0: !test.int<unsigned, 8>, %arg1: !test.int<unsigned, 2>, %arg2: !test.int<unsigned, 1>)
+// CHECK: @testInt(%arg0: !test.int<signed, 8>, %arg1: !test.int<unsigned, 2>, %arg2: !test.int<none, 1>)
 func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<n, 1>) {
   return
 }
 
-// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int<unsigned, 3>}>)
+// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int<none, 3>}>)
 func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int<none, 3>} > ) {
   return
 }


        


More information about the Mlir-commits mailing list